diff --git a/examples/dnn_introduction3_ex.cpp b/examples/dnn_introduction3_ex.cpp index 2f4c0140f..3b5b0ece3 100644 --- a/examples/dnn_introduction3_ex.cpp +++ b/examples/dnn_introduction3_ex.cpp @@ -29,7 +29,7 @@ namespace model using net_type = loss_metric< fc_no_bias<128, avg_pool_everything< - typename resnet::template backbone_50< + typename resnet::def::template backbone_50< input_rgb_image >>>>; @@ -75,7 +75,7 @@ int main() try { // Now, let's define the classic ResNet50 network and load the pretrained model on // ImageNet. - resnet::n50 resnet50; + resnet::train_50 resnet50; std::vector labels; deserialize("resnet50_1000_imagenet_classifier.dnn") >> resnet50 >> labels; diff --git a/examples/resnet.h b/examples/resnet.h index 57e98c4ec..61759f5d5 100644 --- a/examples/resnet.h +++ b/examples/resnet.h @@ -3,99 +3,104 @@ #include -// BATCHNORM must be bn_con or affine layer -template class BATCHNORM> -struct resnet +namespace resnet { - // the resnet basic block, where BN is bn_con or affine - template class BN, int stride, typename SUBNET> - using basicblock = BN>>>>; + using namespace dlib; + // BN is bn_con or affine layer + template class BN> + struct def + { + // the resnet basic block, where BN is bn_con or affine + template + using basicblock = BN>>>>; - // the resnet bottleneck block - template class BN, int stride, typename SUBNET> - using bottleneck = BN>>>>>>>; + // the resnet bottleneck block + template + using bottleneck = BN>>>>>>>; - // the resnet residual - template< - template class, int, typename> class BLOCK, // basicblock or bottleneck - long num_filters, - template class BN, // bn_con or affine - typename SUBNET - > // adds the block to the result of tag1 (the subnet) - using residual = dlib::add_prev1>>; + // the resnet residual, where BLOCK is either basicblock or bottleneck + template class BLOCK, long num_filters, typename SUBNET> + using residual = add_prev1>>; - // a resnet residual that does subsampling on both paths - template< - template class, int, typename> class BLOCK, // basicblock or bottleneck - long num_filters, - template class BN, // bn_con or affine - typename SUBNET - > - using residual_down = dlib::add_prev2>>>>>; + // a resnet residual that does subsampling on both paths + template class BLOCK, long num_filters, typename SUBNET> + using residual_down = add_prev2>>>>>; - // residual block with optional downsampling and custom regularization (bn_con or affine) - template< - template class, int, typename> class, long, templateclass, typename> class RESIDUAL, - template class, int, typename> class BLOCK, - long num_filters, - template class BN, // bn_con or affine - typename SUBNET - > - using residual_block = dlib::relu>; + // residual block with optional downsampling + template< + template class, long, typename> class RESIDUAL, + template class BLOCK, + long num_filters, + typename SUBNET + > + using residual_block = relu>; - template - using resbasicblock_down = residual_block; - template - using resbottleneck_down = residual_block; + template + using resbasicblock_down = residual_block; + template + using resbottleneck_down = residual_block; - // some definitions to allow the use of the repeat layer - template using resbasicblock_512 = residual_block; - template using resbasicblock_256 = residual_block; - template using resbasicblock_128 = residual_block; - template using resbasicblock_64 = residual_block; - template using resbottleneck_512 = residual_block; - template using resbottleneck_256 = residual_block; - template using resbottleneck_128 = residual_block; - template using resbottleneck_64 = residual_block; + // some definitions to allow the use of the repeat layer + template using resbasicblock_512 = residual_block; + template using resbasicblock_256 = residual_block; + template using resbasicblock_128 = residual_block; + template using resbasicblock_64 = residual_block; + template using resbottleneck_512 = residual_block; + template using resbottleneck_256 = residual_block; + template using resbottleneck_128 = residual_block; + template using resbottleneck_64 = residual_block; - // common processing for standard resnet inputs - template class BN, typename INPUT> - using input_processing = dlib::max_pool<3, 3, 2, 2, dlib::relu>>>; + // common processing for standard resnet inputs + template + using input_processing = max_pool<3, 3, 2, 2, relu>>>; - // the resnet backbone with basicblocks - template - using backbone_basicblock = - dlib::repeat>>>>>>>; + // the resnet backbone with basicblocks + template + using backbone_basicblock = + repeat>>>>>>>; - // the resnet backbone with bottlenecks - template - using backbone_bottleneck = - dlib::repeat>>>>>>>; + // the resnet backbone with bottlenecks + template + using backbone_bottleneck = + repeat>>>>>>>; - // the backbones for the classic architectures - template using backbone_18 = backbone_basicblock<1, 1, 1, 2, INPUT>; - template using backbone_34 = backbone_basicblock<2, 5, 3, 3, INPUT>; - template using backbone_50 = backbone_bottleneck<2, 5, 3, 3, INPUT>; - template using backbone_101 = backbone_bottleneck<2, 22, 3, 3, INPUT>; - template using backbone_152 = backbone_bottleneck<2, 35, 7, 3, INPUT>; + // the backbones for the classic architectures + template using backbone_18 = backbone_basicblock<1, 1, 1, 2, INPUT>; + template using backbone_34 = backbone_basicblock<2, 5, 3, 3, INPUT>; + template using backbone_50 = backbone_bottleneck<2, 5, 3, 3, INPUT>; + template using backbone_101 = backbone_bottleneck<2, 22, 3, 3, INPUT>; + template using backbone_152 = backbone_bottleneck<2, 35, 7, 3, INPUT>; - // the typical classifier models - using n18 = dlib::loss_multiclass_log>>>; - using n34 = dlib::loss_multiclass_log>>>; - using n50 = dlib::loss_multiclass_log>>>; - using n101 = dlib::loss_multiclass_log>>>; - using n152 = dlib::loss_multiclass_log>>>; -}; + // the typical classifier models + using n18 = loss_multiclass_log>>>; + using n34 = loss_multiclass_log>>>; + using n50 = loss_multiclass_log>>>; + using n101 = loss_multiclass_log>>>; + using n152 = loss_multiclass_log>>>; + }; + + using train_18 = def::n18; + using train_34 = def::n34; + using train_50 = def::n50; + using train_101 = def::n101; + using train_152 = def::n152; + + using infer_18 = def::n18; + using infer_34 = def::n34; + using infer_50 = def::n50; + using infer_101 = def::n101; + using infer_152 = def::n152; +} #endif // ResNet_H