Contrastive loss minor fix

This commit is contained in:
AlexeyAB 2020-08-02 17:37:03 +03:00
parent 6af4370c3f
commit f2eb30b52c
2 changed files with 13 additions and 11 deletions

View File

@ -638,8 +638,8 @@ float P_constrastive_f_det(size_t il, int *labels, float **z, unsigned int featu
} }
} }
float result = numerator / denominator; float result = 0.9999;
if (denominator == 0) result = 1; if (denominator != 0) result = numerator / denominator;
if (result > 1) result = 0.9999; if (result > 1) result = 0.9999;
return result; return result;
} }
@ -669,8 +669,8 @@ float P_constrastive_f(size_t i, size_t l, int *labels, float **z, unsigned int
} }
} }
float result = numerator / denominator; float result = 0.9999;
if (denominator == 0) result = 1; if (denominator != 0) result = numerator / denominator;
if (result > 1) result = 0.9999; if (result > 1) result = 0.9999;
return result; return result;
} }

View File

@ -405,11 +405,13 @@ void forward_contrastive_layer(contrastive_layer l, network_state state)
*/ */
const size_t contr_size = contrast_p_index;
if (l.detection) { if (l.detection) {
int k; int k;
#pragma omp parallel for #pragma omp parallel for
for (k = 0; k < contrast_p_index; ++k) { for (k = 0; k < contr_size; ++k) {
contrast_p[k].P = P_constrastive_f_det(k, l.labels, z, l.embedding_size, l.temperature, contrast_p, contrast_p_index); contrast_p[k].P = P_constrastive_f_det(k, l.labels, z, l.embedding_size, l.temperature, contrast_p, contr_size);
} }
} }
else { else {
@ -439,7 +441,7 @@ void forward_contrastive_layer(contrastive_layer l, network_state state)
float P = -10; float P = -10;
if (l.detection) { if (l.detection) {
P = P_constrastive_f(z_index, z_index2, l.labels, z, l.embedding_size, l.temperature, contrast_p, contrast_p_index); P = P_constrastive_f(z_index, z_index2, l.labels, z, l.embedding_size, l.temperature, contrast_p, contr_size);
} }
else { else {
P = P_constrastive(z_index, z_index2, l.labels, step, z, l.embedding_size, l.temperature, l.cos_sim, l.exp_cos_sim); P = P_constrastive(z_index, z_index2, l.labels, step, z, l.embedding_size, l.temperature, l.cos_sim, l.exp_cos_sim);
@ -447,13 +449,13 @@ void forward_contrastive_layer(contrastive_layer l, network_state state)
} }
int q; int q;
for (q = 0; q < contrast_p_index; ++q) for (q = 0; q < contr_size; ++q)
if (contrast_p[q].i == z_index && contrast_p[q].j == z_index2) { if (contrast_p[q].i == z_index && contrast_p[q].j == z_index2) {
contrast_p[q].P = P; contrast_p[q].P = P;
break; break;
} }
//if (q == contrast_p_index) getchar(); //if (q == contr_size) getchar();
//if (P > 1 || P < -1) { //if (P > 1 || P < -1) {
@ -488,10 +490,10 @@ void forward_contrastive_layer(contrastive_layer l, network_state state)
// detector // detector
// positive // positive
grad_contrastive_loss_positive_f(z_index, l.labels, step, z, l.embedding_size, l.temperature, l.delta + delta_index, wh, contrast_p, contrast_p_index); grad_contrastive_loss_positive_f(z_index, l.labels, step, z, l.embedding_size, l.temperature, l.delta + delta_index, wh, contrast_p, contr_size);
// negative // negative
grad_contrastive_loss_negative_f(z_index, l.labels, step, z, l.embedding_size, l.temperature, l.delta + delta_index, wh, contrast_p, contrast_p_index); grad_contrastive_loss_negative_f(z_index, l.labels, step, z, l.embedding_size, l.temperature, l.delta + delta_index, wh, contrast_p, contr_size);
} }
else { else {
// classifier // classifier