[be1279]: / sub-packages / bionemo-evo2 / examples / zeroshot_brca1.ipynb

Download this file

1216 lines (1215 with data), 177.6 kB

{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Zero-shot prediction of BRCA1 variant effects with Evo 2\n",
    "Deploy this tutorial on brev.dev: \n",
    "[![ Click here to deploy.](https://brev-assets.s3.us-west-1.amazonaws.com/nv-lb-dark.svg)](https://console.brev.dev/launchable/deploy?launchableID=env-2uGqxNLgVdl752F2qcTFwHHn4Rj)\n",
    "\n",
    "*Note - this notebook is a reproduction of The Arc Institute’s same-titled notebook [here](https://github.com/ArcInstitute/evo2/blob/main/notebooks/brca1/brca1_zero_shot_vep.ipynb), using the BioNeMo 2 implementation of Evo2.*\n",
    "\n",
    "Evo2 is a foundation AI model trained on 9.3 trillion DNA base pairs, predicting variant effects without prior tast-specific training. \n",
    "\n",
    "Without being explicitly trained on BRCA1 variants, we show Evo 2's ability to generalize across all life forms.\n",
    "\n",
    "The human *BRCA1* gene encodes for a protein that repairs damaged DNA ([Moynahan et al., 1999](https://www.cell.com/molecular-cell/fulltext/S1097-2765%2800%2980202-6)). Certain variants of this gene have been associated with an increased risk of breast and ovarian cancers ([Miki et al., 1994](https://www.science.org/doi/10.1126/science.7545954?url_ver=Z39.88-2003&rfr_id=ori:rid:crossref.org&rfr_dat=cr_pub%20%200pubmed)). Using Evo 2, we can predict whether a particular single nucleotide variant (SNV) of the *BRCA1* gene is likely to be harmful to the protein's function, and thus potentially increase the risk of cancer for the patient with the genetic variant."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "!pip install biopython openpyxl\n",
    "import os\n",
    "\n",
    "\n",
    "# Runs a subset of the model layers to test that the notebook runs in CI, but the output will be incorrect.\n",
    "FAST_CI_MODE: bool = os.environ.get(\"FAST_CI_MODE\", False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import glob\n",
    "import gzip\n",
    "import json\n",
    "import math\n",
    "import os\n",
    "from pathlib import Path\n",
    "\n",
    "import matplotlib.pyplot as plt\n",
    "import pandas as pd\n",
    "import seaborn as sns\n",
    "import torch\n",
    "from Bio import SeqIO\n",
    "from sklearn.metrics import auc, roc_auc_score, roc_curve"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We start by loading a dataset from [Findlay et al. (2018)](https://www.nature.com/articles/s41586-018-0461-z), which contains experimentally measured function scores of 3,893 *BRCA1* SNVs. These function scores reflect the extent by which the genetic variant has disrupted the protein's function, with lower scores indicating greater disruption. In this dataset, the SNVs are classified into three categories based on their function scores: `LOF` (loss-of-function), `INT` (intermediate), and `FUNC` (functional). We start by reading in this dataset.\n",
    "\n",
    "To keep the notebook streamlined, we've abstracted much of the preprocessing logic into accompanying scripts located in `brca1_utils`. The full notebook can be viewed [here](https://github.com/NVIDIA/bionemo-framework/blob/main/docs/docs/user-guide/examples/bionemo-evo2/evo2_zeroshot_brca.ipynb)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    },
    "tags": [
     "hide-cell"
    ]
   },
   "outputs": [],
   "source": [
    "def download_data(data_dir=\"brca1\", commit_hash=\"3819474bee6c24938016614411f1fa025e542bbe\"):\n",
    "    \"\"\"Download required data files if they don't exist locally.\n",
    "\n",
    "    Parameters:\n",
    "    -----------\n",
    "    data_dir : str\n",
    "        Directory to store downloaded files\n",
    "    commit_hash : str\n",
    "        GitHub commit hash for data version\n",
    "    \"\"\"\n",
    "    if not os.path.exists(data_dir):\n",
    "        os.makedirs(data_dir)\n",
    "\n",
    "    excel_path = os.path.join(data_dir, \"41586_2018_461_MOESM3_ESM.xlsx\")\n",
    "    genome_path = os.path.join(data_dir, \"GRCh37.p13_chr17.fna.gz\")\n",
    "\n",
    "    if not os.path.exists(excel_path):\n",
    "        os.system(\n",
    "            f\"wget https://github.com/ArcInstitute/evo2/raw/{commit_hash}/notebooks/brca1/41586_2018_461_MOESM3_ESM.xlsx -O {excel_path}\"\n",
    "        )\n",
    "\n",
    "    if not os.path.exists(genome_path):\n",
    "        os.system(\n",
    "            f\"wget https://github.com/ArcInstitute/evo2/raw/{commit_hash}/notebooks/brca1/GRCh37.p13_chr17.fna.gz -O {genome_path}\"\n",
    "        )\n",
    "\n",
    "    return excel_path, genome_path\n",
    "\n",
    "\n",
    "def load_genome_sequence(genome_path):\n",
    "    \"\"\"Load genome sequence from FASTA file.\n",
    "\n",
    "    Parameters:\n",
    "    -----------\n",
    "    genome_path : str\n",
    "        Path to the genome FASTA file\n",
    "\n",
    "    Returns:\n",
    "    --------\n",
    "    str\n",
    "        Genome sequence string\n",
    "    \"\"\"\n",
    "    with gzip.open(genome_path, \"rt\") as handle:\n",
    "        for record in SeqIO.parse(handle, \"fasta\"):\n",
    "            return str(record.seq)\n",
    "\n",
    "    raise ValueError(\"Failed to parse genome sequence\")\n",
    "\n",
    "\n",
    "def load_brca1_data(excel_path):\n",
    "    \"\"\"Load and preprocess BRCA1 data from Excel file.\n",
    "\n",
    "    Parameters:\n",
    "    -----------\n",
    "    excel_path : str\n",
    "        Path to the Excel file\n",
    "\n",
    "    Returns:\n",
    "    --------\n",
    "    pandas.DataFrame\n",
    "        Processed BRCA1 dataframe\n",
    "    \"\"\"\n",
    "    # Load the dataframe\n",
    "    brca1_df = pd.read_excel(excel_path, header=2)\n",
    "\n",
    "    # Select and rename columns\n",
    "    brca1_df = brca1_df[\n",
    "        [\n",
    "            \"chromosome\",\n",
    "            \"position (hg19)\",\n",
    "            \"reference\",\n",
    "            \"alt\",\n",
    "            \"function.score.mean\",\n",
    "            \"func.class\",\n",
    "        ]\n",
    "    ]\n",
    "\n",
    "    brca1_df.rename(\n",
    "        columns={\n",
    "            \"chromosome\": \"chrom\",\n",
    "            \"position (hg19)\": \"pos\",\n",
    "            \"reference\": \"ref\",\n",
    "            \"alt\": \"alt\",\n",
    "            \"function.score.mean\": \"score\",\n",
    "            \"func.class\": \"class\",\n",
    "        },\n",
    "        inplace=True,\n",
    "    )\n",
    "\n",
    "    # Convert to two-class system\n",
    "    brca1_df[\"class\"] = brca1_df[\"class\"].replace([\"FUNC\", \"INT\"], \"FUNC/INT\")\n",
    "\n",
    "    return brca1_df"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "\n",
    "# Configuration parameters\n",
    "DATA_DIR = \"brca1\"\n",
    "SAMPLE_CONFIG = {\"sample_frac\": 0.05, \"balanced\": True, \"disable\": False, \"random_state\": 42}\n",
    "\n",
    "# 1. Download the necessary data files if not present\n",
    "excel_path, genome_path = download_data(DATA_DIR)\n",
    "seq_chr17 = load_genome_sequence(genome_path)\n",
    "\n",
    "# 2. Load and preprocess BRCA1 data\n",
    "brca1_df = load_brca1_data(excel_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We then group the `FUNC` and `INT` classes of SNVs together into a single category (`FUNC/INT`).\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We build a function to parse the reference and variant sequences of a 8,192-bp window around the genomic position of each SNV, using the reference sequence of human chromosome 17 where *BRCA1* is located.\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To make things run faster, we'll just look at a balanced sample of our data. If you want to run on the full dataset, set `disable_sample=True`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    },
    "tags": [
     "hide-cell"
    ]
   },
   "outputs": [],
   "source": [
    "def sample_data(df, sample_frac=1.0, balanced=True, disable=False, random_state=42):\n",
    "    \"\"\"Sample dataframe, optionally with balanced classes.\n",
    "\n",
    "    Parameters:\n",
    "    -----------\n",
    "    df : pandas.DataFrame\n",
    "        Input dataframe\n",
    "    sample_frac : float\n",
    "        Fraction of data to sample\n",
    "    balanced : bool\n",
    "        Whether to balance classes\n",
    "    disable : bool\n",
    "        Whether to disable sampling\n",
    "    random_state : int\n",
    "        Random seed for reproducibility\n",
    "\n",
    "    Returns:\n",
    "    --------\n",
    "    pandas.DataFrame\n",
    "        Sampled dataframe\n",
    "    \"\"\"\n",
    "    if disable:\n",
    "        return df\n",
    "\n",
    "    if balanced:\n",
    "        # Get the number of rows in the dataframe\n",
    "        num_rows_minor_class = math.ceil(len(df[df[\"class\"] == \"LOF\"]) * sample_frac)\n",
    "        return (\n",
    "            pd.concat(\n",
    "                [\n",
    "                    df[df[\"class\"] == \"LOF\"].sample(n=num_rows_minor_class, random_state=random_state),\n",
    "                    df[df[\"class\"] == \"FUNC/INT\"].sample(n=num_rows_minor_class, random_state=random_state),\n",
    "                ]\n",
    "            )\n",
    "            .sample(frac=1.0, random_state=random_state)\n",
    "            .reset_index(drop=True)\n",
    "        )\n",
    "    else:\n",
    "        # Calculate the number of rows to sample\n",
    "        return df.sample(frac=sample_frac, random_state=random_state).reset_index(drop=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>chrom</th>\n",
       "      <th>pos</th>\n",
       "      <th>ref</th>\n",
       "      <th>alt</th>\n",
       "      <th>score</th>\n",
       "      <th>class</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>17</td>\n",
       "      <td>41199726</td>\n",
       "      <td>T</td>\n",
       "      <td>C</td>\n",
       "      <td>0.159762</td>\n",
       "      <td>FUNC/INT</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>17</td>\n",
       "      <td>41209074</td>\n",
       "      <td>T</td>\n",
       "      <td>A</td>\n",
       "      <td>-2.065569</td>\n",
       "      <td>LOF</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>17</td>\n",
       "      <td>41256913</td>\n",
       "      <td>A</td>\n",
       "      <td>C</td>\n",
       "      <td>-0.847753</td>\n",
       "      <td>FUNC/INT</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>17</td>\n",
       "      <td>41219631</td>\n",
       "      <td>T</td>\n",
       "      <td>A</td>\n",
       "      <td>-2.053739</td>\n",
       "      <td>LOF</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>17</td>\n",
       "      <td>41215965</td>\n",
       "      <td>G</td>\n",
       "      <td>A</td>\n",
       "      <td>-1.671525</td>\n",
       "      <td>LOF</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   chrom       pos ref alt     score     class\n",
       "0     17  41199726   T   C  0.159762  FUNC/INT\n",
       "1     17  41209074   T   A -2.065569       LOF\n",
       "2     17  41256913   A   C -0.847753  FUNC/INT\n",
       "3     17  41219631   T   A -2.053739       LOF\n",
       "4     17  41215965   G   A -1.671525       LOF"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "OUTPUT_DIR = \"brca1_fasta_files\"\n",
    "\n",
    "brca1_df = sample_data(\n",
    "    brca1_df,\n",
    "    sample_frac=SAMPLE_CONFIG[\"sample_frac\"],\n",
    "    balanced=SAMPLE_CONFIG[\"balanced\"],\n",
    "    disable=SAMPLE_CONFIG[\"disable\"],\n",
    "    random_state=SAMPLE_CONFIG[\"random_state\"],\n",
    ")\n",
    "\n",
    "brca1_df.head(5)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we'll write these to local `.fasta` files so we can use them for prediction below."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    },
    "tags": [
     "hide-cell"
    ]
   },
   "outputs": [],
   "source": [
    "def parse_sequences(pos, ref, alt, seq_chr17, window_size=8192):\n",
    "    \"\"\"Parse reference and variant sequences from the reference genome sequence.\n",
    "\n",
    "    Parameters:\n",
    "    -----------\n",
    "    pos : int\n",
    "        Position (1-indexed)\n",
    "    ref : str\n",
    "        Reference base\n",
    "    alt : str\n",
    "        Alternate base\n",
    "    seq_chr17 : str\n",
    "        Full chromosome 17 sequence\n",
    "    window_size : int\n",
    "        Size of the sequence window to extract\n",
    "\n",
    "    Returns:\n",
    "    --------\n",
    "    tuple\n",
    "        (reference_sequence, variant_sequence)\n",
    "    \"\"\"\n",
    "    p = pos - 1  # Convert to 0-indexed position\n",
    "    full_seq = seq_chr17\n",
    "\n",
    "    ref_seq_start = max(0, p - window_size // 2)\n",
    "    ref_seq_end = min(len(full_seq), p + window_size // 2)\n",
    "    ref_seq = seq_chr17[ref_seq_start:ref_seq_end]\n",
    "    snv_pos_in_ref = min(window_size // 2, p)\n",
    "    var_seq = ref_seq[:snv_pos_in_ref] + alt + ref_seq[snv_pos_in_ref + 1 :]\n",
    "\n",
    "    # Sanity checks\n",
    "    assert len(var_seq) == len(ref_seq)\n",
    "    assert ref_seq[snv_pos_in_ref] == ref\n",
    "    assert var_seq[snv_pos_in_ref] == alt\n",
    "\n",
    "    return ref_seq, var_seq\n",
    "\n",
    "\n",
    "def generate_fasta_files(df, seq_chr17, output_dir=\"brca1_fasta_files\", window_size=8192):\n",
    "    \"\"\"Generate FASTA files for reference and variant sequences.\n",
    "\n",
    "    Parameters:\n",
    "    -----------\n",
    "    df : pandas.DataFrame\n",
    "        Dataframe with variant information\n",
    "    seq_chr17 : str\n",
    "        Chromosome 17 sequence\n",
    "    output_dir : str\n",
    "        Output directory for FASTA files\n",
    "    window_size : int\n",
    "        Size of sequence window\n",
    "\n",
    "    Returns:\n",
    "    --------\n",
    "    pandas.DataFrame\n",
    "        Dataframe with added columns for FASTA names\n",
    "    \"\"\"\n",
    "    # Create output directory\n",
    "    output_dir = Path(output_dir)\n",
    "    output_dir.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "    # Paths for output files\n",
    "    ref_fasta_path = output_dir / \"brca1_reference_sequences.fasta\"\n",
    "    var_fasta_path = output_dir / \"brca1_variant_sequences.fasta\"\n",
    "\n",
    "    # Track unique sequences\n",
    "    ref_sequences = set()\n",
    "    var_sequences = set()\n",
    "    ref_seq_to_name = {}\n",
    "\n",
    "    # Store unique sequences with metadata for writing\n",
    "    ref_entries = []\n",
    "    var_entries = []\n",
    "    ref_names = []\n",
    "    var_names = []\n",
    "\n",
    "    # Collect unique reference and variant sequences\n",
    "    for idx, row in df.iterrows():\n",
    "        ref_seq, var_seq = parse_sequences(row[\"pos\"], row[\"ref\"], row[\"alt\"], seq_chr17, window_size)\n",
    "\n",
    "        # Add to sets to ensure uniqueness\n",
    "        if ref_seq not in ref_sequences:\n",
    "            ref_sequences.add(ref_seq)\n",
    "            ref_name = f\"BRCA1_ref_pos_{row['pos']}_{row['ref']}_class_{row['class']}\"\n",
    "\n",
    "            ref_entries.append(f\">{ref_name}\\n{ref_seq}\\n\")\n",
    "            ref_names.append(ref_name)\n",
    "            ref_seq_to_name[ref_seq] = ref_name\n",
    "        else:\n",
    "            ref_name = ref_seq_to_name[ref_seq]\n",
    "            ref_names.append(ref_name)\n",
    "\n",
    "        if var_seq not in var_sequences:\n",
    "            var_sequences.add(var_seq)\n",
    "            var_name = f\"BRCA1_var_pos_{row['pos']}_{row['ref']}to{row['alt']}_class_{row['class']}\"\n",
    "\n",
    "            var_entries.append(f\">{var_name}\\n{var_seq}\\n\")\n",
    "            var_names.append(var_name)\n",
    "        else:\n",
    "            assert False, \"Duplicate variant sequence\"\n",
    "\n",
    "    # Write unique sequences to FASTA files\n",
    "    with open(ref_fasta_path, \"w\") as f:\n",
    "        f.writelines(ref_entries)\n",
    "\n",
    "    with open(var_fasta_path, \"w\") as f:\n",
    "        f.writelines(var_entries)\n",
    "\n",
    "    # Add FASTA names to dataframe\n",
    "    df_with_names = df.copy()\n",
    "    df_with_names[\"ref_fasta_name\"] = ref_names\n",
    "    df_with_names[\"var_fasta_name\"] = var_names\n",
    "\n",
    "    print(f\"Total unique reference sequences: {len(ref_sequences)}\")\n",
    "    print(f\"Total unique variant sequences: {len(var_sequences)}\")\n",
    "\n",
    "    return df_with_names"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Total unique reference sequences: 79\n",
      "Total unique variant sequences: 84\n"
     ]
    }
   ],
   "source": [
    "brca1_df = generate_fasta_files(brca1_df, seq_chr17, output_dir=OUTPUT_DIR)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Load Evo 2 Checkpoints"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "\n",
    "Then, we load Evo 2 1B model, loading the Evo 2 weights from hugging face.\n",
    "\n",
    "*Note - for better performance, load the 7b model by setting `MODEL_SIZE=\"7b\"` which also works well GPUs that do not support FP8.*\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "EXPERIMENTAL_1b_CHECKPOINT: bool = False\n",
    "MODEL_SIZE = \"1b\"  # also try 7b if you have a GPU with more than 32GB of memory\n",
    "# Define checkpoint path\n",
    "if EXPERIMENTAL_1b_CHECKPOINT and MODEL_SIZE == \"1b\":\n",
    "    from bionemo.core.data.load import load\n",
    "\n",
    "    # This is a new 1b checkpoint that has been fine-tuned on BF16 hardware. It should be able to handle FP8 as well.\n",
    "    #  this line will download the checkpoint from NGC to your $HOME/.cache/bionemo directory and return the path.\n",
    "    #  alternatively you can use `CHECKPOINT_PATH=$(download_bionemo_data evo2/1b-8k-bf16:1.0)`  to do the same on the\n",
    "    #  command line.\n",
    "    checkpoint_path = load(\"evo2/1b-8k-bf16:1.0\")\n",
    "else:\n",
    "    checkpoint_path = Path(f\"nemo2_evo2_{MODEL_SIZE}_8k\")\n",
    "\n",
    "    # Check if the directory does not exist or is empty\n",
    "    if not checkpoint_path.exists() or not any(checkpoint_path.iterdir()):\n",
    "        !evo2_convert_to_nemo2 --model-path hf://arcinstitute/savanna_evo2_{MODEL_SIZE}_base --model-size {MODEL_SIZE} --output-dir nemo2_evo2_{MODEL_SIZE}_8k\n",
    "    else:\n",
    "        print(\"Checkpoint directory is not empty. Skipping command.\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Score Sequences"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next, we score the likelihoods of the reference and variant sequences of each SNV.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    },
    "tags": [
     "hide-cell"
    ]
   },
   "outputs": [],
   "source": [
    "def check_fp8_support():\n",
    "    \"\"\"Check if FP8 is supported on the current GPU.\n",
    "\n",
    "    FP8 requires compute capability 8.9+ (Ada Lovelace/Hopper architecture or newer).\n",
    "    \"\"\"\n",
    "    if not torch.cuda.is_available():\n",
    "        return False, \"CUDA not available\"\n",
    "\n",
    "    device_props = torch.cuda.get_device_properties(0)\n",
    "    compute_capability = f\"{device_props.major}.{device_props.minor}\"\n",
    "    device_name = device_props.name\n",
    "\n",
    "    # FP8 is supported on compute capability 8.9+ (Ada Lovelace/Hopper architecture)\n",
    "    is_supported = (device_props.major > 8) or (device_props.major == 8 and device_props.minor >= 9)\n",
    "\n",
    "    return is_supported, f\"Device: {device_name}, Compute Capability: {compute_capability}\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "FP8 Support: False\n",
      "Device: NVIDIA RTX A6000, Compute Capability: 8.6\n"
     ]
    }
   ],
   "source": [
    "# Define output directories for prediction results\n",
    "output_dir = Path(\"brca1_fasta_files\")\n",
    "output_dir.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "# Save reference and variant sequences to FASTA\n",
    "ref_fasta_path = output_dir / \"brca1_reference_sequences.fasta\"\n",
    "var_fasta_path = output_dir / \"brca1_variant_sequences.fasta\"\n",
    "\n",
    "predict_ref_dir = output_dir / \"reference_predictions\"\n",
    "predict_var_dir = output_dir / \"variant_predictions\"\n",
    "predict_ref_dir.mkdir(parents=True, exist_ok=True)\n",
    "predict_var_dir.mkdir(parents=True, exist_ok=True)\n",
    "\n",
    "fp8_supported, gpu_info = check_fp8_support()\n",
    "print(f\"FP8 Support: {fp8_supported}\")\n",
    "print(gpu_info)\n",
    "\n",
    "# Note: If FP8 is not supported, you may want to disable it in the model config\n",
    "# The Evo2 config has 'use_fp8_input_projections: True' by default\n",
    "\n",
    "if FAST_CI_MODE:\n",
    "    model_subset_option = \"--num-layers 4 --hybrid-override-pattern SDH*\"\n",
    "else:\n",
    "    model_subset_option = \"\"\n",
    "\n",
    "fp8_option = \"--fp8\" if fp8_supported else \"\"\n",
    "\n",
    "# Update predict commands to run on the full dataset\n",
    "predict_ref_command = (\n",
    "    f\"predict_evo2 --fasta {ref_fasta_path} --ckpt-dir {checkpoint_path} \"\n",
    "    f\"--output-dir {predict_ref_dir} --model-size {MODEL_SIZE} --tensor-parallel-size 1  {model_subset_option} \"\n",
    "    f\"--pipeline-model-parallel-size 1 --context-parallel-size 1 --output-log-prob-seqs {fp8_option}\"\n",
    ")\n",
    "\n",
    "predict_var_command = (\n",
    "    f\"predict_evo2 --fasta {var_fasta_path} --ckpt-dir {checkpoint_path} \"\n",
    "    f\"--output-dir {predict_var_dir} --model-size {MODEL_SIZE} --tensor-parallel-size 1 {model_subset_option} \"\n",
    "    f\"--pipeline-model-parallel-size 1 --context-parallel-size 1 --output-log-prob-seqs {fp8_option}\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Score reference sequences:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "print(f\"Running command: {predict_ref_command}\")\n",
    "!{predict_ref_command}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Score variant sequences:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "%%capture\n",
    "print(f\"Running command: {predict_var_command}\")\n",
    "!{predict_var_command}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We calculate the change in likelihoods for each variant relative to the likelihood of their respective wild-type sequence.\n",
    "\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "First, we load the prediction files and sequence id maps:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Find and load prediction files\n",
    "ref_pred_files = glob.glob(os.path.join(predict_ref_dir, \"predictions__rank_*.pt\"))\n",
    "var_pred_files = glob.glob(os.path.join(predict_var_dir, \"predictions__rank_*.pt\"))\n",
    "\n",
    "# Load sequence ID maps (maps sequence ID -> prediction index)\n",
    "with open(os.path.join(predict_ref_dir, \"seq_idx_map.json\"), \"r\") as f:\n",
    "    ref_seq_idx_map = json.load(f)\n",
    "with open(os.path.join(predict_var_dir, \"seq_idx_map.json\"), \"r\") as f:\n",
    "    var_seq_idx_map = json.load(f)\n",
    "\n",
    "# Load predictions\n",
    "ref_preds = torch.load(ref_pred_files[0])\n",
    "var_preds = torch.load(var_pred_files[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Then, calculate the delta score:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>chrom</th>\n",
       "      <th>pos</th>\n",
       "      <th>ref</th>\n",
       "      <th>alt</th>\n",
       "      <th>score</th>\n",
       "      <th>class</th>\n",
       "      <th>ref_fasta_name</th>\n",
       "      <th>var_fasta_name</th>\n",
       "      <th>ref_log_probs</th>\n",
       "      <th>var_log_probs</th>\n",
       "      <th>evo2_delta_score</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>17</td>\n",
       "      <td>41199726</td>\n",
       "      <td>T</td>\n",
       "      <td>C</td>\n",
       "      <td>0.159762</td>\n",
       "      <td>FUNC/INT</td>\n",
       "      <td>BRCA1_ref_pos_41199726_T_class_FUNC/INT</td>\n",
       "      <td>BRCA1_var_pos_41199726_TtoC_class_FUNC/INT</td>\n",
       "      <td>-0.952916</td>\n",
       "      <td>-0.953258</td>\n",
       "      <td>-0.000342</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>17</td>\n",
       "      <td>41209074</td>\n",
       "      <td>T</td>\n",
       "      <td>A</td>\n",
       "      <td>-2.065569</td>\n",
       "      <td>LOF</td>\n",
       "      <td>BRCA1_ref_pos_41209074_T_class_LOF</td>\n",
       "      <td>BRCA1_var_pos_41209074_TtoA_class_LOF</td>\n",
       "      <td>-0.750398</td>\n",
       "      <td>-0.750437</td>\n",
       "      <td>-0.000039</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>17</td>\n",
       "      <td>41256913</td>\n",
       "      <td>A</td>\n",
       "      <td>C</td>\n",
       "      <td>-0.847753</td>\n",
       "      <td>FUNC/INT</td>\n",
       "      <td>BRCA1_ref_pos_41256913_A_class_FUNC/INT</td>\n",
       "      <td>BRCA1_var_pos_41256913_AtoC_class_FUNC/INT</td>\n",
       "      <td>-0.798164</td>\n",
       "      <td>-0.799011</td>\n",
       "      <td>-0.000847</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>17</td>\n",
       "      <td>41219631</td>\n",
       "      <td>T</td>\n",
       "      <td>A</td>\n",
       "      <td>-2.053739</td>\n",
       "      <td>LOF</td>\n",
       "      <td>BRCA1_ref_pos_41219631_T_class_LOF</td>\n",
       "      <td>BRCA1_var_pos_41219631_TtoA_class_LOF</td>\n",
       "      <td>-1.032126</td>\n",
       "      <td>-1.032697</td>\n",
       "      <td>-0.000571</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>17</td>\n",
       "      <td>41215965</td>\n",
       "      <td>G</td>\n",
       "      <td>A</td>\n",
       "      <td>-1.671525</td>\n",
       "      <td>LOF</td>\n",
       "      <td>BRCA1_ref_pos_41215965_G_class_LOF</td>\n",
       "      <td>BRCA1_var_pos_41215965_GtoA_class_LOF</td>\n",
       "      <td>-0.860847</td>\n",
       "      <td>-0.861287</td>\n",
       "      <td>-0.000441</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "   chrom       pos ref alt     score     class  \\\n",
       "0     17  41199726   T   C  0.159762  FUNC/INT   \n",
       "1     17  41209074   T   A -2.065569       LOF   \n",
       "2     17  41256913   A   C -0.847753  FUNC/INT   \n",
       "3     17  41219631   T   A -2.053739       LOF   \n",
       "4     17  41215965   G   A -1.671525       LOF   \n",
       "\n",
       "                            ref_fasta_name  \\\n",
       "0  BRCA1_ref_pos_41199726_T_class_FUNC/INT   \n",
       "1       BRCA1_ref_pos_41209074_T_class_LOF   \n",
       "2  BRCA1_ref_pos_41256913_A_class_FUNC/INT   \n",
       "3       BRCA1_ref_pos_41219631_T_class_LOF   \n",
       "4       BRCA1_ref_pos_41215965_G_class_LOF   \n",
       "\n",
       "                               var_fasta_name  ref_log_probs  var_log_probs  \\\n",
       "0  BRCA1_var_pos_41199726_TtoC_class_FUNC/INT      -0.952916      -0.953258   \n",
       "1       BRCA1_var_pos_41209074_TtoA_class_LOF      -0.750398      -0.750437   \n",
       "2  BRCA1_var_pos_41256913_AtoC_class_FUNC/INT      -0.798164      -0.799011   \n",
       "3       BRCA1_var_pos_41219631_TtoA_class_LOF      -1.032126      -1.032697   \n",
       "4       BRCA1_var_pos_41215965_GtoA_class_LOF      -0.860847      -0.861287   \n",
       "\n",
       "   evo2_delta_score  \n",
       "0         -0.000342  \n",
       "1         -0.000039  \n",
       "2         -0.000847  \n",
       "3         -0.000571  \n",
       "4         -0.000441  "
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# next, calculate change in likelihoods\n",
    "ref_log_probs = []\n",
    "var_log_probs = []\n",
    "for _, row in brca1_df.iterrows():\n",
    "    ref_name = row[\"ref_fasta_name\"]\n",
    "    var_name = row[\"var_fasta_name\"]\n",
    "    ref_log_probs.append(ref_preds[\"log_probs_seqs\"][ref_seq_idx_map[ref_name]].item())\n",
    "    var_log_probs.append(var_preds[\"log_probs_seqs\"][var_seq_idx_map[var_name]].item())\n",
    "brca1_df[\"ref_log_probs\"] = ref_log_probs\n",
    "brca1_df[\"var_log_probs\"] = var_log_probs\n",
    "# ideally probability of a broken variant is lower than a good one. So a bad var - good ref is negative.\n",
    "brca1_df[\"evo2_delta_score\"] = brca1_df[\"var_log_probs\"] - brca1_df[\"ref_log_probs\"]\n",
    "brca1_df.head()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "This delta likelihood should be predictive of how disruptive the SNV is to the protein's function: the lower the delta, the more likely that the SNV is disruptive. We can show this by comparing the distributions of delta likelihoods for the two classes of SNVs (functional/intermediate vs loss-of-function)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    },
    "tags": [
     "hide-cell"
    ]
   },
   "outputs": [],
   "source": [
    "def plot_strip_with_means(df, x_col=\"evo2_delta_score\", class_col=\"class\"):\n",
    "    \"\"\"Creates a strip plot with jittered points and median indicators for each class using Seaborn.\n",
    "\n",
    "    Parameters:\n",
    "    - df (pd.DataFrame): The input DataFrame containing data.\n",
    "    - x_col (str): The column name representing the x-axis values (e.g., evo2_delta_score).\n",
    "    - class_col (str): The column name representing the class labels.\n",
    "\n",
    "    Returns:\n",
    "    - matplotlib Figure: Strip plot with median indicators.\n",
    "    \"\"\"\n",
    "    # NVIDIA theme colors\n",
    "    NVIDIA_GREEN = \"#76B900\"\n",
    "    BACKGROUND_COLOR = \"#F8F8F8\"\n",
    "    GRID_COLOR = \"#DDDDDD\"\n",
    "    FONT_COLOR = \"#333333\"\n",
    "\n",
    "    # Determine order of classes (if not already specified)\n",
    "    unique_classes = sorted(df[class_col].unique())\n",
    "\n",
    "    # Set up the plot with NVIDIA theme\n",
    "    plt.figure(figsize=(9, 5), facecolor=BACKGROUND_COLOR)\n",
    "    plt.style.use(\"default\")  # Reset to default to avoid any pre-existing style\n",
    "\n",
    "    # Create strip plot\n",
    "    p = sns.stripplot(\n",
    "        data=df,\n",
    "        x=x_col,\n",
    "        y=class_col,\n",
    "        hue=class_col,\n",
    "        order=unique_classes,\n",
    "        palette=[NVIDIA_GREEN, \"red\"],\n",
    "        size=6,\n",
    "        jitter=0.3,\n",
    "        alpha=0.6,\n",
    "    )\n",
    "\n",
    "    # Add median indicators using boxplot\n",
    "    sns.boxplot(\n",
    "        showmeans=True,\n",
    "        meanline=True,\n",
    "        meanprops={\"visible\": False},\n",
    "        medianprops={\"color\": \"black\", \"ls\": \"-\", \"lw\": 2},\n",
    "        whiskerprops={\"visible\": False},\n",
    "        zorder=10,\n",
    "        x=x_col,\n",
    "        y=class_col,\n",
    "        data=df,\n",
    "        order=unique_classes,\n",
    "        showfliers=False,\n",
    "        showbox=False,\n",
    "        showcaps=False,\n",
    "        ax=p,\n",
    "    )\n",
    "\n",
    "    # Customize plot appearance\n",
    "    plt.title(\n",
    "        \"Distribution of Delta Likelihoods Scores\\nComparing Evo 2 likelihood scores for different BRCA1 SNV classes\",\n",
    "        color=FONT_COLOR,\n",
    "        fontsize=12,\n",
    "        loc=\"left\",\n",
    "    )\n",
    "    plt.xlabel(\"Delta Likelihood Score, Evo 2\", color=FONT_COLOR)\n",
    "    plt.ylabel(\"BRCA1 SNV Class\", color=FONT_COLOR)\n",
    "\n",
    "    # Customize grid and tick colors\n",
    "    plt.grid(color=GRID_COLOR, axis=\"x\", linestyle=\"--\", linewidth=0.5)\n",
    "    plt.tick_params(colors=FONT_COLOR)\n",
    "\n",
    "    # Set background color\n",
    "    plt.gca().set_facecolor(BACKGROUND_COLOR)\n",
    "    plt.gcf().set_facecolor(BACKGROUND_COLOR)\n",
    "\n",
    "    plt.tight_layout()\n",
    "\n",
    "    # return plt.gcf()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 900x500 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plot_strip_with_means(brca1_df, x_col=\"evo2_delta_score\", class_col=\"class\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We can also calculate the area under the receiver operating characteristic curve (AUROC) of this zero-shot prediction method. Note that the results are nearly random unless you are on one of the following configurations:\n",
    "* `--fp8` on an fp8 enabled GPU with either the 1b or 7b models. The 40b likely works as well.\n",
    "* the 7b model uniquely seems to work well without `--fp8` so if you are on an older device, the 7b model should produce\n",
    "  robust results. Change the `MODEL_SIZE` earlier in this tutorial and rerun for good results in that case.\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Zero-shot prediction AUROC: 0.77\n"
     ]
    }
   ],
   "source": [
    "# Calculate AUROC of zero-shot predictions\n",
    "#  class 1 is LOF which is the bad thing. That means we expect this to be more negative.\n",
    "y_true = brca1_df[\"class\"] == \"LOF\"\n",
    "auroc = roc_auc_score(y_true, -brca1_df[\"evo2_delta_score\"])\n",
    "print(f\"Zero-shot prediction AUROC: {auroc:.2}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "jupyter": {
     "source_hidden": true
    },
    "tags": [
     "hide-cell"
    ]
   },
   "outputs": [],
   "source": [
    "def plot_roc_curve(df):\n",
    "    \"\"\"Plots an ROC curve using Seaborn with a light NVIDIA-themed design.\n",
    "\n",
    "    The function assumes:\n",
    "    - `class` column as the true labels (binary, 'LOF' = 1, else 0).\n",
    "    - `evo2_delta_score` as the prediction score.\n",
    "\n",
    "    Parameters:\n",
    "    - df (pd.DataFrame): DataFrame containing `class` and `evo2_delta_score`.\n",
    "\n",
    "    Returns:\n",
    "    - matplotlib Figure: ROC Curve Visualization.\n",
    "    \"\"\"\n",
    "    # NVIDIA theme colors\n",
    "    NVIDIA_GREEN = \"#76B900\"\n",
    "    BACKGROUND_COLOR = \"#F8F8F8\"\n",
    "    GRID_COLOR = \"#DDDDDD\"\n",
    "    FONT_COLOR = \"#333333\"\n",
    "\n",
    "    # Validate required columns\n",
    "    if \"class\" not in df.columns or \"evo2_delta_score\" not in df.columns:\n",
    "        raise ValueError(\"DataFrame must contain 'class' and 'evo2_delta_score' columns.\")\n",
    "\n",
    "    # Convert 'class' to binary labels: Assume 'LOF' = 1, anything else = 0\n",
    "    y_true = (df[\"class\"] == \"LOF\").astype(int)\n",
    "\n",
    "    # Compute ROC curve\n",
    "    fpr, tpr, _ = roc_curve(y_true, -df[\"evo2_delta_score\"])  # Negative to align with previous logic\n",
    "    roc_auc = auc(fpr, tpr)\n",
    "\n",
    "    # Set up the plot with NVIDIA theme\n",
    "    plt.figure(figsize=(9, 5), facecolor=BACKGROUND_COLOR)\n",
    "    plt.style.use(\"default\")  # Reset to default to avoid any pre-existing style\n",
    "\n",
    "    # Plot ROC curve\n",
    "    plt.plot(fpr, tpr, color=NVIDIA_GREEN, lw=3, label=f\"ROC curve (AUROC = {roc_auc:.2f})\")\n",
    "\n",
    "    # Plot diagonal reference line for random guessing\n",
    "    plt.plot([0, 1], [0, 1], color=\"gray\", lw=2, linestyle=\"--\")\n",
    "\n",
    "    # Customize plot appearance\n",
    "    plt.xlim([0.0, 1.0])\n",
    "    plt.ylim([0.0, 1.05])\n",
    "    plt.xlabel(\"False Positive Rate\", color=FONT_COLOR, fontsize=12)\n",
    "    plt.ylabel(\"True Positive Rate\", color=FONT_COLOR, fontsize=12)\n",
    "    plt.title(\n",
    "        \"Zeroshot ROC Curve\\nEvaluating the discriminative performance of Evo 2 predictions\",\n",
    "        color=FONT_COLOR,\n",
    "        fontsize=16,\n",
    "        loc=\"left\",\n",
    "    )\n",
    "\n",
    "    # Customize grid and tick colors\n",
    "    plt.grid(color=GRID_COLOR, linestyle=\"--\", linewidth=0.5)\n",
    "    plt.tick_params(colors=FONT_COLOR)\n",
    "\n",
    "    # Set background color\n",
    "    plt.gca().set_facecolor(BACKGROUND_COLOR)\n",
    "\n",
    "    # Add legend\n",
    "    plt.legend(loc=\"lower right\", frameon=True, facecolor=BACKGROUND_COLOR, edgecolor=GRID_COLOR)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "",
      "text/plain": [
       "<Figure size 900x500 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "plot_roc_curve(brca1_df)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Full Sample Performance\n",
    "\n",
    "The above analysis may have been performed on a subset of the available data.\n",
    "\n",
    "For comparison, the table below presents the AUROC scores for different model sizes trained on the *full dataset* (100% sample fraction).\n",
    "\n",
    "| Model Size  | Dataset Sample Fraction | AUROC |\n",
    "|------------|------------------------|-------|\n",
    "| Evo 2 1B   | 100%                   | 0.76  |\n",
    "| Evo 2 7B   | 100%                   | 0.87  |\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}