diff --git a/src/lib/models/model.py b/src/lib/models/model.py index 5fcc227..d11ad8f 100644 --- a/src/lib/models/model.py +++ b/src/lib/models/model.py @@ -21,11 +21,11 @@ _model_factory = { 'hourglass': get_large_hourglass_net, } -def create_model(arch, head, head_conv): +def create_model(arch, heads, head_conv): num_layers = int(arch[arch.find('_') + 1:]) if '_' in arch else 0 arch = arch[:arch.find('_')] if '_' in arch else arch get_model = _model_factory[arch] - model = get_model(num_layers, head, head_conv) + model = get_model(num_layers=num_layers, heads=heads, head_conv=head_conv) return model def load_model(model, model_path, optimizer=None, resume=False, diff --git a/src/lib/models/networks/dlav0.py b/src/lib/models/networks/dlav0.py index 0f8743b..e7ca539 100644 --- a/src/lib/models/networks/dlav0.py +++ b/src/lib/models/networks/dlav0.py @@ -532,7 +532,7 @@ def fill_fc_weights(layers): class DLASeg(nn.Module): def __init__(self, base_name, heads, - pretrained=True, down_ratio=4, add_conv=256): + pretrained=True, down_ratio=4, head_conv=256): super(DLASeg, self).__init__() assert down_ratio in [2, 4, 8, 16] self.heads = heads @@ -551,12 +551,12 @@ class DLASeg(nn.Module): for head in self.heads: classes = self.heads[head] - if add_conv > 0: + if head_conv > 0: fc = nn.Sequential( - nn.Conv2d(channels[self.first_level], add_conv, + nn.Conv2d(channels[self.first_level], head_conv, kernel_size=3, padding=1, bias=True), nn.ReLU(inplace=True), - nn.Conv2d(add_conv, classes, + nn.Conv2d(head_conv, classes, kernel_size=1, stride=1, padding=0, bias=True)) if 'hm' in head: @@ -639,9 +639,9 @@ def dla169up(classes, pretrained_base=None, **kwargs): return model ''' -def get_pose_net(heads, down_ratio=4, add_conv=256): +def get_pose_net(heads, down_ratio=4, head_conv=256): model = DLASeg('dla34', heads, pretrained=True, down_ratio=down_ratio, - add_conv=add_conv) + head_conv=head_conv) return model \ No newline at end of file diff --git a/src/lib/models/networks/large_hourglass.py b/src/lib/models/networks/large_hourglass.py index 0a13789..b40ba72 100644 --- a/src/lib/models/networks/large_hourglass.py +++ b/src/lib/models/networks/large_hourglass.py @@ -295,6 +295,6 @@ class HourglassNet(exkp): kp_layer=residual, cnv_dim=256 ) -def get_large_hourglass_net(_, heads, __): +def get_large_hourglass_net(num_layers, heads, head_conv): model = HourglassNet(heads, 2) return model diff --git a/src/lib/models/networks/pose_dla_dcn.py b/src/lib/models/networks/pose_dla_dcn.py index b2bff1a..7cb6869 100644 --- a/src/lib/models/networks/pose_dla_dcn.py +++ b/src/lib/models/networks/pose_dla_dcn.py @@ -482,7 +482,7 @@ class DLASeg(nn.Module): return [z] -def get_pose_net(num_layers, heads, version, down_ratio=4, head_conv=256): +def get_pose_net(num_layers, heads, head_conv=256, down_ratio=4): model = DLASeg('dla{}'.format(num_layers), heads, pretrained=True, down_ratio=down_ratio,