dlib/examples/resnet.h

107 lines
5.4 KiB
C++

#ifndef ResNet_H
#define ResNet_H
#include <dlib/dnn.h>
namespace resnet
{
using namespace dlib;
// BN is bn_con or affine layer
template<template<typename> class BN>
struct def
{
// the resnet basic block, where BN is bn_con or affine
template<long num_filters, int stride, typename SUBNET>
using basicblock = BN<con<num_filters, 3, 3, 1, 1,
relu<BN<con<num_filters, 3, 3, stride, stride, SUBNET>>>>>;
// the resnet bottleneck block
template<long num_filters, int stride, typename SUBNET>
using bottleneck = BN<con<4 * num_filters, 1, 1, 1, 1,
relu<BN<con<num_filters, 3, 3, stride, stride,
relu<BN<con<num_filters, 1, 1, 1, 1, SUBNET>>>>>>>>;
// the resnet residual, where BLOCK is either basicblock or bottleneck
template<template<long, int, typename> class BLOCK, long num_filters, typename SUBNET>
using residual = add_prev1<BLOCK<num_filters, 1, tag1<SUBNET>>>;
// a resnet residual that does subsampling on both paths
template<template<long, int, typename> class BLOCK, long num_filters, typename SUBNET>
using residual_down = add_prev2<avg_pool<2, 2, 2, 2,
skip1<tag2<BLOCK<num_filters, 2,
tag1<SUBNET>>>>>>;
// residual block with optional downsampling
template<
template<template<long, int, typename> class, long, typename> class RESIDUAL,
template<long, int, typename> class BLOCK,
long num_filters,
typename SUBNET
>
using residual_block = relu<RESIDUAL<BLOCK, num_filters, SUBNET>>;
template<long num_filters, typename SUBNET>
using resbasicblock_down = residual_block<residual_down, basicblock, num_filters, SUBNET>;
template<long num_filters, typename SUBNET>
using resbottleneck_down = residual_block<residual_down, bottleneck, num_filters, SUBNET>;
// some definitions to allow the use of the repeat layer
template<typename SUBNET> using resbasicblock_512 = residual_block<residual, basicblock, 512, SUBNET>;
template<typename SUBNET> using resbasicblock_256 = residual_block<residual, basicblock, 256, SUBNET>;
template<typename SUBNET> using resbasicblock_128 = residual_block<residual, basicblock, 128, SUBNET>;
template<typename SUBNET> using resbasicblock_64 = residual_block<residual, basicblock, 64, SUBNET>;
template<typename SUBNET> using resbottleneck_512 = residual_block<residual, bottleneck, 512, SUBNET>;
template<typename SUBNET> using resbottleneck_256 = residual_block<residual, bottleneck, 256, SUBNET>;
template<typename SUBNET> using resbottleneck_128 = residual_block<residual, bottleneck, 128, SUBNET>;
template<typename SUBNET> using resbottleneck_64 = residual_block<residual, bottleneck, 64, SUBNET>;
// common processing for standard resnet inputs
template<typename INPUT>
using input_processing = max_pool<3, 3, 2, 2, relu<BN<con<64, 7, 7, 2, 2, INPUT>>>>;
// the resnet backbone with basicblocks
template<long nb_512, long nb_256, long nb_128, long nb_64, typename INPUT>
using backbone_basicblock =
repeat<nb_512, resbasicblock_512, resbasicblock_down<512,
repeat<nb_256, resbasicblock_256, resbasicblock_down<256,
repeat<nb_128, resbasicblock_128, resbasicblock_down<128,
repeat<nb_64, resbasicblock_64, input_processing<INPUT>>>>>>>>;
// the resnet backbone with bottlenecks
template<long nb_512, long nb_256, long nb_128, long nb_64, typename INPUT>
using backbone_bottleneck =
repeat<nb_512, resbottleneck_512, resbottleneck_down<512,
repeat<nb_256, resbottleneck_256, resbottleneck_down<256,
repeat<nb_128, resbottleneck_128, resbottleneck_down<128,
repeat<nb_64, resbottleneck_64, input_processing<INPUT>>>>>>>>;
// the backbones for the classic architectures
template<typename INPUT> using backbone_18 = backbone_basicblock<1, 1, 1, 2, INPUT>;
template<typename INPUT> using backbone_34 = backbone_basicblock<2, 5, 3, 3, INPUT>;
template<typename INPUT> using backbone_50 = backbone_bottleneck<2, 5, 3, 3, INPUT>;
template<typename INPUT> using backbone_101 = backbone_bottleneck<2, 22, 3, 3, INPUT>;
template<typename INPUT> using backbone_152 = backbone_bottleneck<2, 35, 7, 3, INPUT>;
// the typical classifier models
using n18 = loss_multiclass_log<fc<1000, avg_pool_everything<backbone_18<input_rgb_image>>>>;
using n34 = loss_multiclass_log<fc<1000, avg_pool_everything<backbone_34<input_rgb_image>>>>;
using n50 = loss_multiclass_log<fc<1000, avg_pool_everything<backbone_50<input_rgb_image>>>>;
using n101 = loss_multiclass_log<fc<1000, avg_pool_everything<backbone_101<input_rgb_image>>>>;
using n152 = loss_multiclass_log<fc<1000, avg_pool_everything<backbone_152<input_rgb_image>>>>;
};
using train_18 = def<bn_con>::n18;
using train_34 = def<bn_con>::n34;
using train_50 = def<bn_con>::n50;
using train_101 = def<bn_con>::n101;
using train_152 = def<bn_con>::n152;
using infer_18 = def<affine>::n18;
using infer_34 = def<affine>::n34;
using infer_50 = def<affine>::n50;
using infer_101 = def<affine>::n101;
using infer_152 = def<affine>::n152;
}
#endif // ResNet_H