aboutsummaryrefslogtreecommitdiff
path: root/tests/validation/fixtures/MatMulKernelFixture.h
diff options
context:
space:
mode:
Diffstat (limited to 'tests/validation/fixtures/MatMulKernelFixture.h')
-rw-r--r--tests/validation/fixtures/MatMulKernelFixture.h130
1 files changed, 101 insertions, 29 deletions
diff --git a/tests/validation/fixtures/MatMulKernelFixture.h b/tests/validation/fixtures/MatMulKernelFixture.h
index 10e2a0659a..7d0b1a40a9 100644
--- a/tests/validation/fixtures/MatMulKernelFixture.h
+++ b/tests/validation/fixtures/MatMulKernelFixture.h
@@ -25,11 +25,15 @@
#define ACL_TESTS_VALIDATION_FIXTURES_MATMULKERNELFIXTURE
#include "arm_compute/core/KernelDescriptors.h"
-#include "src/gpu/cl/kernels/ClMatMulNativeKernel.h"
+#include "arm_compute/core/Utils.h"
+#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
+
#include "tests/CL/CLAccessor.h"
#include "tests/CL/Helper.h"
#include "tests/framework/Fixture.h"
+#include "tests/validation/Helpers.h"
#include "tests/validation/reference/GEMM.h"
+#include "tests/validation/reference/GEMMLowp.h"
#include "tests/validation/reference/Permute.h"
#include "tests/validation/reference/ReshapeLayer.h"
@@ -43,14 +47,43 @@ namespace validation
{
using namespace arm_compute::opencl::kernels;
-template <typename T>
+template <typename T, typename KernelType>
class MatMulKernelValidationFixture : public framework::Fixture
{
public:
template <typename...>
- void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool pretranspose_a, bool pretranspose_b, const int M0, const int N0, const int K0, bool export_rhs_to_cl_image, DataType data_type)
+ void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool pretranspose_a, bool pretranspose_b, int M0, int N0, int K0, bool export_rhs_to_cl_image, DataType data_type)
{
// For brevity, the input shapes are assumed to be not-transposed for both Lhs and Rhs matrices.
+ QuantizationInfo lhs_q_info;
+ QuantizationInfo rhs_q_info;
+ QuantizationInfo dst_q_info;
+
+ if(is_data_type_quantized(data_type))
+ {
+ const int32_t t_max = static_cast<int32_t>(std::numeric_limits<T>::max());
+ const int32_t t_min = static_cast<int32_t>(std::numeric_limits<T>::min());
+
+ std::mt19937 generator(library->seed());
+ std::uniform_real_distribution<float> distribution_float(-5.0f, 3.0f);
+ std::uniform_int_distribution<int32_t> distribution_t(t_min, t_max);
+
+ const float scale_lhs = pow(2, distribution_float(generator)); // [2^-5, 2^3]
+ const float scale_rhs = pow(2, distribution_float(generator)); // [2^-5, 2^3]
+
+ const int32_t offset_lhs = distribution_t(generator);
+ const int32_t offset_rhs = distribution_t(generator);
+
+ lhs_q_info = QuantizationInfo(scale_lhs, offset_lhs);
+ rhs_q_info = QuantizationInfo(scale_rhs, offset_rhs);
+
+ const int m = shape_a.y();
+ const int n = shape_b.x();
+ const int k = shape_a.x();
+
+ dst_q_info = calculate_mat_mul_dst_q_info(lhs_q_info, rhs_q_info, m, n, k, data_type);
+ }
+
if(pretranspose_a)
{
permute(shape_a, PermutationVector(1U, 0U));
@@ -65,8 +98,8 @@ public:
if(!export_rhs_to_cl_image || _device_supports_export_to_cl_image)
{
- _target = compute_target(shape_a, shape_b, output_shape, pretranspose_a, pretranspose_b, M0, N0, K0, export_rhs_to_cl_image, data_type);
- _reference = compute_reference(shape_a, shape_b, output_shape, pretranspose_a, pretranspose_b, data_type);
+ _target = compute_target(shape_a, shape_b, output_shape, pretranspose_a, pretranspose_b, M0, N0, K0, export_rhs_to_cl_image, data_type, lhs_q_info, rhs_q_info, dst_q_info);
+ _reference = compute_reference(shape_a, shape_b, output_shape, pretranspose_a, pretranspose_b, data_type, lhs_q_info, rhs_q_info, dst_q_info);
}
}
@@ -93,23 +126,29 @@ protected:
}
}
+ template <typename U, typename D>
+ void fill_constant(U &&tensor, D value)
+ {
+ library->fill_tensor_value(tensor, value);
+ }
+
CLTensor compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &output_shape, bool pretranspose_a, bool pretranspose_b, const int M0, const int N0, const int K0,
- bool export_rhs_to_cl_image, DataType data_type)
+ bool export_rhs_to_cl_image, DataType data_type, const QuantizationInfo &lhs_q_info, const QuantizationInfo &rhs_q_info, const QuantizationInfo &dst_q_info)
{
- // Create tensors
- CLTensor a = create_tensor<CLTensor>(shape_a, data_type, 1);
- CLTensor b = create_tensor<CLTensor>(shape_b, data_type, 1);
- CLTensor dst = create_tensor<CLTensor>(output_shape, data_type, 1);
-
- CLSynthetizeOperator<ClMatMulNativeKernel> matMul{};
- MatMulKernelInfo matmul_info;
- matmul_info.adj_lhs = pretranspose_a;
- matmul_info.adj_rhs = pretranspose_b;
- matmul_info.m0 = M0;
- matmul_info.n0 = N0;
- matmul_info.k0 = K0;
+ CLSynthetizeOperator<KernelType> matMul{};
+ MatMulKernelInfo matmul_info;
+ matmul_info.adj_lhs = pretranspose_a;
+ matmul_info.adj_rhs = pretranspose_b;
+ matmul_info.m0 = M0;
+ matmul_info.n0 = N0;
+ matmul_info.k0 = K0;
matmul_info.export_rhs_to_cl_image = export_rhs_to_cl_image;
+ // Create tensors
+ CLTensor a = create_tensor<CLTensor>(shape_a, data_type, 1, lhs_q_info);
+ CLTensor b = create_tensor<CLTensor>(shape_b, data_type, 1, rhs_q_info);
+ CLTensor dst = create_tensor<CLTensor>(output_shape, data_type, 1, dst_q_info);
+
matMul.configure(a.info(), b.info(), dst.info(), matmul_info);
ARM_COMPUTE_ASSERT(a.info()->is_resizable());
ARM_COMPUTE_ASSERT(b.info()->is_resizable());
@@ -138,18 +177,19 @@ protected:
return dst;
}
- SimpleTensor<T> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &output_shape, bool pretranspose_a, bool pretranspose_b, DataType data_type)
+ SimpleTensor<T> compute_reference(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &output_shape, bool pretranspose_a, bool pretranspose_b, DataType data_type,
+ const QuantizationInfo &lhs_q_info, const QuantizationInfo &rhs_q_info, const QuantizationInfo &dst_q_info)
{
// We collapse dimensions > 3 onto dimension 3, i.e. 5D+ tensors will look like 4D
// This is necessary unless we choose to extend gemm reference for 5D+ tensors
- TensorShape output_shape_collapsed = output_shape.collapsed_from(Window::DimW);
- TensorShape shape_a_collapsed = shape_a.collapsed_from(Window::DimW);
- TensorShape shape_b_collapsed = shape_b.collapsed_from(Window::DimW);
+ TensorShape output_shape_collapsed = output_shape.collapsed_from(Window::DimZ);
+ TensorShape shape_a_collapsed = shape_a.collapsed_from(Window::DimZ);
+ TensorShape shape_b_collapsed = shape_b.collapsed_from(Window::DimZ);
// Create reference
- SimpleTensor<T> a{ shape_a_collapsed, data_type, 1 };
- SimpleTensor<T> b{ shape_b_collapsed, data_type, 1 };
- SimpleTensor<T> c{ output_shape_collapsed, data_type, 1 };
+ SimpleTensor<T> a{ shape_a_collapsed, data_type, 1, lhs_q_info };
+ SimpleTensor<T> b{ shape_b_collapsed, data_type, 1, rhs_q_info };
+ SimpleTensor<T> c{ output_shape_collapsed, data_type, 1, dst_q_info };
// Fill reference
fill(a, 0);
@@ -185,10 +225,8 @@ protected:
b_transposed = reference::permute<T>(b, PermutationVector(1U, 0U));
}
- // Setting beta to 0 will effectively disable C for the
- // computation of the reference: alpha * A * B + 0 * C
// Use transposed tensors if boolean enabled else use original tensors
- SimpleTensor<T> result = reference::gemm<T>((pretranspose_a) ? a_transposed : a, (pretranspose_b) ? b_transposed : b, c, 1.0f, 0.f);
+ SimpleTensor<T> result = gemm_reference<T>((pretranspose_a) ? a_transposed : a, (pretranspose_b) ? b_transposed : b, c);
// We reshape the gemm output back if the tensor is high dimensional
if(output_shape_collapsed != output_shape)
@@ -199,9 +237,43 @@ protected:
return result;
}
+ template <typename U = T>
+ typename std::enable_if < std::is_same<U, float>::value || std::is_same<U, half>::value, SimpleTensor<U >>::type gemm_reference(SimpleTensor<U> &a, SimpleTensor<U> &b, SimpleTensor<U> &c)
+ {
+ // Setting beta to 0 will effectively disable C for the
+ // computation of the reference: alpha * A * B + 0 * C
+ return reference::gemm<U>(a, b, c, 1.0f, 0.f);
+ }
+
+ template <typename U = T>
+ typename std::enable_if < std::is_same<U, int8_t>::value || std::is_same<U, uint8_t>::value, SimpleTensor<U >>::type gemm_reference(SimpleTensor<U> &a, SimpleTensor<U> &b, SimpleTensor<U> &c)
+ {
+ const UniformQuantizationInfo aq = a.quantization_info().uniform();
+ const UniformQuantizationInfo bq = b.quantization_info().uniform();
+ const UniformQuantizationInfo cq = c.quantization_info().uniform();
+
+ const SimpleTensor<int32_t> result = reference::gemmlowp_matrix_multiply_core<int32_t, U, U>(a, b, c.shape(), -aq.offset, -bq.offset);
+
+ std::vector<int32_t> gemmlowp_multipliers{ 1 };
+ std::vector<int32_t> gemmlowp_shifts{ 1 };
+ const int gemmlowp_offset = cq.offset;
+ const float scale = aq.scale * bq.scale / cq.scale;
+
+ quantization::calculate_quantized_multiplier(scale, &gemmlowp_multipliers[0], &gemmlowp_shifts[0]);
+ constexpr int32_t gemmlowp_min_bound = std::numeric_limits<int32_t>::min();
+ constexpr int32_t gemmlowp_max_bound = std::numeric_limits<int32_t>::max();
+
+ SimpleTensor<int> bias{ c.shape(), DataType::S32 };
+ fill_constant(bias, static_cast<int32_t>(0));
+
+ const SimpleTensor<U> final_result = reference::gemmlowp_quantize_down_scale_by_fixedpoint<int32_t, U>(result, bias,
+ gemmlowp_multipliers, gemmlowp_shifts, gemmlowp_offset, gemmlowp_min_bound, gemmlowp_max_bound);
+ return final_result;
+ }
+
CLTensor _target{};
SimpleTensor<T> _reference{};
- bool _device_supports_export_to_cl_image { true };
+ bool _device_supports_export_to_cl_image{ true };
};
} // namespace validation