make the function signature consistent

This commit is contained in:
任广辉 2019-04-24 16:01:07 +08:00
parent d56bfd94cf
commit c9fa20b72f
4 changed files with 10 additions and 10 deletions

View File

@ -21,11 +21,11 @@ _model_factory = {
'hourglass': get_large_hourglass_net,
}
def create_model(arch, head, head_conv):
def create_model(arch, heads, head_conv):
num_layers = int(arch[arch.find('_') + 1:]) if '_' in arch else 0
arch = arch[:arch.find('_')] if '_' in arch else arch
get_model = _model_factory[arch]
model = get_model(num_layers, head, head_conv)
model = get_model(num_layers=num_layers, heads=heads, head_conv=head_conv)
return model
def load_model(model, model_path, optimizer=None, resume=False,

View File

@ -532,7 +532,7 @@ def fill_fc_weights(layers):
class DLASeg(nn.Module):
def __init__(self, base_name, heads,
pretrained=True, down_ratio=4, add_conv=256):
pretrained=True, down_ratio=4, head_conv=256):
super(DLASeg, self).__init__()
assert down_ratio in [2, 4, 8, 16]
self.heads = heads
@ -551,12 +551,12 @@ class DLASeg(nn.Module):
for head in self.heads:
classes = self.heads[head]
if add_conv > 0:
if head_conv > 0:
fc = nn.Sequential(
nn.Conv2d(channels[self.first_level], add_conv,
nn.Conv2d(channels[self.first_level], head_conv,
kernel_size=3, padding=1, bias=True),
nn.ReLU(inplace=True),
nn.Conv2d(add_conv, classes,
nn.Conv2d(head_conv, classes,
kernel_size=1, stride=1,
padding=0, bias=True))
if 'hm' in head:
@ -639,9 +639,9 @@ def dla169up(classes, pretrained_base=None, **kwargs):
return model
'''
def get_pose_net(heads, down_ratio=4, add_conv=256):
def get_pose_net(heads, down_ratio=4, head_conv=256):
model = DLASeg('dla34', heads,
pretrained=True,
down_ratio=down_ratio,
add_conv=add_conv)
head_conv=head_conv)
return model

View File

@ -295,6 +295,6 @@ class HourglassNet(exkp):
kp_layer=residual, cnv_dim=256
)
def get_large_hourglass_net(_, heads, __):
def get_large_hourglass_net(num_layers, heads, head_conv):
model = HourglassNet(heads, 2)
return model

View File

@ -482,7 +482,7 @@ class DLASeg(nn.Module):
return [z]
def get_pose_net(num_layers, heads, version, down_ratio=4, head_conv=256):
def get_pose_net(num_layers, heads, head_conv=256, down_ratio=4):
model = DLASeg('dla{}'.format(num_layers), heads,
pretrained=True,
down_ratio=down_ratio,