aboutsummaryrefslogtreecommitdiff
path: root/src/cpu
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu')
-rw-r--r--src/cpu/kernels/CpuActivationKernel.cpp18
-rw-r--r--src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.cpp6
-rw-r--r--src/cpu/kernels/CpuKernelSelectionTypes.h2
-rw-r--r--src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h16
-rw-r--r--src/cpu/kernels/assembly/arm_gemm.hpp12
-rw-r--r--src/cpu/kernels/assembly/convolution_parameters.hpp10
-rw-r--r--src/cpu/kernels/assembly/gemm_common.hpp18
-rw-r--r--src/cpu/kernels/gemm_matrix_mul/generic/neon/fp16.cpp6
-rw-r--r--src/cpu/operators/CpuConv2d.h13
-rw-r--r--src/cpu/operators/CpuGemmConv2d.h3
-rw-r--r--src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp75
-rw-r--r--src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h2
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp270
13 files changed, 255 insertions, 196 deletions
diff --git a/src/cpu/kernels/CpuActivationKernel.cpp b/src/cpu/kernels/CpuActivationKernel.cpp
index 7cfa39b286..4253027231 100644
--- a/src/cpu/kernels/CpuActivationKernel.cpp
+++ b/src/cpu/kernels/CpuActivationKernel.cpp
@@ -43,6 +43,13 @@ namespace kernels
{
namespace
{
+
+bool is_fp16_lut_supported(ActivationLayerInfo::ActivationFunction func)
+{
+ return func == ActivationLayerInfo::ActivationFunction::LOGISTIC ||
+ func == ActivationLayerInfo::ActivationFunction::TANH;
+}
+
static const std::vector<CpuActivationKernel::ActivationKernel> available_kernels = {
#ifdef ARM_COMPUTE_ENABLE_SVE
{"sve2_q8_activation_lut",
@@ -85,10 +92,7 @@ static const std::vector<CpuActivationKernel::ActivationKernel> available_kernel
REGISTER_QSYMM16_SVE2(arm_compute::cpu::sve2_qsymm16_activation)},
{"sve_fp16_activation_lut",
[](const ActivationDataTypeISASelectorData &data)
- {
- return data.dt == DataType::F16 && data.isa.fp16 && data.isa.sve &&
- data.f == ActivationLayerInfo::ActivationFunction::LOGISTIC;
- },
+ { return data.dt == DataType::F16 && data.isa.fp16 && data.isa.sve && is_fp16_lut_supported(data.f); },
REGISTER_FP16_SVE(arm_compute::cpu::sve_fp16_activation_lut)},
{"sve_fp16_activation",
[](const ActivationDataTypeISASelectorData &data)
@@ -299,10 +303,10 @@ void CpuActivationKernel::configure(const ITensorInfo *src, ITensorInfo *dst, Ac
activation_info.setLookupTable256(tmp_lut);
}
- if (src->data_type() == DataType::F16 &&
- activation_info.activation() == ActivationLayerInfo::ActivationFunction::LOGISTIC)
+ if (std::string(uk->name) == "sve_fp16_activation_lut")
{
- const LUTInfo info = {activation_info.activation(), src->data_type(), src->quantization_info()};
+ const LUTInfo info = {activation_info.activation(), activation_info.a(), activation_info.b(), src->data_type(),
+ src->quantization_info().uniform()};
activation_info.setLookupTable65536((lut_manager.get_lut_table(info)));
}
#endif // __aarch64__
diff --git a/src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.cpp b/src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.cpp
index a3ed2cd171..87340e566e 100644
--- a/src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.cpp
+++ b/src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021 Arm Limited.
+ * Copyright (c) 2017-2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -684,6 +684,10 @@ Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, cons
DataType::U8);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(dst, 1, DataType::S32);
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(src0->data_type() == DataType::QASYMM8_SIGNED &&
+ src1->data_type() == DataType::QASYMM8,
+ "QASYMM8_SIGNED input with QASYMM8 weights not supported");
+
TensorShape in0_shape = src0->tensor_shape();
TensorShape in1_shape = src1->tensor_shape();
TensorShape out_shape = dst->tensor_shape();
diff --git a/src/cpu/kernels/CpuKernelSelectionTypes.h b/src/cpu/kernels/CpuKernelSelectionTypes.h
index 7c1e4772a6..03a474de53 100644
--- a/src/cpu/kernels/CpuKernelSelectionTypes.h
+++ b/src/cpu/kernels/CpuKernelSelectionTypes.h
@@ -105,7 +105,7 @@ struct SoftmaxKernelDataTypeISASelectorData
cpuinfo::CpuIsaInfo isa;
bool is_log;
int axis;
- unsigned long sme2_vector_length;
+ uint64_t sme2_vector_length;
};
// Selector pointer types
diff --git a/src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h b/src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h
index 6e8f32ef47..72fafca1bb 100644
--- a/src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h
+++ b/src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2022 Arm Limited.
+ * Copyright (c) 2018-2022, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,8 +21,8 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
-#ifndef ARM_COMPUTE_ASSEMBLY_GEMM_KERNEL_WRAPPER_KERNEL_H
-#define ARM_COMPUTE_ASSEMBLY_GEMM_KERNEL_WRAPPER_KERNEL_H
+#ifndef ACL_SRC_CPU_KERNELS_ASSEMBLY_CPUGEMMASSEMBLYWRAPPERKERNEL_H
+#define ACL_SRC_CPU_KERNELS_ASSEMBLY_CPUGEMMASSEMBLYWRAPPERKERNEL_H
#include "arm_compute/core/Utils.h"
#include "arm_compute/core/Validate.h"
@@ -52,7 +52,7 @@ namespace kernel
*
*
*/
-template <typename TypeInput, typename TypeOutput>
+template <typename TypeInput, typename TypeWeight, typename TypeOutput>
class CpuGemmAssemblyWrapperKernel final : public INEKernel
{
public:
@@ -101,7 +101,7 @@ public:
* @param[in] kernel Pointer to an assembly kernel implementation.
* @param[in] kernel_name_tag Tag to be attacehd to the kernel's name.
*/
- void configure(arm_gemm::GemmCommon<TypeInput, TypeOutput> *kernel, std::string kernel_name_tag)
+ void configure(arm_gemm::GemmCommon<TypeInput, TypeWeight, TypeOutput> *kernel, std::string kernel_name_tag)
{
ARM_COMPUTE_ERROR_ON_NULLPTR((reinterpret_cast<void *>(kernel)));
_kernel = kernel;
@@ -131,10 +131,10 @@ public:
}
private:
- arm_gemm::GemmCommon<TypeInput, TypeOutput> *_kernel;
- std::string _name;
+ arm_gemm::GemmCommon<TypeInput, TypeWeight, TypeOutput> *_kernel;
+ std::string _name;
};
} // namespace kernel
} // namespace cpu
} // namespace arm_compute
-#endif /* ARM_COMPUTE_ASSEMBLY_GEMM_KERNEL_WRAPPER_KERNEL_H */
+#endif // ACL_SRC_CPU_KERNELS_ASSEMBLY_CPUGEMMASSEMBLYWRAPPERKERNEL_H
diff --git a/src/cpu/kernels/assembly/arm_gemm.hpp b/src/cpu/kernels/assembly/arm_gemm.hpp
index 941fed0ba8..cbc8be416e 100644
--- a/src/cpu/kernels/assembly/arm_gemm.hpp
+++ b/src/cpu/kernels/assembly/arm_gemm.hpp
@@ -277,8 +277,8 @@ struct Nothing
{
};
-template <typename Top, typename Tret>
-using UniqueGemmCommon = std::unique_ptr<GemmCommon<Top, Tret>>;
+template <typename Tlop, typename Trop, typename Tret>
+using UniqueGemmCommon = std::unique_ptr<GemmCommon<Tlop, Trop, Tret>>;
/* Low level API calls.
* These are implemented as 'GemmArgs' versions, or with the arguments explicitly listed. */
@@ -288,13 +288,13 @@ using UniqueGemmCommon = std::unique_ptr<GemmCommon<Top, Tret>>;
template <typename Top, typename Tret, class OutputStage = Nothing>
KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage & = {});
-template <typename Top, typename Tret, class OutputStage = Nothing>
-UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage & = {});
+template <typename Tlop, typename Trop, typename Tret, class OutputStage = Nothing>
+UniqueGemmCommon<Tlop, Trop, Tret> gemm(const GemmArgs &args, const OutputStage & = {});
-template <typename Top, typename Tret, class OutputStage = Nothing>
+template <typename Tlop, typename Trop, typename Tret, class OutputStage = Nothing>
std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage & = {});
-template <typename Top, typename Tret, class OutputStage = Nothing>
+template <typename Tlop, typename Trop, typename Tret, class OutputStage = Nothing>
bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const OutputStage & = {});
} // namespace arm_gemm
diff --git a/src/cpu/kernels/assembly/convolution_parameters.hpp b/src/cpu/kernels/assembly/convolution_parameters.hpp
index 0c1ae58902..09b73ca409 100644
--- a/src/cpu/kernels/assembly/convolution_parameters.hpp
+++ b/src/cpu/kernels/assembly/convolution_parameters.hpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2018-2021 Arm Limited.
+ * Copyright (c) 2018-2021, 2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,6 +21,10 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
+
+#ifndef ACL_SRC_CPU_KERNELS_ASSEMBLY_CONVOLUTION_PARAMETERS_HPP
+#define ACL_SRC_CPU_KERNELS_ASSEMBLY_CONVOLUTION_PARAMETERS_HPP
+
#pragma once
#include <cstdint>
@@ -57,9 +61,13 @@ struct ConvolutionParameters
int64_t output_stride_w;
int64_t output_stride_h;
// output_channels not included as they do not affect the input.
+ int64_t dilation_w;
+ int64_t dilation_h;
int64_t padding_top;
int64_t padding_left;
float padding_value;
};
} // namespace arm_gemm
+
+#endif // ACL_SRC_CPU_KERNELS_ASSEMBLY_CONVOLUTION_PARAMETERS_HPP
diff --git a/src/cpu/kernels/assembly/gemm_common.hpp b/src/cpu/kernels/assembly/gemm_common.hpp
index 45d1e43274..f693021fcb 100644
--- a/src/cpu/kernels/assembly/gemm_common.hpp
+++ b/src/cpu/kernels/assembly/gemm_common.hpp
@@ -189,7 +189,7 @@ public:
* 'set_arrays' to capture the provided arguments in protected class
* members, as essentially any implementation will need these.
*/
-template <typename To, typename Tr>
+template <typename To, typename Tw, typename Tr>
class GemmCommon : public IGemmCommon
{
protected:
@@ -197,7 +197,7 @@ protected:
int _lda = 0;
int _A_batch_stride = 0;
int _A_multi_stride = 0;
- const To *_Bptr = nullptr;
+ const Tw *_Bptr = nullptr;
int _ldb = 0;
int _B_multi_stride = 0;
Tr *_Cptr = nullptr;
@@ -214,7 +214,7 @@ public:
const int lda,
const int A_batch_stride,
const int A_multi_stride,
- const To *B,
+ const Tw *B,
const int ldb,
/* batches share B */ const int B_multi_stride,
Tr *C,
@@ -254,7 +254,7 @@ public:
const void *bias,
/* no row or batch stride needed */ const int bias_multi_stride) override
{
- set_arrays(static_cast<const To *>(A), lda, A_batch_stride, A_multi_stride, static_cast<const To *>(B), ldb,
+ set_arrays(static_cast<const To *>(A), lda, A_batch_stride, A_multi_stride, static_cast<const Tw *>(B), ldb,
B_multi_stride, static_cast<Tr *>(C), ldc, C_batch_stride, C_multi_stride,
static_cast<const Tr *>(bias), bias_multi_stride);
}
@@ -262,17 +262,17 @@ public:
/*** "Pretransposed" interface ***/
/* Compute col sums over all columns */
- virtual void requantize_bias(void *, const To *, const int, const int){};
+ virtual void requantize_bias(void *, const Tw *, const int, const int){};
/* Perform pretranspose - the void * passed in must remain allocated for the duration of any execute calls. */
/* Arguments are: output buffer pointer, source pointer, source row stride, source multi stride */
- virtual void pretranspose_B_array(void *, const To *, const int, const int, bool){};
+ virtual void pretranspose_B_array(void *, const Tw *, const int, const int, bool){};
/* Implementation of the void * overload which casts its arguments to the appropriate type. */
void pretranspose_B_array_generic(
void *out, const void *in, const int row_stride, const int multi_stride, bool transposed) override
{
- pretranspose_B_array(out, static_cast<const To *>(in), row_stride, multi_stride, transposed);
+ pretranspose_B_array(out, static_cast<const Tw *>(in), row_stride, multi_stride, transposed);
}
/* Threaded versions of the above.
@@ -280,7 +280,7 @@ public:
* just calls the non-threaded functions to do the work. This is valid as with window size of 1 the only
* legal values for start and end are 0 and 1 respectively. */
virtual void pretranspose_B_array_part(
- void *out, const To *in, const int row_stride, const int multi_stride, bool transposed, size_t, size_t)
+ void *out, const Tw *in, const int row_stride, const int multi_stride, bool transposed, size_t, size_t)
{
pretranspose_B_array(out, in, row_stride, multi_stride, transposed);
};
@@ -293,7 +293,7 @@ public:
size_t start,
size_t end) override
{
- pretranspose_B_array_part(out, static_cast<const To *>(in), row_stride, multi_stride, transposed, start, end);
+ pretranspose_B_array_part(out, static_cast<const Tw *>(in), row_stride, multi_stride, transposed, start, end);
}
/*** Indirect interface ***/
diff --git a/src/cpu/kernels/gemm_matrix_mul/generic/neon/fp16.cpp b/src/cpu/kernels/gemm_matrix_mul/generic/neon/fp16.cpp
index 60fda511e3..6a93be0618 100644
--- a/src/cpu/kernels/gemm_matrix_mul/generic/neon/fp16.cpp
+++ b/src/cpu/kernels/gemm_matrix_mul/generic/neon/fp16.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2022-2023 Arm Limited.
+ * Copyright (c) 2022-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -81,7 +81,7 @@ void vector_matrix_multiply_f16(
// window_end_x is computed above which may cause out-of-bound writes to the dst.
for (; x < (window_end_x - window_step_x); x += window_step_x)
{
- if (x > width_matrix_b)
+ if (x >= width_matrix_b)
{
return;
}
@@ -176,7 +176,7 @@ void vector_matrix_multiply_f16(
for (; x < window_end_x; ++x)
{
- if (x > width_matrix_b)
+ if (x >= width_matrix_b)
{
return;
}
diff --git a/src/cpu/operators/CpuConv2d.h b/src/cpu/operators/CpuConv2d.h
index 71b9e15dc1..0012ff6609 100644
--- a/src/cpu/operators/CpuConv2d.h
+++ b/src/cpu/operators/CpuConv2d.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021, 2023 Arm Limited.
+ * Copyright (c) 2017-2021, 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,6 +21,10 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
+
+#ifndef ACL_SRC_CPU_OPERATORS_CPUCONV2D_H
+#define ACL_SRC_CPU_OPERATORS_CPUCONV2D_H
+
#include "arm_compute/function_info/ActivationLayerInfo.h"
#include "src/core/common/Macros.h"
@@ -81,6 +85,7 @@ public:
* |F16 |F16 |F16 |F16 |
* |F32 |F32 |F32 |F32 |
* |QASYMM8 |QASYMM8 |S32 |QASYMM8 |
+ * |QASYMM8 |QASYMM8_SIGNED |S32 |QASYMM8 |
* |QASYMM8 |QSYMM8_PER_CHANNEL |S32 |QASYMM8 |
* |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |QASYMM8_SIGNED |
* |QASYMM8_SIGNED |QSYMM8_PER_CHANNEL |S32 |QASYMM8_SIGNED |
@@ -89,7 +94,7 @@ public:
* while every optional dimension from 4 and above represent a batch of inputs.
* Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32.
* @param[in] weights Weights tensor info. Weights are 4D tensor with dimensions [kernel_x, kernel_y, IFM, OFM].
- * Data type supported: Same as @p src, also could be QSYMM8_PER_CHANNEL if input is QASYMM8/QASYMM8_SIGNED.
+ * Data type supported: Same as @p src, also could be QSYMM8_PER_CHANNEL or QASYMM8_SIGNED if input is QASYMM8/QASYMM8_SIGNED.
* @param[in] biases Biases tensor info. Shared biases supported. Biases are 1D tensor with dimensions [OFM].
* Data type supported: Same as @p src, except for input of QASYMM8/QASYMM8_SIGNED type where biases should be of S32 type.
* @param[out] dst Destination tensor info. 3 lower dimensions represent a single output [width, height, OFM], while the rest represent batch of outputs.
@@ -135,7 +140,7 @@ public:
* while every optional dimension from 4 and above represent a batch of inputs.
* Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32.
* @param[in] weights Weights tensor info. Weights are 4D tensor with dimensions [kernel_x, kernel_y, IFM, OFM].
- * Data type supported:Same as @p src, also could be QSYMM8_PER_CHANNEL if input is QASYMM8/QASYMM8_SIGNED.
+ * Data type supported:Same as @p src, also could be QSYMM8_PER_CHANNEL or QASYMM8_SIGNED if input is QASYMM8/QASYMM8_SIGNED.
* @param[in] dst Destination tensor info. 3 lower dimensions represent a single output [width, height, OFM], while the rest represent batch of outputs.
* Data types supported: Same as @p src.
* @param[in] conv_info Contains padding and stride information described in @ref PadStrideInfo.
@@ -167,3 +172,5 @@ private:
};
} // namespace cpu
} // namespace arm_compute
+
+#endif // ACL_SRC_CPU_OPERATORS_CPUCONV2D_H
diff --git a/src/cpu/operators/CpuGemmConv2d.h b/src/cpu/operators/CpuGemmConv2d.h
index 48a0d11107..ae5023a71a 100644
--- a/src/cpu/operators/CpuGemmConv2d.h
+++ b/src/cpu/operators/CpuGemmConv2d.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -76,6 +76,7 @@ public:
* |F32 |F32 |F32 |F32 |
* |BFLOAT16 |BFLOAT16 |BFLOAT16 |BFLOAT16 |
* |QASYMM8 |QASYMM8 |S32 |QASYMM8 |
+ * |QASYMM8 |QASYMM8_SIGNED |S32 |QASYMM8 |
* |QASYMM8 |QSYMM8_PER_CHANNEL |S32 |QASYMM8 |
* |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |QASYMM8_SIGNED |
* |QASYMM8_SIGNED |QSYMM8_PER_CHANNEL |S32 |QASYMM8_SIGNED |
diff --git a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
index f3396fbb5c..1dbe3d8a31 100644
--- a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
+++ b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
@@ -128,24 +128,31 @@ void CpuGemmLowpMatrixMultiplyCore::configure(
_reshape_b_only_on_first_run;
_gemm_info = gemm_info;
- // Offset kernel is need if offset is non-zero or it may change (i.e. dynamic).
- // It is not needed if the datatype is symmetric, because there is no offset
- bool a_offset_kernel_needed = _a_offset != 0 || a->quantization_info().is_dynamic();
- bool b_offset_kernel_needed = _b_offset != 0 || b->quantization_info().is_dynamic();
+ const ITensorInfo *a_to_use = a;
- _asm_glue = std::make_unique<cpu::CpuGemmAssemblyDispatch>();
+ // Initialize assembly kernel meta-data
+ const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
- const ITensorInfo *a_to_use = a;
+ const int32_t offset_correction = 128;
+ const DataType dt = DataType::QASYMM8_SIGNED;
+ const UniformQuantizationInfo iqinfo = a_to_use->quantization_info().uniform();
+
+ _signed_a = a_to_use->clone()->set_data_type(dt).set_quantization_info(
+ QuantizationInfo(iqinfo.scale, iqinfo.offset + offset_correction));
+
+ // If inputs are mixed-sign but this machine does not support mixed sign kernels,
+ // flip the sign so matched-sign kernels can be used.
+ if (!_flip_signedness && a->data_type() == DataType::QASYMM8 && b->data_type() == DataType::QASYMM8_SIGNED &&
+ !bool(CpuGemmAssemblyDispatch::validate(a_to_use, b, c, dst, asm_info)))
+ {
+ _flip_signedness = true;
+ }
+
+ _asm_glue = std::make_unique<cpu::CpuGemmAssemblyDispatch>();
// Convert to QASYMM8 -> QASYMM8_SIGNED and back
if (_flip_signedness)
{
- const int32_t offset_correction = 128;
- const DataType dt = DataType::QASYMM8_SIGNED;
- const UniformQuantizationInfo iqinfo = a_to_use->quantization_info().uniform();
-
- _signed_a = a_to_use->clone()->set_data_type(dt).set_quantization_info(
- QuantizationInfo(iqinfo.scale, iqinfo.offset + offset_correction));
_convert_to_signed_asymm = std::make_unique<kernels::CpuConvertQuantizedSignednessKernel>();
_convert_to_signed_asymm->configure(a_to_use, &_signed_a);
a_to_use = &_signed_a;
@@ -166,6 +173,11 @@ void CpuGemmLowpMatrixMultiplyCore::configure(
matrix_a = &_signed_a;
}
+ // Offset kernel is need if offset is non-zero or it may change (i.e. dynamic).
+ // It is not needed if the datatype is symmetric, because there is no offset
+ bool a_offset_kernel_needed = _a_offset != 0 || a->quantization_info().is_dynamic();
+ bool b_offset_kernel_needed = _b_offset != 0 || b->quantization_info().is_dynamic();
+
// If GEMMLowpOutputStage != NONE, fuse the offset contribution with the output stage
if (info.gemmlowp_output_stage().type != GEMMLowpOutputStageType::NONE)
{
@@ -173,8 +185,6 @@ void CpuGemmLowpMatrixMultiplyCore::configure(
_mm_result_s32 = TensorInfo(dst->tensor_shape(), 1, DataType::S32);
}
- // Initialize assembly kernel meta-data
- const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
#ifdef __aarch64__
if (!(!b->are_values_constant() &&
b->tensor_shape().z() > 1)) // Disable batch matmul as optimized GeMM handles batching differently.
@@ -375,10 +385,6 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a,
int32_t a_offset = a->quantization_info().uniform().offset;
int32_t b_offset = b->quantization_info().uniform().offset;
- // Offset kernel is need if offset is non-zero or it may change (i.e. dynamic).
- bool a_offset_kernel_needed = a_offset != 0 || a->quantization_info().is_dynamic();
- bool b_offset_kernel_needed = b_offset != 0 || b->quantization_info().is_dynamic();
-
bool fuse_output_stage = info.gemmlowp_output_stage().type != GEMMLowpOutputStageType::NONE;
if (fuse_output_stage)
{
@@ -386,19 +392,31 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a,
a->clone()->set_tensor_shape(output->tensor_shape()).set_data_type(DataType::S32));
}
+ // Initialize assembly kernel meta-data
+ const AsmGemmInfo asm_info = init_assembly_metadata(info);
+
// Convert QASYMM8->QASYMM8_SIGNED
- TensorInfo signed_a{};
+ const int32_t offset_correction = 128;
+ const DataType dt = DataType::QASYMM8_SIGNED;
+ const UniformQuantizationInfo iqinfo = a_to_use->quantization_info().uniform();
+
+ TensorInfo signed_a = a_to_use->clone()->set_data_type(dt).set_quantization_info(
+ QuantizationInfo(iqinfo.scale, iqinfo.offset + offset_correction));
TensorInfo signed_output{};
- bool flip_signedness = is_data_type_quantized_per_channel(b->data_type()) &&
+
+ bool flip_signedness = is_data_type_quantized_per_channel(b->data_type()) &&
(a->data_type() == DataType::QASYMM8) && info.reshape_b_only_on_first_run();
- if (flip_signedness)
+
+ // If inputs are mixed-sign but this machine does not support mixed sign kernels,
+ // flip the sign so matched-sign kernels can be used.
+ if (!flip_signedness && a->data_type() == DataType::QASYMM8 && b->data_type() == DataType::QASYMM8_SIGNED &&
+ !bool(CpuGemmAssemblyDispatch::validate(a_to_use, b, c, output, asm_info)))
{
- const int32_t offset_correction = 128;
- const DataType dt = DataType::QASYMM8_SIGNED;
- const UniformQuantizationInfo iqinfo = a_to_use->quantization_info().uniform();
+ flip_signedness = true;
+ }
- signed_a = a_to_use->clone()->set_data_type(dt).set_quantization_info(
- QuantizationInfo(iqinfo.scale, iqinfo.offset + offset_correction));
+ if (flip_signedness)
+ {
ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuConvertQuantizedSignednessKernel::validate(a_to_use, &signed_a));
a_to_use = &signed_a;
a_offset = signed_a.quantization_info().uniform().offset;
@@ -418,8 +436,9 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a,
matrix_a_info = &signed_a;
}
- // Initialize assembly kernel meta-data
- const AsmGemmInfo asm_info = init_assembly_metadata(info);
+ // Offset kernel is need if offset is non-zero or it may change (i.e. dynamic).
+ bool a_offset_kernel_needed = a_offset != 0 || a->quantization_info().is_dynamic();
+ bool b_offset_kernel_needed = b_offset != 0 || b->quantization_info().is_dynamic();
// Check if we need to run the optimized assembly kernel
bool run_optimised = false;
diff --git a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h
index 38121c9bb4..11fe6f9ef0 100644
--- a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h
+++ b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h
@@ -81,11 +81,13 @@ public:
* |src0 |src1 |src2 |dst |
* |:--------------|:------------------|:--------|:--------------|
* |QASYMM8 |QASYMM8 |S32 |QASYMM8 |
+ * |QASYMM8 |QASYMM8_SIGNED |S32 |QASYMM8 |
* |QASYMM8 |QSYMM8_PER_CHANNEL |S32 |QASYMM8 |
* |QASYMM8 |QSYMM8 |S32 |QASYMM8 |
* |QASYMM8 |QASYMM8 |S32 |S32 |
* |QASYMM8 |QSYMM8_PER_CHANNEL |S32 |S32 |
* |QASYMM8 |QSYMM8 |S32 |S32 |
+ * |QASYMM8 |QASYMM8_SIGNED |F32 |F32 |
* |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |QASYMM8_SIGNED |
* |QASYMM8_SIGNED |QSYMM8_PER_CHANNEL |S32 |QASYMM8_SIGNED |
* |QASYMM8_SIGNED |QSYMM8 |S32 |QASYMM8_SIGNED |
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index a4c856bb8f..156a798d50 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -45,6 +45,7 @@ namespace
/** Run pretranspose_B_array in parallel (1D static scheduling)
*
* @tparam TypeInput
+ * @tparam TypeWeight
* @tparam TypeOutput
*
* @param[in] gemm_asm GemmCommon kernel to run
@@ -54,14 +55,14 @@ namespace
* @param[in] src_multi_stride Stride in z ("multi")
* @param[in] num_threads Number of threads to run this method. Must be >= 1
*/
-template <typename TypeInput, typename TypeOutput>
-void run_parallel_pretranspose_B_array(arm_gemm::GemmCommon<TypeInput, TypeOutput> *gemm_asm,
- ITensor *dst,
- const TypeInput *src,
- int src_ld,
- int src_multi_stride,
- unsigned int num_threads,
- bool transpose)
+template <typename TypeInput, typename TypeWeight, typename TypeOutput>
+void run_parallel_pretranspose_B_array(arm_gemm::GemmCommon<TypeInput, TypeWeight, TypeOutput> *gemm_asm,
+ ITensor *dst,
+ const TypeWeight *src,
+ int src_ld,
+ int src_multi_stride,
+ unsigned int num_threads,
+ bool transpose)
{
ARM_COMPUTE_ERROR_ON(gemm_asm == nullptr);
ARM_COMPUTE_ERROR_ON(num_threads == 0);
@@ -91,14 +92,6 @@ using namespace arm_compute::experimental;
namespace
{
-struct free_delete
-{
- void operator()(void *x)
- {
- free(x);
- }
-};
-
struct Params
{
unsigned int M;
@@ -113,14 +106,13 @@ struct Params
Params extract_parameters(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, const AsmGemmInfo &info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
- Params p;
- p.M = d->tensor_shape().y();
- p.K = a->tensor_shape().x();
- p.N = d->tensor_shape().x();
- p.batches = 1;
- p.multis = 1;
- p.sections = 1;
- p.indirect = false;
+ Params p{/* M */ static_cast<unsigned int>(d->tensor_shape().y()),
+ /* N */ static_cast<unsigned int>(d->tensor_shape().x()),
+ /* K */ static_cast<unsigned int>(a->tensor_shape().x()),
+ /* batches */ 1,
+ /* multis */ 1,
+ /* sections */ 1,
+ /* indirect */ false};
if (info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect)
{
@@ -172,13 +164,10 @@ IScheduler::Hints scheduling_hint_heuristic(arm_gemm::GemmMethod method, DataTyp
}
/** Fallback in case ACL doesn't have a function */
-template <typename TypeInput, typename TypeOutput, class OutputStage = arm_gemm::Nothing>
+template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage = arm_gemm::Nothing>
class Fallback : public CpuGemmAssemblyDispatch::IFallback
{
public:
- /** Destructor */
- ~Fallback() = default;
-
/** Initialise the functions's input and output.
*
* @param[in] a Input tensor containing the Matrix A.
@@ -222,7 +211,9 @@ public:
bool isVarWeightsKernel() const override
{
if (!_gemm_kernel_asm)
+ {
return false;
+ }
const arm_compute::WeightFormat wf =
assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format);
return wf != arm_compute::WeightFormat::UNSPECIFIED && wf != arm_compute::WeightFormat::ANY;
@@ -251,7 +242,7 @@ private:
/** Operator to transpose B before gemm or pretranspose_B_array*/
std::unique_ptr<CpuTranspose> _pre_pretranspose_b{nullptr};
/** Assembly Gemm kernel */
- std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{nullptr};
+ std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeWeight, TypeOutput>> _gemm_kernel_asm{nullptr};
/** Optimised Arm® Neon™ kernel */
std::unique_ptr<INEKernel> _optimised_kernel{nullptr};
/** Assembly GEMM workspace tensor info */
@@ -273,22 +264,22 @@ private:
/** Per channel quantization multipliers */
std::vector<int32_t> _multipliers{};
/** Indirect buffer */
- std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{};
- std::unique_ptr<const TypeInput *, free_delete> _indirect_buf{};
- std::vector<TypeInput> _indirect_pad{};
- arm_gemm::ConvolutionParameters _cp{};
- experimental::MemoryRequirements _aux_mem{Count};
- bool _B_pretranspose_required{false};
- bool _is_b_constant{true};
- bool _is_c_constant{true};
- bool _run_pre_pretranspose_b{false};
- bool _B_pre_pretranspose_required{false};
+ std::vector<const TypeInput *const *> _indirect_arg{};
+ std::vector<const TypeInput *> _indirect_buf{};
+ std::vector<TypeInput> _indirect_pad{};
+ arm_gemm::ConvolutionParameters _cp{};
+ experimental::MemoryRequirements _aux_mem{Count};
+ bool _B_pretranspose_required{false};
+ bool _is_b_constant{true};
+ bool _is_c_constant{true};
+ bool _run_pre_pretranspose_b{false};
+ bool _B_pre_pretranspose_required{false};
};
-template <typename TypeInput, typename TypeOutput, class OutputStage>
+template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage>
std::tuple<bool, const int32_t *, const int32_t *, const int32_t *>
-Fallback<TypeInput, TypeOutput, OutputStage>::set_requantize_data(const std::vector<int32_t> &shifts,
- const std::vector<int32_t> &multipliers)
+Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::set_requantize_data(const std::vector<int32_t> &shifts,
+ const std::vector<int32_t> &multipliers)
{
_multipliers = multipliers;
_shifts = shifts;
@@ -305,8 +296,8 @@ Fallback<TypeInput, TypeOutput, OutputStage>::set_requantize_data(const std::vec
return std::make_tuple(need_left, left_shifts.data(), right_shifts.data(), _multipliers.data());
}
-template <typename TypeInput, typename TypeOutput, class OutputStage>
-void Fallback<TypeInput, TypeOutput, OutputStage>::prepare_indirect_buffer(ITensorPack &tensors)
+template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage>
+void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::prepare_indirect_buffer(ITensorPack &tensors)
{
auto a = tensors.get_const_tensor(TensorType::ACL_SRC_0);
const TypeInput *A_ptr = reinterpret_cast<TypeInput *>(a->buffer());
@@ -343,14 +334,12 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare_indirect_buffer(ITens
if (input_x < 0 || input_x >= _cp.input_width || input_y < 0 || input_y >= _cp.input_height)
{
- _indirect_buf
- .get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] =
+ _indirect_buf[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] =
_indirect_pad.data();
}
else
{
- _indirect_buf
- .get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] =
+ _indirect_buf[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] =
A_ptr + (m * multi_stride_A + b * batch_stride_A + input_xy * stride_A);
}
}
@@ -361,11 +350,11 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare_indirect_buffer(ITens
}
}
-template <typename TypeInput, typename TypeOutput, class OutputStage>
-void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITensorInfo *a,
- const ITensorInfo *b,
- const ITensorInfo *d,
- const AsmGemmInfo &info)
+template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage>
+void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::configure_indirect(const ITensorInfo *a,
+ const ITensorInfo *b,
+ const ITensorInfo *d,
+ const AsmGemmInfo &info)
{
ARM_COMPUTE_ERROR_ON(!(info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect));
@@ -375,13 +364,13 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITen
zeropad = a->quantization_info().uniform().offset;
}
- const int64_t input_width = static_cast<int64_t>(a->tensor_shape()[1]);
- const int64_t input_height = static_cast<int64_t>(a->tensor_shape()[2]);
- const int64_t input_channels = static_cast<int64_t>(a->tensor_shape()[0]);
- const int64_t kernel_width = static_cast<int64_t>(b->tensor_shape()[2]);
- const int64_t kernel_height = static_cast<int64_t>(b->tensor_shape()[3]);
- const int64_t output_width = static_cast<int64_t>(d->tensor_shape()[1]);
- const int64_t output_height = static_cast<int64_t>(d->tensor_shape()[2]);
+ const auto input_width = static_cast<int64_t>(a->tensor_shape()[1]);
+ const auto input_height = static_cast<int64_t>(a->tensor_shape()[2]);
+ const auto input_channels = static_cast<int64_t>(a->tensor_shape()[0]);
+ const auto kernel_width = static_cast<int64_t>(b->tensor_shape()[2]);
+ const auto kernel_height = static_cast<int64_t>(b->tensor_shape()[3]);
+ const auto output_width = static_cast<int64_t>(d->tensor_shape()[1]);
+ const auto output_height = static_cast<int64_t>(d->tensor_shape()[2]);
_cp = {input_width,
input_height,
@@ -392,6 +381,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITen
output_height,
info.ps_info.stride().first,
info.ps_info.stride().second,
+ 1,
+ 1,
info.padding_top,
info.padding_left,
zeropad};
@@ -414,10 +405,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITen
const int multi_size = batch_size * batches;
const size_t multi_stride = multi_size / sizeof(TypeInputPtr);
- _indirect_buf = std::unique_ptr<const TypeInput *, free_delete>(
- reinterpret_cast<const TypeInput **>(malloc(multi_size * multis)));
- _indirect_arg = std::unique_ptr<const TypeInput *const *, free_delete>(
- reinterpret_cast<const TypeInput *const **>(malloc(sizeof(TypeInput **) * kernel_hw * multis * batches)));
+ _indirect_buf = std::vector<const TypeInput *>(multi_size * multis);
+ _indirect_arg = std::vector<const TypeInput *const *>(sizeof(TypeInput **) * kernel_hw * multis * batches);
_indirect_pad = std::vector<TypeInput>(_cp.input_channels, TypeInput(zeropad));
// Set indirect argument
@@ -428,29 +417,28 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITen
{
for (int64_t kernel_xy = 0; kernel_xy < kernel_hw; kernel_xy++)
{
- (_indirect_arg.get())[pos++] =
- _indirect_buf.get() + m * multi_stride + b * batch_stride + kernel_xy * output_hw;
+ _indirect_arg[pos++] = &_indirect_buf[m * multi_stride + b * batch_stride + kernel_xy * output_hw];
}
}
}
- _gemm_kernel_asm->set_indirect_parameters(a->tensor_shape()[0], _indirect_arg.get());
+ _gemm_kernel_asm->set_indirect_parameters(a->tensor_shape()[0], _indirect_arg.data());
}
}
-template <typename TypeInput, typename TypeOutput, class OutputStage>
-void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *a,
- const ITensorInfo *b,
- const ITensorInfo *c,
- ITensorInfo *d,
- arm_gemm::GemmArgs args,
- const AsmGemmInfo &gemm_info,
- const OutputStage &os)
+template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage>
+void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::configure(const ITensorInfo *a,
+ const ITensorInfo *b,
+ const ITensorInfo *c,
+ ITensorInfo *d,
+ arm_gemm::GemmArgs args,
+ const AsmGemmInfo &gemm_info,
+ const OutputStage &os)
{
_is_b_constant = b->are_values_constant();
_is_c_constant = c ? c->are_values_constant() : true;
- _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput, OutputStage>(args, os);
+ _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeWeight, TypeOutput, OutputStage>(args, os);
if (_gemm_kernel_asm == nullptr)
{
//configuration not supported: Leave function unconfigured:
@@ -460,7 +448,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *
arm_gemm::GemmConfig gemm_cfg = _gemm_kernel_asm->get_config();
// arm_compute wrapper for the Gemm object (see above)
- auto acl_gemm_wrapper = std::make_unique<kernel::CpuGemmAssemblyWrapperKernel<TypeInput, TypeOutput>>();
+ auto acl_gemm_wrapper = std::make_unique<kernel::CpuGemmAssemblyWrapperKernel<TypeInput, TypeWeight, TypeOutput>>();
ARM_COMPUTE_ERROR_ON(acl_gemm_wrapper == nullptr);
acl_gemm_wrapper->configure(_gemm_kernel_asm.get(), gemm_cfg.filter);
const size_t workspace_size = _gemm_kernel_asm->get_working_size();
@@ -549,8 +537,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *
}
}
-template <typename TypeInput, typename TypeOutput, class OutputStage>
-void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
+template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage>
+void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
{
if (!_is_prepared)
{
@@ -588,17 +576,17 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
// Fixed format kernels need no pretranspose.
ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(
assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format)));
- const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size();
- const auto in1_ptr = reinterpret_cast<const TypeInput *>(b_to_use->buffer() +
- b_to_use->info()->offset_first_element_in_bytes());
- const int multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size();
+ const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size();
+ const auto in1_ptr = reinterpret_cast<const TypeWeight *>(
+ b_to_use->buffer() + b_to_use->info()->offset_first_element_in_bytes());
+ const int multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size();
CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, false);
ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr);
const bool kernel_supports_transpose = _gemm_kernel_asm->B_pretranspose_supports_transpose();
- run_parallel_pretranspose_B_array<TypeInput, TypeOutput>(
+ run_parallel_pretranspose_B_array<TypeInput, TypeWeight, TypeOutput>(
_gemm_kernel_asm.get(), pretranspose.get(), in1_ptr, ldb, multi_stride_b,
NEScheduler::get().num_threads(), _B_pre_pretranspose_required && kernel_supports_transpose);
@@ -616,20 +604,20 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
}
}
-template <typename TypeInput, typename TypeOutput, class OutputStage>
-bool Fallback<TypeInput, TypeOutput, OutputStage>::is_configured() const
+template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage>
+bool Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::is_configured() const
{
return _optimised_kernel != nullptr;
}
-template <typename TypeInput, typename TypeOutput, class OutputStage>
-experimental::MemoryRequirements Fallback<TypeInput, TypeOutput, OutputStage>::workspace() const
+template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage>
+experimental::MemoryRequirements Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::workspace() const
{
return _aux_mem;
}
-template <typename TypeInput, typename TypeOutput, class OutputStage>
-void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
+template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage>
+void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::run(ITensorPack &tensors)
{
auto a = tensors.get_const_tensor(TensorType::ACL_SRC_0);
auto b = tensors.get_const_tensor(TensorType::ACL_SRC_1);
@@ -663,8 +651,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
const int multi_stride_d = d->info()->strides_in_bytes()[d_multi_idx] / d->info()->element_size();
auto in0_ptr = reinterpret_cast<const TypeInput *>(a->buffer() + a->info()->offset_first_element_in_bytes());
- const TypeInput *in1_ptr = nullptr;
- auto out_ptr = reinterpret_cast<TypeOutput *>(d->buffer() + d->info()->offset_first_element_in_bytes());
+ const TypeWeight *in1_ptr = nullptr;
+ auto out_ptr = reinterpret_cast<TypeOutput *>(d->buffer() + d->info()->offset_first_element_in_bytes());
const ITensor *b_to_use = b;
@@ -686,8 +674,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
{
ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size();
multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size();
- in1_ptr =
- reinterpret_cast<const TypeInput *>(b_to_use->buffer() + b_to_use->info()->offset_first_element_in_bytes());
+ in1_ptr = reinterpret_cast<const TypeWeight *>(b_to_use->buffer() +
+ b_to_use->info()->offset_first_element_in_bytes());
}
// If necessary, run pretranspose every time if either weights or biases are non-constant
@@ -706,8 +694,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(
assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format)));
const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size();
- const auto b_ptr = reinterpret_cast<const TypeInput *>(b_to_use->buffer() +
- b_to_use->info()->offset_first_element_in_bytes());
+ const auto b_ptr = reinterpret_cast<const TypeWeight *>(b_to_use->buffer() +
+ b_to_use->info()->offset_first_element_in_bytes());
const int multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size();
CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, true);
@@ -720,7 +708,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
else
{
const bool kernel_supports_transpose = _gemm_kernel_asm->B_pretranspose_supports_transpose();
- run_parallel_pretranspose_B_array<TypeInput, TypeOutput>(
+ run_parallel_pretranspose_B_array<TypeInput, TypeWeight, TypeOutput>(
_gemm_kernel_asm.get(), pretranspose.get(), b_ptr, ldb, multi_stride_b,
NEScheduler::get().num_threads(), _B_pre_pretranspose_required && kernel_supports_transpose);
}
@@ -744,7 +732,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
if (split_dim != IScheduler::split_dimensions_all)
{
// Make sure the kernel does not expect more threads than we can actually spawn
- const unsigned int num_iterations = _optimised_kernel.get()->window().num_iterations(split_dim);
+ const unsigned int num_iterations = _optimised_kernel->window().num_iterations(split_dim);
num_threads = std::min(num_iterations, num_threads);
}
_gemm_kernel_asm->set_nthreads(num_threads);
@@ -775,7 +763,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
NEScheduler::get().schedule(_optimised_kernel.get(), scheduling_hint);
}
-template <typename TypeInput, typename TypeOutput>
+template <typename TypeInput, typename TypeWeight, typename TypeOutput>
void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm,
const ITensorInfo *a,
const ITensorInfo *b,
@@ -794,12 +782,12 @@ void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_ge
info.fixed_format, info.fast_mode, info.accumulate, &cfg);
// Create arm_gemm fallback
- auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput>>();
+ auto fallback = std::make_unique<Fallback<TypeInput, TypeWeight, TypeOutput>>();
fallback->configure(a, b, c, d, args, info);
arm_gemm = std::move(fallback);
}
-template <typename TypeInput, typename TypeOutput>
+template <typename TypeInput, typename TypeWeight, typename TypeOutput>
void create_arm_gemm_dequant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm,
const ITensorInfo *a,
const ITensorInfo *b,
@@ -820,7 +808,7 @@ void create_arm_gemm_dequant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback>
info.fixed_format, info.fast_mode, info.accumulate, &cfg);
// Create arm_gemm fallback
- auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::DequantizeFloat>>();
+ auto fallback = std::make_unique<Fallback<TypeInput, TypeWeight, TypeOutput, arm_gemm::DequantizeFloat>>();
// Configure requantization info
const GEMMLowpOutputStageInfo os_info = info.output_stage;
@@ -832,7 +820,7 @@ void create_arm_gemm_dequant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback>
arm_gemm = std::move(fallback);
}
-template <typename TypeInput, typename TypeOutput>
+template <typename TypeInput, typename TypeWeight, typename TypeOutput>
void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm,
const ITensorInfo *a,
const ITensorInfo *b,
@@ -852,7 +840,7 @@ void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &
info.fixed_format, info.fast_mode, info.accumulate, &cfg);
// Create arm_gemm fallback
- auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::Requantize32>>();
+ auto fallback = std::make_unique<Fallback<TypeInput, TypeWeight, TypeOutput, arm_gemm::Requantize32>>();
// Configure requantization info
const int32_t negation = info.negated_offsets ? 1 : -1;
@@ -905,12 +893,12 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format);
arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads,
info.fixed_format, info.fast_mode, info.accumulate, &cfg);
- // TODO: Incorporate info.transpose_b COMPMID-6595
+ // TODO(COMPMID-6595): Incorporate info.transpose_b
switch (a->data_type())
{
case DataType::F32:
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
+ !(arm_gemm::has_opt_gemm<float, float, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for F32 input");
break;
#ifdef __aarch64__
@@ -919,13 +907,22 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
if (d->data_type() == DataType::S32)
{
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<uint8_t, uint32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
+ !(arm_gemm::has_opt_gemm<uint8_t, uint8_t, uint32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args,
+ {})),
"We could not find an optimized kernel for U8/QASYMM8 input and U32 output");
}
+ else if (b->data_type() == DataType::QASYMM8_SIGNED)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ !(arm_gemm::has_opt_gemm<uint8_t, int8_t, uint8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf,
+ args, {})),
+ "We could not find an optimized kernel for U8 input with S8 weights and U8 output");
+ }
else
{
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})),
+ !(arm_gemm::has_opt_gemm<uint8_t, uint8_t, uint8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf,
+ args, {})),
"We could not find an optimized kernel for U8 input and U8 output");
}
break;
@@ -934,13 +931,15 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
if (d->data_type() == DataType::S32)
{
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
+ !(arm_gemm::has_opt_gemm<int8_t, int8_t, int32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args,
+ {})),
"We could not find an optimized kernel for S8/QASYMM8_SIGNED input and S32 output");
}
else
{
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})),
+ !(arm_gemm::has_opt_gemm<int8_t, int8_t, int8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args,
+ {})),
"We could not find an optimized kernel for S8 input and S8 output");
}
break;
@@ -952,13 +951,15 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
if (d->data_type() == DataType::BFLOAT16)
{
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<bfloat16, bfloat16, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
+ !(arm_gemm::has_opt_gemm<bfloat16, bfloat16, bfloat16, arm_gemm::Nothing>(arm_gemm_expected_wf,
+ args, {})),
"We could not find an optimized kernel for BFLOAT16 input and BFLOAT16 output");
}
else
{
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<bfloat16, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
+ !(arm_gemm::has_opt_gemm<bfloat16, bfloat16, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args,
+ {})),
"We could not find an optimized kernel for BFLOAT16 input and F32 output");
}
break;
@@ -968,7 +969,8 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
#if defined(ENABLE_FP16_KERNELS)
case DataType::F16:
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
+ !(arm_gemm::has_opt_gemm<float16_t, float16_t, float16_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args,
+ {})),
"We could not find an optimized kernel for F16 input and F16 output");
break;
#endif /* ENABLE_FP16_KERNELS */
@@ -1009,7 +1011,7 @@ Status CpuGemmAssemblyDispatch::validate(
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(a, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(b, DataType::BFLOAT16);
}
- else
+ else if (!(a->data_type() == DataType::QASYMM8 && b->data_type() == DataType::QASYMM8_SIGNED))
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);
}
@@ -1024,12 +1026,13 @@ Status CpuGemmAssemblyDispatch::validate(
"Only U32 output supported for U8 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32,
"Only S32 output supported for S8 input");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 &&
- (d->data_type() != DataType::QASYMM8 && d->data_type() != DataType::S32),
- "Only QASYMM8/S32 output supported for QASYMM8 input");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ a->data_type() == DataType::QASYMM8 &&
+ (d->data_type() != DataType::QASYMM8 && d->data_type() != DataType::S32 && d->data_type() != DataType::F32),
+ "Only QASYMM8/S32/F32 output supported for QASYMM8 input");
arm_compute::WeightFormat expected_weight_format = arm_compute::WeightFormat::UNSPECIFIED;
const Status ret = CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, info);
- if ((bool)ret && expected_weight_format != arm_compute::WeightFormat::ANY)
+ if (bool(ret) && expected_weight_format != arm_compute::WeightFormat::ANY)
{
// Correctness check: if the format expected by the kernel is
// not "any", make sure that the one found matches the format
@@ -1062,33 +1065,44 @@ void CpuGemmAssemblyDispatch::configure(
switch (a->data_type())
{
case DataType::F32:
- create_arm_gemm<float, float>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm<float, float, float>(_arm_gemm, a, b, c, d, act, info);
break;
#ifdef __aarch64__
case DataType::U8:
case DataType::QASYMM8:
- if (d->data_type() == DataType::S32)
+ if (b->data_type() == DataType::S8 || b->data_type() == DataType::QASYMM8_SIGNED)
+ {
+ if (d->data_type() == DataType::F32)
+ {
+ create_arm_gemm_dequant<uint8_t, int8_t, float>(_arm_gemm, a, b, c, d, act, info);
+ }
+ else
+ {
+ create_arm_gemm_quant<uint8_t, int8_t, uint8_t>(_arm_gemm, a, b, c, d, act, info);
+ }
+ }
+ else if (d->data_type() == DataType::S32)
{
- create_arm_gemm<uint8_t, uint32_t>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm<uint8_t, uint8_t, uint32_t>(_arm_gemm, a, b, c, d, act, info);
}
else
{
- create_arm_gemm_quant<uint8_t, uint8_t>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm_quant<uint8_t, uint8_t, uint8_t>(_arm_gemm, a, b, c, d, act, info);
}
break;
case DataType::S8:
case DataType::QASYMM8_SIGNED:
if (d->data_type() == DataType::S32)
{
- create_arm_gemm<int8_t, int32_t>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm<int8_t, int8_t, int32_t>(_arm_gemm, a, b, c, d, act, info);
}
else if (d->data_type() == DataType::F32)
{
- create_arm_gemm_dequant<int8_t, float>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm_dequant<int8_t, int8_t, float>(_arm_gemm, a, b, c, d, act, info);
}
else
{
- create_arm_gemm_quant<int8_t, int8_t>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm_quant<int8_t, int8_t, int8_t>(_arm_gemm, a, b, c, d, act, info);
}
break;
#endif /* __aarch64__ */
@@ -1096,17 +1110,17 @@ void CpuGemmAssemblyDispatch::configure(
case DataType::BFLOAT16:
if (d->data_type() == DataType::BFLOAT16)
{
- create_arm_gemm<bfloat16, bfloat16>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm<bfloat16, bfloat16, bfloat16>(_arm_gemm, a, b, c, d, act, info);
}
else
{
- create_arm_gemm<bfloat16, float>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm<bfloat16, bfloat16, float>(_arm_gemm, a, b, c, d, act, info);
}
break;
#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */
#ifdef ENABLE_FP16_KERNELS
case DataType::F16:
- create_arm_gemm<float16_t, float16_t>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm<float16_t, float16_t, float16_t>(_arm_gemm, a, b, c, d, act, info);
break;
#endif /* ENABLE_FP16_KERNELS */
default: