diff --git a/modules/transformation.py b/modules/transformation.py index 7c4dd5a..c43f173 100755 --- a/modules/transformation.py +++ b/modules/transformation.py @@ -4,7 +4,6 @@ import torch.nn as nn import torch.nn.functional as F device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - class TPS_SpatialTransformerNetwork(nn.Module): """ Rectification Network of RARE, namely TPS based STN """ @@ -30,7 +29,11 @@ class TPS_SpatialTransformerNetwork(nn.Module): batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2 build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) # batch_size x n (= I_r_width x I_r_height) x 2 build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2]) - batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border') + + if torch.__version__ > "1.2.0": + batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corner=True) + else: + batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border') return batch_I_r