diff options
Diffstat (limited to 'tests/validation/NEON/ConvolutionLayer.cpp')
-rw-r--r-- | tests/validation/NEON/ConvolutionLayer.cpp | 97 |
1 files changed, 94 insertions, 3 deletions
diff --git a/tests/validation/NEON/ConvolutionLayer.cpp b/tests/validation/NEON/ConvolutionLayer.cpp index 80615c5d57..112188fdfa 100644 --- a/tests/validation/NEON/ConvolutionLayer.cpp +++ b/tests/validation/NEON/ConvolutionLayer.cpp @@ -23,6 +23,7 @@ */ #include "arm_compute/core/Types.h" #include "arm_compute/runtime/NEON/functions/NEConvolutionLayer.h" +#include "arm_compute/runtime/NEON/functions/NEGEMMConv2d.h" #include "arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h" #include "arm_compute/runtime/NEON/functions/NEWinogradConvolutionLayer.h" #include "arm_compute/runtime/Tensor.h" @@ -45,6 +46,20 @@ namespace test { namespace validation { +namespace detail +{ +template <> +void configure_conv_function<NEGEMMConv2d, Tensor>(NEGEMMConv2d &func, + Tensor *src, const Tensor *weights, const Tensor *bias, Tensor *dst, + const PadStrideInfo &info, const WeightsInfo &weights_info, + const Size2D &dilation, const ActivationLayerInfo &act_info, unsigned int num_groups) +{ + ARM_COMPUTE_UNUSED(weights_info); + + Conv2dInfo conv_info(info, dilation, act_info, false, num_groups); + func.configure(src, weights, bias, dst, conv_info); +} +} // namespace detail namespace { const RelativeTolerance<float> rel_tolerance_f32(0.01f); /**< Relative tolerance for FP32 types */ @@ -368,7 +383,7 @@ TEST_SUITE_END() // WinogradLayer TEST_SUITE(GEMMConvolutionLayer) template <typename T> -using NEGEMMConvolutionLayerFixture = ConvolutionValidationFixture<Tensor, Accessor, NEGEMMConvolutionLayer, T>; +using NEGEMMConvolutionLayerFixture = ConvolutionValidationFixture<Tensor, Accessor, NEConvolutionLayer, T>; TEST_SUITE(Float) #if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) @@ -413,10 +428,10 @@ TEST_SUITE_END() // FP32 TEST_SUITE_END() // Float template <typename T> -using NEGEMMConvolutionLayerQuantizedFixture = ConvolutionValidationQuantizedFixture<Tensor, Accessor, NEGEMMConvolutionLayer, T>; +using NEGEMMConvolutionLayerQuantizedFixture = ConvolutionValidationQuantizedFixture<Tensor, Accessor, NEConvolutionLayer, T>; template <typename T> -using NEGEMMConvolutionLayerQuantizedPerChannelFixture = ConvolutionValidationQuantizedPerChannelFixture<Tensor, Accessor, NEGEMMConvolutionLayer, T, int8_t>; +using NEGEMMConvolutionLayerQuantizedPerChannelFixture = ConvolutionValidationQuantizedPerChannelFixture<Tensor, Accessor, NEConvolutionLayer, T, int8_t>; const auto QuantizedActivationFunctionsDataset = framework::dataset::make("ActivationInfo", { @@ -480,6 +495,82 @@ TEST_SUITE_END() // QSYMM8_PER_CHANNEL TEST_SUITE_END() // Quantized TEST_SUITE_END() // GEMMConvolutionLayer + +TEST_SUITE(DirectGEMMConv2d) +template <typename T> +using NEDirectGEMMConv2dLayerFixture = ConvolutionValidationFixture<Tensor, Accessor, NEGEMMConv2d, T>; + +TEST_SUITE(Float) +TEST_SUITE(FP32) +FIXTURE_DATA_TEST_CASE(RunSmall, NEDirectGEMMConv2dLayerFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(), + framework::dataset::make("ReshapeWeights", { true })), + framework::dataset::make("DataType", DataType::F32)), + framework::dataset::make("DataLayout", { DataLayout::NHWC })), + ActivationFunctionsDataset)) +{ + // Validate output + validate(Accessor(_target), _reference, rel_tolerance_f32, 0.f, float(abs_tolerance_f32)); +} +TEST_SUITE_END() // FP32 +TEST_SUITE_END() // Float + +template <typename T> +using NEDirectGEMMConv2dLayerQuantizedFixture = ConvolutionValidationQuantizedFixture<Tensor, Accessor, NEGEMMConv2d, T>; + +template <typename T> +using NEDirectGEMMConv2dLayerQuantizedPerChannelFixture = ConvolutionValidationQuantizedPerChannelFixture<Tensor, Accessor, NEGEMMConv2d, T, int8_t>; + +const auto QuantizedActivationFunctionsDataset = framework::dataset::make("ActivationInfo", +{ + ActivationLayerInfo(), + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU), + ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 6.f) +}); +TEST_SUITE(Quantized) +TEST_SUITE(QASYMM8) +FIXTURE_DATA_TEST_CASE(RunSmall, NEDirectGEMMConv2dLayerQuantizedFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(), + framework::dataset::make("ReshapeWeights", { true })), + framework::dataset::make("DataType", DataType::QASYMM8)), + framework::dataset::make("DataLayout", { DataLayout::NHWC })), + framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })), + QuantizedActivationFunctionsDataset)) +{ + // Validate output + validate(Accessor(_target), _reference, tolerance_qasymm8); +} +TEST_SUITE_END() // QASYMM8 + +TEST_SUITE(QASYMM8_SIGNED) +FIXTURE_DATA_TEST_CASE(RunSmall, NEDirectGEMMConv2dLayerQuantizedFixture<int8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(), + framework::dataset::make("ReshapeWeights", { true })), + framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)), + framework::dataset::make("DataLayout", { DataLayout::NHWC })), + framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.01f, -10) })), + QuantizedActivationFunctionsDataset)) +{ + // Validate output + validate(Accessor(_target), _reference, tolerance_qasymm8); +} +TEST_SUITE_END() // QASYMM8_SIGNED + +TEST_SUITE(QSYMM8_PER_CHANNEL) +FIXTURE_DATA_TEST_CASE(RunSmallSigned, NEDirectGEMMConv2dLayerQuantizedPerChannelFixture<int8_t>, framework::DatasetMode::ALL, + combine(combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(), + framework::dataset::make("ReshapeWeights", { true })), + framework::dataset::make("DataType", { DataType::QASYMM8_SIGNED })), + framework::dataset::make("DataLayout", { DataLayout::NHWC })), + QuantizationData), + QuantizedActivationFunctionsDataset), + framework::dataset::make("WeightsDataType", { DataType::QSYMM8_PER_CHANNEL }))) +{ + // Validate output + validate(Accessor(_target), _reference, tolerance_qasymm8); +} +TEST_SUITE_END() // QSYMM8_PER_CHANNEL +TEST_SUITE_END() // Quantized + +TEST_SUITE_END() // DirectGEMMConv2d + TEST_SUITE_END() // NEON } // namespace validation } // namespace test |