-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathba_example.py
More file actions
109 lines (85 loc) · 3.08 KB
/
ba_example.py
File metadata and controls
109 lines (85 loc) · 3.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from time import perf_counter
import pypose as pp
import torch
import torch.nn as nn
from pypose.autograd.function import psjac
from datapipes.bal_loader import get_problem
from bae.optim import LM
from bae.utils.pysolvers import PCG
TARGET_DATASET = "trafalgar"
TARGET_PROBLEM = "problem-257-65132-pre"
# other options:
# TARGET_DATASET = "ladybug"
# TARGET_PROBLEM = "problem-1723-156502-pre"
# TARGET_DATASET = "dubrovnik"
# TARGET_PROBLEM = "problem-356-226730-pre"
DEVICE = "cuda"
OPTIMIZE_INTRINSICS = True
NUM_CAMERA_PARAMS = 10 if OPTIMIZE_INTRINSICS else 7
@psjac
def project(points, camera_params):
projection = pp.SE3(camera_params[..., :7]).Act(points)
projection = -projection[..., :2] / projection[..., [2]]
f = camera_params[..., [-3]]
k1 = camera_params[..., [-2]]
k2 = camera_params[..., [-1]]
n = torch.sum(projection**2, axis=-1, keepdim=True)
r = 1 + k1 * n + k2 * n**2
return projection * r * f
class Residual(nn.Module):
def __init__(self, camera_params, points):
super().__init__()
self.pose = pp.Parameter(camera_params, sjac=True)
self.points = pp.Parameter(points, sjac=True)
self.pose.trim_SE3_grad = True
def forward(self, observes, cidx, pidx):
points_proj = project(self.points[pidx], self.pose[cidx])
return points_proj - observes
def least_square_error(camera_params, points, cidx, pidx, observes):
model = Residual(camera_params, points)
loss = model(observes, cidx, pidx)
return torch.sum(loss**2, dim=-1).mean()
def main():
dataset = get_problem(TARGET_PROBLEM, TARGET_DATASET)
print(f"Fetched {TARGET_PROBLEM} from {TARGET_DATASET}")
dataset = {
key: value.to(DEVICE)
for key, value in dataset.items()
if isinstance(value, torch.Tensor)
}
input = {
"observes": dataset["points_2d"],
"cidx": dataset["camera_index_of_observations"],
"pidx": dataset["point_index_of_observations"],
}
model = Residual(
dataset["camera_params"][:, :NUM_CAMERA_PARAMS].clone(),
dataset["points_3d"].clone(),
).to(DEVICE)
strategy = pp.optim.strategy.TrustRegion(up=2.0, down=0.5**4)
solver = PCG(tol=1e-4, maxiter=250)
optimizer = LM(model, strategy=strategy, solver=solver, reject=30)
print('Loss:', least_square_error(
model.pose,
model.points,
dataset["camera_index_of_observations"],
dataset["point_index_of_observations"],
dataset["points_2d"],
).item())
print("Initial loss", optimizer.model.loss(input, None).item())
start = perf_counter()
for idx in range(20):
loss = optimizer.step(input)
print("Iteration", idx, "loss", loss.item(), "time", perf_counter() - start)
torch.cuda.synchronize()
end = perf_counter()
print("Time", end - start)
print('Ending loss:', least_square_error(
model.pose,
model.points,
dataset["camera_index_of_observations"],
dataset["point_index_of_observations"],
dataset["points_2d"],
).item())
if __name__ == "__main__":
main()