make the function signature consistent
This commit is contained in:
parent
d56bfd94cf
commit
c9fa20b72f
|
@ -21,11 +21,11 @@ _model_factory = {
|
||||||
'hourglass': get_large_hourglass_net,
|
'hourglass': get_large_hourglass_net,
|
||||||
}
|
}
|
||||||
|
|
||||||
def create_model(arch, head, head_conv):
|
def create_model(arch, heads, head_conv):
|
||||||
num_layers = int(arch[arch.find('_') + 1:]) if '_' in arch else 0
|
num_layers = int(arch[arch.find('_') + 1:]) if '_' in arch else 0
|
||||||
arch = arch[:arch.find('_')] if '_' in arch else arch
|
arch = arch[:arch.find('_')] if '_' in arch else arch
|
||||||
get_model = _model_factory[arch]
|
get_model = _model_factory[arch]
|
||||||
model = get_model(num_layers, head, head_conv)
|
model = get_model(num_layers=num_layers, heads=heads, head_conv=head_conv)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def load_model(model, model_path, optimizer=None, resume=False,
|
def load_model(model, model_path, optimizer=None, resume=False,
|
||||||
|
|
|
@ -532,7 +532,7 @@ def fill_fc_weights(layers):
|
||||||
|
|
||||||
class DLASeg(nn.Module):
|
class DLASeg(nn.Module):
|
||||||
def __init__(self, base_name, heads,
|
def __init__(self, base_name, heads,
|
||||||
pretrained=True, down_ratio=4, add_conv=256):
|
pretrained=True, down_ratio=4, head_conv=256):
|
||||||
super(DLASeg, self).__init__()
|
super(DLASeg, self).__init__()
|
||||||
assert down_ratio in [2, 4, 8, 16]
|
assert down_ratio in [2, 4, 8, 16]
|
||||||
self.heads = heads
|
self.heads = heads
|
||||||
|
@ -551,12 +551,12 @@ class DLASeg(nn.Module):
|
||||||
|
|
||||||
for head in self.heads:
|
for head in self.heads:
|
||||||
classes = self.heads[head]
|
classes = self.heads[head]
|
||||||
if add_conv > 0:
|
if head_conv > 0:
|
||||||
fc = nn.Sequential(
|
fc = nn.Sequential(
|
||||||
nn.Conv2d(channels[self.first_level], add_conv,
|
nn.Conv2d(channels[self.first_level], head_conv,
|
||||||
kernel_size=3, padding=1, bias=True),
|
kernel_size=3, padding=1, bias=True),
|
||||||
nn.ReLU(inplace=True),
|
nn.ReLU(inplace=True),
|
||||||
nn.Conv2d(add_conv, classes,
|
nn.Conv2d(head_conv, classes,
|
||||||
kernel_size=1, stride=1,
|
kernel_size=1, stride=1,
|
||||||
padding=0, bias=True))
|
padding=0, bias=True))
|
||||||
if 'hm' in head:
|
if 'hm' in head:
|
||||||
|
@ -639,9 +639,9 @@ def dla169up(classes, pretrained_base=None, **kwargs):
|
||||||
return model
|
return model
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def get_pose_net(heads, down_ratio=4, add_conv=256):
|
def get_pose_net(heads, down_ratio=4, head_conv=256):
|
||||||
model = DLASeg('dla34', heads,
|
model = DLASeg('dla34', heads,
|
||||||
pretrained=True,
|
pretrained=True,
|
||||||
down_ratio=down_ratio,
|
down_ratio=down_ratio,
|
||||||
add_conv=add_conv)
|
head_conv=head_conv)
|
||||||
return model
|
return model
|
|
@ -295,6 +295,6 @@ class HourglassNet(exkp):
|
||||||
kp_layer=residual, cnv_dim=256
|
kp_layer=residual, cnv_dim=256
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_large_hourglass_net(_, heads, __):
|
def get_large_hourglass_net(num_layers, heads, head_conv):
|
||||||
model = HourglassNet(heads, 2)
|
model = HourglassNet(heads, 2)
|
||||||
return model
|
return model
|
||||||
|
|
|
@ -482,7 +482,7 @@ class DLASeg(nn.Module):
|
||||||
return [z]
|
return [z]
|
||||||
|
|
||||||
|
|
||||||
def get_pose_net(num_layers, heads, version, down_ratio=4, head_conv=256):
|
def get_pose_net(num_layers, heads, head_conv=256, down_ratio=4):
|
||||||
model = DLASeg('dla{}'.format(num_layers), heads,
|
model = DLASeg('dla{}'.format(num_layers), heads,
|
||||||
pretrained=True,
|
pretrained=True,
|
||||||
down_ratio=down_ratio,
|
down_ratio=down_ratio,
|
||||||
|
|
Loading…
Reference in New Issue