--- a
+++ b/tests/regression_test.py
@@ -0,0 +1,423 @@
+import datetime
+import io
+import json
+import math
+import os
+import shutil
+import time
+import unittest
+
+import numpy as np
+import requests
+import tqdm
+import warnings
+import zipfile
+
+warnings.filterwarnings("ignore", category=DeprecationWarning)
+
+import sybil.model
+import sybil.models.calibrator
+from sybil import Serie, Sybil, visualize_attentions
+from sybil.utils import device_utils
+
+script_directory = os.path.dirname(os.path.abspath(__file__))
+PROJECT_DIR = os.path.dirname(script_directory)
+
+nlst_test_series_uids = """
+1.2.840.113654.2.55.117165331353985769278030759027968557175
+1.2.840.113654.2.55.125761431810488169605478704683628260210
+1.2.840.113654.2.55.141145605876336438705007116410698504988
+1.2.840.113654.2.55.172973285539665405130180217312651302726
+1.2.840.113654.2.55.177114075868256371370044474147630945288
+1.2.840.113654.2.55.210451208063625047828616019396666958685
+1.2.840.113654.2.55.22343358537878328490619391877977879745
+1.2.840.113654.2.55.250355771186119178528311921318050236359
+1.2.840.113654.2.55.264036959200244122726184171100390477201
+1.2.840.113654.2.55.270666838959776453521953970167166965589
+1.2.840.113654.2.55.5405951206377419400128917954731813327
+1.2.840.113654.2.55.83074092506605340087865221843273784687
+1.2.840.113654.2.55.9114064256331314804445563449996729696
+1.3.6.1.4.1.14519.5.2.1.7009.9004.102050757680671140089992182963
+1.3.6.1.4.1.14519.5.2.1.7009.9004.140916852551836049221836980755
+1.3.6.1.4.1.14519.5.2.1.7009.9004.145444099046834219014840219889
+1.3.6.1.4.1.14519.5.2.1.7009.9004.160633847701259284025259919227
+1.3.6.1.4.1.14519.5.2.1.7009.9004.219693265059595773200467950221
+1.3.6.1.4.1.14519.5.2.1.7009.9004.228293333306602707645036607751
+1.3.6.1.4.1.14519.5.2.1.7009.9004.230644512623268816899910856967
+1.3.6.1.4.1.14519.5.2.1.7009.9004.234524223570882184991800514748
+1.3.6.1.4.1.14519.5.2.1.7009.9004.252281466173937391895189766240
+1.3.6.1.4.1.14519.5.2.1.7009.9004.310293448890324961317272491664
+1.3.6.1.4.1.14519.5.2.1.7009.9004.330739122093904668699523188451
+1.3.6.1.4.1.14519.5.2.1.7009.9004.338644625343131376124729421878
+1.3.6.1.4.1.14519.5.2.1.7009.9004.646014655040104355409047679769
+"""
+
+test_series_uids = nlst_test_series_uids
+
+
+def myprint(instr):
+    print(f"{datetime.datetime.now()} - {instr}")
+
+
+def download_file(url, filename):
+    response = requests.get(url)
+
+    target_dir = os.path.dirname(filename)
+    if target_dir and not os.path.exists(target_dir):
+        os.makedirs(target_dir)
+
+    # Check if the request was successful
+    if response.status_code == 200:
+        with open(filename, 'wb') as file:
+            file.write(response.content)
+    else:
+        print(f"Failed to download file. Status code: {response.status_code}")
+
+    return filename
+
+
+def download_and_extract_zip(zip_file_name, cache_dir, url, demo_data_dir):
+    # Check and construct the full path of the zip file
+    zip_file_path = os.path.join(cache_dir, zip_file_name)
+
+    # 1. Check if the zip file exists
+    if not os.path.exists(zip_file_path):
+        # 2. Download the file
+        response = requests.get(url)
+        with open(zip_file_path, 'wb') as file:
+            file.write(response.content)
+
+    # 3. Check if the output directory exists
+    if not os.path.exists(demo_data_dir):
+        # 4. Extract the zip file
+        with zipfile.ZipFile(zip_file_path, 'r') as zip_ref:
+            zip_ref.extractall(demo_data_dir)
+    else:
+        pass
+        # myprint(f"Output directory {demo_data_dir} already exists. No extraction needed.")
+
+
+def get_sybil_model_path(model_name_or_path, cache_dir="~/.sybil"):
+    cache_dir = os.path.expanduser(cache_dir)
+    if os.path.exists(model_name_or_path):
+        path = model_name_or_path
+    elif model_name_or_path in sybil.model.NAME_TO_FILE:
+        paths = sybil.model.NAME_TO_FILE[model_name_or_path]["checkpoint"]
+        assert len(paths) == 1, "Can only save 1 model at a time, no ensembles"
+        path = os.path.join(cache_dir, paths[0] + ".ckpt")
+    else:
+        raise ValueError(f"Model name or path not found: {model_name_or_path}")
+
+    return path
+
+
+class TestPredict(unittest.TestCase):
+    def test_demo_data(self):
+        if not os.environ.get("SYBIL_TEST_RUN_REGRESSION", "false").lower() == "true":
+            import pytest
+            pytest.skip(f"Skipping long-running test in {type(self)}.")
+
+        # Download demo data
+        demo_data_url = "https://www.dropbox.com/scl/fi/covbvo6f547kak4em3cjd/sybil_example.zip?rlkey=7a13nhlc9uwga9x7pmtk1cf1c&st=dqi0cf9k&dl=1"
+        expected_scores = [
+            0.021628819563619374,
+            0.03857256315036462,
+            0.07191945816622261,
+            0.07926975188037134,
+            0.09584583525781108,
+            0.13568094038444453
+        ]
+
+        zip_file_name = "sybil_example.zip"
+        cache_dir = os.path.expanduser("~/.sybil")
+        demo_data_dir = os.path.join(cache_dir, "sybil_example")
+        image_data_dir = os.path.join(demo_data_dir, "sybil_demo_data")
+        os.makedirs(cache_dir, exist_ok=True)
+        download_and_extract_zip(zip_file_name, cache_dir, demo_data_url, demo_data_dir)
+
+        dicom_files = os.listdir(image_data_dir)
+        dicom_files = [os.path.join(image_data_dir, x) for x in dicom_files]
+        num_files = len(dicom_files)
+
+        # Load a trained model
+        model = Sybil()
+
+        myprint(f"Beginning prediction using {num_files} files from {image_data_dir}")
+
+        # Get risk scores
+        serie = Serie(dicom_files)
+        series = [serie]
+        prediction = model.predict(series, return_attentions=True)
+        actual_scores = prediction.scores[0]
+        count = len(actual_scores)
+
+        myprint(f"Prediction finished. Results\n{actual_scores}")
+
+        assert len(expected_scores) == len(actual_scores), f"Unexpected score length {count}"
+
+        all_elements_match = True
+        for exp_score, act_score in zip(expected_scores, actual_scores):
+            does_match = math.isclose(exp_score, act_score, rel_tol=1e-6)
+            assert does_match, f"Mismatched scores. {exp_score} != {act_score}"
+            all_elements_match &= does_match
+
+        print(f"Data URL: {demo_data_url}\nAll {count} elements match: {all_elements_match}")
+
+        series_with_attention = visualize_attentions(
+            series,
+            attentions=prediction.attentions,
+            save_directory="regression_test_output",
+            gain=3,
+        )
+
+
+def _get_nlst(series_instance_uid, cache_dir=".cache"):
+    base_url = "https://nlst.cancerimagingarchive.net/nbia-api/services/v1"
+    series_dir = os.path.join(cache_dir, series_instance_uid)
+    if os.path.exists(series_dir):
+        return series_dir
+
+    action = "getImage"
+    remote_url = f"{base_url}/{action}"
+    print(f"Downloading {series_instance_uid} from {remote_url}")
+    response = requests.get(remote_url, params={"SeriesInstanceUID": series_instance_uid})
+    # The response is a zip file, I want to unzip it into a directory
+    os.makedirs(series_dir, exist_ok=True)
+
+    if response.status_code == 200:
+        zip_file_bytes = io.BytesIO(response.content)
+        with zipfile.ZipFile(zip_file_bytes) as zip_file:
+            zip_file.extractall(series_dir)
+        print(f"Files extracted to {series_dir}")
+    else:
+        print(f"Failed to download file. Status code: {response.status_code}")
+
+    return series_dir
+
+
+class TestPredictionRegression(unittest.TestCase):
+
+    def test_nlst_predict(self, allow_resume=True, delete_downloaded_files=False):
+        if not os.environ.get("SYBIL_TEST_RUN_REGRESSION", "false").lower() == "true":
+            import pytest
+            pytest.skip(f"Skipping long-running test in {type(self)}.")
+
+        test_series_list = test_series_uids.split("\n")
+        test_series_list = [x.strip() for x in test_series_list if x.strip()]
+        print(f"About to test {len(test_series_list)} series")
+
+        # Whether to allow resuming from a previous run,
+        # or to overwrite the existing results file.
+        # Operates on a per-series basis.
+        model_name = "sybil_ensemble"
+
+        # True ->  send web requests to the ARK server (must be launched separately).
+        # False -> to run inference directly.
+        use_ark = os.environ.get("SYBIL_TEST_USE_ARK", "false").lower() == "true"
+        ark_host = os.environ.get("SYBIL_ARK_HOST", "http://localhost:5000")
+
+        version = sybil.__version__
+
+        out_fi_name = f"nlst_predictions_{model_name}_v{version}.json"
+        info_data = {}
+        if use_ark:
+            # Query the ARK server to get the version
+            print(f"Will use ark server {ark_host} for prediction")
+            resp = requests.get(f"{ark_host}/info")
+            info_data = resp.json()["data"]
+            print(f"ARK server response: {resp.text}")
+            version = info_data["modelVersion"]
+            out_fi_name = f"nlst_predictions_ark_v{version}.json"
+
+        output_dir = os.path.join(PROJECT_DIR, "tests", "nlst_predictions")
+
+        metadata = {
+                "modelName": model_name,
+                "modelVersion": version,
+                "start_time": datetime.datetime.now().isoformat(),
+            }
+        metadata.update(info_data)
+        all_results = {"__metadata__":  metadata}
+
+        os.makedirs(output_dir, exist_ok=True)
+        cur_pred_results = os.path.join(output_dir, out_fi_name)
+        cache_dir = os.path.join(PROJECT_DIR, ".cache")
+
+        if os.path.exists(cur_pred_results):
+            if allow_resume:
+                with open(cur_pred_results, 'r') as f:
+                    all_results = json.load(f)
+            else:
+                os.remove(cur_pred_results)
+
+        if use_ark:
+            model = device = None
+        else:
+            model = Sybil(model_name)
+
+        device = device_utils.get_default_device()
+        if bool(model) and bool(device):
+            model.to(device)
+
+        num_to_process = len(test_series_list)
+        for idx, series_uid in enumerate(tqdm.tqdm(test_series_list)):
+            print(f"{datetime.datetime.now()} Processing {series_uid} ({idx}/{num_to_process})")
+            if series_uid in all_results:
+                print(f"Already processed {series_uid}, skipping")
+                continue
+
+            series_dir = _get_nlst(series_uid, cache_dir=cache_dir)
+            dicom_files = os.listdir(series_dir)
+            dicom_files = sorted([os.path.join(series_dir, x) for x in dicom_files if x.endswith(".dcm")])
+
+            if len(dicom_files) < 20:
+                print(f"Skipping {series_uid} due to insufficient files ({len(dicom_files)})")
+                continue
+
+            try:
+                prediction = all_results.get(series_uid, {})
+                if use_ark:
+                    # Submit prediction to ARK server.
+                    files = [('dicom', open(file_path, 'rb')) for file_path in dicom_files]
+                    r = requests.post(f"{ark_host}/dicom/files", files=files)
+                    _ = [f[1].close() for f in files]
+                    if r.status_code != 200:
+                        print(f"An error occurred while processing {series_uid}: {r.text}")
+                        prediction["error"] = r.text
+                        continue
+                    else:
+                        r_json = r.json()
+                        prediction = r_json["data"]
+                        prediction["runtime"] = r_json["runtime"]
+                        prediction["predictions"] = prediction["predictions"][0]
+                else:
+                    serie = Serie(dicom_files)
+                    start_time = time.time()
+                    pred_result = model.predict([serie], return_attentions=False)
+                    runtime = "{:.2f}s".format(time.time() - start_time)
+
+                    scores = pred_result.scores
+                    prediction = {"predictions": scores, "runtime": runtime}
+
+                if delete_downloaded_files:
+                    shutil.rmtree(series_dir)
+
+            except Exception as e:
+                print(f"Failed to process {series_uid}: {e}")
+                continue
+
+            cur_dict = {
+                "series_uid": series_uid,
+                "num_slices": len(dicom_files),
+            }
+
+            if prediction:
+                cur_dict.update(prediction)
+
+            all_results[series_uid] = cur_dict
+
+            # Save as we go
+            with open(cur_pred_results, 'w') as f:
+                json.dump(all_results, f, indent=2)
+
+    def test_compare_predict_scores(self):
+        if not os.environ.get("SYBIL_TEST_RUN_REGRESSION", "false").lower() == "true":
+            import pytest
+            pytest.skip(f"Skipping long-running test '{type(self)}'.")
+
+        default_baseline_preds_path = os.path.join(PROJECT_DIR, "tests",
+                                                   "nlst_predictions", "nlst_predictions_ark_v1.4.0.json")
+        baseline_preds_path = os.environ.get("SYBIL_TEST_BASELINE_PATH", default_baseline_preds_path)
+
+        version = sybil.__version__
+        default_new_preds_path = os.path.join(PROJECT_DIR, "tests",
+                                                "nlst_predictions", f"nlst_predictions_sybil_ensemble_v{version}.json")
+        new_preds_path = os.environ.get("SYBIL_TEST_COMPARE_PATH", default_new_preds_path)
+        assert new_preds_path, "SYBIL_TEST_COMPARE_PATH must be set to the path of the new predictions file."
+        pred_key = "predictions"
+        num_compared = 0
+
+        with open(baseline_preds_path, 'r') as f:
+            baseline_preds = json.load(f)
+        with open(new_preds_path, 'r') as f:
+            new_preds = json.load(f)
+
+        ignore_keys = {"__metadata__"}
+        overlap_keys = set(baseline_preds.keys()).intersection(new_preds.keys()) - ignore_keys
+        union_keys = set(baseline_preds.keys()).union(new_preds.keys()) - ignore_keys
+        print(f"{len(overlap_keys)} / {len(union_keys)} patients in common between the two prediction files.")
+
+        all_mismatches = []
+        for series_uid_key in overlap_keys:
+            if series_uid_key in ignore_keys:
+                continue
+
+            if pred_key not in baseline_preds[series_uid_key]:
+                print(f"{pred_key} not found in baseline predictions for {series_uid_key}")
+                assert pred_key not in new_preds[series_uid_key]
+                continue
+
+            compare_keys = ["predictions"]
+            for comp_key in compare_keys:
+                cur_baseline_preds = baseline_preds[series_uid_key][comp_key][0]
+                cur_new_preds = new_preds[series_uid_key][comp_key][0]
+                for ind in range(len(cur_baseline_preds)):
+                    year = ind + 1
+                    baseline_score = cur_baseline_preds[ind]
+                    new_score = cur_new_preds[ind]
+                    does_match = math.isclose(baseline_score, new_score, abs_tol=1e-6)
+                    if not does_match:
+                        err_str = f"Scores for {series_uid_key}, {comp_key} differ for year {year}.\n"
+                        err_str += f"Diff: {new_score - baseline_score:0.4e}. Baseline: {baseline_score:0.4e}, New: {new_score:0.4e}"
+                        all_mismatches.append(err_str)
+
+            num_compared += 1
+
+        assert num_compared > 0
+        print(f"Compared {num_compared} patients.")
+
+        if all_mismatches:
+            print(f"Found {len(all_mismatches)} mismatches.")
+            for err in all_mismatches:
+                print(err)
+
+        num_mismatches = len(all_mismatches)
+        assert num_mismatches == 0, f"Found {num_mismatches} mismatches between the two prediction files."
+
+    def test_calibrator(self):
+        """
+        Test the calibrator against previous known calibrations.
+        """
+
+        default_baseline_path = os.path.join(PROJECT_DIR, "tests", "sybil_ensemble_v1.4.0_calibrations.json")
+        baseline_path = os.environ.get("SYBIL_TEST_BASELINE_PATH", default_baseline_path)
+
+        if not os.path.exists(baseline_path) and baseline_path == default_baseline_path:
+            os.makedirs(os.path.dirname(default_baseline_path), exist_ok=True)
+            reference_calibrations_url = "https://www.dropbox.com/scl/fi/2fx6ukmozia7y3u8mie97/sybil_ensemble_v1.4.0_calibrations.json?rlkey=tquids9qo4mkkuf315nqdq0o7&dl=1"
+            download_file(reference_calibrations_url, default_baseline_path)
+
+        default_cal_dict_path = os.path.expanduser("~/.sybil/sybil_ensemble_simple_calibrator.json")
+        compare_calibrator_path = os.environ.get("SYBIL_TEST_COMPARE_PATH", default_cal_dict_path)
+        compare_calibrator_path = os.path.expanduser(compare_calibrator_path)
+        if not os.path.exists(compare_calibrator_path) and compare_calibrator_path == default_cal_dict_path:
+            test_model = Sybil("sybil_ensemble")
+
+        raw_calibrator_dict = json.load(open(compare_calibrator_path, "r"))
+        new_calibrator_dict = {}
+        for key, val in raw_calibrator_dict.items():
+            new_calibrator_dict[key] = sybil.models.calibrator.SimpleClassifierGroup.from_json(val)
+
+        baseline_preds = json.load(open(baseline_path, "r"))
+        test_probs = np.array(baseline_preds["x"]).reshape(-1, 1)
+        year_keys = [key for key in baseline_preds.keys() if key.startswith("Year")]
+        for year_key in year_keys:
+            baseline_scores = np.array(baseline_preds[year_key]).flatten()
+            new_cal = new_calibrator_dict[year_key]
+            new_scores = new_cal.predict_proba(test_probs).flatten()
+
+            self.assertTrue(np.allclose(baseline_scores, new_scores, atol=1e-10), f"Calibration mismatch for {year_key}")
+
+
+if __name__ == "__main__":
+    unittest.main()