diff options
Diffstat (limited to 'tests/validation/fixtures')
-rw-r--r-- | tests/validation/fixtures/BatchNormalizationLayerFixture.h | 54 |
1 files changed, 46 insertions, 8 deletions
diff --git a/tests/validation/fixtures/BatchNormalizationLayerFixture.h b/tests/validation/fixtures/BatchNormalizationLayerFixture.h index e02c619249..4a6ac1af7f 100644 --- a/tests/validation/fixtures/BatchNormalizationLayerFixture.h +++ b/tests/validation/fixtures/BatchNormalizationLayerFixture.h @@ -45,10 +45,12 @@ class BatchNormalizationLayerValidationFixedPointFixture : public framework::Fix { public: template <typename...> - void setup(TensorShape shape0, TensorShape shape1, float epsilon, ActivationLayerInfo act_info, DataType dt, int fractional_bits) + void setup(TensorShape shape0, TensorShape shape1, float epsilon, bool use_beta, bool use_gamma, ActivationLayerInfo act_info, DataType dt, int fractional_bits) { _fractional_bits = fractional_bits; _data_type = dt; + _use_beta = use_beta; + _use_gamma = use_gamma; _target = compute_target(shape0, shape1, epsilon, act_info, dt, fractional_bits); _reference = compute_reference(shape0, shape1, epsilon, act_info, dt, fractional_bits); } @@ -67,8 +69,24 @@ protected: library->fill(src_tensor, distribution, 0); library->fill(mean_tensor, distribution, 1); library->fill(var_tensor, distribution_var, 0); - library->fill(beta_tensor, distribution, 3); - library->fill(gamma_tensor, distribution, 4); + if(_use_beta) + { + library->fill(beta_tensor, distribution, 3); + } + else + { + // Fill with default value 0.f + library->fill_tensor_value(beta_tensor, 0.f); + } + if(_use_gamma) + { + library->fill(gamma_tensor, distribution, 4); + } + else + { + // Fill with default value 1.f + library->fill_tensor_value(gamma_tensor, 1.f); + } } else { @@ -80,8 +98,24 @@ protected: library->fill(src_tensor, distribution, 0); library->fill(mean_tensor, distribution, 1); library->fill(var_tensor, distribution_var, 0); - library->fill(beta_tensor, distribution, 3); - library->fill(gamma_tensor, distribution, 4); + if(_use_beta) + { + library->fill(beta_tensor, distribution, 3); + } + else + { + // Fill with default value 0 + library->fill_tensor_value(beta_tensor, static_cast<T>(0)); + } + if(_use_gamma) + { + library->fill(gamma_tensor, distribution, 4); + } + else + { + // Fill with default value 1 + library->fill_tensor_value(gamma_tensor, static_cast<T>(1 << (_fractional_bits))); + } } } @@ -97,7 +131,9 @@ protected: // Create and configure function FunctionType norm; - norm.configure(&src, &dst, &mean, &var, &beta, &gamma, epsilon, act_info); + TensorType *beta_ptr = _use_beta ? &beta : nullptr; + TensorType *gamma_ptr = _use_gamma ? &gamma : nullptr; + norm.configure(&src, &dst, &mean, &var, beta_ptr, gamma_ptr, epsilon, act_info); ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS); ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS); @@ -149,6 +185,8 @@ protected: SimpleTensor<T> _reference{}; int _fractional_bits{}; DataType _data_type{}; + bool _use_beta{}; + bool _use_gamma{}; }; template <typename TensorType, typename AccessorType, typename FunctionType, typename T> @@ -156,9 +194,9 @@ class BatchNormalizationLayerValidationFixture : public BatchNormalizationLayerV { public: template <typename...> - void setup(TensorShape shape0, TensorShape shape1, float epsilon, ActivationLayerInfo act_info, DataType dt) + void setup(TensorShape shape0, TensorShape shape1, float epsilon, bool use_beta, bool use_gamma, ActivationLayerInfo act_info, DataType dt) { - BatchNormalizationLayerValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T>::setup(shape0, shape1, epsilon, act_info, dt, 0); + BatchNormalizationLayerValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T>::setup(shape0, shape1, epsilon, use_beta, use_gamma, act_info, dt, 0); } }; } // namespace validation |