[de07e6]: / src / Matcher / LangChain_QARetrieval.ipynb

Download this file

541 lines (540 with data), 19.5 kB

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Import the load_dotenv function\n",
    "from dotenv import load_dotenv\n",
    "import os\n",
    "# Load the .env file\n",
    "load_dotenv('../.env')  # Assuming your .env file is in the same directory\n",
    "\n",
    "# Example: Accessing an environment variable\n",
    "openai_access_key = os.getenv('OPENAI_ACCESS_KEY')\n",
    "huggingface_token = os.getenv('HUGGINGFACEHUB_API_TOKEN')\n",
    "cohere_api_token = os.getenv('COHERE_API_KEY')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pretty_print_docs(docs):\n",
    "    print(\n",
    "        f\"\\n{'-' * 100}\\n\".join(\n",
    "            [f\"Criteria {i+1}:\\n\\nPage Content: {d.page_content}\\nNCT ID: {d.metadata.get('nct_id', 'N/A')}\\nCriteria Type: {d.metadata.get('criteria_type', 'N/A')}\" for i, d in enumerate(docs)]\n",
    "        )\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "from langchain.docstore.document import Document\n",
    "from langchain_community.embeddings.sentence_transformer import (\n",
    "    SentenceTransformerEmbeddings,\n",
    ")\n",
    "from langchain_community.vectorstores import Chroma"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from transformers import pipeline\n",
    "# classifier = pipeline(\"zero-shot-classification\")\n",
    "# candidate_labels = [\"negative\", \"neutral\"]\n",
    "\n",
    "# Specify the directory containing your JSON files\n",
    "json_directory = '../../data/trials_jsons/'\n",
    "desired_fields = [\"nct_id\", \"eligibility\"]\n",
    "\n",
    "# Initialize an empty list to store the loaded data\n",
    "docs = []\n",
    "\n",
    "# Loop through each file in the directory\n",
    "for filename in os.listdir(json_directory):\n",
    "    if filename.endswith('.json'):\n",
    "        # Construct the full path to the JSON file\n",
    "        file_path = os.path.join(json_directory, filename)\n",
    "        # Open and load the JSON file\n",
    "        with open(file_path, 'r') as file:\n",
    "            json_data = json.load(file)\n",
    "            extracted_data = {field: json_data.get(field) for field in desired_fields}\n",
    "\n",
    "            # Extracting eligibility criteria\n",
    "            eligibility_criteria = json_data.get(\"eligibility\")\n",
    "            if eligibility_criteria is not None:\n",
    "                for index, criterion in enumerate(eligibility_criteria):\n",
    "                    # Creating metadata for each criterion\n",
    "                    metadata = {\n",
    "                        \"nct_id\" :extracted_data['nct_id'],\n",
    "                        \"idx\": index +1,\n",
    "                    }\n",
    "                    # Adding the \"field\" value to metadata\n",
    "                    metadata[\"criteria_type\"] = criterion[\"entities_data\"][0][\"field\"]  # Assuming field is same for all entities\n",
    "                    # Adding entities data to metadata with indexes\n",
    "                    for i, entity in enumerate(criterion[\"entities_data\"]):\n",
    "                        for key, value in entity.items():\n",
    "                            if key != \"field\":\n",
    "                                metadata[f\"{key}_{i + 1}\"] = value\n",
    "\n",
    "                    # Creating a document for each criterion\n",
    "                    doc = Document(page_content=criterion[\"text\"], metadata=metadata)\n",
    "                    docs.append(doc)\n",
    "                    \n",
    "vectorstore = Chroma.from_documents(docs, SentenceTransformerEmbeddings(), persist_directory=\"../../data/db/\", collection_name=\"criteria\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vectorstore.persist()\n",
    "vectorstore = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "import os\n",
    "from langchain.docstore.document import Document\n",
    "\n",
    "# Specify the directory containing your JSON files\n",
    "json_directory = '../../data/trials_jsons/'\n",
    "desired_fields = [\"nct_id\", \"brief_title\", \"brief_summary\", \"condition\", \"gender\", \"minimum_age\", \"maximum_age\", \"phase\"]\n",
    "fields_to_concatenate = [\"brief_title\", \"brief_summary\"]\n",
    "\n",
    "# Initialize an empty list to store the loaded data\n",
    "docs = []\n",
    "ids = []\n",
    "\n",
    "# Loop through each file in the directory\n",
    "for filename in os.listdir(json_directory):\n",
    "    if filename.endswith('.json'):\n",
    "        # Construct the full path to the JSON file\n",
    "        file_path = os.path.join(json_directory, filename)\n",
    "\n",
    "        # Open and load the JSON file\n",
    "        with open(file_path, 'r') as file:\n",
    "            json_data = json.load(file)\n",
    "            extracted_data = {field: json_data.get(field) for field in desired_fields}\n",
    "            ids.append(extracted_data[\"nct_id\"])\n",
    "            \n",
    "            # Construct metadata, handling None values\n",
    "            metadata = {\n",
    "                \"id\": extracted_data.get(\"nct_id\", \"\"),\n",
    "                \"condition\": extracted_data.get(\"condition\", \"\"),\n",
    "                \"gender\": extracted_data.get(\"gender\", \"\"),\n",
    "                \"minimum_age\": extracted_data.get(\"minimum_age\", \"\"),\n",
    "                \"maximum_age\": extracted_data.get(\"maximum_age\", \"\"),\n",
    "                \"phase\": extracted_data.get(\"phase\", \"\"),\n",
    "            }\n",
    "            # Remove None values from metadata\n",
    "            metadata = {k: v for k, v in metadata.items() if v is not None}\n",
    "                \n",
    "            concatenated_string = ', '.join(str(extracted_data[field]) for field in fields_to_concatenate)\n",
    "            docs.append(Document(page_content=concatenated_string, metadata=metadata))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vectorstore = Chroma.from_documents(docs, SentenceTransformerEmbeddings(), persist_directory=\"../../data/db\", collection_name=\"trials\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "db3 = Chroma(persist_directory=\"../../data/db\", embedding_function= SentenceTransformerEmbeddings(), collection_name=\"criteria\")\n",
    "retriever = db3.as_retriever(\n",
    "    search_type=\"similarity_score_threshold\", \n",
    "    search_kwargs={\"score_threshold\": 0.5, \"k\":1500},\n",
    "    filters=None,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import List, Optional\n",
    "\n",
    "from langchain_core.pydantic_v1 import BaseModel, Field\n",
    "\n",
    "\n",
    "class Search(BaseModel):\n",
    "    \"\"\"Search over a database of clinical trial eligibility criteria records\"\"\"\n",
    "\n",
    "    queries: List[str] = Field(\n",
    "        ...,\n",
    "        description=\"Distinct queries to search for\",\n",
    "    )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.output_parsers.openai_tools import PydanticToolsParser\n",
    "from langchain_core.prompts import ChatPromptTemplate\n",
    "from langchain_core.runnables import RunnablePassthrough\n",
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "output_parser = PydanticToolsParser(tools=[Search])\n",
    "\n",
    "system = \"\"\"\n",
    "You are tasked with a critical role: to dissect a complex, structured query into its component sub-queries. Each component of the query is encapsulated in a JSON dictionary, representing a unique aspect of the information sought. Your objective is to meticulously parse this JSON, isolating each field as a standalone sub-query. These sub-queries are the keys to unlocking detailed, specific information pertinent to each field.\n",
    "\n",
    "As you embark on this task, remember:\n",
    "- Treat each JSON field with precision, extracting it as an individual query without altering its essence.\n",
    "- Your analysis should preserve the integrity of each sub-query, ensuring that the original context and purpose remain intact.\n",
    "- Enhance each sub-query by contextually expanding it into a complete, meaningful sentence. The aim is to transform each piece of data into a narrative that provides insight into the patient's health condition or medical history.\n",
    "- Approach this task with the understanding that the fidelity of the sub-queries to their source is paramount. Alterations or misinterpretations could lead to inaccuracies in the information retrieved.\n",
    "\n",
    "This meticulous separation of the structured query into clear, unmodified sub-queries is fundamental. It enables a tailored search for information, enhancing the relevance and accuracy of the responses generated. Your role in this process is not just to parse data, but to ensure that each piece of information extracted is a faithful reflection of the query's intent, ready to be matched with precise and relevant data points.\n",
    "\"\"\"\n",
    "prompt = ChatPromptTemplate.from_messages(\n",
    "    [\n",
    "        (\"system\", system),\n",
    "        (\"human\", \"{question}\"),\n",
    "    ]\n",
    ")\n",
    "llm = ChatOpenAI(model=\"gpt-3.5-turbo-0125\", temperature=0, api_key=openai_access_key)\n",
    "structured_llm = llm.with_structured_output(Search)\n",
    "query_analyzer = {\"question\": RunnablePassthrough()} | prompt | structured_llm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain_core.runnables import chain\n",
    "from langchain.retrievers import ContextualCompressionRetriever\n",
    "from langchain.retrievers.document_compressors import CrossEncoderReranker\n",
    "from langchain_community.cross_encoders import HuggingFaceCrossEncoder\n",
    "\n",
    "model = HuggingFaceCrossEncoder(model_name=\"BAAI/bge-reranker-base\")\n",
    "compressor = CrossEncoderReranker(model=model, top_n=3)\n",
    "compression_retriever = ContextualCompressionRetriever(\n",
    "    base_compressor=compressor, base_retriever=retriever\n",
    ")\n",
    "\n",
    "\n",
    "@chain\n",
    "async def custom_chain(question):\n",
    "    response = await query_analyzer.ainvoke(question)\n",
    "    docs = []\n",
    "    for query in response.queries:\n",
    "        new_docs = await compression_retriever.ainvoke(query)\n",
    "        docs.extend(new_docs)\n",
    "    # You probably want to think about reranking or deduplicating documents here\n",
    "    # But that is a separate topic\n",
    "    return docs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "f  = open(\"../../data/synthetic_patients/1234.json\")\n",
    "query = str(json.load(f))\n",
    "docs_result = await custom_chain.ainvoke(query)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import openai\n",
    "\n",
    "openai.api_key = openai_access_key\n",
    "\n",
    "def rerank_answers(question, candidate_answers):\n",
    "    scored_answers = []\n",
    "\n",
    "    for answer in candidate_answers:\n",
    "        prompt = f\"Question: {question}\\nAnswer: {answer}\\n\\nHow relevant and correct is this answer to the question above? Rate from 1 (least relevant) to 10 (most relevant).\"\n",
    "        response = openai.Completion.create(\n",
    "            engine=\"text-davinci-002\",\n",
    "            prompt=prompt,\n",
    "            max_tokens=3,  # We only need a short numeric response\n",
    "            n=1\n",
    "        )\n",
    "        score = int(response['choices'][0]['text'].strip())\n",
    "        scored_answers.append((answer, score))\n",
    "\n",
    "    # Sort answers based on the score\n",
    "    scored_answers.sort(key=lambda x: x[1], reverse=True)  # Higher scores first\n",
    "\n",
    "    return scored_answers\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "question = query\n",
    "candidate_answers = docs_result[0:3]\n",
    "\n",
    "scored_answers = rerank_answers(question, candidate_answers)\n",
    "for answer, score in scored_answers:\n",
    "    print(f\"Score: {score}, Answer: {answer}\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "query_analyzer.invoke(query) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "pretty_print_docs(docs_result)\n",
    "# docs_result[0].metadata[\"nct_id\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# get a new token: https://dashboard.cohere.ai/\n",
    "\n",
    "import getpass\n",
    "import os\n",
    "\n",
    "os.environ[\"COHERE_API_KEY\"] = getpass.getpass(\"Cohere API Key:\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.retrievers.contextual_compression import ContextualCompressionRetriever\n",
    "from langchain_cohere import CohereRerank\n",
    "from langchain_community.llms import Cohere\n",
    "\n",
    "llm = Cohere(temperature=0)\n",
    "compressor = CohereRerank()\n",
    "compression_retriever = ContextualCompressionRetriever(\n",
    "    base_compressor=compressor, base_retriever=retriever\n",
    ")\n",
    "\n",
    "compressed_docs = compression_retriever.get_relevant_documents(\n",
    "    \"What did the president say about Ketanji Jackson Brown\"\n",
    ")\n",
    "pretty_print_docs(compressed_docs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "@chain\n",
    "async def custom_reranker_chain(question):\n",
    "    response = await query_analyzer.ainvoke(question)\n",
    "    docs = []\n",
    "    for query in response.queries:\n",
    "        new_docs = await compression_retriever.ainvoke(query)\n",
    "        docs.extend(new_docs)\n",
    "    return docs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "await custom_reranker_chain(\"Patient has diabetes | Patient has COVID-19 | Colorectal Cancer Patient with KRAS mutation | Patient has a diagnosis of COVID-19\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from langchain.aggregators import SimpleAggregator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "llm = ChatOpenAI(model=\"gpt-3.5-turbo-0125\", temperature=0)\n",
    "\n",
    "async def refine_query_with_llm(query):\n",
    "    \"\"\"\n",
    "    Uses an LLM to expand and refine the query to make it more specific to clinical trial criteria.\n",
    "    \n",
    "    Args:\n",
    "    query (str): The initial query to be refined.\n",
    "    \n",
    "    Returns:\n",
    "    str: The refined and expanded query.\n",
    "    \"\"\"\n",
    "    # Define a prompt that instructs the LLM on how to expand the query\n",
    "    prompt = (\n",
    "        f\"Given a patient profile for a clinical trial, refine and expand the following query to be more specific and contextual:\\n\\n\"\n",
    "        f\"Query: {query}\\n\\n\"\n",
    "        \"Refined Query:\"\n",
    "    )\n",
    "    \n",
    "    # Use LangChain's structured LLM interaction method\n",
    "    # Note: Adjust the method invocation based on the specific LangChain version and LLM interface you are using\n",
    "    response = await llm.invoke(prompt=prompt, max_tokens=100, temperature=0.7, stop_sequences=[\"\\n\"])\n",
    "    refined_query = response['choices'][0]['text'].strip()\n",
    "    \n",
    "    # Fallback in case the LLM does not generate a useful output\n",
    "    if not refined_query:\n",
    "        refined_query = query\n",
    "    \n",
    "    return refined_query\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "g = await refine_query_with_llm(\"Stage III colon adenocarcinoma\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sentence_transformers import CrossEncoder\n",
    "model = CrossEncoder('cross-encoder/nli-deberta-v3-large')\n",
    "scores = model.predict([('The patient has KRAS mutation', 'The man has cancer')])\n",
    "\n",
    "#Convert scores to labels\n",
    "label_mapping = ['disagreement', 'agreement', 'neutral']\n",
    "labels = [label_mapping[score_max] for score_max in scores.argmax(axis=1)]\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/home/mabdallah/miniconda3/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "from crossencoder_reranker import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'{\"scores\": [-2.5739340782165527, 4.65949010848999]}'"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model = model_fn(\"\")\n",
    "transform_fn(model, \"{\\\"pairs\\\": [[\\\"Patient with KRAS mutation\\\", \\\"Patient without KRAS mutation\\\"], [\\\"Patient with KRAS mutation\\\", \\\"KRAS mutation positive\\\"]]}\", \"application/json\", \"application/json\")\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}