Contrastive loss minor fix

This commit is contained in:
AlexeyAB 2020-08-02 17:37:03 +03:00
parent 6af4370c3f
commit f2eb30b52c
2 changed files with 13 additions and 11 deletions

View File

@ -638,8 +638,8 @@ float P_constrastive_f_det(size_t il, int *labels, float **z, unsigned int featu
}
}
float result = numerator / denominator;
if (denominator == 0) result = 1;
float result = 0.9999;
if (denominator != 0) result = numerator / denominator;
if (result > 1) result = 0.9999;
return result;
}
@ -669,8 +669,8 @@ float P_constrastive_f(size_t i, size_t l, int *labels, float **z, unsigned int
}
}
float result = numerator / denominator;
if (denominator == 0) result = 1;
float result = 0.9999;
if (denominator != 0) result = numerator / denominator;
if (result > 1) result = 0.9999;
return result;
}

View File

@ -405,11 +405,13 @@ void forward_contrastive_layer(contrastive_layer l, network_state state)
*/
const size_t contr_size = contrast_p_index;
if (l.detection) {
int k;
#pragma omp parallel for
for (k = 0; k < contrast_p_index; ++k) {
contrast_p[k].P = P_constrastive_f_det(k, l.labels, z, l.embedding_size, l.temperature, contrast_p, contrast_p_index);
for (k = 0; k < contr_size; ++k) {
contrast_p[k].P = P_constrastive_f_det(k, l.labels, z, l.embedding_size, l.temperature, contrast_p, contr_size);
}
}
else {
@ -439,7 +441,7 @@ void forward_contrastive_layer(contrastive_layer l, network_state state)
float P = -10;
if (l.detection) {
P = P_constrastive_f(z_index, z_index2, l.labels, z, l.embedding_size, l.temperature, contrast_p, contrast_p_index);
P = P_constrastive_f(z_index, z_index2, l.labels, z, l.embedding_size, l.temperature, contrast_p, contr_size);
}
else {
P = P_constrastive(z_index, z_index2, l.labels, step, z, l.embedding_size, l.temperature, l.cos_sim, l.exp_cos_sim);
@ -447,13 +449,13 @@ void forward_contrastive_layer(contrastive_layer l, network_state state)
}
int q;
for (q = 0; q < contrast_p_index; ++q)
for (q = 0; q < contr_size; ++q)
if (contrast_p[q].i == z_index && contrast_p[q].j == z_index2) {
contrast_p[q].P = P;
break;
}
//if (q == contrast_p_index) getchar();
//if (q == contr_size) getchar();
//if (P > 1 || P < -1) {
@ -488,10 +490,10 @@ void forward_contrastive_layer(contrastive_layer l, network_state state)
// detector
// positive
grad_contrastive_loss_positive_f(z_index, l.labels, step, z, l.embedding_size, l.temperature, l.delta + delta_index, wh, contrast_p, contrast_p_index);
grad_contrastive_loss_positive_f(z_index, l.labels, step, z, l.embedding_size, l.temperature, l.delta + delta_index, wh, contrast_p, contr_size);
// negative
grad_contrastive_loss_negative_f(z_index, l.labels, step, z, l.embedding_size, l.temperature, l.delta + delta_index, wh, contrast_p, contrast_p_index);
grad_contrastive_loss_negative_f(z_index, l.labels, step, z, l.embedding_size, l.temperature, l.delta + delta_index, wh, contrast_p, contr_size);
}
else {
// classifier