aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/WinogradLayerFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/WinogradLayerFixture.h')
-rw-r--r--tests/validation/fixtures/WinogradLayerFixture.h120
1 files changed, 99 insertions, 21 deletions
diff --git a/tests/validation/fixtures/WinogradLayerFixture.h b/tests/validation/fixtures/WinogradLayerFixture.h
index bfe1efce3b..9811c28008 100644
--- a/tests/validation/fixtures/WinogradLayerFixture.h
+++ b/tests/validation/fixtures/WinogradLayerFixture.h
@@ -48,14 +48,14 @@ namespace validation
using namespace arm_compute::misc::shape_calculator;
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
-class WinogradLayerValidationFixture : public framework::Fixture
+class WinogradConvolutionLayerValidationFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info)
+ void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, DataType data_type)
{
- _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info);
- _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info);
+ _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, info, data_type);
+ _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, data_type);
}
protected:
@@ -79,13 +79,14 @@ protected:
}
}
- TensorType compute_target(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info)
+ TensorType compute_target(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info,
+ DataType data_type)
{
// Create tensors
- TensorType src = create_tensor<TensorType>(input_shape, DataType::F32, 1);
- TensorType weights = create_tensor<TensorType>(weights_shape, DataType::F32, 1);
- TensorType bias = create_tensor<TensorType>(bias_shape, DataType::F32, 1);
- TensorType dst = create_tensor<TensorType>(output_shape, DataType::F32, 1);
+ TensorType src = create_tensor<TensorType>(input_shape, data_type, 1);
+ TensorType weights = create_tensor<TensorType>(weights_shape, data_type, 1);
+ TensorType bias = create_tensor<TensorType>(bias_shape, data_type, 1);
+ TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1);
// Create and configure function
FunctionType conv;
@@ -111,20 +112,20 @@ protected:
fill(AccessorType(src), 0, -1.f, 1.f);
fill(AccessorType(weights), 1, -1.f, 1.f);
fill(AccessorType(bias), 2, -1.f, 1.f);
- fill(AccessorType(dst), 3, -1.f, 1.f);
- // Compute NEWinogradLayer function
+ // Compute Winograd Convolution function
conv.run();
return dst;
}
- SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info)
+ SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info,
+ DataType data_type)
{
// Create reference
- SimpleTensor<T> src{ input_shape, DataType::F32, 1 };
- SimpleTensor<T> weights{ weights_shape, DataType::F32, 1 };
- SimpleTensor<T> bias{ bias_shape, DataType::F32, 1 };
+ SimpleTensor<T> src{ input_shape, data_type, 1 };
+ SimpleTensor<T> weights{ weights_shape, data_type, 1 };
+ SimpleTensor<T> bias{ bias_shape, data_type, 1 };
// Fill reference
fill(src, 0, -1.f, 1.f);
@@ -136,8 +137,6 @@ protected:
TensorType _target{};
SimpleTensor<T> _reference{};
- int _fractional_bits{};
- DataType _data_type{};
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
@@ -178,7 +177,6 @@ protected:
{
ARM_COMPUTE_UNUSED(is_nchw_format);
- // Create tensors
TensorType src = create_tensor<TensorType>(input_shape, data_type);
TensorType dst = create_tensor<TensorType>(output_shape, data_type);
@@ -261,8 +259,8 @@ protected:
ARM_COMPUTE_UNUSED(is_nchw_format);
// Create tensors
- TensorType src = create_tensor<TensorType>(input_shape, data_type);
- TensorType dst = create_tensor<TensorType>(output_shape, data_type);
+ TensorType src = create_tensor<TensorType>(input_shape, data_type, 1);
+ TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1);
// Create and configure function
FunctionType filter_transform;
@@ -288,7 +286,7 @@ protected:
SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &output_shape, bool is_nchw_format, DataType data_type)
{
- ARM_COMPUTE_ERROR_ON(!is_nchw_format);
+ ARM_COMPUTE_UNUSED(is_nchw_format);
// Create reference
SimpleTensor<T> src{ input_shape, data_type, 1 };
@@ -302,6 +300,86 @@ protected:
TensorType _target{};
SimpleTensor<T> _reference{};
};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
+class WinogradOutputTransformValidationFixture : public framework::Fixture
+{
+public:
+ template <typename...>
+ void setup(TensorShape input_shape, Size2D kernel_dims, Size2D output_convolved_dims, Size2D num_tiles, DataLayout data_layout, DataType data_type)
+ {
+ TensorShape output_shape = compute_winograd_output_transform_shape(TensorInfo(input_shape, 1, data_type), output_convolved_dims, data_layout);
+
+ _target = compute_target(input_shape, output_shape, kernel_dims, output_convolved_dims, num_tiles, data_layout, data_type);
+ _reference = compute_reference(input_shape, output_shape, kernel_dims, output_convolved_dims, num_tiles, data_layout, data_type);
+ }
+
+protected:
+ template <typename U>
+ void fill(U &&tensor, int i, float min, float max)
+ {
+ switch(tensor.data_type())
+ {
+ case DataType::F32:
+ {
+ std::uniform_real_distribution<> distribution(min, max);
+ library->fill(tensor, distribution, i);
+ break;
+ }
+ default:
+ {
+ ARM_COMPUTE_ERROR("Not supported");
+ library->fill_tensor_uniform(tensor, i);
+ break;
+ }
+ }
+ }
+
+ TensorType compute_target(const TensorShape &input_shape, const TensorShape &output_shape, const Size2D &kernel_dims, const Size2D &output_convolved_dims, Size2D &num_tiles, DataLayout data_layout,
+ DataType data_type)
+ {
+ // Create tensors
+ TensorType src = create_tensor<TensorType>(input_shape, data_type, 1, 0, QuantizationInfo(), data_layout);
+ TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1, 0, QuantizationInfo(), data_layout);
+
+ // Create and configure function
+ FunctionType output_transform;
+ output_transform.configure(&src, nullptr, &dst, kernel_dims, output_convolved_dims, num_tiles);
+
+ ARM_COMPUTE_EXPECT(src.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(dst.info()->is_resizable(), framework::LogLevel::ERRORS);
+
+ // Allocate tensors
+ src.allocator()->allocate();
+ dst.allocator()->allocate();
+
+ ARM_COMPUTE_EXPECT(!src.info()->is_resizable(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(!dst.info()->is_resizable(), framework::LogLevel::ERRORS);
+
+ // Fill tensors
+ fill(AccessorType(src), 0, -1.f, 1.f);
+
+ output_transform.run();
+
+ return dst;
+ }
+
+ SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &output_shape, const Size2D &kernel_dims, const Size2D &output_convolved_dims, Size2D &num_tiles,
+ DataLayout data_layout,
+ DataType data_type)
+ {
+ // Create reference
+ SimpleTensor<T> src{ input_shape, data_type, 1, 0, QuantizationInfo(), data_layout };
+
+ // Fill reference
+ fill(src, 0, -1.f, 1.f);
+
+ return reference::winograd_output_transform<T>(src, output_shape, kernel_dims, num_tiles);
+ }
+
+ TensorType _target{};
+ SimpleTensor<T> _reference{};
+};
} // namespace validation
} // namespace test
} // namespace arm_compute