Update transformation.py

This commit is contained in:
Baek JeongHun 2019-09-27 18:06:16 +09:00 committed by GitHub
parent fc62d2232d
commit 235095fc9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 0 deletions

View File

@ -90,8 +90,12 @@ class GridGenerator(nn.Module):
self.F = F self.F = F
self.C = self._build_C(self.F) # F x 2 self.C = self._build_C(self.F) # F x 2
self.P = self._build_P(self.I_r_width, self.I_r_height) self.P = self._build_P(self.I_r_width, self.I_r_height)
## for multi-gpu, you need register buffer
self.register_buffer("inv_delta_C", torch.tensor(self._build_inv_delta_C(self.F, self.C)).float()) # F+3 x F+3 self.register_buffer("inv_delta_C", torch.tensor(self._build_inv_delta_C(self.F, self.C)).float()) # F+3 x F+3
self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float()) # n x F+3 self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float()) # n x F+3
## for fine-tuning with different image width, you may use below instead of self.register_buffer
#self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.F, self.C)).float().cuda() # F+3 x F+3
#self.P_hat = torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float().cuda() # n x F+3
def _build_C(self, F): def _build_C(self, F):
""" Return coordinates of fiducial points in I_r; C """ """ Return coordinates of fiducial points in I_r; C """