Merge pull request #2651 from jveitchmichaelis/sgdr

Implement stochastic gradient descent with warm restarts
This commit is contained in:
Alexey 2019-03-30 14:25:45 +03:00 committed by GitHub
commit 6231b748c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 13 additions and 1 deletions

View File

@ -518,7 +518,7 @@ struct layer {
// network.h // network.h
typedef enum { typedef enum {
CONSTANT, STEP, EXP, POLY, STEPS, SIG, RANDOM CONSTANT, STEP, EXP, POLY, STEPS, SIG, RANDOM, SGDR
} learning_rate_policy; } learning_rate_policy;
// network.h // network.h
@ -534,6 +534,9 @@ typedef struct network {
learning_rate_policy policy; learning_rate_policy policy;
float learning_rate; float learning_rate;
float learning_rate_min;
float learning_rate_max;
int batches_per_cycle;
float momentum; float momentum;
float decay; float decay;
float gamma; float gamma;

View File

@ -117,6 +117,12 @@ float get_current_rate(network net)
return net.learning_rate * pow(rand_uniform(0,1), net.power); return net.learning_rate * pow(rand_uniform(0,1), net.power);
case SIG: case SIG:
return net.learning_rate * (1./(1.+exp(net.gamma*(batch_num - net.step)))); return net.learning_rate * (1./(1.+exp(net.gamma*(batch_num - net.step))));
case SGDR:
rate = net.learning_rate_min +
0.5*(net.learning_rate-net.learning_rate_min)
* (1. + cos( (float) (batch_num % net.batches_per_cycle)*3.14159265 / net.batches_per_cycle));
return rate;
default: default:
fprintf(stderr, "Policy is weird!\n"); fprintf(stderr, "Policy is weird!\n");
return net.learning_rate; return net.learning_rate;

View File

@ -629,6 +629,7 @@ learning_rate_policy get_policy(char *s)
if (strcmp(s, "exp")==0) return EXP; if (strcmp(s, "exp")==0) return EXP;
if (strcmp(s, "sigmoid")==0) return SIG; if (strcmp(s, "sigmoid")==0) return SIG;
if (strcmp(s, "steps")==0) return STEPS; if (strcmp(s, "steps")==0) return STEPS;
if (strcmp(s, "sgdr")==0) return SGDR;
fprintf(stderr, "Couldn't find policy %s, going with constant\n", s); fprintf(stderr, "Couldn't find policy %s, going with constant\n", s);
return CONSTANT; return CONSTANT;
} }
@ -637,6 +638,8 @@ void parse_net_options(list *options, network *net)
{ {
net->batch = option_find_int(options, "batch",1); net->batch = option_find_int(options, "batch",1);
net->learning_rate = option_find_float(options, "learning_rate", .001); net->learning_rate = option_find_float(options, "learning_rate", .001);
net->learning_rate_min = option_find_float_quiet(options, "learning_rate_min", .00001);
net->batches_per_cycle = option_find_int_quiet(options, "sgdr_cycle", 500);
net->momentum = option_find_float(options, "momentum", .9); net->momentum = option_find_float(options, "momentum", .9);
net->decay = option_find_float(options, "decay", .0001); net->decay = option_find_float(options, "decay", .0001);
int subdivs = option_find_int(options, "subdivisions",1); int subdivs = option_find_int(options, "subdivisions",1);