forked from yashbhalgat/HashNeRF-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
127 lines (94 loc) · 4.25 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import json
import numpy as np
import pdb
import torch
from ray_utils import get_rays, get_ray_directions, get_ndc_rays
BOX_OFFSETS = torch.tensor([[[i,j,k] for i in [0, 1] for j in [0, 1] for k in [0, 1]]],
device='cuda')
def hash(coords, log2_hashmap_size):
'''
coords: 3D coordinates. B x 3
log2T: logarithm of T w.r.t 2
'''
x, y, z = coords[..., 0], coords[..., 1], coords[..., 2]
return ((1<<log2_hashmap_size)-1) & (x*73856093 ^ y*19349663 ^ z*83492791)
def get_bbox3d_for_blenderobj(camera_transforms, H, W, near=2.0, far=6.0):
camera_angle_x = float(camera_transforms['camera_angle_x'])
focal = 0.5*W/np.tan(0.5 * camera_angle_x)
# ray directions in camera coordinates
directions = get_ray_directions(H, W, focal)
min_bound = [100, 100, 100]
max_bound = [-100, -100, -100]
points = []
for frame in camera_transforms["frames"]:
c2w = torch.FloatTensor(frame["transform_matrix"])
rays_o, rays_d = get_rays(directions, c2w)
def find_min_max(pt):
for i in range(3):
if(min_bound[i] > pt[i]):
min_bound[i] = pt[i]
if(max_bound[i] < pt[i]):
max_bound[i] = pt[i]
return
for i in [0, W-1, H*W-W, H*W-1]:
min_point = rays_o[i] + near*rays_d[i]
max_point = rays_o[i] + far*rays_d[i]
points += [min_point, max_point]
find_min_max(min_point)
find_min_max(max_point)
return (torch.tensor(min_bound)-torch.tensor([1.0,1.0,1.0]), torch.tensor(max_bound)+torch.tensor([1.0,1.0,1.0]))
def get_bbox3d_for_llff(poses, hwf, near=0.0, far=1.0):
H, W, focal = hwf
H, W = int(H), int(W)
# ray directions in camera coordinates
directions = get_ray_directions(H, W, focal)
min_bound = [100, 100, 100]
max_bound = [-100, -100, -100]
points = []
poses = torch.FloatTensor(poses)
for pose in poses:
rays_o, rays_d = get_rays(directions, pose)
rays_o, rays_d = get_ndc_rays(H, W, focal, 1.0, rays_o, rays_d)
def find_min_max(pt):
for i in range(3):
if(min_bound[i] > pt[i]):
min_bound[i] = pt[i]
if(max_bound[i] < pt[i]):
max_bound[i] = pt[i]
return
for i in [0, W-1, H*W-W, H*W-1]:
min_point = rays_o[i] + near*rays_d[i]
max_point = rays_o[i] + far*rays_d[i]
points += [min_point, max_point]
find_min_max(min_point)
find_min_max(max_point)
return (torch.tensor(min_bound)-torch.tensor([0.1,0.1,0.0001]), torch.tensor(max_bound)+torch.tensor([0.1,0.1,0.0001]))
def get_voxel_vertices(xyz, bounding_box, resolution, log2_hashmap_size):
'''
xyz: 3D coordinates of samples. B x 3
bounding_box: min and max x,y,z coordinates of object bbox
resolution: number of voxels per axis
'''
box_min, box_max = bounding_box
if not torch.all(xyz <= box_max) or not torch.all(xyz >= box_min):
# print("ALERT: some points are outside bounding box. Clipping them!")
pdb.set_trace()
xyz = torch.clamp(xyz, min=box_min, max=box_max)
grid_size = (box_max-box_min)/resolution
bottom_left_idx = torch.floor((xyz-box_min)/grid_size).int()
voxel_min_vertex = bottom_left_idx*grid_size + box_min
voxel_max_vertex = voxel_min_vertex + torch.tensor([1.0,1.0,1.0])*grid_size
# hashed_voxel_indices = [] # B x 8 ... 000,001,010,011,100,101,110,111
# for i in [0, 1]:
# for j in [0, 1]:
# for k in [0, 1]:
# vertex_idx = bottom_left_idx + torch.tensor([i,j,k])
# # vertex = bottom_left + torch.tensor([i,j,k])*grid_size
# hashed_voxel_indices.append(hash(vertex_idx, log2_hashmap_size))
voxel_indices = bottom_left_idx.unsqueeze(1) + BOX_OFFSETS
hashed_voxel_indices = hash(voxel_indices, log2_hashmap_size)
return voxel_min_vertex, voxel_max_vertex, hashed_voxel_indices
if __name__=="__main__":
with open("data/nerf_synthetic/chair/transforms_train.json", "r") as f:
camera_transforms = json.load(f)
bounding_box = get_bbox3d_for_blenderobj(camera_transforms, 800, 800)