aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/BatchNormalizationLayerFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/BatchNormalizationLayerFixture.h')
-rw-r--r--tests/validation/fixtures/BatchNormalizationLayerFixture.h54
1 files changed, 46 insertions, 8 deletions
diff --git a/tests/validation/fixtures/BatchNormalizationLayerFixture.h b/tests/validation/fixtures/BatchNormalizationLayerFixture.h
index e02c619249..4a6ac1af7f 100644
--- a/tests/validation/fixtures/BatchNormalizationLayerFixture.h
+++ b/tests/validation/fixtures/BatchNormalizationLayerFixture.h
@@ -45,10 +45,12 @@ class BatchNormalizationLayerValidationFixedPointFixture : public framework::Fix
{
public:
template <typename...>
- void setup(TensorShape shape0, TensorShape shape1, float epsilon, ActivationLayerInfo act_info, DataType dt, int fractional_bits)
+ void setup(TensorShape shape0, TensorShape shape1, float epsilon, bool use_beta, bool use_gamma, ActivationLayerInfo act_info, DataType dt, int fractional_bits)
{
_fractional_bits = fractional_bits;
_data_type = dt;
+ _use_beta = use_beta;
+ _use_gamma = use_gamma;
_target = compute_target(shape0, shape1, epsilon, act_info, dt, fractional_bits);
_reference = compute_reference(shape0, shape1, epsilon, act_info, dt, fractional_bits);
}
@@ -67,8 +69,24 @@ protected:
library->fill(src_tensor, distribution, 0);
library->fill(mean_tensor, distribution, 1);
library->fill(var_tensor, distribution_var, 0);
- library->fill(beta_tensor, distribution, 3);
- library->fill(gamma_tensor, distribution, 4);
+ if(_use_beta)
+ {
+ library->fill(beta_tensor, distribution, 3);
+ }
+ else
+ {
+ // Fill with default value 0.f
+ library->fill_tensor_value(beta_tensor, 0.f);
+ }
+ if(_use_gamma)
+ {
+ library->fill(gamma_tensor, distribution, 4);
+ }
+ else
+ {
+ // Fill with default value 1.f
+ library->fill_tensor_value(gamma_tensor, 1.f);
+ }
}
else
{
@@ -80,8 +98,24 @@ protected:
library->fill(src_tensor, distribution, 0);
library->fill(mean_tensor, distribution, 1);
library->fill(var_tensor, distribution_var, 0);
- library->fill(beta_tensor, distribution, 3);
- library->fill(gamma_tensor, distribution, 4);
+ if(_use_beta)
+ {
+ library->fill(beta_tensor, distribution, 3);
+ }
+ else
+ {
+ // Fill with default value 0
+ library->fill_tensor_value(beta_tensor, static_cast<T>(0));
+ }
+ if(_use_gamma)
+ {
+ library->fill(gamma_tensor, distribution, 4);
+ }
+ else
+ {
+ // Fill with default value 1
+ library->fill_tensor_value(gamma_tensor, static_cast<T>(1 << (_fractional_bits)));
+ }
}
}
@@ -97,7 +131,9 @@ protected:
// Create and configure function
FunctionType norm;
- norm.configure(&src, &dst, &mean, &var, &beta, &gamma, epsilon, act_info);
+ TensorType *beta_ptr = _use_beta ? &beta : nullptr;
+ TensorType *gamma_ptr = _use_gamma ? &gamma : nullptr;
+ norm.configure(&src, &dst, &mean, &var, beta_ptr, gamma_ptr, epsilon, act_info);
ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
@@ -149,6 +185,8 @@ protected:
SimpleTensor<T> _reference{};
int _fractional_bits{};
DataType _data_type{};
+ bool _use_beta{};
+ bool _use_gamma{};
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
@@ -156,9 +194,9 @@ class BatchNormalizationLayerValidationFixture : public BatchNormalizationLayerV
{
public:
template <typename...>
- void setup(TensorShape shape0, TensorShape shape1, float epsilon, ActivationLayerInfo act_info, DataType dt)
+ void setup(TensorShape shape0, TensorShape shape1, float epsilon, bool use_beta, bool use_gamma, ActivationLayerInfo act_info, DataType dt)
{
- BatchNormalizationLayerValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T>::setup(shape0, shape1, epsilon, act_info, dt, 0);
+ BatchNormalizationLayerValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T>::setup(shape0, shape1, epsilon, use_beta, use_gamma, act_info, dt, 0);
}
};
} // namespace validation