This commit is contained in:
akarazniewicz 2020-01-16 14:58:59 +01:00
parent 24749cd1a3
commit c529003649
1 changed files with 5 additions and 2 deletions

View File

@ -4,7 +4,6 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class TPS_SpatialTransformerNetwork(nn.Module): class TPS_SpatialTransformerNetwork(nn.Module):
""" Rectification Network of RARE, namely TPS based STN """ """ Rectification Network of RARE, namely TPS based STN """
@ -30,6 +29,10 @@ class TPS_SpatialTransformerNetwork(nn.Module):
batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2 batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2
build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) # batch_size x n (= I_r_width x I_r_height) x 2 build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) # batch_size x n (= I_r_width x I_r_height) x 2
build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2]) build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2])
if torch.__version__ > "1.2.0":
batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corner=True)
else:
batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border') batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border')
return batch_I_r return batch_I_r