make the function signature consistent
This commit is contained in:
parent
d56bfd94cf
commit
c9fa20b72f
|
@ -21,11 +21,11 @@ _model_factory = {
|
|||
'hourglass': get_large_hourglass_net,
|
||||
}
|
||||
|
||||
def create_model(arch, head, head_conv):
|
||||
def create_model(arch, heads, head_conv):
|
||||
num_layers = int(arch[arch.find('_') + 1:]) if '_' in arch else 0
|
||||
arch = arch[:arch.find('_')] if '_' in arch else arch
|
||||
get_model = _model_factory[arch]
|
||||
model = get_model(num_layers, head, head_conv)
|
||||
model = get_model(num_layers=num_layers, heads=heads, head_conv=head_conv)
|
||||
return model
|
||||
|
||||
def load_model(model, model_path, optimizer=None, resume=False,
|
||||
|
|
|
@ -532,7 +532,7 @@ def fill_fc_weights(layers):
|
|||
|
||||
class DLASeg(nn.Module):
|
||||
def __init__(self, base_name, heads,
|
||||
pretrained=True, down_ratio=4, add_conv=256):
|
||||
pretrained=True, down_ratio=4, head_conv=256):
|
||||
super(DLASeg, self).__init__()
|
||||
assert down_ratio in [2, 4, 8, 16]
|
||||
self.heads = heads
|
||||
|
@ -551,12 +551,12 @@ class DLASeg(nn.Module):
|
|||
|
||||
for head in self.heads:
|
||||
classes = self.heads[head]
|
||||
if add_conv > 0:
|
||||
if head_conv > 0:
|
||||
fc = nn.Sequential(
|
||||
nn.Conv2d(channels[self.first_level], add_conv,
|
||||
nn.Conv2d(channels[self.first_level], head_conv,
|
||||
kernel_size=3, padding=1, bias=True),
|
||||
nn.ReLU(inplace=True),
|
||||
nn.Conv2d(add_conv, classes,
|
||||
nn.Conv2d(head_conv, classes,
|
||||
kernel_size=1, stride=1,
|
||||
padding=0, bias=True))
|
||||
if 'hm' in head:
|
||||
|
@ -639,9 +639,9 @@ def dla169up(classes, pretrained_base=None, **kwargs):
|
|||
return model
|
||||
'''
|
||||
|
||||
def get_pose_net(heads, down_ratio=4, add_conv=256):
|
||||
def get_pose_net(heads, down_ratio=4, head_conv=256):
|
||||
model = DLASeg('dla34', heads,
|
||||
pretrained=True,
|
||||
down_ratio=down_ratio,
|
||||
add_conv=add_conv)
|
||||
head_conv=head_conv)
|
||||
return model
|
|
@ -295,6 +295,6 @@ class HourglassNet(exkp):
|
|||
kp_layer=residual, cnv_dim=256
|
||||
)
|
||||
|
||||
def get_large_hourglass_net(_, heads, __):
|
||||
def get_large_hourglass_net(num_layers, heads, head_conv):
|
||||
model = HourglassNet(heads, 2)
|
||||
return model
|
||||
|
|
|
@ -482,7 +482,7 @@ class DLASeg(nn.Module):
|
|||
return [z]
|
||||
|
||||
|
||||
def get_pose_net(num_layers, heads, version, down_ratio=4, head_conv=256):
|
||||
def get_pose_net(num_layers, heads, head_conv=256, down_ratio=4):
|
||||
model = DLASeg('dla{}'.format(num_layers), heads,
|
||||
pretrained=True,
|
||||
down_ratio=down_ratio,
|
||||
|
|
Loading…
Reference in New Issue