diff --git a/include/darknet.h b/include/darknet.h index 0a1451e3..fb62fc2a 100644 --- a/include/darknet.h +++ b/include/darknet.h @@ -518,7 +518,7 @@ struct layer { // network.h typedef enum { - CONSTANT, STEP, EXP, POLY, STEPS, SIG, RANDOM + CONSTANT, STEP, EXP, POLY, STEPS, SIG, RANDOM, SGDR } learning_rate_policy; // network.h @@ -534,6 +534,9 @@ typedef struct network { learning_rate_policy policy; float learning_rate; + float learning_rate_min; + float learning_rate_max; + int batches_per_cycle; float momentum; float decay; float gamma; diff --git a/src/network.c b/src/network.c index cfc747cb..32e5e96e 100644 --- a/src/network.c +++ b/src/network.c @@ -117,6 +117,12 @@ float get_current_rate(network net) return net.learning_rate * pow(rand_uniform(0,1), net.power); case SIG: return net.learning_rate * (1./(1.+exp(net.gamma*(batch_num - net.step)))); + case SGDR: + rate = net.learning_rate_min + + 0.5*(net.learning_rate_max-net.learning_rate_min) + * (1. + cos( (float) (batch_num % net.batches_per_cycle)*3.14159265 / net.batches_per_cycle)); + + return rate; default: fprintf(stderr, "Policy is weird!\n"); return net.learning_rate; diff --git a/src/parser.c b/src/parser.c index 11cc4bfc..8e3af25f 100644 --- a/src/parser.c +++ b/src/parser.c @@ -629,6 +629,7 @@ learning_rate_policy get_policy(char *s) if (strcmp(s, "exp")==0) return EXP; if (strcmp(s, "sigmoid")==0) return SIG; if (strcmp(s, "steps")==0) return STEPS; + if (strcmp(s, "sgdr")==0) return SGDR; fprintf(stderr, "Couldn't find policy %s, going with constant\n", s); return CONSTANT; } @@ -637,6 +638,9 @@ void parse_net_options(list *options, network *net) { net->batch = option_find_int(options, "batch",1); net->learning_rate = option_find_float(options, "learning_rate", .001); + net->learning_rate_min = option_find_float_quiet(options, "learning_rate_min", .00001); + net->learning_rate_max = option_find_float_quiet(options, "learning_rate_max", .001); + net->batches_per_cycle = option_find_int(options, "sgdr_cycle", 500); net->momentum = option_find_float(options, "momentum", .9); net->decay = option_find_float(options, "decay", .0001); int subdivs = option_find_int(options, "subdivisions",1);