Need to fix line reads

This commit is contained in:
Joseph Redmon 2014-12-28 09:42:35 -08:00
parent 4ab366a805
commit f26da0ad5c
6 changed files with 59 additions and 22 deletions

View File

@ -84,11 +84,15 @@ void train_detection_net()
list *plist = get_paths("/home/pjreddie/data/imagenet/horse.txt");
char **paths = (char **)list_to_array(plist);
printf("%d\n", plist->size);
data train, buffer;
pthread_t load_thread = load_data_detection_thread(imgs, paths, plist->size, 256, 256, 7, 7, 256, &buffer);
clock_t time;
while(1){
i += 1;
time=clock();
data train = load_data_detection_jitter_random(imgs, paths, plist->size, 256, 256, 7, 7, 256);
pthread_join(load_thread, 0);
train = buffer;
load_thread = load_data_detection_thread(imgs, paths, plist->size, 256, 256, 7, 7, 256, &buffer);
//data train = load_data_detection_random(imgs, paths, plist->size, 224, 224, 7, 7, 256);
/*
@ -102,7 +106,7 @@ void train_detection_net()
float loss = train_network(net, train);
avg_loss = avg_loss*.9 + loss*.1;
printf("%d: %f, %f avg, %lf seconds, %d images\n", i, loss, avg_loss, sec(clock()-time), i*imgs*net.batch);
if(i%10==0){
if(i%100==0){
char buff[256];
sprintf(buff, "/home/pjreddie/imagenet_backup/detnet_%d.cfg", i);
save_network(net, buff);
@ -155,10 +159,10 @@ void train_imagenet(char *cfgfile)
//network net = parse_network_cfg("/home/pjreddie/imagenet_backup/alexnet_1270.cfg");
srand(time(0));
network net = parse_network_cfg(cfgfile);
set_learning_network(&net, net.learning_rate/10., .5, .0005);
//set_learning_network(&net, net.learning_rate, 0, .0005);
printf("Learning Rate: %g, Momentum: %g, Decay: %g\n", net.learning_rate, net.momentum, net.decay);
int imgs = 1024;
int i = 44700;
int i = 47900;
char **labels = get_labels("/home/pjreddie/data/imagenet/cls.labels.list");
list *plist = get_paths("/data/imagenet/cls.train.list");
char **paths = (char **)list_to_array(plist);

View File

@ -6,6 +6,20 @@
#include <stdlib.h>
#include <string.h>
struct load_args{
char **paths;
int n;
int m;
char **labels;
int k;
int h;
int w;
int nh;
int nw;
float scale;
data *d;
};
list *get_paths(char *filename)
{
char *path;
@ -165,11 +179,36 @@ data load_data_detection_jitter_random(int n, char **paths, int m, int h, int w,
jitter_image(a,224,224,dy,dx);
}
d.X.cols = 224*224*3;
// print_matrix(d.y);
free(random_paths);
return d;
}
void *load_detection_thread(void *ptr)
{
struct load_args a = *(struct load_args*)ptr;
*a.d = load_data_detection_jitter_random(a.n, a.paths, a.m, a.h, a.w, a.nh, a.nw, a.scale);
free(ptr);
return 0;
}
pthread_t load_data_detection_thread(int n, char **paths, int m, int h, int w, int nh, int nw, float scale, data *d)
{
pthread_t thread;
struct load_args *args = calloc(1, sizeof(struct load_args));
args->n = n;
args->paths = paths;
args->m = m;
args->h = h;
args->w = w;
args->nh = nh;
args->nw = nw;
args->scale = scale;
args->d = d;
if(pthread_create(&thread, 0, load_detection_thread, args)) {
error("Thread creation failed");
}
return thread;
}
data load_data_detection_random(int n, char **paths, int m, int h, int w, int nh, int nw, float scale)
{
@ -193,21 +232,11 @@ data load_data(char **paths, int n, int m, char **labels, int k, int h, int w)
return d;
}
struct load_args{
char **paths;
int n;
int m;
char **labels;
int k;
int h;
int w;
data *d;
};
void *load_in_thread(void *ptr)
{
struct load_args a = *(struct load_args*)ptr;
*a.d = load_data(a.paths, a.n, a.m, a.labels, a.k, a.h, a.w);
free(ptr);
return 0;
}

View File

@ -17,6 +17,8 @@ void free_data(data d);
data load_data(char **paths, int n, int m, char **labels, int k, int h, int w);
pthread_t load_data_thread(char **paths, int n, int m, char **labels, int k, int h, int w, data *d);
pthread_t load_data_detection_thread(int n, char **paths, int m, int h, int w, int nh, int nw, float scale, data *d);
data load_data_detection_random(int n, char **paths, int m, int h, int w, int nh, int nw, float scale);
data load_data_detection_jitter_random(int n, char **paths, int m, int h, int w, int nh, int nw, float scale);
data load_data_image_pathfile(char *filename, char **labels, int k, int h, int w);

View File

@ -103,8 +103,8 @@ void update_network(network net)
}
else if(net.types[i] == CONNECTED){
connected_layer layer = *(connected_layer *)net.layers[i];
secret_update_connected_layer((connected_layer *)net.layers[i]);
//update_connected_layer(layer);
//secret_update_connected_layer((connected_layer *)net.layers[i]);
update_connected_layer(layer);
}
}
}

View File

@ -416,6 +416,7 @@ list *read_cfg(char *filename)
strip(line);
switch(line[0]){
case '[':
printf("%s\n", line);
current = malloc(sizeof(section));
list_insert(sections, current);
current->options = make_list();

View File

@ -106,16 +106,17 @@ void strip_char(char *s, char bad)
char *fgetl(FILE *fp)
{
if(feof(fp)) return 0;
unsigned long size = 512;
size_t size = 512;
char *line = malloc(size*sizeof(char));
if(!fgets(line, size, fp)){
free(line);
return 0;
}
int curr = strlen(line);
size_t curr = strlen(line);
while(line[curr-1]!='\n'){
while((line[curr-1] != '\n') && !feof(fp)){
printf("%ld %ld\n", curr, size);
size *= 2;
line = realloc(line, size*sizeof(char));
if(!line) {
@ -125,7 +126,7 @@ char *fgetl(FILE *fp)
fgets(&line[curr], size-curr, fp);
curr = strlen(line);
}
line[curr-1] = '\0';
if(line[curr-1] == '\n') line[curr-1] = '\0';
return line;
}