aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/NEON/ConvolutionLayer.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/NEON/ConvolutionLayer.cpp')
-rw-r--r--tests/validation/NEON/ConvolutionLayer.cpp97
1 files changed, 94 insertions, 3 deletions
diff --git a/tests/validation/NEON/ConvolutionLayer.cpp b/tests/validation/NEON/ConvolutionLayer.cpp
index 80615c5d57..112188fdfa 100644
--- a/tests/validation/NEON/ConvolutionLayer.cpp
+++ b/tests/validation/NEON/ConvolutionLayer.cpp
@@ -23,6 +23,7 @@
*/
#include "arm_compute/core/Types.h"
#include "arm_compute/runtime/NEON/functions/NEConvolutionLayer.h"
+#include "arm_compute/runtime/NEON/functions/NEGEMMConv2d.h"
#include "arm_compute/runtime/NEON/functions/NEGEMMConvolutionLayer.h"
#include "arm_compute/runtime/NEON/functions/NEWinogradConvolutionLayer.h"
#include "arm_compute/runtime/Tensor.h"
@@ -45,6 +46,20 @@ namespace test
{
namespace validation
{
+namespace detail
+{
+template <>
+void configure_conv_function<NEGEMMConv2d, Tensor>(NEGEMMConv2d &func,
+ Tensor *src, const Tensor *weights, const Tensor *bias, Tensor *dst,
+ const PadStrideInfo &info, const WeightsInfo &weights_info,
+ const Size2D &dilation, const ActivationLayerInfo &act_info, unsigned int num_groups)
+{
+ ARM_COMPUTE_UNUSED(weights_info);
+
+ Conv2dInfo conv_info(info, dilation, act_info, false, num_groups);
+ func.configure(src, weights, bias, dst, conv_info);
+}
+} // namespace detail
namespace
{
const RelativeTolerance<float> rel_tolerance_f32(0.01f); /**< Relative tolerance for FP32 types */
@@ -368,7 +383,7 @@ TEST_SUITE_END() // WinogradLayer
TEST_SUITE(GEMMConvolutionLayer)
template <typename T>
-using NEGEMMConvolutionLayerFixture = ConvolutionValidationFixture<Tensor, Accessor, NEGEMMConvolutionLayer, T>;
+using NEGEMMConvolutionLayerFixture = ConvolutionValidationFixture<Tensor, Accessor, NEConvolutionLayer, T>;
TEST_SUITE(Float)
#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
@@ -413,10 +428,10 @@ TEST_SUITE_END() // FP32
TEST_SUITE_END() // Float
template <typename T>
-using NEGEMMConvolutionLayerQuantizedFixture = ConvolutionValidationQuantizedFixture<Tensor, Accessor, NEGEMMConvolutionLayer, T>;
+using NEGEMMConvolutionLayerQuantizedFixture = ConvolutionValidationQuantizedFixture<Tensor, Accessor, NEConvolutionLayer, T>;
template <typename T>
-using NEGEMMConvolutionLayerQuantizedPerChannelFixture = ConvolutionValidationQuantizedPerChannelFixture<Tensor, Accessor, NEGEMMConvolutionLayer, T, int8_t>;
+using NEGEMMConvolutionLayerQuantizedPerChannelFixture = ConvolutionValidationQuantizedPerChannelFixture<Tensor, Accessor, NEConvolutionLayer, T, int8_t>;
const auto QuantizedActivationFunctionsDataset = framework::dataset::make("ActivationInfo",
{
@@ -480,6 +495,82 @@ TEST_SUITE_END() // QSYMM8_PER_CHANNEL
TEST_SUITE_END() // Quantized
TEST_SUITE_END() // GEMMConvolutionLayer
+
+TEST_SUITE(DirectGEMMConv2d)
+template <typename T>
+using NEDirectGEMMConv2dLayerFixture = ConvolutionValidationFixture<Tensor, Accessor, NEGEMMConv2d, T>;
+
+TEST_SUITE(Float)
+TEST_SUITE(FP32)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEDirectGEMMConv2dLayerFixture<float>, framework::DatasetMode::ALL, combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
+ framework::dataset::make("ReshapeWeights", { true })),
+ framework::dataset::make("DataType", DataType::F32)),
+ framework::dataset::make("DataLayout", { DataLayout::NHWC })),
+ ActivationFunctionsDataset))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, rel_tolerance_f32, 0.f, float(abs_tolerance_f32));
+}
+TEST_SUITE_END() // FP32
+TEST_SUITE_END() // Float
+
+template <typename T>
+using NEDirectGEMMConv2dLayerQuantizedFixture = ConvolutionValidationQuantizedFixture<Tensor, Accessor, NEGEMMConv2d, T>;
+
+template <typename T>
+using NEDirectGEMMConv2dLayerQuantizedPerChannelFixture = ConvolutionValidationQuantizedPerChannelFixture<Tensor, Accessor, NEGEMMConv2d, T, int8_t>;
+
+const auto QuantizedActivationFunctionsDataset = framework::dataset::make("ActivationInfo",
+{
+ ActivationLayerInfo(),
+ ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU),
+ ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU, 6.f)
+});
+TEST_SUITE(Quantized)
+TEST_SUITE(QASYMM8)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEDirectGEMMConv2dLayerQuantizedFixture<uint8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
+ framework::dataset::make("ReshapeWeights", { true })),
+ framework::dataset::make("DataType", DataType::QASYMM8)),
+ framework::dataset::make("DataLayout", { DataLayout::NHWC })),
+ framework::dataset::make("QuantizationInfo", { QuantizationInfo(2.f / 255.f, 10) })),
+ QuantizedActivationFunctionsDataset))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_qasymm8);
+}
+TEST_SUITE_END() // QASYMM8
+
+TEST_SUITE(QASYMM8_SIGNED)
+FIXTURE_DATA_TEST_CASE(RunSmall, NEDirectGEMMConv2dLayerQuantizedFixture<int8_t>, framework::DatasetMode::ALL, combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
+ framework::dataset::make("ReshapeWeights", { true })),
+ framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
+ framework::dataset::make("DataLayout", { DataLayout::NHWC })),
+ framework::dataset::make("QuantizationInfo", { QuantizationInfo(0.01f, -10) })),
+ QuantizedActivationFunctionsDataset))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_qasymm8);
+}
+TEST_SUITE_END() // QASYMM8_SIGNED
+
+TEST_SUITE(QSYMM8_PER_CHANNEL)
+FIXTURE_DATA_TEST_CASE(RunSmallSigned, NEDirectGEMMConv2dLayerQuantizedPerChannelFixture<int8_t>, framework::DatasetMode::ALL,
+ combine(combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
+ framework::dataset::make("ReshapeWeights", { true })),
+ framework::dataset::make("DataType", { DataType::QASYMM8_SIGNED })),
+ framework::dataset::make("DataLayout", { DataLayout::NHWC })),
+ QuantizationData),
+ QuantizedActivationFunctionsDataset),
+ framework::dataset::make("WeightsDataType", { DataType::QSYMM8_PER_CHANNEL })))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_qasymm8);
+}
+TEST_SUITE_END() // QSYMM8_PER_CHANNEL
+TEST_SUITE_END() // Quantized
+
+TEST_SUITE_END() // DirectGEMMConv2d
+
TEST_SUITE_END() // NEON
} // namespace validation
} // namespace test