aboutsummaryrefslogtreecommitdiff
path: root/tests/validation
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation')
-rw-r--r--tests/validation/CMakeLists.txt5
-rw-r--r--tests/validation/NEON/ConvolutionLayer.cpp101
-rw-r--r--tests/validation/NEON/GEMMLowp.cpp36
-rw-r--r--tests/validation/fixtures/ConvolutionLayerFixture.h27
-rw-r--r--tests/validation/fixtures/DirectConvolution3DFixture.h5
-rw-r--r--tests/validation/fixtures/GEMMLowpFixture.h50
-rw-r--r--tests/validation/reference/Conv3D.cpp24
-rw-r--r--tests/validation/reference/Conv3D.h10
-rw-r--r--tests/validation/runtime/experimental/operators/CpuGemm.cpp143
9 files changed, 359 insertions, 42 deletions
diff --git a/tests/validation/CMakeLists.txt b/tests/validation/CMakeLists.txt
index 448e96c4f9..b71787db60 100644
--- a/tests/validation/CMakeLists.txt
+++ b/tests/validation/CMakeLists.txt
@@ -1,4 +1,4 @@
-# Copyright (c) 2023 Arm Limited.
+# Copyright (c) 2024 Arm Limited.
#
# SPDX-License-Identifier: MIT
#
@@ -142,5 +142,6 @@ if(ENABLE_NEON)
NEON/UNIT/DynamicTensor.cpp
NEON/UNIT/TensorAllocator.cpp
NEON/UNIT/MemoryManager.cpp
- NEON/UNIT/RuntimeContext.cpp)
+ NEON/UNIT/RuntimeContext.cpp
+ runtime/experimental/operators/CpuGemm.cpp)
endif()
diff --git a/tests/validation/NEON/ConvolutionLayer.cpp b/tests/validation/NEON/ConvolutionLayer.cpp
index d739d4e1a4..7eada81ce5 100644
--- a/tests/validation/NEON/ConvolutionLayer.cpp
+++ b/tests/validation/NEON/ConvolutionLayer.cpp
@@ -147,6 +147,45 @@ const auto QuantizationData = make("QuantizationInfo",
TEST_SUITE(NEON)
TEST_SUITE(ConvolutionLayer)
+DATA_TEST_CASE(SupportedTypes, framework::DatasetMode::ALL, zip(
+ make("DataType", {
+ DataType::F32,
+ DataType::QASYMM8,
+ DataType::QASYMM8,
+ DataType::QASYMM8_SIGNED
+ }),
+ make("WeightsDataType", {
+ DataType::F32,
+ DataType::QASYMM8,
+ DataType::QASYMM8_SIGNED,
+ DataType::QASYMM8
+ }),
+ make("Expected",
+ {
+ true,
+ true,
+ true,
+ false
+ })),
+data_type_const, weights_data_type_const, expected_const)
+{
+ TensorInfo input_info = TensorInfo(TensorShape(3U, 3U, 1U), 1, data_type_const);
+ TensorInfo weights_info = TensorInfo(TensorShape(2U, 2U, 1U, 1U), 1, weights_data_type_const);
+ TensorInfo output_info = TensorInfo(TensorShape(2U, 2U, 1U), 1, data_type_const);
+
+ input_info.set_quantization_info(arm_compute::QuantizationInfo(1, 0));
+ weights_info.set_quantization_info(arm_compute::QuantizationInfo(1, 0));
+ output_info.set_quantization_info(arm_compute::QuantizationInfo(1, 0));
+
+ Status status = NEConvolutionLayer::validate(
+ &input_info,
+ &weights_info,
+ nullptr,
+ &output_info,
+ PadStrideInfo());
+
+ ARM_COMPUTE_EXPECT(bool(status) == expected_const, framework::LogLevel::ERRORS);
+}
// *INDENT-OFF*
// clang-format off
@@ -257,7 +296,7 @@ TEST_CASE(MemoryInjection, framework::DatasetMode::ALL)
for(size_t i = 0; i < result_0.info()->tensor_shape().total_size(); ++i)
{
- ARM_COMPUTE_EXPECT(((float *)result_0.buffer())[i] == ((float *)result_1.buffer())[i], framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(reinterpret_cast<float *>(result_0.buffer())[i] == reinterpret_cast<float *>(result_1.buffer())[i], framework::LogLevel::ERRORS);
}
}
@@ -303,7 +342,7 @@ TEST_CASE(MultipleExecutionWithConfigure, framework::DatasetMode::ALL)
for(size_t i = 0; i < result_0.info()->tensor_shape().total_size(); ++i)
{
- ARM_COMPUTE_EXPECT(((float *)result_0.buffer())[i] == ((float *)result_1.buffer())[i], framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(reinterpret_cast<float *>(result_0.buffer())[i] == reinterpret_cast<float *>(result_1.buffer())[i], framework::LogLevel::ERRORS);
}
}
@@ -580,7 +619,7 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEWinogradConvolutionLayerFixture<float>, frame
/// It's enough to run the activations for a single weight/input combination and data type because
/// activation function is called on top of the winograd output as a separate operator
-/// TODO: Enable after COMPMID-6573 is resolved
+/// TODO(COMPMID-6573): Enable after COMPMID-6573 is resolved
FIXTURE_DATA_TEST_CASE(RunActivations, NEWinogradConvolutionLayerFixture<float>, framework::DatasetMode::DISABLED,
combine(
make("Input", TensorShape(3U, 3U, 32U)),
@@ -1119,7 +1158,7 @@ TEST_CASE(MemoryInjection, framework::DatasetMode::ALL)
auto result_1 = run_conv();
for(size_t i = 0; i < result_0.info()->tensor_shape().total_size(); ++i)
{
- ARM_COMPUTE_EXPECT(((float *)result_0.buffer())[i] == ((float *)result_1.buffer())[i], framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(reinterpret_cast<float *>(result_0.buffer())[i] == reinterpret_cast<float *>(result_1.buffer())[i], framework::LogLevel::ERRORS);
}
}
@@ -1160,7 +1199,7 @@ TEST_CASE(MultipleExecutionWithConfigure, framework::DatasetMode::ALL)
auto result_1 = run_conv();
for(size_t i = 0; i < result_0.info()->tensor_shape().total_size(); ++i)
{
- ARM_COMPUTE_EXPECT(((float *)result_0.buffer())[i] == ((float *)result_1.buffer())[i], framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(reinterpret_cast<float *>(result_0.buffer())[i] == reinterpret_cast<float *>(result_1.buffer())[i], framework::LogLevel::ERRORS);
}
}
@@ -1251,12 +1290,14 @@ FIXTURE_DATA_TEST_CASE(RunVeryLarge, NEGEMMConvolutionLayerFixture<float>, frame
TEST_SUITE_END() // FP32
TEST_SUITE_END() // Float
-// TODO: COMPMID-6596 Extend quantized tests with at least one suite where the weight is padded (the legacy case, see floating point's RunPaddedWeights)
+// TODO(COMPMID-6573): Extend quantized tests with at least one suite where the weight is padded (the legacy case, see floating point's RunPaddedWeights)
template <typename T>
using NEGEMMConvolutionLayerQuantizedFixture = ConvolutionValidationQuantizedFixture<Tensor, Accessor, NEConvolutionLayer, T>;
template <typename T>
using NEGEMMConvolutionLayerQuantizedMixedDataLayoutFixture = ConvolutionValidationQuantizedFixture<Tensor, Accessor, NEConvolutionLayer, T, true>;
+using NEGEMMConvolutionLayerQuantizedMixedSignFixture = ConvolutionValidationQuantizedMixedTypeFixture<Tensor, Accessor, NEConvolutionLayer, uint8_t, int8_t>;
+
template <typename T>
using NEGEMMConvolutionLayerQuantizedPerChannelFixture = ConvolutionValidationQuantizedPerChannelFixture<Tensor, Accessor, NEConvolutionLayer, T, int8_t>;
@@ -1332,6 +1373,50 @@ FIXTURE_DATA_TEST_CASE(RunMixedDataLayout, NEGEMMConvolutionLayerQuantizedFixtur
}
TEST_SUITE_END() // QASYMM8_SIGNED
+TEST_SUITE(QASYMM8_MIXED)
+FIXTURE_DATA_TEST_CASE(
+ RunSmall,
+ NEGEMMConvolutionLayerQuantizedMixedSignFixture,
+ framework::DatasetMode::ALL,
+ combine(combine(combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
+ framework::dataset::make("ReshapeWeights", {true})),
+ framework::dataset::make("DataType", DataType::QASYMM8)),
+ framework::dataset::make("WeightsDataType", DataType::QASYMM8_SIGNED)),
+ framework::dataset::make("DataLayout", {DataLayout::NCHW, DataLayout::NHWC})),
+ framework::dataset::make("QuantizationInfoIfActivationEnabled",
+{QuantizationInfo(2.f / 255.f, 10)})),
+framework::dataset::make("WeightQuantizationInfoIfActivationEnabled",
+{QuantizationInfo(2.f / 255.f, 10)})),
+QuantizedActivationFunctionsDataset))
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_qasymm8);
+}
+FIXTURE_DATA_TEST_CASE(
+ RunMixedDataLayout,
+ NEGEMMConvolutionLayerQuantizedMixedSignFixture,
+ framework::DatasetMode::ALL,
+ combine(
+ framework::dataset::make("Input", TensorShape(23U, 27U, 5U)),
+ framework::dataset::make("Weights", TensorShape(3U, 3U, 5U, 2U)),
+ framework::dataset::make("Bias", TensorShape(2U)),
+ framework::dataset::make("Output", TensorShape(11U, 25U, 2U)),
+ framework::dataset::make("PadStrideInfo", PadStrideInfo(2, 1, 0, 0)),
+ framework::dataset::make("Dilation", Size2D(1, 1)),
+ framework::dataset::make("ReshapeWeights", {true}),
+ framework::dataset::make("DataType", DataType::QASYMM8),
+ framework::dataset::make("WeightsDataType", DataType::QASYMM8_SIGNED),
+ framework::dataset::make("DataLayout", {DataLayout::NCHW, DataLayout::NHWC}),
+ framework::dataset::make("QuantizationInfoIfActivationEnabled", {QuantizationInfo(2.f / 255.f, 10)}),
+ framework::dataset::make("WeightQuantizationInfoIfActivationEnabled", {QuantizationInfo(2.f / 255.f, 10)}),
+ QuantizedActivationFunctionsDataset)
+ )
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_qasymm8);
+}
+TEST_SUITE_END() // QASYMM8_MIXED
+
TEST_SUITE(QSYMM8_PER_CHANNEL)
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMConvolutionLayerQuantizedPerChannelFixture<uint8_t>, framework::DatasetMode::ALL,
combine(combine(combine(combine(combine(combine(datasets::SmallConvolutionLayerDataset(),
@@ -1436,7 +1521,7 @@ TEST_CASE(MemoryInjection, framework::DatasetMode::ALL)
auto result_1 = run_conv();
for(size_t i = 0; i < result_0.info()->tensor_shape().total_size(); ++i)
{
- ARM_COMPUTE_EXPECT(((float *)result_0.buffer())[i] == ((float *)result_1.buffer())[i], framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(reinterpret_cast<float *>(result_0.buffer())[i] == reinterpret_cast<float *>(result_1.buffer())[i], framework::LogLevel::ERRORS);
}
}
@@ -1476,7 +1561,7 @@ TEST_CASE(MultipleExecutionWithConfigure, framework::DatasetMode::ALL)
auto result_1 = run_conv();
for(size_t i = 0; i < result_0.info()->tensor_shape().total_size(); ++i)
{
- ARM_COMPUTE_EXPECT(((float *)result_0.buffer())[i] == ((float *)result_1.buffer())[i], framework::LogLevel::ERRORS);
+ ARM_COMPUTE_EXPECT(reinterpret_cast<float *>(result_0.buffer())[i] == reinterpret_cast<float *>(result_1.buffer())[i], framework::LogLevel::ERRORS);
}
}
diff --git a/tests/validation/NEON/GEMMLowp.cpp b/tests/validation/NEON/GEMMLowp.cpp
index d25f43a330..61202ee2b7 100644
--- a/tests/validation/NEON/GEMMLowp.cpp
+++ b/tests/validation/NEON/GEMMLowp.cpp
@@ -141,20 +141,23 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(
TensorInfo(TensorShape(20U, 13U), 1, DataType::QASYMM8, QuantizationInfo(1.f/255, 10)), // Invalid dimensions
TensorInfo(TensorShape(21U, 13U), 1, DataType::QASYMM8, QuantizationInfo(1.f/255, 10)), // Invalid dimensions
TensorInfo(TensorShape(16U, 32U), 1, DataType::QASYMM8, QuantizationInfo(1.f/255, 10)),
+ TensorInfo(TensorShape(16U, 32U), 1, DataType::QASYMM8_SIGNED, QuantizationInfo(1.f/255, 10)), // Invalid types
}),
make("InputBInfo",{ TensorInfo(TensorShape(33U, 21U), 1, DataType::QASYMM8, QuantizationInfo(1.f/256, 10)),
TensorInfo(TensorShape(33U, 21U), 1, DataType::QASYMM8, QuantizationInfo(1.f/256, 10)),
TensorInfo(TensorShape(33U, 21U), 1, DataType::QASYMM8, QuantizationInfo(1.f/256, 10)),
TensorInfo(TensorShape(33U, 21U), 1, DataType::QASYMM8, QuantizationInfo(1.f/256, 10)),
TensorInfo(TensorShape(64U, 16U), 1, DataType::QASYMM8, QuantizationInfo(1.f/256, 10)),
+ TensorInfo(TensorShape(64U, 16U), 1, DataType::QASYMM8, QuantizationInfo(1.f/256, 10)),
}),
make("OutputInfo",{ TensorInfo(TensorShape(33U, 13U), 1, DataType::S32),
TensorInfo(TensorShape(33U, 13U), 1, DataType::S32),
TensorInfo(TensorShape(33U, 13U), 1, DataType::S32),
TensorInfo(TensorShape(8U, 11U), 1, DataType::S32),
TensorInfo(TensorShape(64U, 32U), 1, DataType::S32),
+ TensorInfo(TensorShape(64U, 32U), 1, DataType::S32),
}),
- make("Expected", { true, false, false, false, true })),
+ make("Expected", { true, false, false, false, true, false })),
a_info, b_info, output_info, expected)
{
// Lock tensors
@@ -359,10 +362,39 @@ TEST_SUITE_END() // DynamicQuantization
#ifdef __aarch64__
// Deqaunt tests involve returning F32 from the MatrixMultiplyCore kernels and is only implemented in aarch64
TEST_SUITE(Dequant)
+DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(
+ make("InputAInfo", {
+ TensorInfo(TensorShape(16U, 32U), 1, DataType::QASYMM8, QuantizationInfo(1.f/255, 10)),
+ TensorInfo(TensorShape(16U, 32U), 1, DataType::QASYMM8_SIGNED, QuantizationInfo(1.f/255, 10)),
+ TensorInfo(TensorShape(16U, 32U), 1, DataType::QASYMM8_SIGNED, QuantizationInfo(1.f/255, 10)), // Invalid types
+ }),
+ make("InputBInfo",{
+ TensorInfo(TensorShape(64U, 16U), 1, DataType::QASYMM8_SIGNED, QuantizationInfo(1.f/256, 10)),
+ TensorInfo(TensorShape(64U, 16U), 1, DataType::QASYMM8_SIGNED, QuantizationInfo(1.f/256, 10)),
+ TensorInfo(TensorShape(64U, 16U), 1, DataType::QASYMM8, QuantizationInfo(1.f/256, 10)),
+ }),
+ make("OutputInfo",{
+ TensorInfo(TensorShape(64U, 32U), 1, DataType::F32),
+ TensorInfo(TensorShape(64U, 32U), 1, DataType::F32),
+ TensorInfo(TensorShape(64U, 32U), 1, DataType::F32),
+ }),
+ make("Expected", { true, true, false })),
+ a_info, b_info, output_info, expected)
+{
+ // Lock tensors
+ Status status = NEGEMMLowpMatrixMultiplyCore::validate(&a_info.clone()->set_is_resizable(false),
+ &b_info.clone()->set_is_resizable(false),
+ nullptr,
+ &output_info.clone()->set_is_resizable(false));
+ ARM_COMPUTE_EXPECT(bool(status) == expected, framework::LogLevel::ERRORS);
+}
+
constexpr AbsoluteTolerance<float> tolerance_dequantized(0.01f);
FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpDequantizedMatrixMultiplyValidationFixture, framework::DatasetMode::ALL,
combine(
datasets::SmallGEMMLowpDataset(),
+ make("DataTypeA", {DataType::QASYMM8_SIGNED, DataType::QASYMM8}),
+ make("DataTypeB", DataType::QASYMM8_SIGNED),
make("accumulate", {true, false})
))
{
@@ -373,6 +405,8 @@ FIXTURE_DATA_TEST_CASE(RunSmall, NEGEMMLowpDequantizedMatrixMultiplyValidationFi
FIXTURE_DATA_TEST_CASE(RunLarge, NEGEMMLowpDequantizedMatrixMultiplyValidationFixture, framework::DatasetMode::NIGHTLY,
combine(
datasets::LargeGEMMLowpDataset(),
+ make("DataTypeA", {DataType::QASYMM8_SIGNED, DataType::QASYMM8}),
+ make("DataTypeB", DataType::QASYMM8_SIGNED),
make("accumulate", {false})
))
{
diff --git a/tests/validation/fixtures/ConvolutionLayerFixture.h b/tests/validation/fixtures/ConvolutionLayerFixture.h
index 0622e5e6f0..939ac032cd 100644
--- a/tests/validation/fixtures/ConvolutionLayerFixture.h
+++ b/tests/validation/fixtures/ConvolutionLayerFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2023 Arm Limited.
+ * Copyright (c) 2017-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -480,6 +480,31 @@ public:
};
template <typename TensorType, typename AccessorType, typename FunctionType, typename T, typename TW>
+class ConvolutionValidationQuantizedMixedTypeFixture
+ : public ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, TW>
+{
+public:
+ void setup(TensorShape input_shape,
+ TensorShape weights_shape,
+ TensorShape bias_shape,
+ TensorShape output_shape,
+ PadStrideInfo info,
+ Size2D dilation,
+ bool reshape_weights,
+ DataType data_type,
+ DataType weights_data_type,
+ DataLayout data_layout,
+ QuantizationInfo quantization_info,
+ QuantizationInfo weight_quantization_info,
+ ActivationLayerInfo act_info)
+ {
+ ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, TW>::setup(
+ input_shape, weights_shape, bias_shape, output_shape, info, dilation, reshape_weights, data_type,
+ weights_data_type, data_layout, quantization_info, weight_quantization_info, act_info);
+ }
+};
+
+template <typename TensorType, typename AccessorType, typename FunctionType, typename T, typename TW>
class ConvolutionValidationQuantizedPerChannelFixture : public ConvolutionValidationGenericFixture<TensorType, AccessorType, FunctionType, T, TW>
{
public:
diff --git a/tests/validation/fixtures/DirectConvolution3DFixture.h b/tests/validation/fixtures/DirectConvolution3DFixture.h
index e80ad2f54f..e27a41a23b 100644
--- a/tests/validation/fixtures/DirectConvolution3DFixture.h
+++ b/tests/validation/fixtures/DirectConvolution3DFixture.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021, 2023 Arm Limited.
+ * Copyright (c) 2021, 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -46,6 +46,7 @@ class DirectConvolution3DValidationGenericFixture : public framework::Fixture
{
public:
using TBias = typename std::conditional < std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value, int32_t, T >::type;
+ using TAcc = typename std::conditional < std::is_integral<T>::value, int32_t, float >::type;
void setup(const TensorShape &input_shape, int stride_x, int stride_y, int stride_z, int pad_x, int pad_y, int pad_z, unsigned int kernel_width, int kernel_height, int kernel_depth,
unsigned int num_kernels, bool has_bias, const ActivationLayerInfo &act_info, const DataType &data_type, const DataLayout &data_layout,
@@ -150,7 +151,7 @@ protected:
fill(bias, 2);
}
- return reference::activation_layer(reference::conv3d<T, TBias>(src, weights, bias, dst, conv3d_info), conv3d_info.act_info);
+ return reference::activation_layer(reference::conv3d<T, TBias, TAcc>(src, weights, bias, dst, conv3d_info), conv3d_info.act_info);
}
TensorType _target{};
diff --git a/tests/validation/fixtures/GEMMLowpFixture.h b/tests/validation/fixtures/GEMMLowpFixture.h
index aa4eedb75d..7931d8467d 100644
--- a/tests/validation/fixtures/GEMMLowpFixture.h
+++ b/tests/validation/fixtures/GEMMLowpFixture.h
@@ -97,8 +97,7 @@ TensorType compute_gemmlowp_target(const TensorShape &shape_a, const TensorShape
bool accumulate = false, bool dynamic_qinfo = false, DataType data_type_output = DataType::UNKNOWN)
{
ARM_COMPUTE_ASSERT(is_data_type_quantized_asymmetric(data_type_a));
- ARM_COMPUTE_ASSERT(data_type_a == data_type_b);
- // If unknown, set to sensible defaults
+ // If unknown, set to sensible defaults
if (data_type_output == DataType::UNKNOWN) {
data_type_output = output_stage.type == GEMMLowpOutputStageType::NONE ? DataType::S32 : data_type_a;
}
@@ -185,7 +184,6 @@ SimpleTensor<int32_t> compute_gemmlowp_reference(const TensorShape &shape_a, con
DataType data_type_a = DataType::QASYMM8, DataType data_type_b = DataType::QASYMM8, const TensorFillInfo& finfo = TensorFillInfo())
{
ARM_COMPUTE_ASSERT(is_data_type_quantized_asymmetric(data_type_a));
- ARM_COMPUTE_ASSERT(data_type_a == data_type_b);
TensorShape shape_a_to_use = shape_a;
if(reinterpret_input_as_3d)
{
@@ -472,29 +470,59 @@ template <typename TensorType, typename AccessorType, typename FunctionType, boo
class GEMMLowpDequantizedMatrixMultiplyValidationFixture : public framework::Fixture
{
public:
- void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset, bool accumulate)
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape shape_output, int32_t a_offset, int32_t b_offset, DataType data_type_a, DataType data_type_b, bool accumulate)
{
const bool dynamic_qinfo = false;
const auto a_qinfo = QuantizationInfo(1.0f / 255, a_offset);
const auto b_qinfo = QuantizationInfo(5.0f / 255, b_offset);
TensorFillInfo finfo;
- _target = compute_target(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, finfo, accumulate, dynamic_qinfo);
- _reference = compute_reference(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, finfo, accumulate, dynamic_qinfo);
+ _target = compute_target(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, data_type_a, data_type_b, finfo,
+ accumulate, dynamic_qinfo);
+ _reference = compute_reference(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, data_type_a, data_type_b,
+ finfo, accumulate, dynamic_qinfo);
}
protected:
- TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo, const TensorFillInfo& finfo, const bool accumulate, const bool dynamic_qinfo)
+ TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo, DataType data_type_a, DataType data_type_b, const TensorFillInfo& finfo, const bool accumulate, const bool dynamic_qinfo)
{
const auto output_qinfo = QuantizationInfo();
- return compute_gemmlowp_target<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, int32_t, false, run_twice>(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, output_qinfo, DataType::QASYMM8_SIGNED, DataType::QASYMM8_SIGNED, GEMMLowpOutputStageInfo(), false, finfo, accumulate, dynamic_qinfo, DataType::F32);
+ return compute_gemmlowp_target<TensorType, AccessorType, FunctionType, reinterpret_input_as_3d, reinterpret_output_as_3d, int32_t, false, run_twice>(shape_a, shape_b, shape_output, a_qinfo, b_qinfo, output_qinfo, data_type_a, data_type_b, GEMMLowpOutputStageInfo(), false, finfo, accumulate, dynamic_qinfo, DataType::F32);
}
- SimpleTensor<float> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo, const TensorFillInfo& finfo, bool accumulate, const bool dynamic_qinfo)
+ SimpleTensor<float> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &shape_output, const QuantizationInfo& a_qinfo, const QuantizationInfo& b_qinfo, DataType data_type_a, DataType data_type_b, const TensorFillInfo& finfo, bool accumulate, const bool dynamic_qinfo)
{
QuantizationInfo s32_ref_output_quant_info = QuantizationInfo(a_qinfo.uniform().scale * b_qinfo.uniform().scale, 0, dynamic_qinfo);
- SimpleTensor<int32_t> s32_ref_output = compute_gemmlowp_reference<reinterpret_input_as_3d, int8_t, int8_t, false, false, run_twice>(shape_a, shape_b, shape_output, a_qinfo, b_qinfo,
- DataType::QASYMM8_SIGNED, DataType::QASYMM8_SIGNED, finfo);
+ SimpleTensor<int32_t> s32_ref_output;
+ if (data_type_a == DataType::QASYMM8)
+ {
+ if (data_type_b == DataType::QASYMM8)
+ {
+ s32_ref_output = compute_gemmlowp_reference<reinterpret_input_as_3d, uint8_t, uint8_t, false, false, run_twice>(
+ shape_a, shape_b, shape_output, a_qinfo, b_qinfo, data_type_a, data_type_b, finfo);
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR_ON(data_type_b != DataType::QASYMM8_SIGNED);
+ s32_ref_output = compute_gemmlowp_reference<reinterpret_input_as_3d, uint8_t, int8_t, false, false, run_twice>(
+ shape_a, shape_b, shape_output, a_qinfo, b_qinfo, data_type_a, data_type_b, finfo);
+ }
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR_ON(data_type_a != DataType::QASYMM8_SIGNED);
+ if (data_type_b == DataType::QASYMM8)
+ {
+ ARM_COMPUTE_ERROR("QASYMM8_SIGNED input with QASYMM8 weights not supported");
+ }
+ else
+ {
+ ARM_COMPUTE_ERROR_ON(data_type_b != DataType::QASYMM8_SIGNED);
+ s32_ref_output = compute_gemmlowp_reference<reinterpret_input_as_3d, int8_t, int8_t, false, false, run_twice>(
+ shape_a, shape_b, shape_output, a_qinfo, b_qinfo, data_type_a, data_type_b, finfo);
+ }
+ }
+
s32_ref_output.quantization_info(s32_ref_output_quant_info);
SimpleTensor<float> f32_ref_output(s32_ref_output.shape(), DataType::F32);
diff --git a/tests/validation/reference/Conv3D.cpp b/tests/validation/reference/Conv3D.cpp
index e4010a507a..38472a9aec 100644
--- a/tests/validation/reference/Conv3D.cpp
+++ b/tests/validation/reference/Conv3D.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021, 2023 Arm Limited.
+ * Copyright (c) 2021, 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -58,7 +58,7 @@ inline bool is_valid_pixel(int i, int min, int max)
}
// Evaluate the weights against an element in a given tensor.
-template < typename T, typename TB, typename std::enable_if < validation::is_floating_point<T>::value &&validation::is_floating_point<TB>::value, int >::type = 0 >
+template < typename T, typename TB, typename TACC, typename std::enable_if < validation::is_floating_point<T>::value &&validation::is_floating_point<TB>::value, int >::type = 0 >
T calculate_conv3d(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<TB> &bias, const Size3D &dilation, int batch,
int z_start, int y_start, int x_start, int ch_out, UniformQuantizationInfo oq_info)
{
@@ -73,7 +73,7 @@ T calculate_conv3d(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, c
const unsigned int src_height = src.shape()[height_dim];
const unsigned int src_depth = src.shape()[depth_dim];
- T total(0);
+ TACC total(0);
for(unsigned int weight_d = 0; weight_d < weights_depth; ++weight_d)
{
const int idx_z = z_start + dilation.depth * weight_d;
@@ -112,10 +112,10 @@ T calculate_conv3d(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, c
const TB *b_ptr = bias.data();
TB bias_value = b_ptr[ch_out];
- return total + bias_value;
+ return static_cast<T>(total) + bias_value;
}
-template < typename T, typename TB, ARM_COMPUTE_REQUIRES_TA(std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value) >
+template < typename T, typename TB, typename TACC, ARM_COMPUTE_REQUIRES_TA(std::is_same<T, uint8_t>::value || std::is_same<T, int8_t>::value) >
T calculate_conv3d(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<TB> &bias, const Size3D &dilation, int batch,
int z_start, int y_start, int x_start, int ch_out, UniformQuantizationInfo oq_info)
{
@@ -143,7 +143,7 @@ T calculate_conv3d(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, c
const float multiplier = input_scale * weights_scale / output_scale;
arm_compute::quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift);
- int32_t total(0);
+ TACC total(0);
for(unsigned int weight_d = 0; weight_d < weights_depth; ++weight_d)
{
const int idx_z = z_start + dilation.depth * weight_d;
@@ -189,7 +189,7 @@ T calculate_conv3d(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, c
}
} // namespace
-template <typename T, typename TB>
+template <typename T, typename TB, typename TACC = T>
SimpleTensor<T> conv3d(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<TB> &bias, SimpleTensor<T> &dst, const Conv3dInfo &conv3d_info)
{
// Compute reference
@@ -237,7 +237,7 @@ SimpleTensor<T> conv3d(const SimpleTensor<T> &src, const SimpleTensor<T> &weight
T *out_ptr = dst.data();
const int out_offset = coord2index(dst.shape(), Coordinates{ ch_out, x_out, y_out, z_out, batch });
- out_ptr[out_offset] = calculate_conv3d<T, TB>(src, weights, bias, conv3d_info.dilation, batch, z_start, y_start, x_start, ch_out, dst.quantization_info().uniform());
+ out_ptr[out_offset] = calculate_conv3d<T, TB, TACC>(src, weights, bias, conv3d_info.dilation, batch, z_start, y_start, x_start, ch_out, dst.quantization_info().uniform());
}
}
}
@@ -246,13 +246,13 @@ SimpleTensor<T> conv3d(const SimpleTensor<T> &src, const SimpleTensor<T> &weight
return dst;
}
-template SimpleTensor<float> conv3d(const SimpleTensor<float> &src, const SimpleTensor<float> &weights, const SimpleTensor<float> &bias, SimpleTensor<float> &dst,
+template SimpleTensor<float> conv3d<float, float, float>(const SimpleTensor<float> &src, const SimpleTensor<float> &weights, const SimpleTensor<float> &bias, SimpleTensor<float> &dst,
const Conv3dInfo &conv3d_info);
-template SimpleTensor<half> conv3d(const SimpleTensor<half> &src, const SimpleTensor<half> &weights, const SimpleTensor<half> &bias, SimpleTensor<half> &dst,
+template SimpleTensor<half> conv3d<half, half, float>(const SimpleTensor<half> &src, const SimpleTensor<half> &weights, const SimpleTensor<half> &bias, SimpleTensor<half> &dst,
const Conv3dInfo &conv3d_info);
-template SimpleTensor<uint8_t> conv3d(const SimpleTensor<uint8_t> &src, const SimpleTensor<uint8_t> &weights, const SimpleTensor<int32_t> &bias, SimpleTensor<uint8_t> &dst,
+template SimpleTensor<uint8_t> conv3d<uint8_t, int32_t, int32_t>(const SimpleTensor<uint8_t> &src, const SimpleTensor<uint8_t> &weights, const SimpleTensor<int32_t> &bias, SimpleTensor<uint8_t> &dst,
const Conv3dInfo &conv3d_info);
-template SimpleTensor<int8_t> conv3d(const SimpleTensor<int8_t> &src, const SimpleTensor<int8_t> &weights, const SimpleTensor<int32_t> &bias, SimpleTensor<int8_t> &dst,
+template SimpleTensor<int8_t> conv3d<int8_t, int32_t, int32_t>(const SimpleTensor<int8_t> &src, const SimpleTensor<int8_t> &weights, const SimpleTensor<int32_t> &bias, SimpleTensor<int8_t> &dst,
const Conv3dInfo &conv3d_info);
} // namespace reference
} // namespace validation
diff --git a/tests/validation/reference/Conv3D.h b/tests/validation/reference/Conv3D.h
index e3674f4bfb..a440b15d55 100644
--- a/tests/validation/reference/Conv3D.h
+++ b/tests/validation/reference/Conv3D.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021 Arm Limited.
+ * Copyright (c) 2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_TEST_CONV3D_LAYER_H
-#define ARM_COMPUTE_TEST_CONV3D_LAYER_H
+#ifndef ACL_TESTS_VALIDATION_REFERENCE_CONV3D_H
+#define ACL_TESTS_VALIDATION_REFERENCE_CONV3D_H
#include "Utils.h"
#include "arm_compute/runtime/FunctionDescriptors.h"
@@ -37,11 +37,11 @@ namespace validation
{
namespace reference
{
-template <typename T, typename TB>
+template <typename T, typename TB, typename TACC>
SimpleTensor<T> conv3d(const SimpleTensor<T> &src, const SimpleTensor<T> &weights, const SimpleTensor<TB> &bias, SimpleTensor<T> &dst,
const Conv3dInfo &conv3d_info);
} // namespace reference
} // namespace validation
} // namespace test
} // namespace arm_compute
-#endif /* ARM_COMPUTE_TEST_CONV3D_LAYER_H */
+#endif // ACL_TESTS_VALIDATION_REFERENCE_CONV3D_H
diff --git a/tests/validation/runtime/experimental/operators/CpuGemm.cpp b/tests/validation/runtime/experimental/operators/CpuGemm.cpp
new file mode 100644
index 0000000000..c6df429a4d
--- /dev/null
+++ b/tests/validation/runtime/experimental/operators/CpuGemm.cpp
@@ -0,0 +1,143 @@
+/*
+ * Copyright (c) 2024 Arm Limited.
+ *
+ * SPDX-License-Identifier: MIT
+ *
+ * Permission is hereby granted, free of charge, to any person obtaining a copy
+ * of this software and associated documentation files (the "Software"), to
+ * deal in the Software without restriction, including without limitation the
+ * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
+ * sell copies of the Software, and to permit persons to whom the Software is
+ * furnished to do so, subject to the following conditions:
+ *
+ * The above copyright notice and this permission notice shall be included in all
+ * copies or substantial portions of the Software.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+ * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+ * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+ * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+ * SOFTWARE.
+ */
+#include "arm_compute/runtime/experimental/operators/CpuGemm.h"
+#include "src/core/helpers/MemoryHelpers.h"
+#include "tests/NEON/Accessor.h"
+#include "tests/framework/Macros.h"
+#include "tests/framework/datasets/Datasets.h"
+#include "tests/validation/fixtures/GEMMFixture.h"
+
+/*
+ * Tests for arm_compute::experimental::ops::CpuGemm which is a shallow wrapper for
+ * arm_compute::cpu::CpuGemm. Any future testing to the functionalities of cpu::CpuGemm will
+ * be tested in tests/NEON/GEMM.cpp given that ops::CpuGemm remain a shallow wrapper.
+*/
+
+namespace arm_compute
+{
+namespace test
+{
+namespace validation
+{
+using framework::dataset::make;
+
+namespace
+{
+/** CNN data types */
+const auto CNNDataTypes = make("DataType",
+{
+ DataType::F32,
+});
+} // namespace
+
+TEST_SUITE(NEON)
+TEST_SUITE(OPERATORS)
+
+TEST_SUITE(CPUGEMM)
+/** Test case for memory injection in @ref arm_compute::experimental::ops::CpuGemm.
+ *
+ * Configure the operator once and inject memory at run-time in multiple executions.
+ *
+ * Checks performed in order:
+ * - Both runs compute the same output
+ */
+TEST_CASE(OpsCpuGemmMemoryInjection, framework::DatasetMode::ALL)
+{
+ auto gemm = std::make_unique<arm_compute::experimental::ops::CpuGemm>();
+ const auto lhs_info = TensorInfo(TensorShape(3U, 3U), 1, DataType::F32);
+ const auto rhs_info = TensorInfo(TensorShape(4U, 3U), 1, DataType::F32);
+ const auto c_info = TensorInfo(TensorShape(4U, 3U), 1, DataType::F32);
+ auto dst_info = TensorInfo(TensorShape(4U, 3U), 1, DataType::F32);
+ const auto gemm_info = GEMMInfo{};
+ gemm->configure(&lhs_info, &rhs_info, &c_info, &dst_info, 1.f, 1.f, gemm_info);
+
+ // telhs are newly created every call of this lambda function
+ auto lhs = create_tensor<Tensor>(lhs_info);
+ auto rhs = create_tensor<Tensor>(rhs_info);
+ auto c = create_tensor<Tensor>(c_info);
+ lhs.allocator()->allocate();
+ rhs.allocator()->allocate();
+ c.allocator()->allocate();
+
+ ITensorPack run_pack{ { TensorType::ACL_SRC_0, &lhs }, { TensorType::ACL_SRC_1, &rhs }, { TensorType::ACL_SRC_2, &c } };
+ ITensorPack prep_pack{ { TensorType::ACL_SRC_1, &rhs }, { TensorType::ACL_SRC_2, &c } };
+
+ auto mg = MemoryGroup{};
+ auto ws = manage_workspace<Tensor>(gemm->workspace(), mg, run_pack, prep_pack);
+
+ auto run_conv = [&]() -> Tensor
+ {
+ auto dst = create_tensor<Tensor>(dst_info);
+ dst.allocator()->allocate();
+ run_pack.add_tensor(TensorType::ACL_DST, &dst);
+
+ library->fill_tensor_value(Accessor(lhs), 1.f);
+ library->fill_tensor_value(Accessor(rhs), 2.f);
+ library->fill_tensor_value(Accessor(c), 3.f);
+ // This operator is configured once and captured by this lambda.
+ gemm->prepare(prep_pack);
+ gemm->run(run_pack);
+ return dst;
+ };
+ auto result_0 = run_conv();
+ auto result_1 = run_conv();
+ for(size_t i = 0; i < result_0.info()->tensor_shape().total_size(); ++i)
+ {
+ ARM_COMPUTE_EXPECT(((float *)result_0.buffer())[i] == ((float *)result_1.buffer())[i], framework::LogLevel::ERRORS);
+ }
+}
+
+DATA_TEST_CASE(OpsCpuGemmValidateAccumulate, framework::DatasetMode::ALL, combine(
+ zip(make("In0",{ TensorShape(21U, 13U) }),
+ make("In1", { TensorShape(33U, 21U) }),
+ make("Dst", { TensorShape(33U, 13U) })),
+ zip(
+ make("alpha", { 1.0, 100.0, 1.0, 1.0 }),
+ make("beta", { 0.0, 0.0, 1.0, 1.0 }),
+ make("is_c_null", { false, false, false, true }),
+ make("Expected", { true, false, false, true }))),
+ shape_a, shape_b, shape_dst, alpha, beta, is_c_null, expected)
+{
+ /* Accumulation test for GEMM kernels */
+ // Create tensors
+ TensorInfo in_a(shape_a, 1, DataType::F32);
+ TensorInfo in_b(shape_b, 1, DataType::F32);
+ TensorInfo in_c(shape_dst, 1, DataType::F32);
+ TensorInfo dst(shape_dst, 1, DataType::F32);
+
+ GEMMInfo gemm_info = GEMMInfo();
+ gemm_info.set_accumulate(true);
+
+ // Validate accumulation
+ arm_compute::experimental::ops::CpuGemm gemm;
+ Status status = gemm.validate(&in_a, &in_b, (is_c_null ? nullptr : &in_c), &dst, alpha, beta, gemm_info);
+ ARM_COMPUTE_EXPECT((expected == bool(status)), framework::LogLevel::ERRORS);
+}
+
+TEST_SUITE_END() // CPUGEMM
+TEST_SUITE_END() // OPERATORS
+TEST_SUITE_END() // NEON
+} // namespace validation
+} // namespace test
+} // namespace arm_compute