make the function signature consistent

This commit is contained in:
任广辉 2019-04-24 16:01:07 +08:00
parent d56bfd94cf
commit c9fa20b72f
4 changed files with 10 additions and 10 deletions

View File

@ -21,11 +21,11 @@ _model_factory = {
'hourglass': get_large_hourglass_net, 'hourglass': get_large_hourglass_net,
} }
def create_model(arch, head, head_conv): def create_model(arch, heads, head_conv):
num_layers = int(arch[arch.find('_') + 1:]) if '_' in arch else 0 num_layers = int(arch[arch.find('_') + 1:]) if '_' in arch else 0
arch = arch[:arch.find('_')] if '_' in arch else arch arch = arch[:arch.find('_')] if '_' in arch else arch
get_model = _model_factory[arch] get_model = _model_factory[arch]
model = get_model(num_layers, head, head_conv) model = get_model(num_layers=num_layers, heads=heads, head_conv=head_conv)
return model return model
def load_model(model, model_path, optimizer=None, resume=False, def load_model(model, model_path, optimizer=None, resume=False,

View File

@ -532,7 +532,7 @@ def fill_fc_weights(layers):
class DLASeg(nn.Module): class DLASeg(nn.Module):
def __init__(self, base_name, heads, def __init__(self, base_name, heads,
pretrained=True, down_ratio=4, add_conv=256): pretrained=True, down_ratio=4, head_conv=256):
super(DLASeg, self).__init__() super(DLASeg, self).__init__()
assert down_ratio in [2, 4, 8, 16] assert down_ratio in [2, 4, 8, 16]
self.heads = heads self.heads = heads
@ -551,12 +551,12 @@ class DLASeg(nn.Module):
for head in self.heads: for head in self.heads:
classes = self.heads[head] classes = self.heads[head]
if add_conv > 0: if head_conv > 0:
fc = nn.Sequential( fc = nn.Sequential(
nn.Conv2d(channels[self.first_level], add_conv, nn.Conv2d(channels[self.first_level], head_conv,
kernel_size=3, padding=1, bias=True), kernel_size=3, padding=1, bias=True),
nn.ReLU(inplace=True), nn.ReLU(inplace=True),
nn.Conv2d(add_conv, classes, nn.Conv2d(head_conv, classes,
kernel_size=1, stride=1, kernel_size=1, stride=1,
padding=0, bias=True)) padding=0, bias=True))
if 'hm' in head: if 'hm' in head:
@ -639,9 +639,9 @@ def dla169up(classes, pretrained_base=None, **kwargs):
return model return model
''' '''
def get_pose_net(heads, down_ratio=4, add_conv=256): def get_pose_net(heads, down_ratio=4, head_conv=256):
model = DLASeg('dla34', heads, model = DLASeg('dla34', heads,
pretrained=True, pretrained=True,
down_ratio=down_ratio, down_ratio=down_ratio,
add_conv=add_conv) head_conv=head_conv)
return model return model

View File

@ -295,6 +295,6 @@ class HourglassNet(exkp):
kp_layer=residual, cnv_dim=256 kp_layer=residual, cnv_dim=256
) )
def get_large_hourglass_net(_, heads, __): def get_large_hourglass_net(num_layers, heads, head_conv):
model = HourglassNet(heads, 2) model = HourglassNet(heads, 2)
return model return model

View File

@ -482,7 +482,7 @@ class DLASeg(nn.Module):
return [z] return [z]
def get_pose_net(num_layers, heads, version, down_ratio=4, head_conv=256): def get_pose_net(num_layers, heads, head_conv=256, down_ratio=4):
model = DLASeg('dla{}'.format(num_layers), heads, model = DLASeg('dla{}'.format(num_layers), heads,
pretrained=True, pretrained=True,
down_ratio=down_ratio, down_ratio=down_ratio,