Handling breaking change in pytorch grid_sample >=1.4.0 (https://pytorch.org/docs/stable/nn.functional.html?highlight=grid_sample#torch.nn.functional.grid_sample)
This commit is contained in:
parent
24749cd1a3
commit
c529003649
|
@ -4,7 +4,6 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||||
|
|
||||||
|
|
||||||
class TPS_SpatialTransformerNetwork(nn.Module):
|
class TPS_SpatialTransformerNetwork(nn.Module):
|
||||||
""" Rectification Network of RARE, namely TPS based STN """
|
""" Rectification Network of RARE, namely TPS based STN """
|
||||||
|
|
||||||
|
@ -30,7 +29,11 @@ class TPS_SpatialTransformerNetwork(nn.Module):
|
||||||
batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2
|
batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2
|
||||||
build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) # batch_size x n (= I_r_width x I_r_height) x 2
|
build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) # batch_size x n (= I_r_width x I_r_height) x 2
|
||||||
build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2])
|
build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2])
|
||||||
batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border')
|
|
||||||
|
if torch.__version__ > "1.2.0":
|
||||||
|
batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corner=True)
|
||||||
|
else:
|
||||||
|
batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border')
|
||||||
|
|
||||||
return batch_I_r
|
return batch_I_r
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue