-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmidas_depth_extraction.py
117 lines (93 loc) · 3.18 KB
/
midas_depth_extraction.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import torch
import cv2
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
import argparse
import time
# Load the model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print('Using device:', device)
#Additional Info when using cuda
if device.type == 'cuda':
print(torch.cuda.get_device_name(0))
print('Memory Usage:')
print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
print('Cached: ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')
model = torch.hub.load("intel-isl/MiDaS", "MiDaS", pretrained=True)
if device.type == 'cuda':
model = model.to(device)
model.eval()
# Transform the input
transform = transforms.Compose(
[
transforms.Resize((320, 480)), # Resize to dimensions divisible by 32
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
]
)
def depth_extract_image(path):
depth_map = depth_extract(path)
# Display the depth map
cv2.imshow('Depth Map', depth_map)
cv2.waitKey(0)
cv2.destroyAllWindows()
def depth_extract(frame):
# Load the image
frame = Image.fromarray(frame)
img = transform(frame).unsqueeze(0)
img = img.to(device)
# Run the model
with torch.no_grad():
prediction = model(img)
# Convert the prediction to a numpy array
depth_map = prediction[0].cpu().numpy()
# Normalize the depth map
depth_map = cv2.normalize(depth_map, None, 255, 0, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)
return depth_map
def depth_extract_video(path):
video = cv2.VideoCapture(path)
# Check if the video is opened successfully.
if not video.isOpened():
print("Error opening video file")
return
while True:
# Read a frame from the video.
ret, frame = video.read()
frame = cv2.resize(frame,(480, 320))
# Check if the frame is read successfully.
if not ret:
break
cv2.imshow('Depth', depth_extract(frame))
cv2.imshow('Frame', frame)
if cv2.waitKey(25) & 0xFF == ord('q'):
break
print('Allocated:', round(torch.cuda.memory_allocated(0)/1024**3,1), 'GB')
print('Cached: ', round(torch.cuda.memory_reserved(0)/1024**3,1), 'GB')
# Release the video file and close any open windows.
video.release()
cv2.destroyAllWindows()
def main(video,image,path):
if image:
depth_extract_image(path)
elif video:
depth_extract_video(path)
if __name__ == "__main__":
# Parse command line arguments
parser = argparse.ArgumentParser(description="Contact Sheet Generator")
parser.add_argument(
"--path", type=str, default=None, help="Absolute path to the image/video"
)
parser.add_argument(
"--image",
help="image depth extraction",
default=None,
action="store_true")
parser.add_argument(
"--video",
help="video frame depth extraction",
default=None,
action="store_true")
args = parser.parse_args()
# Run the main function with the provided arguments
main(args.video, args.image, args.path)