diff --git a/src/data.c b/src/data.c index 883585fc..2053030b 100644 --- a/src/data.c +++ b/src/data.c @@ -1507,7 +1507,7 @@ data load_data_super(char **paths, int n, int m, int w, int h, int scale) return d; } -data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *hierarchy, int use_flip, int min, int max, int w, int h, float angle, float aspect, float hue, float saturation, float exposure, int mixup, int use_blur, int show_imgs, float label_smooth_eps) +data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *hierarchy, int use_flip, int min, int max, int w, int h, float angle, float aspect, float hue, float saturation, float exposure, int use_mixup, int use_blur, int show_imgs, float label_smooth_eps) { char **paths_stored = paths; if(m) paths = get_random_paths(paths, n, m); @@ -1516,7 +1516,7 @@ data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *h d.X = load_image_augment_paths(paths, n, use_flip, min, max, w, h, angle, aspect, hue, saturation, exposure); d.y = load_labels_paths(paths, n, labels, k, hierarchy, label_smooth_eps); - if (mixup && rand_int(0, 1)) { + if (use_mixup && rand_int(0, 1)) { char **paths_mix = get_random_paths(paths_stored, n, m); data d2 = { 0 }; d2.shallow = 0; @@ -1528,7 +1528,7 @@ data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *h d3.shallow = 0; data d4 = { 0 }; d4.shallow = 0; - if (mixup >= 3) { + if (use_mixup >= 3) { char **paths_mix3 = get_random_paths(paths_stored, n, m); d3.X = load_image_augment_paths(paths_mix3, n, use_flip, min, max, w, h, angle, aspect, hue, saturation, exposure); d3.y = load_labels_paths(paths_mix3, n, labels, k, hierarchy, label_smooth_eps); @@ -1545,7 +1545,8 @@ data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *h int i, j; for (i = 0; i < d2.X.rows; ++i) { - if (mixup == 4) mixup = rand_int(2, 3); // alternate CutMix and Mosaic + int mixup = use_mixup; + if (use_mixup == 4) mixup = rand_int(2, 3); // alternate CutMix and Mosaic // MixUp ----------------------------------- if (mixup == 1) { @@ -1637,7 +1638,7 @@ data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *h free_data(d2); - if (mixup == 3) { + if (use_mixup >= 3) { free_data(d3); free_data(d4); } diff --git a/src/data.h b/src/data.h index dc18a928..f675ae82 100644 --- a/src/data.h +++ b/src/data.h @@ -91,7 +91,7 @@ data load_data_detection(int n, char **paths, int m, int w, int h, int c, int bo data load_data_tag(char **paths, int n, int m, int k, int use_flip, int min, int max, int w, int h, float angle, float aspect, float hue, float saturation, float exposure); matrix load_image_augment_paths(char **paths, int n, int use_flip, int min, int max, int w, int h, float angle, float aspect, float hue, float saturation, float exposure); data load_data_super(char **paths, int n, int m, int w, int h, int scale); -data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *hierarchy, int use_flip, int min, int max, int w, int h, float angle, float aspect, float hue, float saturation, float exposure, int mixup, int use_blur, int show_imgs, float label_smooth_eps); +data load_data_augment(char **paths, int n, int m, char **labels, int k, tree *hierarchy, int use_flip, int min, int max, int w, int h, float angle, float aspect, float hue, float saturation, float exposure, int use_mixup, int use_blur, int show_imgs, float label_smooth_eps); data load_go(char *filename); box_label *read_boxes(char *filename, int *n);