Zero-Shot Medical Image Segmentation with MedSAM
Introduction
In this cookbook, we will cover how to deploy MedSAM, a foundational model for medical image segmentation, to perform zero-shot segmentation on radiology images without any fine-tuning.
Prerequisites
- Python 3.10+
- PyTorch and Segment Anything (SAM) installed
- A sample MRI or CT scan (DICOM or PNG)
Step 1: Install Dependencies
pip install git+https://github.com/facebookresearch/segment-anything.git
pip install torch torchvision opencv-python matplotlib
Step 2: Load the MedSAM Model
First, download the MedSAM weights and initialize the model.
from segment_anything import sam_model_registry, SamPredictor
import torch
medsam_checkpoint = "medsam_vit_b.pth"
device = "cuda" if torch.cuda.is_available() else "cpu"
sam = sam_model_registry["vit_b"](checkpoint=medsam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)
Step 3: Perform Segmentation
We can now provide a bounding box prompt to MedSAM to extract a specific organ or tumor.
import cv2
import matplotlib.pyplot as plt
# Load your medical image
image = cv2.imread('sample_mri.png')
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
predictor.set_image(image)
# Define a bounding box prompt [x_min, y_min, x_max, y_max]
input_box = np.array([100, 150, 300, 400])
masks, _, _ = predictor.predict(
point_coords=None,
point_labels=None,
box=input_box[None, :],
multimask_output=False,
)
# Visualize the mask
plt.imshow(image)
plt.imshow(masks[0], alpha=0.5)
plt.show()
Conclusion
MedSAM offers an incredibly powerful zero-shot segmentation capability. By simply providing a bounding box, you can extract precise structures from medical imaging data instantly.