From 36a75dafdbe6d6a3a6f50bd075fe01f5b7dace38 Mon Sep 17 00:00:00 2001 From: Renato Arantes Date: Fri, 26 Jan 2024 17:31:18 +0000 Subject: =?UTF-8?q?[ONCPUML-1451]=20Add=20matmul=20kernel=20to=20enable=20?= =?UTF-8?q?bf16=20to=20bf16=20operations=20via=20PyTorch=C2=AE=20autocast(?= =?UTF-8?q?)=20function?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The full range of tests must be added with [MLINFSW-482] epic due to the lack of reordering kernels implemented in Acl. Co-Authored-By: David Mansell Change-Id: I820d316295a1ec94fdc89c37e4144a268f914c36 Signed-off-by: Renato Arantes Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11169 Tested-by: Arm Jenkins Reviewed-by: Gunes Bayir Comments-Addressed: Arm Jenkins Benchmark: Arm Jenkins --- tests/validation/fixtures/MatMulFixture.h | 383 +++++++++++++++++++++++++----- 1 file changed, 326 insertions(+), 57 deletions(-) (limited to 'tests/validation/fixtures/MatMulFixture.h') diff --git a/tests/validation/fixtures/MatMulFixture.h b/tests/validation/fixtures/MatMulFixture.h index 2e79612a37..ffd12e56d0 100644 --- a/tests/validation/fixtures/MatMulFixture.h +++ b/tests/validation/fixtures/MatMulFixture.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023 Arm Limited. + * Copyright (c) 2023-2024 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -27,15 +27,17 @@ #include "arm_compute/core/Types.h" #include "arm_compute/core/Utils.h" #include "arm_compute/core/utils/quantization/AsymmHelpers.h" + #include "src/core/utils/quantization/AsymmHelpers.h" #include "tests/framework/Asserts.h" // Required for ARM_COMPUTE_ASSERT #include "tests/framework/Fixture.h" -#include "tests/validation/Validation.h" #include "tests/validation/reference/ActivationLayer.h" #include "tests/validation/reference/GEMM.h" #include "tests/validation/reference/GEMMLowp.h" #include "tests/validation/reference/Permute.h" #include "tests/validation/reference/ReshapeLayer.h" +#include "tests/validation/Validation.h" + #include #include #include @@ -50,32 +52,50 @@ template void fill(U &&tensor, int i, float lo = -1.f, float hi = 1.f) { - switch(tensor.data_type()) + switch (tensor.data_type()) { + case DataType::BFLOAT16: + { + arm_compute::utils::uniform_real_distribution_16bit distribution{float(lo), float(hi)}; + library->fill(tensor, distribution, i); + break; + } case DataType::F16: { - arm_compute::utils::uniform_real_distribution_16bit distribution{ float(lo), float(hi) }; + arm_compute::utils::uniform_real_distribution_16bit distribution{float(lo), float(hi)}; library->fill(tensor, distribution, i); break; } @@ -98,8 +118,18 @@ protected: } } - TensorType compute_target(const TensorShape &shape_a, const TensorShape &shape_b, const TensorShape &output_shape, bool transpose_a, bool transpose_b, DataType data_type, - ActivationLayerInfo act_info, int num_extra_runs, const Settings &settings, QuantizationInfo a_qinfo, QuantizationInfo b_qinfo, QuantizationInfo o_qinfo) + virtual TensorType compute_target(const TensorShape &shape_a, + const TensorShape &shape_b, + const TensorShape &output_shape, + bool transpose_a, + bool transpose_b, + DataType data_type, + ActivationLayerInfo act_info, + int num_extra_runs, + const Settings &settings, + QuantizationInfo a_qinfo, + QuantizationInfo b_qinfo, + QuantizationInfo o_qinfo) { // 1. Create Classes and configure function // ---------------------------------------------------- @@ -137,7 +167,7 @@ protected: ARM_COMPUTE_ASSERT(!dst.info()->is_resizable()); // For multiple runs. - for(int i = 0; i < num_extra_runs; i++) + for (int i = 0; i < num_extra_runs; i++) { // Stress dynamic tensors by running multiple times. // -------------------------------------------------------- @@ -164,7 +194,12 @@ protected: template typename std::enable_if < !std::is_integral::value, SimpleTensor>::type - compute_reference_gemm(const SimpleTensor &a, const SimpleTensor &b, const SimpleTensor &c, float alpha, float beta, const QuantizationInfo &o_qinfo) + compute_reference_gemm(const SimpleTensor &a, + const SimpleTensor &b, + const SimpleTensor &c, + float alpha, + float beta, + const QuantizationInfo &o_qinfo) { ARM_COMPUTE_UNUSED(o_qinfo); @@ -173,7 +208,12 @@ protected: template typename std::enable_if::value, SimpleTensor>::type - compute_reference_gemm(const SimpleTensor &a, const SimpleTensor &b, const SimpleTensor &c, float alpha, float beta, const QuantizationInfo &o_qinfo) + compute_reference_gemm(const SimpleTensor &a, + const SimpleTensor &b, + const SimpleTensor &c, + float alpha, + float beta, + const QuantizationInfo &o_qinfo) { ARM_COMPUTE_UNUSED(alpha, beta); @@ -186,23 +226,30 @@ protected: int32_t output_multiplier = 0; int32_t output_shift = 0; quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift); - std::vector output_multipliers{ output_multiplier }; - std::vector output_shifts{ output_shift }; + std::vector output_multipliers{output_multiplier}; + std::vector output_shifts{output_shift}; //The lhs and rhs offsets are negated here to keep the reference aligned with the function implementation where the lhs and rhs offsets are also negated. - const auto tmp = reference::gemmlowp_matrix_multiply_core( - a, b, c.shape(), -aq.offset, -bq.offset); + const auto tmp = reference::gemmlowp_matrix_multiply_core(a, b, c.shape(), -aq.offset, -bq.offset); auto output = reference::gemmlowp_quantize_down_scale_by_fixedpoint( - tmp, output_multipliers, output_shifts, oq.offset, - std::numeric_limits::lowest(), std::numeric_limits::max()); + tmp, output_multipliers, output_shifts, oq.offset, std::numeric_limits::lowest(), + std::numeric_limits::max()); output.quantization_info(o_qinfo); return output; } - SimpleTensor compute_reference(const TensorShape &a_shape, const TensorShape &b_shape, const TensorShape &output_shape, bool transpose_a, bool transpose_b, DataType data_type, - ActivationLayerInfo act_info, QuantizationInfo a_qinfo, QuantizationInfo b_qinfo, QuantizationInfo o_qinfo) + SimpleTensor compute_reference(const TensorShape &a_shape, + const TensorShape &b_shape, + const TensorShape &output_shape, + bool transpose_a, + bool transpose_b, + DataType data_type, + ActivationLayerInfo act_info, + QuantizationInfo a_qinfo, + QuantizationInfo b_qinfo, + QuantizationInfo o_qinfo) { // We collapse dimensions > 2 onto dimension 2, i.e. 4D+ tensors will look like 3D // This is necessary unless we choose to extend gemm reference for 4D+ tensors @@ -211,9 +258,9 @@ protected: TensorShape b_shape_collapsed = b_shape.collapsed_from(Window::DimZ); // Create reference - SimpleTensor a{ a_shape_collapsed, data_type, 1, a_qinfo }; - SimpleTensor b{ b_shape_collapsed, data_type, 1, b_qinfo }; - SimpleTensor c{ output_shape_collapsed, data_type, 1 }; + SimpleTensor a{a_shape_collapsed, data_type, 1, a_qinfo}; + SimpleTensor b{b_shape_collapsed, data_type, 1, b_qinfo}; + SimpleTensor c{output_shape_collapsed, data_type, 1}; // Fill reference fill(a, 2); @@ -234,16 +281,16 @@ protected: b_transposed_shape.set(1, b.shape().x()); // Define transposed tensors - SimpleTensor a_transposed{ a_transposed_shape, data_type }; - SimpleTensor b_transposed{ b_transposed_shape, data_type }; + SimpleTensor a_transposed{a_transposed_shape, data_type}; + SimpleTensor b_transposed{b_transposed_shape, data_type}; // pretranspose a if necessary - if(transpose_a) + if (transpose_a) { a_transposed = reference::permute(a, PermutationVector(1U, 0U)); } // pretranspose b if necessary - if(transpose_b) + if (transpose_b) { b_transposed = reference::permute(b, PermutationVector(1U, 0U)); } @@ -251,12 +298,13 @@ protected: // Setting beta to 0 will effectively disable C for the // computation of the reference: alpha * A * B + 0 * C // Use transposed tensors if boolean enabled else use original tensors - auto result = compute_reference_gemm((transpose_a) ? a_transposed : a, (transpose_b) ? b_transposed : b, c, 1.0f, 0.f, o_qinfo); + auto result = compute_reference_gemm((transpose_a) ? a_transposed : a, (transpose_b) ? b_transposed : b, c, + 1.0f, 0.f, o_qinfo); result = reference::activation_layer(result, act_info, o_qinfo); // We reshape the gemm output back if the tensor is high dimensional - if(output_shape_collapsed != output_shape) + if (output_shape_collapsed != output_shape) { result = reference::reshape_layer(result, output_shape); } @@ -268,72 +316,293 @@ protected: SimpleTensor _reference{}; }; +/// TODO: (ONCPUML-1451) The current state of this fixture is interim and a longer-term testing method will be implemented later. +/// @note: Currently we support only a 2x2 test due to the lack of reorder ref. implementation. +template +class MatMulFixedFormatFixture + : public MatMulGenericValidationFixture +{ +public: + TensorType compute_target(const TensorShape &shape_a, + const TensorShape &shape_b, + const TensorShape &output_shape, + bool transpose_a, + bool transpose_b, + DataType data_type, + ActivationLayerInfo act_info, + int num_extra_runs, + const Settings &settings, + QuantizationInfo a_qinfo, + QuantizationInfo b_qinfo, + QuantizationInfo o_qinfo) override + { + // 1. Create Classes and configure function + // ---------------------------------------------------- + // Create tensors + // Configure relevant classes and matmul function + TensorType a = create_tensor(shape_a, data_type, 1, a_qinfo); + TensorType b = create_tensor(shape_b, data_type, 1, b_qinfo); + TensorType dst = create_tensor(output_shape, data_type, 1, o_qinfo); + + const auto weight_tensor_info = TensorInfo(*b.info()); + const TensorInfo new_tensor_info = prepare_weights(weight_tensor_info); + TensorType weights_transformed = create_tensor(new_tensor_info); + + // Configure MatMulInfo class + MatMulInfo mm_info; + mm_info.adj_lhs(transpose_a).adj_rhs(transpose_b); + + // Ensure values are dynamic + a.info()->set_are_values_constant(false); + b.info()->set_are_values_constant(false); + weights_transformed.info()->set_are_values_constant(false); + + FunctionType matmul; + + // Configure operator + matmul.configure(&a, &weights_transformed, &dst, mm_info, settings, act_info); + + // Assertions + ARM_COMPUTE_ASSERT(a.info()->is_resizable()); + ARM_COMPUTE_ASSERT(b.info()->is_resizable()); + ARM_COMPUTE_ASSERT(dst.info()->is_resizable()); + ARM_COMPUTE_ASSERT(weights_transformed.info()->is_resizable()); + + // Allocate tensors + a.allocator()->allocate(); + b.allocator()->allocate(); + dst.allocator()->allocate(); + weights_transformed.allocator()->allocate(); + + ARM_COMPUTE_ASSERT(!a.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!b.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!dst.info()->is_resizable()); + ARM_COMPUTE_ASSERT(!weights_transformed.info()->is_resizable()); + + // For multiple runs. + for (int i = 0; i < num_extra_runs; i++) + { + // Stress dynamic tensors by running multiple times. + // -------------------------------------------------------- + // Fill tensors with new seed + // Run function + const int seed_offset = num_extra_runs * 100; + this->fill(AccessorType(a), seed_offset); + this->fill(AccessorType(b), seed_offset + 1); + + matmul.run(); + } + + // 2. Final Run for reference comparison + // -------------------------------------------------------- + // Re-fill tensors same seed as reference run + // Compute MatMul operation + this->fill(AccessorType(a), 2); + this->fill(AccessorType(b), 3); + + rearrange_data(AccessorType(b), AccessorType(weights_transformed)); + + matmul.run(); + + return dst; + } + + void setup(TensorShape shape_a, + TensorShape shape_b, + TensorShape output_shape, + bool transpose_a, + bool transpose_b, + DataType data_type, + ActivationLayerInfo act_info, + int num_extra_runs, + Settings settings, + QuantizationInfo a_qinfo, + QuantizationInfo b_qinfo, + QuantizationInfo o_qinfo) + { + if (CPUInfo::get().has_bf16()) + { + MatMulGenericValidationFixture::setup( + shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, settings, + a_qinfo, b_qinfo, o_qinfo); + } + } + +private: + TensorInfo prepare_weights(const TensorInfo tensor_info) + { + const DataLayout data_layout = tensor_info.data_layout(); + ARM_COMPUTE_EXPECT(data_layout == DataLayout::NCHW, framework::LogLevel::ERRORS); + const DataType data_type = tensor_info.data_type(); + const TensorShape tensor_shape = tensor_info.tensor_shape(); + const int H = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)]; + const int W = tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)]; + ARM_COMPUTE_EXPECT(H <= 2 && W <= 2, framework::LogLevel::ERRORS); + + arm_compute::Strides strides_in_bytes = tensor_info.strides_in_bytes(); + strides_in_bytes.set(1, 32); + strides_in_bytes.set(2, 32); + + const size_t offset_first_element_in_bytes = tensor_info.offset_first_element_in_bytes(); + const size_t total_size_in_bytes = 32; + + const TensorShape TS(H, W); + + TensorInfo new_tensor_info = tensor_info; + new_tensor_info.init(TS, tensor_info.num_channels(), data_type, strides_in_bytes, offset_first_element_in_bytes, + total_size_in_bytes); + + return new_tensor_info; + } + + void rearrange_data(const AccessorType src, AccessorType dst) + { + const TensorShape src_tensor_shape = src.shape(); + const DataLayout data_layout = src.data_layout(); + ARM_COMPUTE_EXPECT(data_layout == DataLayout::NCHW, framework::LogLevel::ERRORS); + const unsigned int O = + src_tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::BATCHES)]; // N=O + const unsigned int H = + src_tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::HEIGHT)]; + const unsigned int W = + src_tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::WIDTH)]; + const unsigned int I = + src_tensor_shape[get_data_layout_dimension_index(data_layout, DataLayoutDimension::CHANNEL)]; // C=I + ARM_COMPUTE_EXPECT(H <= 2 && W <= 2, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(I == 1 && O == 1, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(src.num_elements() <= dst.num_elements(), framework::LogLevel::ERRORS); + + const T *src_ptr = reinterpret_cast(src.data()); + T *dst_ptr = reinterpret_cast(dst.data()); + + // rearrange indexes for 2x2 input and weight + int dst_idx[] = {0, 4, 1, 5}; + for (int i = 0; i < 4; i++) + { + dst_ptr[dst_idx[i]] = src_ptr[i]; + } + } +}; + template -class MatMulValidationFixture : public MatMulGenericValidationFixture +class MatMulValidationFixture + : public MatMulGenericValidationFixture { public: - void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type) + void setup(TensorShape shape_a, + TensorShape shape_b, + TensorShape output_shape, + bool transpose_a, + bool transpose_b, + DataType data_type) { - MatMulGenericValidationFixture::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, ActivationLayerInfo(), 0, - Settings()); + MatMulGenericValidationFixture::setup( + shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, ActivationLayerInfo(), 0, Settings()); } }; template -class MatMulValidationWithDynamicTensorsFixture : public MatMulGenericValidationFixture +class MatMulValidationWithDynamicTensorsFixture + : public MatMulGenericValidationFixture { public: - void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo act_info, int num_extra_runs) + void setup(TensorShape shape_a, + TensorShape shape_b, + TensorShape output_shape, + bool transpose_a, + bool transpose_b, + DataType data_type, + ActivationLayerInfo act_info, + int num_extra_runs) { - MatMulGenericValidationFixture::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, Settings()); + MatMulGenericValidationFixture::setup( + shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, Settings()); } }; template -class QuantizedMatMulValidationFixture : public MatMulGenericValidationFixture +class QuantizedMatMulValidationFixture + : public MatMulGenericValidationFixture { public: - void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo act_info, int num_extra_runs, - QuantizationInfo a_qinfo, QuantizationInfo b_qinfo, QuantizationInfo o_qinfo) + void setup(TensorShape shape_a, + TensorShape shape_b, + TensorShape output_shape, + bool transpose_a, + bool transpose_b, + DataType data_type, + ActivationLayerInfo act_info, + int num_extra_runs, + QuantizationInfo a_qinfo, + QuantizationInfo b_qinfo, + QuantizationInfo o_qinfo) { - MatMulGenericValidationFixture::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, Settings(), - a_qinfo, b_qinfo, o_qinfo); + MatMulGenericValidationFixture::setup( + shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, Settings(), + a_qinfo, b_qinfo, o_qinfo); } }; template -class MatMulValidationWithActivationFixture : public MatMulGenericValidationFixture +class MatMulValidationWithActivationFixture + : public MatMulGenericValidationFixture { public: - void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo act_info) + void setup(TensorShape shape_a, + TensorShape shape_b, + TensorShape output_shape, + bool transpose_a, + bool transpose_b, + DataType data_type, + ActivationLayerInfo act_info) { - MatMulGenericValidationFixture::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, 0, Settings()); + MatMulGenericValidationFixture::setup( + shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, 0, Settings()); } }; template -class MatMulValidationWithActivationAlphaBetaFixture : public MatMulGenericValidationFixture +class MatMulValidationWithActivationAlphaBetaFixture + : public MatMulGenericValidationFixture { public: - void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo::ActivationFunction function, - float alpha_beta) + void setup(TensorShape shape_a, + TensorShape shape_b, + TensorShape output_shape, + bool transpose_a, + bool transpose_b, + DataType data_type, + ActivationLayerInfo::ActivationFunction function, + float alpha_beta) { ActivationLayerInfo act_info(function, alpha_beta, alpha_beta); - MatMulGenericValidationFixture::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, 0, Settings()); + MatMulGenericValidationFixture::setup( + shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, 0, Settings()); } }; template -class QuantizedMatMulValidationWithActivationFixture : public MatMulGenericValidationFixture +class QuantizedMatMulValidationWithActivationFixture + : public MatMulGenericValidationFixture { public: - void setup(TensorShape shape_a, TensorShape shape_b, TensorShape output_shape, bool transpose_a, bool transpose_b, DataType data_type, ActivationLayerInfo::ActivationFunction function, - float alpha_beta, int num_extra_runs, - QuantizationInfo a_qinfo, QuantizationInfo b_qinfo, QuantizationInfo o_qinfo) + void setup(TensorShape shape_a, + TensorShape shape_b, + TensorShape output_shape, + bool transpose_a, + bool transpose_b, + DataType data_type, + ActivationLayerInfo::ActivationFunction function, + float alpha_beta, + int num_extra_runs, + QuantizationInfo a_qinfo, + QuantizationInfo b_qinfo, + QuantizationInfo o_qinfo) { ActivationLayerInfo act_info(function, alpha_beta, alpha_beta); - MatMulGenericValidationFixture::setup(shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, Settings(), - a_qinfo, b_qinfo, o_qinfo); + MatMulGenericValidationFixture::setup( + shape_a, shape_b, output_shape, transpose_a, transpose_b, data_type, act_info, num_extra_runs, Settings(), + a_qinfo, b_qinfo, o_qinfo); } }; -- cgit v1.2.1