Added [empty]/[silence] and [implicit] layers

This commit is contained in:
AlexeyAB 2021-05-11 22:59:21 +03:00
parent e2a128737b
commit 846c79b6d4
14 changed files with 3173 additions and 12 deletions

View File

@ -151,7 +151,7 @@ LDFLAGS+= -L/usr/local/zed/lib -lsl_zed
endif
endif
OBJ=image_opencv.o http_stream.o gemm.o utils.o dark_cuda.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o detector.o layer.o compare.o classifier.o local_layer.o swag.o shortcut_layer.o activation_layer.o rnn_layer.o gru_layer.o rnn.o rnn_vid.o crnn_layer.o demo.o tag.o cifar.o go.o batchnorm_layer.o art.o region_layer.o reorg_layer.o reorg_old_layer.o super.o voxel.o tree.o yolo_layer.o gaussian_yolo_layer.o upsample_layer.o lstm_layer.o conv_lstm_layer.o scale_channels_layer.o sam_layer.o
OBJ=image_opencv.o http_stream.o gemm.o utils.o dark_cuda.o convolutional_layer.o list.o image.o activations.o im2col.o col2im.o blas.o crop_layer.o dropout_layer.o maxpool_layer.o softmax_layer.o data.o matrix.o network.o connected_layer.o cost_layer.o parser.o option_list.o darknet.o detection_layer.o captcha.o route_layer.o writing.o box.o nightmare.o normalization_layer.o avgpool_layer.o coco.o dice.o yolo.o detector.o layer.o compare.o classifier.o local_layer.o swag.o shortcut_layer.o representation_layer.o activation_layer.o rnn_layer.o gru_layer.o rnn.o rnn_vid.o crnn_layer.o demo.o tag.o cifar.o go.o batchnorm_layer.o art.o region_layer.o reorg_layer.o reorg_old_layer.o super.o voxel.o tree.o yolo_layer.o gaussian_yolo_layer.o upsample_layer.o lstm_layer.o conv_lstm_layer.o scale_channels_layer.o sam_layer.o
ifeq ($(GPU), 1)
LDFLAGS+= -lstdc++
OBJ+=convolutional_kernels.o activation_kernels.o im2col_kernels.o col2im_kernels.o blas_kernels.o crop_layer_kernels.o dropout_layer_kernels.o maxpool_layer_kernels.o network_kernels.o avgpool_layer_kernels.o

View File

@ -223,6 +223,7 @@
<ClCompile Include="..\..\src\region_layer.c" />
<ClCompile Include="..\..\src\reorg_layer.c" />
<ClCompile Include="..\..\src\reorg_old_layer.c" />
<ClCompile Include="..\..\src\representation_layer.c" />
<ClCompile Include="..\..\src\rnn.c" />
<ClCompile Include="..\..\src\rnn_layer.c" />
<ClCompile Include="..\..\src\rnn_vid.c" />
@ -286,6 +287,7 @@
<ClInclude Include="..\..\src\region_layer.h" />
<ClInclude Include="..\..\src\reorg_layer.h" />
<ClInclude Include="..\..\src\reorg_old_layer.h" />
<ClInclude Include="..\..\src\representation_layer.h" />
<ClInclude Include="..\..\src\rnn_layer.h" />
<ClInclude Include="..\..\src\route_layer.h" />
<ClInclude Include="..\..\src\sam_layer.h" />

File diff suppressed because it is too large Load Diff

View File

