🚀 We're looking for ML Engineers and Medical Reviewers! Join the OpenPHR Mission →
Tutorial

Zero-Shot Medical Image Segmentation with MedSAM

Difficulty: Beginner Time: 15 min read

Introduction

In this cookbook, we will cover how to deploy MedSAM, a foundational model for medical image segmentation, to perform zero-shot segmentation on radiology images without any fine-tuning.

Prerequisites

  • Python 3.10+
  • PyTorch and Segment Anything (SAM) installed
  • A sample MRI or CT scan (DICOM or PNG)

Step 1: Install Dependencies

pip install git+https://github.com/facebookresearch/segment-anything.git
pip install torch torchvision opencv-python matplotlib

Step 2: Load the MedSAM Model

First, download the MedSAM weights and initialize the model.

from segment_anything import sam_model_registry, SamPredictor
import torch

medsam_checkpoint = "medsam_vit_b.pth"
device = "cuda" if torch.cuda.is_available() else "cpu"

sam = sam_model_registry["vit_b"](checkpoint=medsam_checkpoint)
sam.to(device=device)

predictor = SamPredictor(sam)

Step 3: Perform Segmentation

We can now provide a bounding box prompt to MedSAM to extract a specific organ or tumor.

import cv2
import matplotlib.pyplot as plt

# Load your medical image
image = cv2.imread('sample_mri.png')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)

predictor.set_image(image)

# Define a bounding box prompt [x_min, y_min, x_max, y_max]
input_box = np.array([100, 150, 300, 400])

masks, _, _ = predictor.predict(
    point_coords=None,
    point_labels=None,
    box=input_box[None, :],
    multimask_output=False,
)

# Visualize the mask
plt.imshow(image)
plt.imshow(masks[0], alpha=0.5)
plt.show()

Conclusion

MedSAM offers an incredibly powerful zero-shot segmentation capability. By simply providing a bounding box, you can extract precise structures from medical imaging data instantly.