View file src/colab/midas.py - Download
# -*- coding: utf-8 -*-
"""midas.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1IrH6Qh_A4aUG6GqSPWzFeBU69VlYkacv
https://pytorch.org/hub/intelisl_midas_v2/
Model Description
MiDaS computes relative inverse depth from a single image. The repository provides multiple models that cover different use cases ranging from a small, high-speed model to a very large model that provide the highest accuracy. The models have been trained on 10 distinct datasets using multi-objective optimization to ensure high quality on a wide range of inputs.
Dependencies
MiDaS depends on timm. Install with
"""
!pip install timm
"""Example Usage
Download an image from the PyTorch homepage
"""
import cv2
import torch
import urllib.request
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
url = "https://github.com/pytorch/hub/raw/master/images/dog.jpg"
filename = "dog.jpg"
urllib.request.urlretrieve(url, filename)
img = mpimg.imread("dog.jpg")
plt.imshow(img)
plt.axis('off')
plt.show()
"""Load a model (see https://github.com/intel-isl/MiDaS/#Accuracy for an overview)"""
model_type = "DPT_Large" # MiDaS v3 - Large (highest accuracy, slowest inference speed)
#model_type = "DPT_Hybrid" # MiDaS v3 - Hybrid (medium accuracy, medium inference speed)
#model_type = "MiDaS_small" # MiDaS v2.1 - Small (lowest accuracy, highest inference speed)
midas = torch.hub.load("intel-isl/MiDaS", model_type)
"""Move model to GPU if available"""
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
midas.to(device)
midas.eval()
"""Load transforms to resize and normalize the image for large or small model"""
midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
if model_type == "DPT_Large" or model_type == "DPT_Hybrid":
transform = midas_transforms.dpt_transform
else:
transform = midas_transforms.small_transform
"""Load image and apply transforms"""
img = cv2.imread(filename)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
input_batch = transform(img).to(device)
"""Predict and resize to original resolution"""
with torch.no_grad():
prediction = midas(input_batch)
prediction = torch.nn.functional.interpolate(
prediction.unsqueeze(1),
size=img.shape[:2],
mode="bicubic",
align_corners=False,
).squeeze()
output = prediction.cpu().numpy()
"""Show result"""
plt.imshow(img)
plt.axis('off')
plt.show()
plt.imshow(output)
plt.axis('off')
plt.show()
"""References
Towards Robust Monocular Depth Estimation: Mixing Datasets for Zero-shot Cross-dataset Transfer https://arxiv.org/abs/1907.01341
Vision Transformers for Dense Prediction https://arxiv.org/abs/2103.13413
"""