aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/TensorOperations.h
diff options
context:
space:
mode:
authorMichalis Spyrou <michalis.spyrou@arm.com>2017-06-26 14:18:47 +0100
committerAnthony Barbier <anthony.barbier@arm.com>2018-09-17 14:16:42 +0100
commit172e57028ef14f2f8d6c56edc53c5c85f97e07cd (patch)
treeb3fe8c05902f07fb2381cf6dfd893654c8ccb63f /tests/validation/TensorOperations.h
parent579c0498e161215be1a36080b0b454e5198a992a (diff)
downloadComputeLibrary-172e57028ef14f2f8d6c56edc53c5c85f97e07cd.tar.gz
COMPMID-425 Port CLBatchnormalization to support QS8/QS16
Change-Id: I46c93305f377666ea0915ff789b7dfdfff596087 Reviewed-on: http://mpd-gerrit.cambridge.arm.com/78862 Reviewed-by: Anthony Barbier <anthony.barbier@arm.com> Tested-by: Kaizen <jeremy.johnson+kaizengerrit@arm.com>
Diffstat (limited to 'tests/validation/TensorOperations.h')
-rw-r--r--tests/validation/TensorOperations.h20
1 files changed, 10 insertions, 10 deletions
diff --git a/tests/validation/TensorOperations.h b/tests/validation/TensorOperations.h
index 27c50cf6d2..9e6f5cf5d1 100644
--- a/tests/validation/TensorOperations.h
+++ b/tests/validation/TensorOperations.h
@@ -974,17 +974,17 @@ void batch_normalization_layer(const Tensor<T> &in, Tensor<T> &out, const Tensor
for(int l = 0; l < cols; ++l)
{
const int pos = l + k * cols + i * rows * cols + r * cols * rows * depth;
- fixed_point_arithmetic::fixed_point<T> in_qs8(in[pos], fixed_point_position, true);
- fixed_point_arithmetic::fixed_point<T> var_qs8(var[i], fixed_point_position, true);
- fixed_point_arithmetic::fixed_point<T> mean_qs8(mean[i], fixed_point_position, true);
- fixed_point_arithmetic::fixed_point<T> beta_qs8(beta[i], fixed_point_position, true);
- fixed_point_arithmetic::fixed_point<T> gamma_qs8(gamma[i], fixed_point_position, true);
- fixed_point_arithmetic::fixed_point<T> epsilon_qs8(epsilon, fixed_point_position);
-
- auto denominator = fixed_point_arithmetic::inv_sqrt(var_qs8 + epsilon_qs8);
- auto numerator = in_qs8 - mean_qs8;
+ fixed_point_arithmetic::fixed_point<T> in_qs(in[pos], fixed_point_position, true);
+ fixed_point_arithmetic::fixed_point<T> var_qs(var[i], fixed_point_position, true);
+ fixed_point_arithmetic::fixed_point<T> mean_qs(mean[i], fixed_point_position, true);
+ fixed_point_arithmetic::fixed_point<T> beta_qs(beta[i], fixed_point_position, true);
+ fixed_point_arithmetic::fixed_point<T> gamma_qs(gamma[i], fixed_point_position, true);
+ fixed_point_arithmetic::fixed_point<T> epsilon_qs(epsilon, fixed_point_position);
+
+ auto denominator = fixed_point_arithmetic::inv_sqrt(var_qs + epsilon_qs);
+ auto numerator = in_qs - mean_qs;
auto x_bar = numerator * denominator;
- x_bar = beta_qs8 + x_bar * gamma_qs8;
+ x_bar = beta_qs + x_bar * gamma_qs;
out[pos] = x_bar.raw();
}
}