GridGenerator update

This commit is contained in:
Baek JeongHun 2019-04-15 09:41:24 +00:00
parent 193686a267
commit cb254fa63a
1 changed files with 6 additions and 8 deletions

View File

@ -89,12 +89,8 @@ class GridGenerator(nn.Module):
self.F = F self.F = F
self.C = self._build_C(self.F) # F x 2 self.C = self._build_C(self.F) # F x 2
self.P = self._build_P(self.I_r_width, self.I_r_height) self.P = self._build_P(self.I_r_width, self.I_r_height)
self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.F, self.C)).float() # F+3 x F+3 self.register_buffer("inv_delta_C", torch.tensor(self._build_inv_delta_C(self.F, self.C)).float()) # F+3 x F+3
self.P_hat = torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float() # n x F+3 self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float()) # n x F+3
# batch_size+1 : +1 sometimes happen with specific batch_ratio and multi-GPU setting.
self.register_buffer("batch_inv_delta_C", self.inv_delta_C.repeat(batch_size + 1, 1, 1))
self.register_buffer("batch_P_hat", self.P_hat.repeat(batch_size + 1, 1, 1))
def _build_C(self, F): def _build_C(self, F):
""" Return coordinates of fiducial points in I_r; C """ """ Return coordinates of fiducial points in I_r; C """
@ -150,8 +146,10 @@ class GridGenerator(nn.Module):
def build_P_prime(self, batch_C_prime): def build_P_prime(self, batch_C_prime):
""" Generate Grid from batch_C_prime [batch_size x F x 2] """ """ Generate Grid from batch_C_prime [batch_size x F x 2] """
batch_size = batch_C_prime.size(0) batch_size = batch_C_prime.size(0)
batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1)
batch_P_hat = self.P_hat.repeat(batch_size, 1, 1)
batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros( batch_C_prime_with_zeros = torch.cat((batch_C_prime, torch.zeros(
batch_size, 3, 2).float().cuda()), dim=1) # batch_size x F+3 x 2 batch_size, 3, 2).float().cuda()), dim=1) # batch_size x F+3 x 2
batch_T = torch.bmm(self.batch_inv_delta_C[:batch_size, :, :], batch_C_prime_with_zeros) # batch_size x F+3 x 2 batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) # batch_size x F+3 x 2
batch_P_prime = torch.bmm(self.batch_P_hat[:batch_size, :, :], batch_T) # batch_size x n x 2 batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2
return batch_P_prime # batch_size x n x 2 return batch_P_prime # batch_size x n x 2