aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/NormalizationLayerFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/NormalizationLayerFixture.h')
-rw-r--r--tests/validation/fixtures/NormalizationLayerFixture.h37
1 files changed, 13 insertions, 24 deletions
diff --git a/tests/validation/fixtures/NormalizationLayerFixture.h b/tests/validation/fixtures/NormalizationLayerFixture.h
index e7d83c7735..f4f9c64944 100644
--- a/tests/validation/fixtures/NormalizationLayerFixture.h
+++ b/tests/validation/fixtures/NormalizationLayerFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017 ARM Limited.
+ * Copyright (c) 2017-2018 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -43,40 +43,30 @@ namespace test
namespace validation
{
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class NormalizationValidationFixedPointFixture : public framework::Fixture
+class NormalizationValidationGenericFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape shape, NormType norm_type, int norm_size, float beta, bool is_scaled, DataType data_type, int fractional_bits)
+ void setup(TensorShape shape, NormType norm_type, int norm_size, float beta, bool is_scaled, DataType data_type)
{
- _fractional_bits = fractional_bits;
NormalizationLayerInfo info(norm_type, norm_size, 5, beta, 1.f, is_scaled);
- _target = compute_target(shape, info, data_type, fractional_bits);
- _reference = compute_reference(shape, info, data_type, fractional_bits);
+ _target = compute_target(shape, info, data_type);
+ _reference = compute_reference(shape, info, data_type);
}
protected:
template <typename U>
void fill(U &&tensor)
{
- if(_fractional_bits == 0)
- {
- library->fill_tensor_uniform(tensor, 0);
- }
- else
- {
- const int one_fixed = 1 << _fractional_bits;
- std::uniform_int_distribution<> distribution(-one_fixed, one_fixed);
- library->fill(tensor, distribution, 0);
- }
+ library->fill_tensor_uniform(tensor, 0);
}
- TensorType compute_target(const TensorShape &shape, NormalizationLayerInfo info, DataType data_type, int fixed_point_position = 0)
+ TensorType compute_target(const TensorShape &shape, NormalizationLayerInfo info, DataType data_type)
{
// Create tensors
- TensorType src = create_tensor<TensorType>(shape, data_type, 1, fixed_point_position);
- TensorType dst = create_tensor<TensorType>(shape, data_type, 1, fixed_point_position);
+ TensorType src = create_tensor<TensorType>(shape, data_type, 1);
+ TensorType dst = create_tensor<TensorType>(shape, data_type, 1);
// Create and configure function
FunctionType norm_layer;
@@ -101,10 +91,10 @@ protected:
return dst;
}
- SimpleTensor<T> compute_reference(const TensorShape &shape, NormalizationLayerInfo info, DataType data_type, int fixed_point_position = 0)
+ SimpleTensor<T> compute_reference(const TensorShape &shape, NormalizationLayerInfo info, DataType data_type)
{
// Create reference
- SimpleTensor<T> src{ shape, data_type, 1, fixed_point_position };
+ SimpleTensor<T> src{ shape, data_type, 1 };
// Fill reference
fill(src);
@@ -114,17 +104,16 @@ protected:
TensorType _target{};
SimpleTensor<T> _reference{};
- int _fractional_bits{};
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class NormalizationValidationFixture : public NormalizationValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T>
+class NormalizationValidationFixture : public NormalizationValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
{
public:
template <typename...>
void setup(TensorShape shape, NormType norm_type, int norm_size, float beta, bool is_scaled, DataType data_type)
{
- NormalizationValidationFixedPointFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, norm_type, norm_size, beta, is_scaled, data_type, 0);
+ NormalizationValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, norm_type, norm_size, beta, is_scaled, data_type);
}
};
} // namespace validation