This commit is contained in:
Joseph Redmon 2016-10-26 08:35:44 -07:00
parent 91f95c715b
commit 352ae7e65b
11 changed files with 114 additions and 26 deletions

View File

@ -78,6 +78,7 @@ void mult_add_into_gpu(int num, float *a, float *b, float *c);
void reorg_ongpu(float *x, int w, int h, int c, int batch, int stride, int forward, float *out);
void softmax_gpu(float *input, int n, int offset, int groups, float temp, float *output);
void adam_gpu(int n, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t);
#endif
#endif

View File

@ -140,6 +140,21 @@ void backward_bias_gpu(float *bias_updates, float *delta, int batch, int n, int
}
__global__ void adam_kernel(int N, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t)
{
int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if (index >= N) return;
x[index] = x[index] - (rate * sqrt(1.-pow(B2, t)) / (1.-pow(B1, t)) * m[index] / (sqrt(v[index]) + eps));
//if(index == 0) printf("%f %f %f %f\n", m[index], v[index], (rate * sqrt(1.-pow(B2, t)) / (1.-pow(B1, t)) * m[index] / (sqrt(v[index]) + eps)));
}
extern "C" void adam_gpu(int n, float *x, float *m, float *v, float B1, float B2, float rate, float eps, int t)
{
adam_kernel<<<cuda_gridsize(n), BLOCK>>>(n, x, m, v, B1, B2, rate, eps, t);
check_error(cudaPeekAtLastError());
}
__global__ void normalize_kernel(int N, float *x, float *mean, float *variance, int batch, int filters, int spatial)
{
int index = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;

View File

@ -41,7 +41,7 @@ list *read_data_cfg(char *filename)
return options;
}
void hierarchy_predictions(float *predictions, int n, tree *hier)
void hierarchy_predictions(float *predictions, int n, tree *hier, int only_leaves)
{
int j;
for(j = 0; j < n; ++j){
@ -50,8 +50,10 @@ void hierarchy_predictions(float *predictions, int n, tree *hier)
predictions[j] *= predictions[parent];
}
}
for(j = 0; j < n; ++j){
if(!hier->leaf[j]) predictions[j] = 0;
if(only_leaves){
for(j = 0; j < n; ++j){
if(!hier->leaf[j]) predictions[j] = 0;
}
}
}
@ -410,7 +412,7 @@ void validate_classifier_10(char *datacfg, char *filename, char *weightfile)
float *pred = calloc(classes, sizeof(float));
for(j = 0; j < 10; ++j){
float *p = network_predict(net, images[j].data);
if(net.hierarchy) hierarchy_predictions(p, net.outputs, net.hierarchy);
if(net.hierarchy) hierarchy_predictions(p, net.outputs, net.hierarchy, 1);
axpy_cpu(classes, 1, p, 1, pred, 1);
free_image(images[j]);
}
@ -471,7 +473,7 @@ void validate_classifier_full(char *datacfg, char *filename, char *weightfile)
//show_image(crop, "cropped");
//cvWaitKey(0);
float *pred = network_predict(net, resized.data);
if(net.hierarchy) hierarchy_predictions(pred, net.outputs, net.hierarchy);
if(net.hierarchy) hierarchy_predictions(pred, net.outputs, net.hierarchy, 1);
free_image(im);
free_image(resized);
@ -486,6 +488,26 @@ void validate_classifier_full(char *datacfg, char *filename, char *weightfile)
}
}
void change_leaves(tree *t, char *leaf_list)
{
list *llist = get_paths(leaf_list);
char **leaves = (char **)list_to_array(llist);
int n = llist->size;
int i,j;
int found = 0;
for(i = 0; i < t->n; ++i){
t->leaf[i] = 0;
for(j = 0; j < n; ++j){
if (0==strcmp(t->name[i], leaves[j])){
t->leaf[i] = 1;
++found;
break;
}
}
}
fprintf(stderr, "Found %d leaves.\n", found);
}
void validate_classifier_single(char *datacfg, char *filename, char *weightfile)
{
@ -500,6 +522,8 @@ void validate_classifier_single(char *datacfg, char *filename, char *weightfile)
list *options = read_data_cfg(datacfg);
char *label_list = option_find_str(options, "labels", "data/labels.list");
char *leaf_list = option_find_str(options, "leaves", 0);
if(leaf_list) change_leaves(net.hierarchy, leaf_list);
char *valid_list = option_find_str(options, "valid", "data/train.list");
int classes = option_find_int(options, "classes", 2);
int topk = option_find_int(options, "top", 1);
@ -531,7 +555,7 @@ void validate_classifier_single(char *datacfg, char *filename, char *weightfile)
//show_image(crop, "cropped");
//cvWaitKey(0);
float *pred = network_predict(net, crop.data);
if(net.hierarchy) hierarchy_predictions(pred, net.outputs, net.hierarchy);
if(net.hierarchy) hierarchy_predictions(pred, net.outputs, net.hierarchy, 1);
if(resized.data != im.data) free_image(resized);
free_image(im);
@ -592,7 +616,7 @@ void validate_classifier_multi(char *datacfg, char *filename, char *weightfile)
image r = resize_min(im, scales[j]);
resize_network(&net, r.w, r.h);
float *p = network_predict(net, r.data);
if(net.hierarchy) hierarchy_predictions(p, net.outputs, net.hierarchy);
if(net.hierarchy) hierarchy_predictions(p, net.outputs, net.hierarchy, 1);
axpy_cpu(classes, 1, p, 1, pred, 1);
flip_image(r);
p = network_predict(net, r.data);
@ -692,7 +716,7 @@ void try_classifier(char *datacfg, char *cfgfile, char *weightfile, char *filena
}
}
void predict_classifier(char *datacfg, char *cfgfile, char *weightfile, char *filename)
void predict_classifier(char *datacfg, char *cfgfile, char *weightfile, char *filename, int top)
{
network net = parse_network_cfg(cfgfile);
if(weightfile){
@ -705,7 +729,7 @@ void predict_classifier(char *datacfg, char *cfgfile, char *weightfile, char *fi
char *name_list = option_find_str(options, "names", 0);
if(!name_list) name_list = option_find_str(options, "labels", "data/labels.list");
int top = option_find_int(options, "top", 1);
if(top == 0) top = option_find_int(options, "top", 1);
int i = 0;
char **names = get_labels(name_list);
@ -732,7 +756,7 @@ void predict_classifier(char *datacfg, char *cfgfile, char *weightfile, char *fi
float *X = r.data;
time=clock();
float *predictions = network_predict(net, X);
if(net.hierarchy) hierarchy_predictions(predictions, net.outputs, net.hierarchy);
if(net.hierarchy) hierarchy_predictions(predictions, net.outputs, net.hierarchy, 0);
top_k(predictions, net.outputs, top, indexes);
printf("%s: Predicted in %f seconds.\n", input, sec(clock()-time));
for(i = 0; i < top; ++i){
@ -1113,7 +1137,7 @@ void demo_classifier(char *datacfg, char *cfgfile, char *weightfile, int cam_ind
show_image(in, "Classifier");
float *predictions = network_predict(net, in_s.data);
if(net.hierarchy) hierarchy_predictions(predictions, net.outputs, net.hierarchy);
if(net.hierarchy) hierarchy_predictions(predictions, net.outputs, net.hierarchy, 1);
top_predictions(net, top, indexes);
printf("\033[2J");
@ -1165,6 +1189,7 @@ void run_classifier(int argc, char **argv)
}
int cam_index = find_int_arg(argc, argv, "-c", 0);
int top = find_int_arg(argc, argv, "-t", 0);
int clear = find_arg(argc, argv, "-clear");
char *data = argv[3];
char *cfg = argv[4];
@ -1172,7 +1197,7 @@ void run_classifier(int argc, char **argv)
char *filename = (argc > 6) ? argv[6]: 0;
char *layer_s = (argc > 7) ? argv[7]: 0;
int layer = layer_s ? atoi(layer_s) : -1;
if(0==strcmp(argv[2], "predict")) predict_classifier(data, cfg, weights, filename);
if(0==strcmp(argv[2], "predict")) predict_classifier(data, cfg, weights, filename, top);
else if(0==strcmp(argv[2], "try")) try_classifier(data, cfg, weights, filename, atoi(layer_s));
else if(0==strcmp(argv[2], "train")) train_classifier(data, cfg, weights, clear);
else if(0==strcmp(argv[2], "trainm")) train_classifier_multi(data, cfg, weights, gpus, ngpus, clear);

View File

@ -233,7 +233,6 @@ void push_convolutional_layer(convolutional_layer layer)
void update_convolutional_layer_gpu(convolutional_layer layer, int batch, float learning_rate, float momentum, float decay)
{
int size = layer.size*layer.size*layer.c*layer.n;
axpy_ongpu(layer.n, learning_rate/batch, layer.bias_updates_gpu, 1, layer.biases_gpu, 1);
scal_ongpu(layer.n, momentum, layer.bias_updates_gpu, 1);
@ -242,9 +241,23 @@ void update_convolutional_layer_gpu(convolutional_layer layer, int batch, float
scal_ongpu(layer.n, momentum, layer.scale_updates_gpu, 1);
}
axpy_ongpu(size, -decay*batch, layer.weights_gpu, 1, layer.weight_updates_gpu, 1);
axpy_ongpu(size, learning_rate/batch, layer.weight_updates_gpu, 1, layer.weights_gpu, 1);
scal_ongpu(size, momentum, layer.weight_updates_gpu, 1);
if(layer.adam){
scal_ongpu(size, layer.B1, layer.m_gpu, 1);
scal_ongpu(size, layer.B2, layer.v_gpu, 1);
axpy_ongpu(size, -decay*batch, layer.weights_gpu, 1, layer.weight_updates_gpu, 1);
axpy_ongpu(size, -(1-layer.B1), layer.weight_updates_gpu, 1, layer.m_gpu, 1);
mul_ongpu(size, layer.weight_updates_gpu, 1, layer.weight_updates_gpu, 1);
axpy_ongpu(size, (1-layer.B2), layer.weight_updates_gpu, 1, layer.v_gpu, 1);
adam_gpu(size, layer.weights_gpu, layer.m_gpu, layer.v_gpu, layer.B1, layer.B2, learning_rate/batch, layer.eps, layer.t+1);
fill_ongpu(size, 0, layer.weight_updates_gpu, 1);
}else{
axpy_ongpu(size, -decay*batch, layer.weights_gpu, 1, layer.weight_updates_gpu, 1);
axpy_ongpu(size, learning_rate/batch, layer.weight_updates_gpu, 1, layer.weights_gpu, 1);
scal_ongpu(size, momentum, layer.weight_updates_gpu, 1);
}
}

View File

@ -171,7 +171,7 @@ void cudnn_convolutional_setup(layer *l)
#endif
#endif
convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor)
convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam)
{
int i;
convolutional_layer l = {0};
@ -242,6 +242,12 @@ convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int
l.update_gpu = update_convolutional_layer_gpu;
if(gpu_index >= 0){
if (adam) {
l.adam = 1;
l.m_gpu = cuda_make_array(l.weight_updates, c*n*size*size);
l.v_gpu = cuda_make_array(l.weight_updates, c*n*size*size);
}
l.weights_gpu = cuda_make_array(l.weights, c*n*size*size);
l.weight_updates_gpu = cuda_make_array(l.weight_updates, c*n*size*size);
@ -312,7 +318,7 @@ void denormalize_convolutional_layer(convolutional_layer l)
void test_convolutional_layer()
{
convolutional_layer l = make_convolutional_layer(1, 5, 5, 3, 2, 5, 2, 1, LEAKY, 1, 0, 0);
convolutional_layer l = make_convolutional_layer(1, 5, 5, 3, 2, 5, 2, 1, LEAKY, 1, 0, 0, 0);
l.batch_normalize = 1;
float data[] = {1,1,1,1,1,
1,1,1,1,1,

View File

@ -24,7 +24,7 @@ void cudnn_convolutional_setup(layer *l);
#endif
#endif
convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int pad, ACTIVATION activation, int batch_normalization, int binary, int xnor);
convolutional_layer make_convolutional_layer(int batch, int h, int w, int c, int n, int size, int stride, int padding, ACTIVATION activation, int batch_normalize, int binary, int xnor, int adam);
void denormalize_convolutional_layer(convolutional_layer l);
void resize_convolutional_layer(convolutional_layer *layer, int w, int h);
void forward_convolutional_layer(const convolutional_layer layer, network_state state);

View File

@ -48,17 +48,17 @@ layer make_crnn_layer(int batch, int h, int w, int c, int hidden_filters, int ou
l.input_layer = malloc(sizeof(layer));
fprintf(stderr, "\t\t");
*(l.input_layer) = make_convolutional_layer(batch*steps, h, w, c, hidden_filters, 3, 1, 1, activation, batch_normalize, 0, 0);
*(l.input_layer) = make_convolutional_layer(batch*steps, h, w, c, hidden_filters, 3, 1, 1, activation, batch_normalize, 0, 0, 0);
l.input_layer->batch = batch;
l.self_layer = malloc(sizeof(layer));
fprintf(stderr, "\t\t");
*(l.self_layer) = make_convolutional_layer(batch*steps, h, w, hidden_filters, hidden_filters, 3, 1, 1, activation, batch_normalize, 0, 0);
*(l.self_layer) = make_convolutional_layer(batch*steps, h, w, hidden_filters, hidden_filters, 3, 1, 1, activation, batch_normalize, 0, 0, 0);
l.self_layer->batch = batch;
l.output_layer = malloc(sizeof(layer));
fprintf(stderr, "\t\t");
*(l.output_layer) = make_convolutional_layer(batch*steps, h, w, hidden_filters, output_filters, 3, 1, 1, activation, batch_normalize, 0, 0);
*(l.output_layer) = make_convolutional_layer(batch*steps, h, w, hidden_filters, output_filters, 3, 1, 1, activation, batch_normalize, 0, 0, 0);
l.output_layer->batch = batch;
l.output = l.output_layer->output;

View File

@ -94,6 +94,14 @@ struct layer{
int reorg;
int log;
int adam;
float B1;
float B2;
float eps;
float *m_gpu;
float *v_gpu;
int t;
tree *softmax_tree;
float alpha;

View File

@ -37,6 +37,11 @@ typedef struct network{
int num_steps;
int burn_in;
int adam;
float B1;
float B2;
float eps;
int inputs;
int h, w, c;
int max_crop;

View File

@ -83,6 +83,7 @@ void update_network_gpu(network net)
float rate = get_current_rate(net);
for(i = 0; i < net.n; ++i){
layer l = net.layers[i];
l.t = get_current_batch(net);
if(l.update_gpu){
l.update_gpu(l, update_batch, rate, net.momentum, net.decay);
}
@ -134,7 +135,6 @@ void *train_thread(void *ptr)
free(ptr);
cuda_set_device(args.net.gpu_index);
*args.err = train_network(args.net, args.d);
printf("%d\n", args.net.gpu_index);
return 0;
}
@ -177,6 +177,7 @@ void update_layer(layer l, network net)
{
int update_batch = net.batch*net.subdivisions;
float rate = get_current_rate(net);
l.t = get_current_batch(net);
if(l.update_gpu){
l.update_gpu(l, update_batch, rate, net.momentum, net.decay);
}

View File

@ -111,6 +111,7 @@ typedef struct size_params{
int c;
int index;
int time_steps;
network net;
} size_params;
local_layer parse_local(list *options, size_params params)
@ -156,9 +157,14 @@ convolutional_layer parse_convolutional(list *options, size_params params)
int binary = option_find_int_quiet(options, "binary", 0);
int xnor = option_find_int_quiet(options, "xnor", 0);
convolutional_layer layer = make_convolutional_layer(batch,h,w,c,n,size,stride,padding,activation, batch_normalize, binary, xnor);
convolutional_layer layer = make_convolutional_layer(batch,h,w,c,n,size,stride,padding,activation, batch_normalize, binary, xnor, params.net.adam);
layer.flipped = option_find_int_quiet(options, "flipped", 0);
layer.dot = option_find_float_quiet(options, "dot", 0);
if(params.net.adam){
layer.B1 = params.net.B1;
layer.B2 = params.net.B2;
layer.eps = params.net.eps;
}
return layer;
}
@ -482,6 +488,13 @@ void parse_net_options(list *options, network *net)
net->batch *= net->time_steps;
net->subdivisions = subdivs;
net->adam = option_find_int_quiet(options, "adam", 0);
if(net->adam){
net->B1 = option_find_float(options, "B1", .9);
net->B2 = option_find_float(options, "B2", .999);
net->eps = option_find_float(options, "eps", .000001);
}
net->h = option_find_int_quiet(options, "height",0);
net->w = option_find_int_quiet(options, "width",0);
net->c = option_find_int_quiet(options, "channels",0);
@ -564,6 +577,7 @@ network parse_network_cfg(char *filename)
params.inputs = net.inputs;
params.batch = net.batch;
params.time_steps = net.time_steps;
params.net = net;
size_t workspace_size = 0;
n = n->next;
@ -779,7 +793,7 @@ void save_weights_upto(network net, char *filename, int cutoff)
{
#ifdef GPU
if(net.gpu_index >= 0){
cuda_set_device(net.gpu_index);
cuda_set_device(net.gpu_index);
}
#endif
fprintf(stderr, "Saving weights to %s\n", filename);
@ -947,7 +961,7 @@ void load_weights_upto(network *net, char *filename, int cutoff)
{
#ifdef GPU
if(net->gpu_index >= 0){
cuda_set_device(net->gpu_index);
cuda_set_device(net->gpu_index);
}
#endif
fprintf(stderr, "Loading weights from %s...", filename);