aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/ConvolutionLayerFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/ConvolutionLayerFixture.h')
-rw-r--r--tests/validation/fixtures/ConvolutionLayerFixture.h295
1 files changed, 293 insertions, 2 deletions
diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h
index bffdc59758..d3804ee371 100644
--- a/tests/validation/fixtures/ConvolutionLayerFixture.h
+++ b/tests/validation/fixtures/ConvolutionLayerFixture.h
@@ -28,6 +28,7 @@
#include "arm_compute/core/Types.h"
#include "arm_compute/graph/Utils.h"
#include "arm_compute/runtime/NEON/NEScheduler.h"
+#include "src/core/NEON/kernels/arm_gemm/utils.hpp"
#include "src/graph/mutators/MutatorUtils.h"
#include "tests/AssetsLibrary.h"
#include "tests/Globals.h"
@@ -121,14 +122,14 @@ protected:
{
case DataType::QASYMM8:
{
- std::pair<int, int> bounds = get_quantized_bounds(tensor.quantization_info(), -1.0f, 1.0f);
+ std::pair<int, int> bounds = get_quantized_bounds(tensor.quantization_info(), -1.0f, 1.0f);
std::uniform_int_distribution<uint32_t> distribution(bounds.first, bounds.second);
library->fill(tensor, distribution, i);
break;
}
case DataType::QASYMM8_SIGNED:
{
- std::pair<int, int> bounds = get_quantized_qasymm8_signed_bounds(tensor.quantization_info(), -1.0f, 1.0f);
+ std::pair<int, int> bounds = get_quantized_qasymm8_signed_bounds(tensor.quantization_info(), -1.0f, 1.0f);
std::uniform_int_distribution<int32_t> distribution(bounds.first, bounds.second);
library->fill(tensor, distribution, i);
break;
@@ -397,6 +398,296 @@ public:
quantization_info, QuantizationInfo(weights_scales), act_info);
}
};
+
+#ifdef ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
+inline TensorInfo prepare_weights(const TensorInfo tensor_info, const arm_gemm::WeightFormat weight_format)
+{
+ const DataLayout data_layout = tensor_info.data_layout();
+ ARM_COMPUTE_EXPECT(data_layout == DataLayout::NHWC, framework::LogLevel::ERRORS);
+ const DataType data_type = tensor_info.data_type();
+ const TensorShape tensor_shape = tensor_info.tensor_shape();
+ const int N = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES)]; // N=O
+ const int H = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)];
+ const int W = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)];
+ const int C = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)]; // C=I
+
+ const int interleave_by = arm_gemm::interleave_by(weight_format);
+ const int block_by = arm_gemm::block_by(weight_format);
+ const int Ip = arm_gemm::roundup<unsigned int>(C, block_by); // C'=I'
+ const int Op = arm_gemm::roundup<unsigned int>(N, interleave_by); // O'=N'
+
+ const TensorShape TS(Ip, W, H, Op);
+ return TensorInfo(TS, 1 /*num_channels*/, data_type, data_layout);
+}
+
+template <typename ScalarType, typename AccessorType>
+inline void rearrange_data(const AccessorType src, AccessorType dst, const arm_gemm::WeightFormat weight_format)
+{
+ ARM_COMPUTE_EXPECT(arm_gemm::is_fixed_format(weight_format), framework::LogLevel::ERRORS);
+ // Data Layout: OHWIo<interleave_by>i<block_by>
+ const int interleave_by = arm_gemm::interleave_by(weight_format);
+ const int block_by = arm_gemm::block_by(weight_format);
+ const TensorShape src_tensor_shape = src.shape();
+ const DataLayout data_layout = src.data_layout();
+ ARM_COMPUTE_EXPECT(data_layout == DataLayout::NHWC, framework::LogLevel::ERRORS);
+ const unsigned int O = src_tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES)]; // N=O
+ const unsigned int H = src_tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)];
+ const unsigned int W = src_tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)];
+ const unsigned int I = src_tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)]; // C=I
+ const unsigned int Ip = arm_gemm::roundup<unsigned int>(I, block_by); // C'=I'
+ const unsigned int Op = arm_gemm::roundup<unsigned int>(O, interleave_by); // N'=O'
+
+ ARM_COMPUTE_EXPECT_EQUAL(Op * H * W * Ip, (unsigned)dst.num_elements(), framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(src.num_elements() <= dst.num_elements(), framework::LogLevel::ERRORS);
+
+ const ScalarType *src_ptr = reinterpret_cast<const ScalarType *>(src.data());
+ ScalarType *dst_ptr = reinterpret_cast<ScalarType *>(dst.data());
+ for(unsigned i = 0; i < I; ++i)
+ for(unsigned w = 0; w < W; ++w)
+ for(unsigned h = 0; h < H; ++h)
+ for(unsigned o = 0; o < O; ++o)
+ {
+ ScalarType src_element;
+ switch(data_layout)
+ {
+ case DataLayout::NHWC:
+ {
+ src_element = src_ptr[o * H * W * I + h * W * I + w * I + i];
+ }
+ break;
+ default:
+ {
+ ARM_COMPUTE_ERROR("Unsupported memory layout.");
+ }
+ }
+ const int x5 = std::floor(((float)o) / interleave_by);
+ const int x4 = h;
+ const int x3 = w;
+ const int x2 = std::floor((float)i / block_by);
+ const int x1 = o % interleave_by;
+ const int x0 = i % block_by;
+ unsigned dst_idx = x5 * H * W * Ip * interleave_by
+ + x4 * W * Ip * interleave_by
+ + x3 * Ip * interleave_by
+ + x2 * interleave_by * block_by
+ + x1 * block_by
+ + x0;
+ dst_ptr[dst_idx] = src_element;
+ }
+}
+
+template <typename ConvolutionFunction, typename TensorClass, typename AccessorType, typename ScalarType>
+class VariableWeightsFixtureBaseClass : public framework::Fixture
+{
+public:
+ template <typename...>
+ void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, PadStrideInfo info, Size2D dilation, DataLayout data_layout,
+ const DataType data_type)
+ {
+ conv = std::make_unique<ConvolutionFunction>();
+ // prepare data
+ _data_layout = data_layout;
+ // Fixed format kernels for variable weights can work only with NHWC format.
+ ARM_COMPUTE_EXPECT_EQUAL(_data_layout, DataLayout::NHWC, framework::LogLevel::ERRORS);
+ _data_type = data_type;
+ // run the code
+ compute_target(input_shape, weights_shape, bias_shape, output_shape, info, dilation);
+ compute_reference(input_shape, weights_shape, bias_shape, output_shape, info, dilation);
+ }
+ void teardown()
+ {
+ _target.allocator()->free();
+ }
+
+protected:
+ template <typename U>
+ void fill(U &&tensor, int i)
+ {
+ switch(tensor.data_type())
+ {
+ case DataType::F16:
+ {
+ arm_compute::utils::uniform_real_distribution_16bit<half> distribution{ -1.0f, 1.0f };
+ library->fill(tensor, distribution, i);
+ break;
+ }
+ case DataType::F32:
+ {
+ std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
+ library->fill(tensor, distribution, i);
+ break;
+ }
+ default:
+ library->fill_tensor_uniform(tensor, i);
+ }
+ }
+
+private:
+ virtual void configure_and_execute_kernel(TensorInfo src_tensor_info, TensorInfo weight_tensor_info, TensorInfo bias_tensor_info, TensorInfo dst_tensor_info, const WeightsInfo weights_info,
+ const PadStrideInfo &conv_info,
+ const Size2D &dilation) = 0;
+
+ void compute_target(TensorShape input_shape, TensorShape weights_shape, const TensorShape &bias_shape, TensorShape output_shape, const PadStrideInfo &conv_info,
+ const Size2D &dilation)
+ {
+ // The dataset is always in NCHW format - we need to make C the
+ // innermost dimension because the fixed-format kernel work only
+ // with NHWC layout.
+ permute(input_shape, PermutationVector(2U, 0U, 1U));
+ permute(weights_shape, PermutationVector(2U, 0U, 1U));
+ permute(output_shape, PermutationVector(2U, 0U, 1U));
+ const auto src_tensor_info = TensorInfo(input_shape, 1, _data_type, _data_layout);
+ const auto weight_tensor_info = TensorInfo(weights_shape, 1, _data_type, _data_layout);
+ const auto bias_tensor_info = TensorInfo(bias_shape, 1, _data_type, _data_layout);
+ auto dst_tensor_info = TensorInfo(output_shape, 1, _data_type, _data_layout);
+
+ const int kernel_height = weights_shape[get_data_layout_dimension_index(_data_layout, DataLayoutDimension::HEIGHT)];
+ const int kernel_width = weights_shape[get_data_layout_dimension_index(_data_layout, DataLayoutDimension::WIDTH)];
+ const int num_kernels = weights_shape[get_data_layout_dimension_index(_data_layout, DataLayoutDimension::BATCHES)];
+
+ const WeightsInfo query_weights_info(/*reshape_weights*/ false, kernel_width, kernel_height, num_kernels, false, arm_gemm::WeightFormat::ANY);
+ const bool kernel_found = bool(ConvolutionFunction::has_opt_impl(_computed_weight_format, &src_tensor_info, &weight_tensor_info,
+ &bias_tensor_info, &dst_tensor_info, conv_info, query_weights_info));
+ // Make surethat the setup founds a fixed-format kernel as requested by the test case.
+ ARM_COMPUTE_EXPECT(kernel_found, framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(arm_gemm::is_fixed_format(_computed_weight_format), framework::LogLevel::ERRORS);
+
+ const WeightsInfo weights_info(/*reshape_weights*/ false, kernel_width, kernel_height, num_kernels, false, _computed_weight_format);
+ configure_and_execute_kernel(src_tensor_info, weight_tensor_info, bias_tensor_info, dst_tensor_info, weights_info, conv_info,
+ dilation);
+ }
+ void compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, const PadStrideInfo &info,
+ const Size2D &dilation)
+ {
+ ARM_COMPUTE_UNUSED(input_shape, weights_shape, bias_shape, output_shape, info,
+ dilation);
+
+ // Create reference
+ SimpleTensor<ScalarType> src{ input_shape, _data_type };
+ SimpleTensor<ScalarType> weights{ weights_shape, _data_type };
+ SimpleTensor<ScalarType> bias{ bias_shape, _data_type };
+ fill(src, 0);
+ fill(bias, 1);
+ fill(weights, 3);
+ _reference = reference::convolution_layer<ScalarType>(src, weights, bias, output_shape, info, dilation, 1 /*num_groups*/);
+ }
+ DataLayout _data_layout{};
+ DataType _data_type{};
+
+protected:
+ std::unique_ptr<ConvolutionFunction> conv{};
+ arm_gemm::WeightFormat _computed_weight_format{ arm_gemm::WeightFormat::UNSPECIFIED };
+ TensorClass _target{};
+ SimpleTensor<ScalarType> _reference{};
+};
+
+template <typename ConvolutionFunction, typename TensorClass, typename AccessorType, typename ScalarType>
+class VariableWeightsFixture : public VariableWeightsFixtureBaseClass<ConvolutionFunction, TensorClass, AccessorType, ScalarType>
+{
+ void configure_and_execute_kernel(TensorInfo src_tensor_info, TensorInfo weight_tensor_info, TensorInfo bias_tensor_info, TensorInfo dst_tensor_info, const WeightsInfo weights_info,
+ const PadStrideInfo &conv_info,
+ const Size2D &dilation)
+ {
+ this->conv->configure(&src_tensor_info, &weight_tensor_info, &bias_tensor_info, &dst_tensor_info, conv_info, weights_info, dilation);
+
+ // Allocate input tensors
+ auto src = create_tensor<TensorClass>(src_tensor_info);
+ auto weights_original = create_tensor<TensorClass>(weight_tensor_info);
+ const TensorInfo new_tensor_info = prepare_weights(weight_tensor_info, this->_computed_weight_format);
+ auto weights_transformed = create_tensor<TensorClass>(new_tensor_info);
+ auto bias = create_tensor<TensorClass>(bias_tensor_info);
+ src.allocator()->allocate();
+ weights_original.allocator()->allocate();
+ weights_transformed.allocator()->allocate();
+ bias.allocator()->allocate();
+ // Allocate destination tensor
+ this->_target = create_tensor<TensorClass>(dst_tensor_info);
+ this->_target.allocator()->allocate();
+
+ // Prepare source and biases that are left unchanged.
+ this->fill(AccessorType(src), 0);
+ this->fill(AccessorType(bias), 1);
+
+ // First run
+ this->fill(AccessorType(weights_original), 2);
+ rearrange_data<ScalarType, AccessorType>(AccessorType(weights_original), AccessorType(weights_transformed), this->_computed_weight_format);
+ ITensorPack run_pack{ { TensorType::ACL_SRC_0, &src }, { TensorType::ACL_SRC_1, &weights_transformed }, { TensorType::ACL_SRC_2, &bias }, { TensorType::ACL_DST, &(this->_target) } };
+ this->conv->run(run_pack);
+ // Second run, with new weights
+ this->fill(AccessorType(weights_original), 3);
+ rearrange_data<ScalarType, AccessorType>(AccessorType(weights_original), AccessorType(weights_transformed), this->_computed_weight_format);
+ this->conv->run(run_pack);
+ src.allocator()->free();
+ weights_original.allocator()->free();
+ weights_transformed.allocator()->free();
+ bias.allocator()->free();
+ }
+};
+
+template <typename ConvolutionFunction, typename TensorClass, typename AccessorType, typename ScalarType>
+class VariableWeightsFixtureNEInterface : public VariableWeightsFixtureBaseClass<ConvolutionFunction, TensorClass, AccessorType, ScalarType>
+{
+ void configure_and_execute_kernel(TensorInfo src_tensor_info, TensorInfo weight_tensor_info, TensorInfo bias_tensor_info, TensorInfo dst_tensor_info, const WeightsInfo weights_info,
+ const PadStrideInfo &conv_info,
+ const Size2D &dilation)
+ {
+ // Allocate input tensors
+ auto src = create_tensor<TensorClass>(src_tensor_info);
+ auto weights_original = create_tensor<TensorClass>(weight_tensor_info);
+ const TensorInfo new_tensor_info = prepare_weights(weight_tensor_info, this->_computed_weight_format);
+ auto weights_transformed = create_tensor<TensorClass>(new_tensor_info);
+ auto bias = create_tensor<TensorClass>(bias_tensor_info);
+ src.allocator()->allocate();
+ weights_original.allocator()->allocate();
+ weights_transformed.allocator()->allocate();
+ bias.allocator()->allocate();
+ // Allocate destination tensor
+ this->_target = create_tensor<TensorClass>(dst_tensor_info);
+ this->_target.allocator()->allocate();
+ this->conv->configure(&src, &weights_transformed, &bias, &(this->_target), conv_info, weights_info, dilation);
+ // Prepare source and biases that are left unchanged.
+ this->fill(AccessorType(src), 0);
+ this->fill(AccessorType(bias), 1);
+
+ // First run
+ this->fill(AccessorType(weights_original), 2);
+ rearrange_data<ScalarType, AccessorType>(AccessorType(weights_original), AccessorType(weights_transformed), this->_computed_weight_format);
+ this->conv->run();
+ // Second run, with new weights
+ this->fill(AccessorType(weights_original), 3);
+ rearrange_data<ScalarType, AccessorType>(AccessorType(weights_original), AccessorType(weights_transformed), this->_computed_weight_format);
+ this->conv->run();
+ src.allocator()->free();
+ weights_original.allocator()->free();
+ weights_transformed.allocator()->free();
+ bias.allocator()->free();
+ }
+};
+
+template <typename ConvolutionClass>
+class HasOptImplFixture : public framework::Fixture
+{
+public:
+ template <typename...>
+ void setup(DataType data_type, arm_gemm::WeightFormat query_weight_format)
+ {
+ auto conv = std::make_unique<ConvolutionClass>();
+ const auto src_info = TensorInfo(TensorShape(1U, 5U, 2U), 1, data_type, DataLayout::NHWC);
+ const auto weight_info = TensorInfo(TensorShape(1U, 3U, 2U, 3U), 1, data_type, DataLayout::NHWC);
+ const auto bias_info = TensorInfo(TensorShape(3U), 1, data_type, DataLayout::NHWC);
+ auto dst_info = TensorInfo(TensorShape(1U, 7U, 3U), 1, data_type, DataLayout::NHWC);
+ const auto conv_info = PadStrideInfo(1, 1, 0, 0, 2, 2, DimensionRoundingType::FLOOR);
+ const WeightsInfo weights_info(false, 3U, 3U, 1U, false, query_weight_format);
+ _kernel_found = bool(ConvolutionClass::has_opt_impl(_computed_weight_format, &src_info, &weight_info,
+ &bias_info, &dst_info, conv_info, weights_info));
+ }
+
+protected:
+ bool _kernel_found{ false };
+ arm_gemm::WeightFormat _computed_weight_format{ arm_gemm::WeightFormat::UNSPECIFIED };
+};
+#endif // ARM_COMPUTE_ENABLE_FIXED_FORMAT_KERNELS
+
} // namespace validation
} // namespace test
} // namespace arm_compute