diff options
Diffstat (limited to 'tests/validation/fixtures/BatchNormalizationLayerFixture.h')
-rw-r--r-- | tests/validation/fixtures/BatchNormalizationLayerFixture.h | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/tests/validation/fixtures/BatchNormalizationLayerFixture.h b/tests/validation/fixtures/BatchNormalizationLayerFixture.h index 298c9ca411..e02c619249 100644 --- a/tests/validation/fixtures/BatchNormalizationLayerFixture.h +++ b/tests/validation/fixtures/BatchNormalizationLayerFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -45,12 +45,12 @@ class BatchNormalizationLayerValidationFixedPointFixture : public framework::Fix { public: template <typename...> - void setup(TensorShape shape0, TensorShape shape1, float epsilon, DataType dt, int fractional_bits) + void setup(TensorShape shape0, TensorShape shape1, float epsilon, ActivationLayerInfo act_info, DataType dt, int fractional_bits) { _fractional_bits = fractional_bits; _data_type = dt; - _target = compute_target(shape0, shape1, epsilon, dt, fractional_bits); - _reference = compute_reference(shape0, shape1, epsilon, dt, fractional_bits); + _target = compute_target(shape0, shape1, epsilon, act_info, dt, fractional_bits); + _reference = compute_reference(shape0, shape1, epsilon, act_info, dt, fractional_bits); } protected: @@ -85,7 +85,7 @@ protected: } } - TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, float epsilon, DataType dt, int fixed_point_position) + TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, float epsilon, ActivationLayerInfo act_info, DataType dt, int fixed_point_position) { // Create tensors TensorType src = create_tensor<TensorType>(shape0, dt, 1, fixed_point_position); @@ -97,7 +97,7 @@ protected: // Create and configure function FunctionType norm; - norm.configure(&src, &dst, &mean, &var, &beta, &gamma, epsilon); + norm.configure(&src, &dst, &mean, &var, &beta, &gamma, epsilon, act_info); ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -130,7 +130,7 @@ protected: return dst; } - SimpleTensor<T> compute_reference(const TensorShape &shape0, const TensorShape &shape1, float epsilon, DataType dt, int fixed_point_position) + SimpleTensor<T> compute_reference(const TensorShape &shape0, const TensorShape &shape1, float epsilon, ActivationLayerInfo act_info, DataType dt, int fixed_point_position) { // Create reference SimpleTensor<T> ref_src{ shape0, dt, 1, fixed_point_position }; @@ -142,7 +142,7 @@ protected: // Fill reference fill(ref_src, ref_mean, ref_var, ref_beta, ref_gamma); - return reference::batch_normalization_layer(ref_src, ref_mean, ref_var, ref_beta, ref_gamma, epsilon, fixed_point_position); + return reference::batch_normalization_layer(ref_src, ref_mean, ref_var, ref_beta, ref_gamma, epsilon, act_info, fixed_point_position); } TensorType _target{}; @@ -156,9 +156,9 @@ class BatchNormalizationLayerValidationFixture : public BatchNormalizationLayerV { public: template <typename...> - void setup(TensorShape shape0, TensorShape shape1, float epsilon, DataType dt) + void setup(TensorShape shape0, TensorShape shape1, float epsilon, ActivationLayerInfo act_info, DataType dt) { - BatchNormalizationLayerValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T>::setup(shape0, shape1, epsilon, dt, 0); + BatchNormalizationLayerValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T>::setup(shape0, shape1, epsilon, act_info, dt, 0); } }; } // namespace validation |