aboutsummaryrefslogtreecommitdiff
path: root/tests
diff options
context:
space:
mode:
authorViet-Hoa Do <viet-hoa.do@arm.com>2023-04-11 17:16:27 +0100
committerViet-Hoa Do <viet-hoa.do@arm.com>2023-04-19 08:40:45 +0000
commit9c7c2d2d23693877867bb3284c577b33cfbff471 (patch)
treef470a88b23498c1b5d13c5f9578caaf9d0599b74 /tests
parent9d0c4deb760efc2ca07e5e0b8218995201ad8a1f (diff)
downloadComputeLibrary-9c7c2d2d23693877867bb3284c577b33cfbff471.tar.gz
Add quantized support for CPU MatMul
Resolves: COMPMID-5899 Signed-off-by: Viet-Hoa Do <viet-hoa.do@arm.com> Change-Id: I89d96e292c3492ba9b1900a3e5683f9dcd11dfc6 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/9440 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Gunes Bayir <gunes.bayir@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com> Benchmark: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'tests')
-rw-r--r--tests/SimpleTensor.h18
-rw-r--r--tests/datasets/SmallMatMulDataset.h11
-rw-r--r--tests/validation/NEON/MatMul.cpp149
-rw-r--r--tests/validation/fixtures/MatMulFixture.h102
4 files changed, 250 insertions, 30 deletions
diff --git a/tests/SimpleTensor.h b/tests/SimpleTensor.h
index c1bd7f87b5..9ea171d492 100644
--- a/tests/SimpleTensor.h
+++ b/tests/SimpleTensor.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2020 Arm Limited.
+ * Copyright (c) 2017-2020, 2023 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -173,6 +173,15 @@ public:
*/
QuantizationInfo quantization_info() const override;
+ /** Set the quantization information of the tensor.
+ *
+ * This function does not have any effect on the raw quantized data of the tensor.
+ * It simply changes the quantization information, hence changes the dequantized values.
+ *
+ * @return A reference to the current object.
+ */
+ SimpleTensor<T> &quantization_info(const QuantizationInfo &qinfo);
+
/** Constant pointer to the underlying buffer.
*
* @return a constant pointer to the data.
@@ -335,6 +344,13 @@ QuantizationInfo SimpleTensor<T>::quantization_info() const
}
template <typename T>
+SimpleTensor<T> &SimpleTensor<T>::quantization_info(const QuantizationInfo &qinfo)
+{
+ _quantization_info = qinfo;
+ return *this;
+}
+
+template <typename T>
size_t SimpleTensor<T>::size() const
{
const size_t size = std::accumulate(_shape.cbegin(), _shape.cend(), 1, std::multiplies<size_t>());
diff --git a/tests/datasets/SmallMatMulDataset.h b/tests/datasets/SmallMatMulDataset.h
index 52ef01da7b..bb4cdad54b 100644
--- a/tests/datasets/SmallMatMulDataset.h
+++ b/tests/datasets/SmallMatMulDataset.h
@@ -47,6 +47,17 @@ public:
}
};
+class SmallerMatMulDataset final : public MatMulDataset
+{
+public:
+ SmallerMatMulDataset()
+ {
+ add_config(TensorShape(9U, 6U), TensorShape(5U, 9U), TensorShape(5U, 6U));
+ add_config(TensorShape(8U, 4U, 2U), TensorShape(16U, 8U, 2U), TensorShape(16U, 4U, 2U));
+ add_config(TensorShape(32U, 2U), TensorShape(17U, 32U), TensorShape(17U, 2U));
+ }
+};
+
class TinyMatMulDataset final : public MatMulDataset
{
public:
diff --git a/tests/validation/NEON/MatMul.cpp b/tests/validation/NEON/MatMul.cpp
index 3bfbc16e71..1a23697092 100644
--- a/tests/validation/NEON/MatMul.cpp
+++ b/tests/validation/NEON/MatMul.cpp
@@ -43,8 +43,10 @@ namespace validation
TEST_SUITE(NEON)
TEST_SUITE(MatMul)
-constexpr AbsoluteTolerance<float> tolerance_fp32(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for FP32 data types */
-const AbsoluteTolerance<half> tolerance_fp16(half(0.1f));
+constexpr AbsoluteTolerance<float> tolerance_fp32(0.001f); /**< Tolerance value for comparing reference's output against implementation's output for FP32 data types */
+const AbsoluteTolerance<half> tolerance_fp16(half(0.1f));
+constexpr AbsoluteTolerance<uint8_t> tolerance_qasymm8(0);
+constexpr AbsoluteTolerance<uint8_t> tolerance_qasymm8_signed(0);
// clang-format off
// *INDENT-OFF*
@@ -57,6 +59,9 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(
TensorInfo(TensorShape(9U, 6U), 1, DataType::F32),
TensorInfo(TensorShape(9U, 6U , 12U) , 1 , DataType::F32),
TensorInfo(TensorShape(9U, 6U , 12U) , 1 , DataType::F32), // Tensors are not dynamic
+ TensorInfo(TensorShape(9U, 6U), 1, DataType::QASYMM8),
+ TensorInfo(TensorShape(9U, 6U), 1, DataType::QASYMM8_SIGNED),
+ TensorInfo(TensorShape(9U, 6U), 1, DataType::QASYMM8_SIGNED), // Mismatching data type
}),
framework::dataset::make("InputBInfo",{ TensorInfo(TensorShape(5U, 9U), 1, DataType::QASYMM8),
TensorInfo(TensorShape(5U, 9U), 1, DataType::S32),
@@ -65,6 +70,9 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(
TensorInfo(TensorShape(5U, 9U), 1, DataType::F32),
TensorInfo(TensorShape(5U, 9U, 12U), 1, DataType::F32),
TensorInfo(TensorShape(5U, 9U, 12U), 1, DataType::F32),
+ TensorInfo(TensorShape(5U, 9U), 1, DataType::QASYMM8),
+ TensorInfo(TensorShape(5U, 9U), 1, DataType::QASYMM8_SIGNED),
+ TensorInfo(TensorShape(5U, 9U), 1, DataType::QASYMM8_SIGNED),
})),
framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(5U, 6U), 1, DataType::F32),
TensorInfo(TensorShape(5U, 6U), 1, DataType::S32),
@@ -73,9 +81,12 @@ DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(zip(
TensorInfo(TensorShape(5U, 6U), 1, DataType::F32),
TensorInfo(TensorShape(5U, 6U, 12U) , 1, DataType::F32),
TensorInfo(TensorShape(5U, 6U, 12U) , 1, DataType::F32),
+ TensorInfo(TensorShape(5U, 6U), 1, DataType::QASYMM8),
+ TensorInfo(TensorShape(5U, 6U), 1, DataType::QASYMM8_SIGNED),
+ TensorInfo(TensorShape(5U, 6U), 1, DataType::QASYMM8),
})),
- framework::dataset::make( "TensorIsConst", {false, false, false, false, false , false, true} )),
- framework::dataset::make("Expected", { false, false, false, false, true, true, false })),
+ framework::dataset::make( "TensorIsConst", {false, false, false, false, false , false, true, false, false, false} )),
+ framework::dataset::make("Expected", { false, false, false, false, true, true, false, true, true, false })),
a_info, b_info, output_info, are_tensors_const, expected)
{
TensorInfo a{a_info};
@@ -103,6 +114,9 @@ using NEMatMulFastMathFixture = MatMulGenericValidationFixture<Tensor, Accessor,
template <typename T>
using NEMatMulDynamicTensorsFixture = MatMulValidationWithDynamicTensorsFixture<Tensor, Accessor, NEMatMul, CpuMatMulSettings, T>;
+template <typename T>
+using NEQuantizedMatMulFixture = QuantizedMatMulValidationFixture<Tensor, Accessor, NEMatMul, CpuMatMulSettings, T>;
+
TEST_SUITE(Float)
TEST_SUITE(FP32)
FIXTURE_DATA_TEST_CASE(RunSmall, NEMatMulFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(datasets::SmallMatMulDataset(),
@@ -149,13 +163,18 @@ TEST_SUITE_END() // FP32
/* Note : MatMul BF16 is enabled by specifying FP32 datatype and enabling the fast math setting */
constexpr AbsoluteTolerance<float> tolerance_bf16(0.001f);
TEST_SUITE(BF16)
-FIXTURE_DATA_TEST_CASE(RunSmall, NEMatMulFastMathFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(datasets::SmallMatMulDataset(),
- framework::dataset::make("TransposeA", { false, true })),
- framework::dataset::make("TransposeB", { false, true })),
- framework::dataset::make("DataType", DataType::F32)),
- framework::dataset::make("ActivationInfo", { ActivationLayerInfo() })),
- framework::dataset::make("RunTimes", { 0 })),
- framework::dataset::make("Settings", { CpuMatMulSettings().fast_math(true) })))
+FIXTURE_DATA_TEST_CASE(RunSmall, NEMatMulFastMathFixture<float>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(combine(combine(combine(
+ datasets::SmallMatMulDataset(),
+ framework::dataset::make("TransposeA", { false, true })),
+ framework::dataset::make("TransposeB", { false, true })),
+ framework::dataset::make("DataType", DataType::F32)),
+ framework::dataset::make("ActivationInfo", { ActivationLayerInfo() })),
+ framework::dataset::make("RunTimes", { 0 })),
+ framework::dataset::make("Settings", { CpuMatMulSettings().fast_math(true) })),
+ framework::dataset::make("LhsQInfo", { QuantizationInfo() })),
+ framework::dataset::make("RhsQInfo", { QuantizationInfo() })),
+ framework::dataset::make("OutQInfo", { QuantizationInfo() }))
+)
{
// Validate output
validate(Accessor(_target), _reference, tolerance_bf16);
@@ -198,6 +217,114 @@ TEST_SUITE_END() // FP16
TEST_SUITE_END() // Float
+TEST_SUITE(Quantized)
+
+TEST_SUITE(QASYMM8)
+
+FIXTURE_DATA_TEST_CASE(RunSmall, NEQuantizedMatMulFixture<uint8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(combine(combine(
+ datasets::SmallMatMulDataset(),
+ framework::dataset::make("TransposeA", { false, true })),
+ framework::dataset::make("TransposeB", { false, true })),
+ framework::dataset::make("DataType", DataType::QASYMM8)),
+ framework::dataset::make("ActivationInfo", { ActivationLayerInfo(), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU) })),
+ framework::dataset::make("NumberOfExtraRuns", { 0, 1 })),
+ framework::dataset::make("LhsQInfo", { QuantizationInfo(1.f / 50, 1) })),
+ framework::dataset::make("RhsQInfo", { QuantizationInfo(1.f / 30, -1) })),
+ framework::dataset::make("OutQInfo", { QuantizationInfo(1.f, 2) }))
+)
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_qasymm8);
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmallExtraActivation, NEQuantizedMatMulFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(combine(
+ datasets::SmallerMatMulDataset(),
+ framework::dataset::make("TransposeA", { false, true })),
+ framework::dataset::make("TransposeB", { false, true })),
+ framework::dataset::make("DataType", DataType::QASYMM8)),
+ framework::dataset::make("ActivationInfo", { ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU) })),
+ framework::dataset::make("NumberOfExtraRuns", { 0, 1 })),
+ framework::dataset::make("LhsQInfo", { QuantizationInfo(1.f / 50, 1) })),
+ framework::dataset::make("RhsQInfo", { QuantizationInfo(1.f / 30, -1) })),
+ framework::dataset::make("OutQInfo", { QuantizationInfo(1.f, 2) }))
+)
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_qasymm8);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, NEQuantizedMatMulFixture<uint8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(combine(
+ datasets::LargeMatMulDataset(),
+ framework::dataset::make("TransposeA", { false, true })),
+ framework::dataset::make("TransposeB", { false, true })),
+ framework::dataset::make("DataType", DataType::QASYMM8)),
+ framework::dataset::make("ActivationInfo", { ActivationLayerInfo(), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU) })),
+ framework::dataset::make("NumberOfExtraRuns", { 0, 1 })),
+ framework::dataset::make("LhsQInfo", { QuantizationInfo(1.f / 100, 1) })),
+ framework::dataset::make("RhsQInfo", { QuantizationInfo(1.f / 200, -1) })),
+ framework::dataset::make("OutQInfo", { QuantizationInfo(1.f, 2) }))
+)
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_qasymm8);
+}
+
+TEST_SUITE_END() // QASYMM8
+
+TEST_SUITE(QASYMM8_SIGNED)
+
+FIXTURE_DATA_TEST_CASE(RunSmall, NEQuantizedMatMulFixture<int8_t>, framework::DatasetMode::PRECOMMIT, combine(combine(combine(combine(combine(combine(combine(combine(
+ datasets::SmallMatMulDataset(),
+ framework::dataset::make("TransposeA", { false, true })),
+ framework::dataset::make("TransposeB", { false, true })),
+ framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
+ framework::dataset::make("ActivationInfo", { ActivationLayerInfo(), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU) })),
+ framework::dataset::make("NumberOfExtraRuns", { 0, 1 })),
+ framework::dataset::make("LhsQInfo", { QuantizationInfo(1.f / 40, -2) })),
+ framework::dataset::make("RhsQInfo", { QuantizationInfo(1.f / 50, 1) })),
+ framework::dataset::make("OutQInfo", { QuantizationInfo(1.f, 1) }))
+)
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_qasymm8_signed);
+}
+
+FIXTURE_DATA_TEST_CASE(RunSmallExtraActivation, NEQuantizedMatMulFixture<int8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(combine(
+ datasets::SmallerMatMulDataset(),
+ framework::dataset::make("TransposeA", { false, true })),
+ framework::dataset::make("TransposeB", { false, true })),
+ framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
+ framework::dataset::make("ActivationInfo", { ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::BOUNDED_RELU), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LU_BOUNDED_RELU) })),
+ framework::dataset::make("NumberOfExtraRuns", { 0, 1 })),
+ framework::dataset::make("LhsQInfo", { QuantizationInfo(1.f / 40, -2) })),
+ framework::dataset::make("RhsQInfo", { QuantizationInfo(1.f / 50, 1) })),
+ framework::dataset::make("OutQInfo", { QuantizationInfo(1.f, 1) }))
+)
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_qasymm8_signed);
+}
+
+FIXTURE_DATA_TEST_CASE(RunLarge, NEQuantizedMatMulFixture<int8_t>, framework::DatasetMode::NIGHTLY, combine(combine(combine(combine(combine(combine(combine(combine(
+ datasets::LargeMatMulDataset(),
+ framework::dataset::make("TransposeA", { false, true })),
+ framework::dataset::make("TransposeB", { false, true })),
+ framework::dataset::make("DataType", DataType::QASYMM8_SIGNED)),
+ framework::dataset::make("ActivationInfo", { ActivationLayerInfo(), ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::RELU) })),
+ framework::dataset::make("NumberOfExtraRuns", { 0, 1 })),
+ framework::dataset::make("LhsQInfo", { QuantizationInfo(1.f / 150, -2) })),
+ framework::dataset::make("RhsQInfo", { QuantizationInfo(1.f / 250, 1) })),
+ framework::dataset::make("OutQInfo", { QuantizationInfo(1.f, 1) }))
+)
+{
+ // Validate output
+ validate(Accessor(_target), _reference, tolerance_qasymm8_signed);
+}
+
+TEST_SUITE_END() // QASYMM8_SIGNED
+
+TEST_SUITE_END() // Quantized
+
TEST_SUITE_END() // MatMul
TEST_SUITE_END() // NEON
} // namespace validation
diff --git a/tests/validation/fixtures/MatMulFixture.h b/tests/validation/fixtures/MatMulFixture.h
index bb4a1cd7be..f8f038af3f 100644
--- a/tests/validation/fixtures/MatMulFixture.h
+++ b/tests/validation/fixtures/MatMulFixture.h
@@ -25,12 +25,17 @@
#define TESTS_VALIDATION_FIXTURES_MATMULFIXTURE
#include "arm_compute/core/Types.h"
+#include "arm_compute/core/Utils.h"
+#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
#include "tests/framework/Fixture.h"
#include "tests/validation/reference/ActivationLayer.h"
#include "tests/validation/reference/GEMM.h"
+#include "tests/validation/reference/GEMMLowp.h"
#include "tests/validation/reference/Permute.h"
#include "tests/validation/reference/ReshapeLayer.h"
+#include <limits>
#include <random>
+#include <type_traits>
namespace arm_compute
{
@@ -44,7 +49,7 @@ class MatMulGenericValidationFixture : public framework::Fixture
public:
template <typename...>
void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo act_info, int num_extra_runs,
- Settings settings)
+ Settings settings, QuantizationInfo a_qinfo = QuantizationInfo(), QuantizationInfo b_qinfo = QuantizationInfo(), QuantizationInfo o_qinfo = QuantizationInfo())
{
// For brevity, the input shapes are assumed to be not-transposed for both a and b matrices.
if(transpose_a)
@@ -56,8 +61,8 @@ public:
permute(shape_b, PermutationVector(1U, 0U));
}
- _target = compute_target(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, settings);
- _reference = compute_reference(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info);
+ _target = compute_target(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, settings, a_qinfo, b_qinfo, o_qinfo);
+ _reference = compute_reference(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, a_qinfo, b_qinfo, o_qinfo);
}
protected:
@@ -78,23 +83,29 @@ protected:
library->fill(tensor, distribution, i);
break;
}
- default:
+ case DataType::QASYMM8:
+ case DataType::QASYMM8_SIGNED:
{
library->fill_tensor_uniform(tensor, i);
+ break;
+ }
+ default:
+ {
+ ARM_COMPUTE_ERROR("Unsupported data type.");
}
}
}
TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &output_shape, bool transpose_a, bool transpose_b, DataType data_type,
- ActivationLayerInfo act_info, int num_extra_runs, const Settings &settings)
+ ActivationLayerInfo act_info, int num_extra_runs, const Settings &settings, QuantizationInfo a_qinfo, QuantizationInfo b_qinfo, QuantizationInfo o_qinfo)
{
// 1. Create Classes and configure function
// ----------------------------------------------------
// Create tensors
// Configure relevant classes and matmul function
- TensorType a = create_tensor<TensorType>(shape_a, data_type, 1);
- TensorType b = create_tensor<TensorType>(shape_b, data_type, 1);
- TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1);
+ TensorType a = create_tensor<TensorType>(shape_a, data_type, 1, a_qinfo);
+ TensorType b = create_tensor<TensorType>(shape_b, data_type, 1, b_qinfo);
+ TensorType dst = create_tensor<TensorType>(output_shape, data_type, 1, o_qinfo);
FunctionType matmul;
@@ -149,18 +160,61 @@ protected:
return dst;
}
+ template <typename TT>
+ typename std::enable_if<!std::is_integral<TT>::value, SimpleTensor<TT>>::type
+ compute_reference_gemm(const SimpleTensor<TT> &a, const SimpleTensor<TT> &b, const SimpleTensor<TT> &c, float alpha, float beta, const ActivationLayerInfo &act_info, const QuantizationInfo &o_qinfo)
+ {
+ ARM_COMPUTE_UNUSED(act_info, o_qinfo);
+
+ return reference::gemm(a, b, c, alpha, beta);
+ }
+
+ template <typename TT>
+ typename std::enable_if<std::is_integral<TT>::value, SimpleTensor<TT>>::type
+ compute_reference_gemm(const SimpleTensor<TT> &a, const SimpleTensor<TT> &b, const SimpleTensor<TT> &c, float alpha, float beta, const ActivationLayerInfo &act_info, const QuantizationInfo &o_qinfo)
+ {
+ ARM_COMPUTE_UNUSED(alpha, beta);
+
+ const auto aq = a.quantization_info().uniform();
+ const auto bq = b.quantization_info().uniform();
+ const auto oq = o_qinfo.uniform();
+
+ const auto multiplier = aq.scale * bq.scale / oq.scale;
+
+ int32_t output_multiplier = 0;
+ int32_t output_shift = 0;
+ quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift);
+ std::vector<int32_t> output_multipliers{ output_multiplier };
+ std::vector<int32_t> output_shifts{ output_shift };
+
+ PixelValue output_min{};
+ PixelValue output_max{};
+ std::tie(output_min, output_max) = quantization::get_quantized_asymmetric_output_min_max(
+ o_qinfo, act_info, a.data_type());
+
+ const auto tmp = reference::gemmlowp_matrix_multiply_core<int32_t>(
+ a, b, c.shape(), aq.offset, bq.offset);
+
+ auto output = reference::gemmlowp_quantize_down_scale_by_fixedpoint<int32_t, TT>(
+ tmp, output_multipliers, output_shifts, oq.offset,
+ output_min.get<int32_t>(), output_max.get<int32_t>());
+ output.quantization_info(o_qinfo);
+
+ return output;
+ }
+
SimpleTensor<T> compute_reference(const TensorShape &a_shape, const TensorShape &b_shape, const TensorShape &output_shape, bool transpose_a, bool transpose_b, DataType data_type,
- ActivationLayerInfo act_info)
+ ActivationLayerInfo act_info, QuantizationInfo a_qinfo, QuantizationInfo b_qinfo, QuantizationInfo o_qinfo)
{
- // We collapse dimensions > 3 onto dimension 3, i.e. 5D+ tensors will look like 4D
- // This is necessary unless we choose to extend gemm reference for 5D+ tensors
- TensorShape output_shape_collapsed = output_shape.collapsed_from(Window::DimW);
- TensorShape a_shape_collapsed = a_shape.collapsed_from(Window::DimW);
- TensorShape b_shape_collapsed = b_shape.collapsed_from(Window::DimW);
+ // We collapse dimensions > 2 onto dimension 2, i.e. 4D+ tensors will look like 3D
+ // This is necessary unless we choose to extend gemm reference for 4D+ tensors
+ TensorShape output_shape_collapsed = output_shape.collapsed_from(Window::DimZ);
+ TensorShape a_shape_collapsed = a_shape.collapsed_from(Window::DimZ);
+ TensorShape b_shape_collapsed = b_shape.collapsed_from(Window::DimZ);
// Create reference
- SimpleTensor<T> a{ a_shape_collapsed, data_type, 1 };
- SimpleTensor<T> b{ b_shape_collapsed, data_type, 1 };
+ SimpleTensor<T> a{ a_shape_collapsed, data_type, 1, a_qinfo };
+ SimpleTensor<T> b{ b_shape_collapsed, data_type, 1, b_qinfo };
SimpleTensor<T> c{ output_shape_collapsed, data_type, 1 };
// Fill reference
@@ -199,8 +253,9 @@ protected:
// Setting beta to 0 will effectively disable C for the
// computation of the reference: alpha * A * B + 0 * C
// Use transposed tensors if boolean enabled else use original tensors
- SimpleTensor<T> result = reference::gemm<T>((transpose_a) ? a_transposed : a, (transpose_b) ? b_transposed : b, c, 1.0f, 0.f);
- result = reference::activation_layer<T>(result, act_info, QuantizationInfo());
+ auto result = compute_reference_gemm<T>((transpose_a) ? a_transposed : a, (transpose_b) ? b_transposed : b, c, 1.0f, 0.f, act_info, o_qinfo);
+
+ result = reference::activation_layer<T>(result, act_info, o_qinfo);
// We reshape the gemm output back if the tensor is high dimensional
if(output_shape_collapsed != output_shape)
@@ -249,6 +304,17 @@ public:
}
};
+template <typename TensorType, typename AccessorType, typename FunctionType, typename Settings, typename T>
+class QuantizedMatMulValidationFixture : public MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>
+{
+public:
+ template <typename...>
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo act_info, int num_extra_runs, QuantizationInfo a_qinfo, QuantizationInfo b_qinfo, QuantizationInfo o_qinfo)
+ {
+ MatMulGenericValidationFixture<TensorType, AccessorType, FunctionType, Settings, T>::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, Settings(), a_qinfo, b_qinfo, o_qinfo);
+ }
+};
+
} // namespace validation
} // namespace test
} // namespace arm_compute