From 81d381701ffdd3e32b9d57dc7dd0dab3a8752328 Mon Sep 17 00:00:00 2001 From: Baek JeongHun Date: Fri, 12 Apr 2019 11:57:00 +0000 Subject: [PATCH] fix minor --- modules/transformation.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/transformation.py b/modules/transformation.py index 8c2c054..6cc0e54 100755 --- a/modules/transformation.py +++ b/modules/transformation.py @@ -92,8 +92,9 @@ class GridGenerator(nn.Module): self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.F, self.C)).float() # F+3 x F+3 self.P_hat = torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float() # n x F+3 - self.register_buffer("batch_inv_delta_C", self.inv_delta_C.repeat(batch_size, 1, 1)) - self.register_buffer("batch_P_hat", self.P_hat.repeat(batch_size, 1, 1)) + # batch_size+1 : +1 sometimes happen with specific batch_ratio and multi-GPU setting. + self.register_buffer("batch_inv_delta_C", self.inv_delta_C.repeat(batch_size + 1, 1, 1)) + self.register_buffer("batch_P_hat", self.P_hat.repeat(batch_size + 1, 1, 1)) def _build_C(self, F): """ Return coordinates of fiducial points in I_r; C """