diff options
Diffstat (limited to 'tests/validation/TensorOperations.h')
-rw-r--r-- | tests/validation/TensorOperations.h | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/tests/validation/TensorOperations.h b/tests/validation/TensorOperations.h index 27c50cf6d2..9e6f5cf5d1 100644 --- a/tests/validation/TensorOperations.h +++ b/tests/validation/TensorOperations.h @@ -974,17 +974,17 @@ void batch_normalization_layer(const Tensor<T> &in, Tensor<T> &out, const Tensor for(int l = 0; l < cols; ++l) { const int pos = l + k * cols + i * rows * cols + r * cols * rows * depth; - fixed_point_arithmetic::fixed_point<T> in_qs8(in[pos], fixed_point_position, true); - fixed_point_arithmetic::fixed_point<T> var_qs8(var[i], fixed_point_position, true); - fixed_point_arithmetic::fixed_point<T> mean_qs8(mean[i], fixed_point_position, true); - fixed_point_arithmetic::fixed_point<T> beta_qs8(beta[i], fixed_point_position, true); - fixed_point_arithmetic::fixed_point<T> gamma_qs8(gamma[i], fixed_point_position, true); - fixed_point_arithmetic::fixed_point<T> epsilon_qs8(epsilon, fixed_point_position); - - auto denominator = fixed_point_arithmetic::inv_sqrt(var_qs8 + epsilon_qs8); - auto numerator = in_qs8 - mean_qs8; + fixed_point_arithmetic::fixed_point<T> in_qs(in[pos], fixed_point_position, true); + fixed_point_arithmetic::fixed_point<T> var_qs(var[i], fixed_point_position, true); + fixed_point_arithmetic::fixed_point<T> mean_qs(mean[i], fixed_point_position, true); + fixed_point_arithmetic::fixed_point<T> beta_qs(beta[i], fixed_point_position, true); + fixed_point_arithmetic::fixed_point<T> gamma_qs(gamma[i], fixed_point_position, true); + fixed_point_arithmetic::fixed_point<T> epsilon_qs(epsilon, fixed_point_position); + + auto denominator = fixed_point_arithmetic::inv_sqrt(var_qs + epsilon_qs); + auto numerator = in_qs - mean_qs; auto x_bar = numerator * denominator; - x_bar = beta_qs8 + x_bar * gamma_qs8; + x_bar = beta_qs + x_bar * gamma_qs; out[pos] = x_bar.raw(); } } |