GridGenerator update
This commit is contained in:
parent
193686a267
commit
cb254fa63a
|
@ -89,12 +89,8 @@ class GridGenerator(nn.Module):
|
||||||
self.F = F
|
self.F = F
|
||||||
self.C = self._build_C(self.F) # F x 2
|
self.C = self._build_C(self.F) # F x 2
|
||||||
self.P = self._build_P(self.I_r_width, self.I_r_height)
|
self.P = self._build_P(self.I_r_width, self.I_r_height)
|
||||||
self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.F, self.C)).float() # F+3 x F+3
|
self.register_buffer("inv_delta_C", torch.tensor(self._build_inv_delta_C(self.F, self.C)).float()) # F+3 x F+3
|
||||||
self.P_hat = torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float() # n x F+3
|
self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float()) # n x F+3
|
||||||
|
|
||||||
# batch_size+1 : +1 sometimes happen with specific batch_ratio and multi-GPU setting.
|
|
||||||
self.register_buffer("batch_inv_delta_C", self.inv_delta_C.repeat(batch_size + 1, 1, 1))
|
|
||||||
self.register_buffer("batch_P_hat", self.P_hat.repeat(batch_size + 1, 1, 1))
|
|
||||||
|
|
||||||
def _build_C(self, F):
|
def _build_C(self, F):
|
||||||
""" Return coordinates of fiducial points in I_r; C """
|
""" Return coordinates of fiducial points in I_r; C """
|
||||||
|
@ -150,8 +146,10 @@ class GridGenerator(nn.Module):
|
||||||
def build_P_prime(self, batch_C_prime):
|
def build_P_prime(self, batch_C_prime):
|
||||||
""" Generate Grid from batch_C_prime [batch_size x F x 2] """
|
""" Generate Grid from batch_C_prime [batch_size x F x 2] """
|
||||||
batch_size = batch_C_prime.size(0)
|
batch_size = batch_C_prime.size(0)
|
||||||
|
batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1)
|
||||||
|
batch_P_hat = self.P_hat.repeat(batch_size, 1, 1)
|
||||||
batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros(
|
batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros(
|
||||||
batch_size, 3, 2).float().cuda()), dim=1) # batch_size x F+3 x 2
|
batch_size, 3, 2).float().cuda()), dim=1) # batch_size x F+3 x 2
|
||||||
batch_T = torch.bmm(self.batch_inv_delta_C[:batch_size, :, :], batch_C_prime_with_zeros) # batch_size x F+3 x 2
|
batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) # batch_size x F+3 x 2
|
||||||
batch_P_prime = torch.bmm(self.batch_P_hat[:batch_size, :, :], batch_T) # batch_size x n x 2
|
batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2
|
||||||
return batch_P_prime # batch_size x n x 2
|
return batch_P_prime # batch_size x n x 2
|
||||||
|
|
Loading…
Reference in New Issue