@ -25,6 +25,8 @@ import math
import random
import os
print("Run: darknet_images.py or:\n")
print("python.exe darknet_video.py --data_file cfg/coco.data --config_file cfg/yolov4.cfg --weights yolov4.weights --input test.mp4 \n")
class BOX(Structure):
_fields_ = [("x", c_float),

View File

@ -14,6 +14,6 @@ rem C:\Users\Alex\AppData\Local\Programs\Python\Python36\Scripts\pip install sci
rem C:\Users\Alex\AppData\Local\Programs\Python\Python36\Scripts\pip install scipy
rem C:\Users\Alex\AppData\Local\Programs\Python\Python36\Scripts\pip install opencv-python
C:\Users\Alex\AppData\Local\Programs\Python\Python36\python.exe darknet.py
C:\Users\Alex\AppData\Local\Programs\Python\Python36\python.exe darknet_images.py
pause

View File

@ -8,7 +8,10 @@ rem darknet.exe partial cfg/tiny-yolo-voc.cfg tiny-yolo-voc.weights tiny-yolo-vo
darknet.exe partial cfg/yolov4-tiny.cfg yolov4-tiny.weights yolov4-tiny.conv.29 29
darknet.exe partial cfg/yolov4-sam-mish.cfg cfg/yolov4-sam-mish.weights cfg/yolov4-sam-mish.conv.137 137
rem darknet.exe partial cfg/yolov4-sam-mish.cfg cfg/yolov4-sam-mish.weights cfg/yolov4-sam-mish.conv.137 137
rem darknet.exe partial cfg/yolov4-sam-mish.cfg cfg/yolov4-sam-mish.weights cfg/yolov4-sam-mish.conv.105 105
pause

View File

@ -52,7 +52,7 @@
</PropertyGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.props" />
<ImportGroup Label="ExtensionSettings">
<Import Project="$(VCTargetsPath)\BuildCustomizations\CUDA 10.1.props" />
<Import Project="$(VCTargetsPath)\BuildCustomizations\CUDA 11.1.props" />
</ImportGroup>
<ImportGroup Label="PropertySheets" Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">
<Import Project="$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props" Condition="exists('$(UserRootDir)\Microsoft.Cpp.$(Platform).user.props')" Label="LocalAppDataPlatform" />
@ -155,7 +155,7 @@
</Link>
<CudaCompile>
<TargetMachinePlatform>64</TargetMachinePlatform>
<CodeGeneration>compute_30,sm_30;compute_75,sm_75</CodeGeneration>
<CodeGeneration>compute_35,sm_35;compute_75,sm_75</CodeGeneration>
</CudaCompile>
</ItemDefinitionGroup>
<ItemGroup>
@ -225,6 +225,7 @@
<ClCompile Include="..\..\src\region_layer.c" />
<ClCompile Include="..\..\src\reorg_layer.c" />
<ClCompile Include="..\..\src\reorg_old_layer.c" />
<ClCompile Include="..\..\src\representation_layer.c" />
<ClCompile Include="..\..\src\rnn.c" />
<ClCompile Include="..\..\src\rnn_layer.c" />
<ClCompile Include="..\..\src\rnn_vid.c" />
@ -290,6 +291,7 @@
<ClInclude Include="..\..\src\region_layer.h" />
<ClInclude Include="..\..\src\reorg_layer.h" />
<ClInclude Include="..\..\src\reorg_old_layer.h" />
<ClInclude Include="..\..\src\representation_layer.h" />
<ClInclude Include="..\..\src\rnn_layer.h" />
<ClInclude Include="..\..\src\route_layer.h" />
<ClInclude Include="..\..\src\sam_layer.h" />
@ -306,6 +308,6 @@
</ItemGroup>
<Import Project="$(VCTargetsPath)\Microsoft.Cpp.targets" />
<ImportGroup Label="ExtensionTargets">
<Import Project="$(VCTargetsPath)\BuildCustomizations\CUDA 10.1.targets" />
<Import Project="$(VCTargetsPath)\BuildCustomizations\CUDA 11.1.targets" />
</ImportGroup>
</Project>

File diff suppressed because it is too large Load Diff

View File

@ -192,7 +192,8 @@ typedef enum {
L2NORM,
EMPTY,
BLANK,
CONTRASTIVE
CONTRASTIVE,
IMPLICIT
} LAYER_TYPE;
// layer.h

View File

@ -174,6 +174,9 @@ void mult_inverse_array_gpu(const float *src_gpu, float *dst_gpu, int size, floa
void P_constrastive_f_det_gpu(int *labels, unsigned int feature_size, float temperature, contrastive_params *contrast_p, const int contrast_p_size);
void coord_conv_gpu(float *dst, int size, int w, int h, int chan, int b, int type);
void forward_implicit_gpu(int batch, int nweights, float *weight_gpu, float *output_gpu);
void backward_implicit_gpu(int batch, int nweights, float *weight_updates_gpu, float *delta_gpu);
#endif // GPU
#ifdef __cplusplus
}

View File

@ -2435,4 +2435,37 @@ extern "C" void coord_conv_gpu(float *dst, int size, int w, int h, int chan, int
coord_conv_kernel << <num_blocks, block_size, 0, get_cuda_stream() >> > (dst, w, h, chan, b, type);
CHECK_CUDA(cudaPeekAtLastError());
}
}
__global__ void forward_implicit_kernel(int size, int batch, int nweights, float *weight_gpu, float *output_gpu)
{
const int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if (id >= size) return;
output_gpu[id] = weight_gpu[id % nweights];
}
extern "C" void forward_implicit_gpu(int batch, int nweights, float *weight_gpu, float *output_gpu)
{
int size = batch * nweights;
forward_implicit_kernel << <cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >> > (size, batch, nweights, weight_gpu, output_gpu);
CHECK_CUDA(cudaPeekAtLastError());
}
__global__ void backward_implicit_kernel(int size, int batch, int nweights, float *weight_updates_gpu, float *delta_gpu)
{
const int id = (blockIdx.x + blockIdx.y*gridDim.x) * blockDim.x + threadIdx.x;
if (id >= size) return;
weight_updates_gpu[id % nweights] += delta_gpu[id];
}
extern "C" void backward_implicit_gpu(int batch, int nweights, float *weight_updates_gpu, float *delta_gpu)
{
int size = batch * nweights;
backward_implicit_kernel << <cuda_gridsize(size), BLOCK, 0, get_cuda_stream() >> > (size, batch, nweights, weight_updates_gpu, delta_gpu);
CHECK_CUDA(cudaPeekAtLastError());
}

