mirror of https://github.com/AlexeyAB/darknet.git
add SGDR policy
This commit is contained in:
parent
8bcba6c105
commit
d64693eb77
|
@ -518,7 +518,7 @@ struct layer {
|
||||||
|
|
||||||
// network.h
|
// network.h
|
||||||
typedef enum {
|
typedef enum {
|
||||||
CONSTANT, STEP, EXP, POLY, STEPS, SIG, RANDOM
|
CONSTANT, STEP, EXP, POLY, STEPS, SIG, RANDOM, SGDR
|
||||||
} learning_rate_policy;
|
} learning_rate_policy;
|
||||||
|
|
||||||
// network.h
|
// network.h
|
||||||
|
@ -534,6 +534,9 @@ typedef struct network {
|
||||||
learning_rate_policy policy;
|
learning_rate_policy policy;
|
||||||
|
|
||||||
float learning_rate;
|
float learning_rate;
|
||||||
|
float learning_rate_min;
|
||||||
|
float learning_rate_max;
|
||||||
|
int batches_per_cycle;
|
||||||
float momentum;
|
float momentum;
|
||||||
float decay;
|
float decay;
|
||||||
float gamma;
|
float gamma;
|
||||||
|
|
|
@ -117,6 +117,12 @@ float get_current_rate(network net)
|
||||||
return net.learning_rate * pow(rand_uniform(0,1), net.power);
|
return net.learning_rate * pow(rand_uniform(0,1), net.power);
|
||||||
case SIG:
|
case SIG:
|
||||||
return net.learning_rate * (1./(1.+exp(net.gamma*(batch_num - net.step))));
|
return net.learning_rate * (1./(1.+exp(net.gamma*(batch_num - net.step))));
|
||||||
|
case SGDR:
|
||||||
|
rate = net.learning_rate_min +
|
||||||
|
0.5*(net.learning_rate_max-net.learning_rate_min)
|
||||||
|
* (1. + cos( (float) (batch_num % net.batches_per_cycle)*3.14159265 / net.batches_per_cycle));
|
||||||
|
|
||||||
|
return rate;
|
||||||
default:
|
default:
|
||||||
fprintf(stderr, "Policy is weird!\n");
|
fprintf(stderr, "Policy is weird!\n");
|
||||||
return net.learning_rate;
|
return net.learning_rate;
|
||||||
|
|
|
@ -629,6 +629,7 @@ learning_rate_policy get_policy(char *s)
|
||||||
if (strcmp(s, "exp")==0) return EXP;
|
if (strcmp(s, "exp")==0) return EXP;
|
||||||
if (strcmp(s, "sigmoid")==0) return SIG;
|
if (strcmp(s, "sigmoid")==0) return SIG;
|
||||||
if (strcmp(s, "steps")==0) return STEPS;
|
if (strcmp(s, "steps")==0) return STEPS;
|
||||||
|
if (strcmp(s, "sgdr")==0) return SGDR;
|
||||||
fprintf(stderr, "Couldn't find policy %s, going with constant\n", s);
|
fprintf(stderr, "Couldn't find policy %s, going with constant\n", s);
|
||||||
return CONSTANT;
|
return CONSTANT;
|
||||||
}
|
}
|
||||||
|
@ -637,6 +638,9 @@ void parse_net_options(list *options, network *net)
|
||||||
{
|
{
|
||||||
net->batch = option_find_int(options, "batch",1);
|
net->batch = option_find_int(options, "batch",1);
|
||||||
net->learning_rate = option_find_float(options, "learning_rate", .001);
|
net->learning_rate = option_find_float(options, "learning_rate", .001);
|
||||||
|
net->learning_rate_min = option_find_float_quiet(options, "learning_rate_min", .00001);
|
||||||
|
net->learning_rate_max = option_find_float_quiet(options, "learning_rate_max", .001);
|
||||||
|
net->batches_per_cycle = option_find_int(options, "sgdr_cycle", 500);
|
||||||
net->momentum = option_find_float(options, "momentum", .9);
|
net->momentum = option_find_float(options, "momentum", .9);
|
||||||
net->decay = option_find_float(options, "decay", .0001);
|
net->decay = option_find_float(options, "decay", .0001);
|
||||||
int subdivs = option_find_int(options, "subdivisions",1);
|
int subdivs = option_find_int(options, "subdivisions",1);
|
||||||
|
|
Loading…
Reference in New Issue