View File

@ -39,6 +39,11 @@
#include "version.h"
#include "yolo_layer.h"
#include "gaussian_yolo_layer.h"
#include "representation_layer.h"
void empty_func(dropout_layer l, network_state state) {
//l.output_gpu = state.input;
}
typedef struct{
char *type;
@ -90,7 +95,9 @@ LAYER_TYPE string_to_layer_type(char * type)
if (strcmp(type, "[contrastive]") == 0) return CONTRASTIVE;
if (strcmp(type, "[route]")==0) return ROUTE;
if (strcmp(type, "[upsample]") == 0) return UPSAMPLE;
if (strcmp(type, "[empty]") == 0) return EMPTY;
if (strcmp(type, "[empty]") == 0
|| strcmp(type, "[silence]") == 0) return EMPTY;
if (strcmp(type, "[implicit]") == 0) return IMPLICIT;
return BLANK;
}
@ -1036,6 +1043,17 @@ layer parse_sam(list *options, size_params params, network net)
return s;
}
layer parse_implicit(list *options, size_params params, network net)
{
float mean_init = option_find_float(options, "mean", 0.0);
float std_init = option_find_float(options, "std", 0.2);
int filters = option_find_int(options, "filters", 128);
int atoms = option_find_int_quiet(options, "atoms", 1);
layer s = make_implicit_layer(params.batch, params.index, mean_init, std_init, filters, atoms);
return s;
}
layer parse_activation(list *options, size_params params)
{
@ -1480,6 +1498,8 @@ network parse_network_cfg_custom(char *filename, int batch, int time_steps)
net.layers[count - 1].use_bin_output = 0;
net.layers[l.index].use_bin_output = 0;
net.layers[l.index].keep_delta_gpu = 1;
} else if (lt == IMPLICIT) {
l = parse_implicit(options, params, net);
}else if(lt == DROPOUT){
l = parse_dropout(options, params);
l.output = net.layers[count-1].output;
@ -1492,16 +1512,25 @@ network parse_network_cfg_custom(char *filename, int batch, int time_steps)
}
else if (lt == EMPTY) {
layer empty_layer = {(LAYER_TYPE)0};
empty_layer.out_w = params.w;
empty_layer.out_h = params.h;
empty_layer.out_c = params.c;
l = empty_layer;
l.type = EMPTY;
l.w = l.out_w = params.w;
l.h = l.out_h = params.h;
l.c = l.out_c = params.c;
l.batch = params.batch;
l.inputs = l.outputs = params.inputs;
l.output = net.layers[count - 1].output;
l.delta = net.layers[count - 1].delta;
l.forward = empty_func;
l.backward = empty_func;
#ifdef GPU
l.output_gpu = net.layers[count - 1].output_gpu;
l.delta_gpu = net.layers[count - 1].delta_gpu;
l.keep_delta_gpu = 1;
l.forward_gpu = empty_func;
l.backward_gpu = empty_func;
#endif
fprintf(stderr, "empty \n");
}else{
fprintf(stderr, "Type not recognized: %s\n", s->type);
}
@ -1604,6 +1633,7 @@ network parse_network_cfg_custom(char *filename, int batch, int time_steps)
l.dontloadscales = option_find_int_quiet(options, "dontloadscales", 0);
l.learning_rate_scale = option_find_float_quiet(options, "learning_rate", 1);
option_unused(options);
net.layers[count] = l;
if (l.workspace_size > workspace_size) workspace_size = l.workspace_size;
if (l.inputs > max_inputs) max_inputs = l.inputs;
@ -1810,6 +1840,24 @@ void save_shortcut_weights(layer l, FILE *fp)
fwrite(l.weights, sizeof(float), num, fp);
}
void save_implicit_weights(layer l, FILE *fp)
{
#ifdef GPU
if (gpu_index >= 0) {
pull_implicit_layer(l);
//printf("\n pull_implicit_layer \n");
}
#endif
int i;
//if(l.weight_updates) for (i = 0; i < l.nweights; ++i) printf(" %f, ", l.weight_updates[i]);
//printf(" l.nweights = %d - update \n", l.nweights);
//for (i = 0; i < l.nweights; ++i) printf(" %f, ", l.weights[i]);
//printf(" l.nweights = %d \n\n", l.nweights);
int num = l.nweights;
fwrite(l.weights, sizeof(float), num, fp);
}
void save_convolutional_weights(layer l, FILE *fp)
{
if(l.binary){
@ -1921,6 +1969,8 @@ void save_weights_upto(network net, char *filename, int cutoff, int save_ema)
}
} if (l.type == SHORTCUT && l.nweights > 0) {
save_shortcut_weights(l, fp);
} if (l.type == IMPLICIT) {
save_implicit_weights(l, fp);
} if(l.type == CONNECTED){
save_connected_weights(l, fp);
} if(l.type == BATCHNORM){
@ -2131,6 +2181,21 @@ void load_shortcut_weights(layer l, FILE *fp)
#endif
}
void load_implicit_weights(layer l, FILE *fp)
{
int num = l.nweights;
int read_bytes;
read_bytes = fread(l.weights, sizeof(float), num, fp);
if (read_bytes > 0 && read_bytes < num) printf("\n Warning: Unexpected end of wights-file! l.weights - l.index = %d \n", l.index);
//for (int i = 0; i < l.nweights; ++i) printf(" %f, ", l.weights[i]);
//printf(" read_bytes = %d \n\n", read_bytes);
#ifdef GPU
if (gpu_index >= 0) {
push_implicit_layer(l);
}
#endif
}
void load_weights_upto(network *net, char *filename, int cutoff)
{
#ifdef GPU
@ -2175,6 +2240,9 @@ void load_weights_upto(network *net, char *filename, int cutoff)
if (l.type == SHORTCUT && l.nweights > 0) {
load_shortcut_weights(l, fp);
}
if (l.type == IMPLICIT) {
load_implicit_weights(l, fp);
}
if(l.type == CONNECTED){
load_connected_weights(l, fp, transpose);
}

160
src/representation_layer.c Normal file
View File

@ -0,0 +1,160 @@
#include "representation_layer.h"
#include "utils.h"
#include "dark_cuda.h"
#include "blas.h"
#include <stdio.h>
#include <assert.h>
layer make_implicit_layer(int batch, int index, float mean_init, float std_init, int filters, int atoms)
{
fprintf(stderr,"implicit Layer: %d x %d \t mean=%.2f, std=%.2f \n", filters, atoms, mean_init, std_init);
layer l = { (LAYER_TYPE)0 };
l.type = IMPLICIT;
l.batch = batch;
l.w = 1;
l.h = 1;
l.c = 1;
l.out_w = 1;
l.out_h = atoms;
l.out_c = filters;
l.outputs = l.out_w*l.out_h*l.out_c;
l.inputs = 1;
l.index = index;
l.nweights = l.out_w * l.out_h * l.out_c;
l.weight_updates = (float*)xcalloc(l.nweights, sizeof(float));
l.weights = (float*)xcalloc(l.nweights, sizeof(float));
int i;
for (i = 0; i < l.nweights; ++i) l.weights[i] = mean_init + rand_uniform(-std_init, std_init);
l.delta = (float*)xcalloc(l.outputs * batch, sizeof(float));
l.output = (float*)xcalloc(l.outputs * batch, sizeof(float));
l.forward = forward_implicit_layer;
l.backward = backward_implicit_layer;
l.update = update_implicit_layer;
#ifdef GPU
l.forward_gpu = forward_implicit_layer_gpu;
l.backward_gpu = backward_implicit_layer_gpu;
l.update_gpu = update_implicit_layer_gpu;
l.delta_gpu = cuda_make_array(l.delta, l.outputs*batch);
l.output_gpu = cuda_make_array(l.output, l.outputs*batch);
l.weight_updates_gpu = cuda_make_array(l.weight_updates, l.nweights);
l.weights_gpu = cuda_make_array(l.weights, l.nweights);
#endif
return l;
}
void resize_implicit_layer(layer *l, int w, int h)
{
}
void forward_implicit_layer(const layer l, network_state state)
{
int i;
#pragma omp parallel for
for (i = 0; i < l.nweights * l.batch; ++i) {
l.output[i] = l.weights[i % l.nweights];
}
}
void backward_implicit_layer(const layer l, network_state state)
{
int i;
#pragma omp parallel for
for (i = 0; i < l.nweights * l.batch; ++i) {
l.weight_updates[i % l.nweights] += l.delta[i];
}
}
void update_implicit_layer(layer l, int batch, float learning_rate_init, float momentum, float decay)
{
float learning_rate = learning_rate_init*l.learning_rate_scale;
//float momentum = a.momentum;
//float decay = a.decay;
//int batch = a.batch;
axpy_cpu(l.nweights, -decay*batch, l.weights, 1, l.weight_updates, 1);
axpy_cpu(l.nweights, learning_rate / batch, l.weight_updates, 1, l.weights, 1);
scal_cpu(l.nweights, momentum, l.weight_updates, 1);
}
#ifdef GPU
void forward_implicit_layer_gpu(const layer l, network_state state)
{
forward_implicit_gpu(l.batch, l.nweights, l.weights_gpu, l.output_gpu);
}
void backward_implicit_layer_gpu(const layer l, network_state state)
{
backward_implicit_gpu(l.batch, l.nweights, l.weight_updates_gpu, l.delta_gpu);
}
void update_implicit_layer_gpu(layer l, int batch, float learning_rate_init, float momentum, float decay, float loss_scale)
{
// Loss scale for Mixed-Precision on Tensor-Cores
float learning_rate = learning_rate_init*l.learning_rate_scale / loss_scale;
//float momentum = a.momentum;
//float decay = a.decay;
//int batch = a.batch;
reset_nan_and_inf(l.weight_updates_gpu, l.nweights);
fix_nan_and_inf(l.weights_gpu, l.nweights);
if (l.adam) {
//adam_update_gpu(l.weights_gpu, l.weight_updates_gpu, l.m_gpu, l.v_gpu, a.B1, a.B2, a.eps, decay, learning_rate, l.nweights, batch, a.t);
adam_update_gpu(l.weights_gpu, l.weight_updates_gpu, l.m_gpu, l.v_gpu, l.B1, l.B2, l.eps, decay, learning_rate, l.nweights, batch, l.t);
}
else {
//axpy_ongpu(l.nweights, -decay*batch, l.weights_gpu, 1, l.weight_updates_gpu, 1);
//axpy_ongpu(l.nweights, learning_rate / batch, l.weight_updates_gpu, 1, l.weights_gpu, 1);
//scal_ongpu(l.nweights, momentum, l.weight_updates_gpu, 1);
axpy_ongpu(l.nweights, -decay*batch*loss_scale, l.weights_gpu, 1, l.weight_updates_gpu, 1);
axpy_ongpu(l.nweights, learning_rate / batch, l.weight_updates_gpu, 1, l.weights_gpu, 1);
scal_ongpu(l.nweights, momentum, l.weight_updates_gpu, 1);
}
if (l.clip) {
constrain_ongpu(l.nweights, l.clip, l.weights_gpu, 1);
}
}
void pull_implicit_layer(layer l)
{
cuda_pull_array_async(l.weights_gpu, l.weights, l.nweights);
cuda_pull_array_async(l.weight_updates_gpu, l.weight_updates, l.nweights);
if (l.adam) {
cuda_pull_array_async(l.m_gpu, l.m, l.nweights);
cuda_pull_array_async(l.v_gpu, l.v, l.nweights);
}
CHECK_CUDA(cudaPeekAtLastError());
cudaStreamSynchronize(get_cuda_stream());
}
void push_implicit_layer(layer l)
{
cuda_push_array(l.weights_gpu, l.weights, l.nweights);
if (l.train) {
cuda_push_array(l.weight_updates_gpu, l.weight_updates, l.nweights);
}
if (l.adam) {
cuda_push_array(l.m_gpu, l.m, l.nweights);
cuda_push_array(l.v_gpu, l.v, l.nweights);
}
CHECK_CUDA(cudaPeekAtLastError());
}
#endif

View File

@ -0,0 +1,29 @@
#ifndef REPRESENTATION_LAYER_H
#define REPRESENTATION_LAYER_H
#include "layer.h"
#include "network.h"
#ifdef __cplusplus
extern "C" {
#endif
layer make_implicit_layer(int batch, int index, float mean_init, float std_init, int filters, int atoms);
void forward_implicit_layer(const layer l, network_state state);
void backward_implicit_layer(const layer l, network_state state);
void update_implicit_layer(layer l, int batch, float learning_rate_init, float momentum, float decay);
void resize_implicit_layer(layer *l, int w, int h);
#ifdef GPU
void forward_implicit_layer_gpu(const layer l, network_state state);
void backward_implicit_layer_gpu(const layer l, network_state state);
void update_implicit_layer_gpu(layer l, int batch, float learning_rate_init, float momentum, float decay, float loss_scale);
void pull_implicit_layer(layer l);
void push_implicit_layer(layer l);
#endif
#ifdef __cplusplus
}
#endif
#endif // REPRESENTATION_LAYER_H