aboutsummaryrefslogtreecommitdiff
path: root/src/cpu
diff options
context:
space:
mode:
authorGunes Bayir <gunes.bayir@arm.com>2024-07-02 15:45:01 +0100
committerGunes Bayir <gunes.bayir@arm.com>2024-07-02 16:00:11 +0000
commita3f238a44d9f306c77be0177f13d22ae3f3bcc57 (patch)
tree44bf40fb59fb8c4452d65d25e3a967c035bc6863 /src/cpu
parentf92b0fffa0d32dc08340c1abfa1a7f09c6e53795 (diff)
downloadComputeLibrary-a3f238a44d9f306c77be0177f13d22ae3f3bcc57.tar.gz
Revert "Update CPU kernels and add mixed sign GEMM support"
This reverts commit fc94f4d23abd4bc427b701f54ad85282e9ec7872 and 5d6fff041ade7eb44af0945867212f3979be3d3e (because the latter fixes a build failure caused by the former) Change-Id: I7d07fea8307e9a7033b30874bbb14ba9202b23d8 Signed-off-by: Gunes Bayir <gunes.bayir@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/11815 Benchmark: Arm Jenkins <bsgcomp@arm.com> Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Adnan AlSinan <adnan.alsinan@arm.com>
Diffstat (limited to 'src/cpu')
-rw-r--r--src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.cpp4
-rw-r--r--src/cpu/kernels/CpuKernelSelectionTypes.h2
-rw-r--r--src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h8
-rw-r--r--src/cpu/kernels/assembly/arm_gemm.hpp12
-rw-r--r--src/cpu/kernels/assembly/convolution_parameters.hpp2
-rw-r--r--src/cpu/kernels/assembly/gemm_common.hpp18
-rw-r--r--src/cpu/operators/CpuConv2d.h5
-rw-r--r--src/cpu/operators/CpuGemmConv2d.h1
-rw-r--r--src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp75
-rw-r--r--src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h2
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp270
11 files changed, 178 insertions, 221 deletions
diff --git a/src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.cpp b/src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.cpp
index 87340e566e..5b88735e7a 100644
--- a/src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.cpp
+++ b/src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.cpp
@@ -684,10 +684,6 @@ Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, cons
DataType::U8);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(dst, 1, DataType::S32);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(src0->data_type() == DataType::QASYMM8_SIGNED &&
- src1->data_type() == DataType::QASYMM8,
- "QASYMM8_SIGNED input with QASYMM8 weights not supported");
-
TensorShape in0_shape = src0->tensor_shape();
TensorShape in1_shape = src1->tensor_shape();
TensorShape out_shape = dst->tensor_shape();
diff --git a/src/cpu/kernels/CpuKernelSelectionTypes.h b/src/cpu/kernels/CpuKernelSelectionTypes.h
index 03a474de53..7c1e4772a6 100644
--- a/src/cpu/kernels/CpuKernelSelectionTypes.h
+++ b/src/cpu/kernels/CpuKernelSelectionTypes.h
@@ -105,7 +105,7 @@ struct SoftmaxKernelDataTypeISASelectorData
cpuinfo::CpuIsaInfo isa;
bool is_log;
int axis;
- uint64_t sme2_vector_length;
+ unsigned long sme2_vector_length;
};
// Selector pointer types
diff --git a/src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h b/src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h
index 72fafca1bb..e2a27675b3 100644
--- a/src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h
+++ b/src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h
@@ -52,7 +52,7 @@ namespace kernel
*
*
*/
-template <typename TypeInput, typename TypeWeight, typename TypeOutput>
+template <typename TypeInput, typename TypeOutput>
class CpuGemmAssemblyWrapperKernel final : public INEKernel
{
public:
@@ -101,7 +101,7 @@ public:
* @param[in] kernel Pointer to an assembly kernel implementation.
* @param[in] kernel_name_tag Tag to be attacehd to the kernel's name.
*/
- void configure(arm_gemm::GemmCommon<TypeInput, TypeWeight, TypeOutput> *kernel, std::string kernel_name_tag)
+ void configure(arm_gemm::GemmCommon<TypeInput, TypeOutput> *kernel, std::string kernel_name_tag)
{
ARM_COMPUTE_ERROR_ON_NULLPTR((reinterpret_cast<void *>(kernel)));
_kernel = kernel;
@@ -131,8 +131,8 @@ public:
}
private:
- arm_gemm::GemmCommon<TypeInput, TypeWeight, TypeOutput> *_kernel;
- std::string _name;
+ arm_gemm::GemmCommon<TypeInput, TypeOutput> *_kernel;
+ std::string _name;
};
} // namespace kernel
} // namespace cpu
diff --git a/src/cpu/kernels/assembly/arm_gemm.hpp b/src/cpu/kernels/assembly/arm_gemm.hpp
index cbc8be416e..941fed0ba8 100644
--- a/src/cpu/kernels/assembly/arm_gemm.hpp
+++ b/src/cpu/kernels/assembly/arm_gemm.hpp
@@ -277,8 +277,8 @@ struct Nothing
{
};
-template <typename Tlop, typename Trop, typename Tret>
-using UniqueGemmCommon = std::unique_ptr<GemmCommon<Tlop, Trop, Tret>>;
+template <typename Top, typename Tret>
+using UniqueGemmCommon = std::unique_ptr<GemmCommon<Top, Tret>>;
/* Low level API calls.
* These are implemented as 'GemmArgs' versions, or with the arguments explicitly listed. */
@@ -288,13 +288,13 @@ using UniqueGemmCommon = std::unique_ptr<GemmCommon<Tlop, Trop, Tret>>;
template <typename Top, typename Tret, class OutputStage = Nothing>
KernelDescription get_gemm_method(const GemmArgs &args, const OutputStage & = {});
-template <typename Tlop, typename Trop, typename Tret, class OutputStage = Nothing>
-UniqueGemmCommon<Tlop, Trop, Tret> gemm(const GemmArgs &args, const OutputStage & = {});
+template <typename Top, typename Tret, class OutputStage = Nothing>
+UniqueGemmCommon<Top, Tret> gemm(const GemmArgs &args, const OutputStage & = {});
-template <typename Tlop, typename Trop, typename Tret, class OutputStage = Nothing>
+template <typename Top, typename Tret, class OutputStage = Nothing>
std::vector<KernelDescription> get_compatible_kernels(const GemmArgs &args, const OutputStage & = {});
-template <typename Tlop, typename Trop, typename Tret, class OutputStage = Nothing>
+template <typename Top, typename Tret, class OutputStage = Nothing>
bool has_opt_gemm(WeightFormat &weight_format, const GemmArgs &args, const OutputStage & = {});
} // namespace arm_gemm
diff --git a/src/cpu/kernels/assembly/convolution_parameters.hpp b/src/cpu/kernels/assembly/convolution_parameters.hpp
index 09b73ca409..a6cf96344c 100644
--- a/src/cpu/kernels/assembly/convolution_parameters.hpp
+++ b/src/cpu/kernels/assembly/convolution_parameters.hpp
@@ -61,8 +61,6 @@ struct ConvolutionParameters
int64_t output_stride_w;
int64_t output_stride_h;
// output_channels not included as they do not affect the input.
- int64_t dilation_w;
- int64_t dilation_h;
int64_t padding_top;
int64_t padding_left;
float padding_value;
diff --git a/src/cpu/kernels/assembly/gemm_common.hpp b/src/cpu/kernels/assembly/gemm_common.hpp
index f693021fcb..45d1e43274 100644
--- a/src/cpu/kernels/assembly/gemm_common.hpp
+++ b/src/cpu/kernels/assembly/gemm_common.hpp
@@ -189,7 +189,7 @@ public:
* 'set_arrays' to capture the provided arguments in protected class
* members, as essentially any implementation will need these.
*/
-template <typename To, typename Tw, typename Tr>
+template <typename To, typename Tr>
class GemmCommon : public IGemmCommon
{
protected:
@@ -197,7 +197,7 @@ protected:
int _lda = 0;
int _A_batch_stride = 0;
int _A_multi_stride = 0;
- const Tw *_Bptr = nullptr;
+ const To *_Bptr = nullptr;
int _ldb = 0;
int _B_multi_stride = 0;
Tr *_Cptr = nullptr;
@@ -214,7 +214,7 @@ public:
const int lda,
const int A_batch_stride,
const int A_multi_stride,
- const Tw *B,
+ const To *B,
const int ldb,
/* batches share B */ const int B_multi_stride,
Tr *C,
@@ -254,7 +254,7 @@ public:
const void *bias,
/* no row or batch stride needed */ const int bias_multi_stride) override
{
- set_arrays(static_cast<const To *>(A), lda, A_batch_stride, A_multi_stride, static_cast<const Tw *>(B), ldb,
+ set_arrays(static_cast<const To *>(A), lda, A_batch_stride, A_multi_stride, static_cast<const To *>(B), ldb,
B_multi_stride, static_cast<Tr *>(C), ldc, C_batch_stride, C_multi_stride,
static_cast<const Tr *>(bias), bias_multi_stride);
}
@@ -262,17 +262,17 @@ public:
/*** "Pretransposed" interface ***/
/* Compute col sums over all columns */
- virtual void requantize_bias(void *, const Tw *, const int, const int){};
+ virtual void requantize_bias(void *, const To *, const int, const int){};
/* Perform pretranspose - the void * passed in must remain allocated for the duration of any execute calls. */
/* Arguments are: output buffer pointer, source pointer, source row stride, source multi stride */
- virtual void pretranspose_B_array(void *, const Tw *, const int, const int, bool){};
+ virtual void pretranspose_B_array(void *, const To *, const int, const int, bool){};
/* Implementation of the void * overload which casts its arguments to the appropriate type. */
void pretranspose_B_array_generic(
void *out, const void *in, const int row_stride, const int multi_stride, bool transposed) override
{
- pretranspose_B_array(out, static_cast<const Tw *>(in), row_stride, multi_stride, transposed);
+ pretranspose_B_array(out, static_cast<const To *>(in), row_stride, multi_stride, transposed);
}
/* Threaded versions of the above.
@@ -280,7 +280,7 @@ public:
* just calls the non-threaded functions to do the work. This is valid as with window size of 1 the only
* legal values for start and end are 0 and 1 respectively. */
virtual void pretranspose_B_array_part(
- void *out, const Tw *in, const int row_stride, const int multi_stride, bool transposed, size_t, size_t)
+ void *out, const To *in, const int row_stride, const int multi_stride, bool transposed, size_t, size_t)
{
pretranspose_B_array(out, in, row_stride, multi_stride, transposed);
};
@@ -293,7 +293,7 @@ public:
size_t start,
size_t end) override
{
- pretranspose_B_array_part(out, static_cast<const Tw *>(in), row_stride, multi_stride, transposed, start, end);
+ pretranspose_B_array_part(out, static_cast<const To *>(in), row_stride, multi_stride, transposed, start, end);
}
/*** Indirect interface ***/
diff --git a/src/cpu/operators/CpuConv2d.h b/src/cpu/operators/CpuConv2d.h
index 0012ff6609..3f98e71896 100644
--- a/src/cpu/operators/CpuConv2d.h
+++ b/src/cpu/operators/CpuConv2d.h
@@ -85,7 +85,6 @@ public:
* |F16 |F16 |F16 |F16 |
* |F32 |F32 |F32 |F32 |
* |QASYMM8 |QASYMM8 |S32 |QASYMM8 |
- * |QASYMM8 |QASYMM8_SIGNED |S32 |QASYMM8 |
* |QASYMM8 |QSYMM8_PER_CHANNEL |S32 |QASYMM8 |
* |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |QASYMM8_SIGNED |
* |QASYMM8_SIGNED |QSYMM8_PER_CHANNEL |S32 |QASYMM8_SIGNED |
@@ -94,7 +93,7 @@ public:
* while every optional dimension from 4 and above represent a batch of inputs.
* Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32.
* @param[in] weights Weights tensor info. Weights are 4D tensor with dimensions [kernel_x, kernel_y, IFM, OFM].
- * Data type supported: Same as @p src, also could be QSYMM8_PER_CHANNEL or QASYMM8_SIGNED if input is QASYMM8/QASYMM8_SIGNED.
+ * Data type supported: Same as @p src, also could be QSYMM8_PER_CHANNEL if input is QASYMM8/QASYMM8_SIGNED.
* @param[in] biases Biases tensor info. Shared biases supported. Biases are 1D tensor with dimensions [OFM].
* Data type supported: Same as @p src, except for input of QASYMM8/QASYMM8_SIGNED type where biases should be of S32 type.
* @param[out] dst Destination tensor info. 3 lower dimensions represent a single output [width, height, OFM], while the rest represent batch of outputs.
@@ -140,7 +139,7 @@ public:
* while every optional dimension from 4 and above represent a batch of inputs.
* Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32.
* @param[in] weights Weights tensor info. Weights are 4D tensor with dimensions [kernel_x, kernel_y, IFM, OFM].
- * Data type supported:Same as @p src, also could be QSYMM8_PER_CHANNEL or QASYMM8_SIGNED if input is QASYMM8/QASYMM8_SIGNED.
+ * Data type supported:Same as @p src, also could be QSYMM8_PER_CHANNEL if input is QASYMM8/QASYMM8_SIGNED.
* @param[in] dst Destination tensor info. 3 lower dimensions represent a single output [width, height, OFM], while the rest represent batch of outputs.
* Data types supported: Same as @p src.
* @param[in] conv_info Contains padding and stride information described in @ref PadStrideInfo.
diff --git a/src/cpu/operators/CpuGemmConv2d.h b/src/cpu/operators/CpuGemmConv2d.h
index ae5023a71a..fa16ce860b 100644
--- a/src/cpu/operators/CpuGemmConv2d.h
+++ b/src/cpu/operators/CpuGemmConv2d.h
@@ -76,7 +76,6 @@ public:
* |F32 |F32 |F32 |F32 |
* |BFLOAT16 |BFLOAT16 |BFLOAT16 |BFLOAT16 |
* |QASYMM8 |QASYMM8 |S32 |QASYMM8 |
- * |QASYMM8 |QASYMM8_SIGNED |S32 |QASYMM8 |
* |QASYMM8 |QSYMM8_PER_CHANNEL |S32 |QASYMM8 |
* |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |QASYMM8_SIGNED |
* |QASYMM8_SIGNED |QSYMM8_PER_CHANNEL |S32 |QASYMM8_SIGNED |
diff --git a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
index 1dbe3d8a31..f3396fbb5c 100644
--- a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
+++ b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
@@ -128,31 +128,24 @@ void CpuGemmLowpMatrixMultiplyCore::configure(
_reshape_b_only_on_first_run;
_gemm_info = gemm_info;
- const ITensorInfo *a_to_use = a;
-
- // Initialize assembly kernel meta-data
- const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
-
- const int32_t offset_correction = 128;
- const DataType dt = DataType::QASYMM8_SIGNED;
- const UniformQuantizationInfo iqinfo = a_to_use->quantization_info().uniform();
-
- _signed_a = a_to_use->clone()->set_data_type(dt).set_quantization_info(
- QuantizationInfo(iqinfo.scale, iqinfo.offset + offset_correction));
-
- // If inputs are mixed-sign but this machine does not support mixed sign kernels,
- // flip the sign so matched-sign kernels can be used.
- if (!_flip_signedness && a->data_type() == DataType::QASYMM8 && b->data_type() == DataType::QASYMM8_SIGNED &&
- !bool(CpuGemmAssemblyDispatch::validate(a_to_use, b, c, dst, asm_info)))
- {
- _flip_signedness = true;
- }
+ // Offset kernel is need if offset is non-zero or it may change (i.e. dynamic).
+ // It is not needed if the datatype is symmetric, because there is no offset
+ bool a_offset_kernel_needed = _a_offset != 0 || a->quantization_info().is_dynamic();
+ bool b_offset_kernel_needed = _b_offset != 0 || b->quantization_info().is_dynamic();
_asm_glue = std::make_unique<cpu::CpuGemmAssemblyDispatch>();
+ const ITensorInfo *a_to_use = a;
+
// Convert to QASYMM8 -> QASYMM8_SIGNED and back
if (_flip_signedness)
{
+ const int32_t offset_correction = 128;
+ const DataType dt = DataType::QASYMM8_SIGNED;
+ const UniformQuantizationInfo iqinfo = a_to_use->quantization_info().uniform();
+
+ _signed_a = a_to_use->clone()->set_data_type(dt).set_quantization_info(
+ QuantizationInfo(iqinfo.scale, iqinfo.offset + offset_correction));
_convert_to_signed_asymm = std::make_unique<kernels::CpuConvertQuantizedSignednessKernel>();
_convert_to_signed_asymm->configure(a_to_use, &_signed_a);
a_to_use = &_signed_a;
@@ -173,11 +166,6 @@ void CpuGemmLowpMatrixMultiplyCore::configure(
matrix_a = &_signed_a;
}
- // Offset kernel is need if offset is non-zero or it may change (i.e. dynamic).
- // It is not needed if the datatype is symmetric, because there is no offset
- bool a_offset_kernel_needed = _a_offset != 0 || a->quantization_info().is_dynamic();
- bool b_offset_kernel_needed = _b_offset != 0 || b->quantization_info().is_dynamic();
-
// If GEMMLowpOutputStage != NONE, fuse the offset contribution with the output stage
if (info.gemmlowp_output_stage().type != GEMMLowpOutputStageType::NONE)
{
@@ -185,6 +173,8 @@ void CpuGemmLowpMatrixMultiplyCore::configure(
_mm_result_s32 = TensorInfo(dst->tensor_shape(), 1, DataType::S32);
}
+ // Initialize assembly kernel meta-data
+ const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
#ifdef __aarch64__
if (!(!b->are_values_constant() &&
b->tensor_shape().z() > 1)) // Disable batch matmul as optimized GeMM handles batching differently.
@@ -385,6 +375,10 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a,
int32_t a_offset = a->quantization_info().uniform().offset;
int32_t b_offset = b->quantization_info().uniform().offset;
+ // Offset kernel is need if offset is non-zero or it may change (i.e. dynamic).
+ bool a_offset_kernel_needed = a_offset != 0 || a->quantization_info().is_dynamic();
+ bool b_offset_kernel_needed = b_offset != 0 || b->quantization_info().is_dynamic();
+
bool fuse_output_stage = info.gemmlowp_output_stage().type != GEMMLowpOutputStageType::NONE;
if (fuse_output_stage)
{
@@ -392,31 +386,19 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a,
a->clone()->set_tensor_shape(output->tensor_shape()).set_data_type(DataType::S32));
}
- // Initialize assembly kernel meta-data
- const AsmGemmInfo asm_info = init_assembly_metadata(info);
-
// Convert QASYMM8->QASYMM8_SIGNED
- const int32_t offset_correction = 128;
- const DataType dt = DataType::QASYMM8_SIGNED;
- const UniformQuantizationInfo iqinfo = a_to_use->quantization_info().uniform();
-
- TensorInfo signed_a = a_to_use->clone()->set_data_type(dt).set_quantization_info(
- QuantizationInfo(iqinfo.scale, iqinfo.offset + offset_correction));
+ TensorInfo signed_a{};
TensorInfo signed_output{};
-
- bool flip_signedness = is_data_type_quantized_per_channel(b->data_type()) &&
+ bool flip_signedness = is_data_type_quantized_per_channel(b->data_type()) &&
(a->data_type() == DataType::QASYMM8) && info.reshape_b_only_on_first_run();
-
- // If inputs are mixed-sign but this machine does not support mixed sign kernels,
- // flip the sign so matched-sign kernels can be used.
- if (!flip_signedness && a->data_type() == DataType::QASYMM8 && b->data_type() == DataType::QASYMM8_SIGNED &&
- !bool(CpuGemmAssemblyDispatch::validate(a_to_use, b, c, output, asm_info)))
- {
- flip_signedness = true;
- }
-
if (flip_signedness)
{
+ const int32_t offset_correction = 128;
+ const DataType dt = DataType::QASYMM8_SIGNED;
+ const UniformQuantizationInfo iqinfo = a_to_use->quantization_info().uniform();
+
+ signed_a = a_to_use->clone()->set_data_type(dt).set_quantization_info(
+ QuantizationInfo(iqinfo.scale, iqinfo.offset + offset_correction));
ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuConvertQuantizedSignednessKernel::validate(a_to_use, &signed_a));
a_to_use = &signed_a;
a_offset = signed_a.quantization_info().uniform().offset;
@@ -436,9 +418,8 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a,
matrix_a_info = &signed_a;
}
- // Offset kernel is need if offset is non-zero or it may change (i.e. dynamic).
- bool a_offset_kernel_needed = a_offset != 0 || a->quantization_info().is_dynamic();
- bool b_offset_kernel_needed = b_offset != 0 || b->quantization_info().is_dynamic();
+ // Initialize assembly kernel meta-data
+ const AsmGemmInfo asm_info = init_assembly_metadata(info);
// Check if we need to run the optimized assembly kernel
bool run_optimised = false;
diff --git a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h
index 11fe6f9ef0..38121c9bb4 100644
--- a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h
+++ b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h
@@ -81,13 +81,11 @@ public:
* |src0 |src1 |src2 |dst |
* |:--------------|:------------------|:--------|:--------------|
* |QASYMM8 |QASYMM8 |S32 |QASYMM8 |
- * |QASYMM8 |QASYMM8_SIGNED |S32 |QASYMM8 |
* |QASYMM8 |QSYMM8_PER_CHANNEL |S32 |QASYMM8 |
* |QASYMM8 |QSYMM8 |S32 |QASYMM8 |
* |QASYMM8 |QASYMM8 |S32 |S32 |
* |QASYMM8 |QSYMM8_PER_CHANNEL |S32 |S32 |
* |QASYMM8 |QSYMM8 |S32 |S32 |
- * |QASYMM8 |QASYMM8_SIGNED |F32 |F32 |
* |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |QASYMM8_SIGNED |
* |QASYMM8_SIGNED |QSYMM8_PER_CHANNEL |S32 |QASYMM8_SIGNED |
* |QASYMM8_SIGNED |QSYMM8 |S32 |QASYMM8_SIGNED |
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 785837dbc6..fb9bc15212 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -45,7 +45,6 @@ namespace
/** Run pretranspose_B_array in parallel (1D static scheduling)
*
* @tparam TypeInput
- * @tparam TypeWeight
* @tparam TypeOutput
*
* @param[in] gemm_asm GemmCommon kernel to run
@@ -55,14 +54,14 @@ namespace
* @param[in] src_multi_stride Stride in z ("multi")
* @param[in] num_threads Number of threads to run this method. Must be >= 1
*/
-template <typename TypeInput, typename TypeWeight, typename TypeOutput>
-void run_parallel_pretranspose_B_array(arm_gemm::GemmCommon<TypeInput, TypeWeight, TypeOutput> *gemm_asm,
- ITensor *dst,
- const TypeWeight *src,
- int src_ld,
- int src_multi_stride,
- unsigned int num_threads,
- bool transpose)
+template <typename TypeInput, typename TypeOutput>
+void run_parallel_pretranspose_B_array(arm_gemm::GemmCommon<TypeInput, TypeOutput> *gemm_asm,
+ ITensor *dst,
+ const TypeInput *src,
+ int src_ld,
+ int src_multi_stride,
+ unsigned int num_threads,
+ bool transpose)
{
ARM_COMPUTE_ERROR_ON(gemm_asm == nullptr);
ARM_COMPUTE_ERROR_ON(num_threads == 0);
@@ -92,6 +91,14 @@ using namespace arm_compute::experimental;
namespace
{
+struct free_delete
+{
+ void operator()(void *x)
+ {
+ free(x);
+ }
+};
+
struct Params
{
unsigned int M;
@@ -106,13 +113,14 @@ struct Params
Params extract_parameters(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, const AsmGemmInfo &info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
- Params p{/* M */ static_cast<unsigned int>(d->tensor_shape().y()),
- /* N */ static_cast<unsigned int>(d->tensor_shape().x()),
- /* K */ static_cast<unsigned int>(a->tensor_shape().x()),
- /* batches */ 1,
- /* multis */ 1,
- /* sections */ 1,
- /* indirect */ false};
+ Params p;
+ p.M = d->tensor_shape().y();
+ p.K = a->tensor_shape().x();
+ p.N = d->tensor_shape().x();
+ p.batches = 1;
+ p.multis = 1;
+ p.sections = 1;
+ p.indirect = false;
if (info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect)
{
@@ -164,10 +172,13 @@ IScheduler::Hints scheduling_hint_heuristic(arm_gemm::GemmMethod method, DataTyp
}
/** Fallback in case ACL doesn't have a function */
-template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage = arm_gemm::Nothing>
+template <typename TypeInput, typename TypeOutput, class OutputStage = arm_gemm::Nothing>
class Fallback : public CpuGemmAssemblyDispatch::IFallback
{
public:
+ /** Destructor */
+ ~Fallback() = default;
+
/** Initialise the functions's input and output.
*
* @param[in] a Input tensor containing the Matrix A.
@@ -211,9 +222,7 @@ public:
bool isVarWeightsKernel() const override
{
if (!_gemm_kernel_asm)
- {
return false;
- }
const arm_compute::WeightFormat wf =
assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format);
return wf != arm_compute::WeightFormat::UNSPECIFIED && wf != arm_compute::WeightFormat::ANY;
@@ -242,7 +251,7 @@ private:
/** Operator to transpose B before gemm or pretranspose_B_array*/
std::unique_ptr<CpuTranspose> _pre_pretranspose_b{nullptr};
/** Assembly Gemm kernel */
- std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeWeight, TypeOutput>> _gemm_kernel_asm{nullptr};
+ std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{nullptr};
/** Optimised Arm® Neon™ kernel */
std::unique_ptr<INEKernel> _optimised_kernel{nullptr};
/** Assembly GEMM workspace tensor info */
@@ -264,22 +273,22 @@ private:
/** Per channel quantization multipliers */
std::vector<int32_t> _multipliers{};
/** Indirect buffer */
- std::vector<const TypeInput *const *> _indirect_arg{};
- std::vector<const TypeInput *> _indirect_buf{};
- std::vector<TypeInput> _indirect_pad{};
- arm_gemm::ConvolutionParameters _cp{};
- experimental::MemoryRequirements _aux_mem{Count};
- bool _B_pretranspose_required{false};
- bool _is_b_constant{true};
- bool _is_c_constant{true};
- bool _run_pre_pretranspose_b{false};
- bool _B_pre_pretranspose_required{false};
+ std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{};
+ std::unique_ptr<const TypeInput *, free_delete> _indirect_buf{};
+ std::vector<TypeInput> _indirect_pad{};
+ arm_gemm::ConvolutionParameters _cp{};
+ experimental::MemoryRequirements _aux_mem{Count};
+ bool _B_pretranspose_required{false};
+ bool _is_b_constant{true};
+ bool _is_c_constant{true};
+ bool _run_pre_pretranspose_b{false};
+ bool _B_pre_pretranspose_required{false};
};
-template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage>
+template <typename TypeInput, typename TypeOutput, class OutputStage>
std::tuple<bool, const int32_t *, const int32_t *, const int32_t *>
-Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::set_requantize_data(const std::vector<int32_t> &shifts,
- const std::vector<int32_t> &multipliers)
+Fallback<TypeInput, TypeOutput, OutputStage>::set_requantize_data(const std::vector<int32_t> &shifts,
+ const std::vector<int32_t> &multipliers)
{
_multipliers = multipliers;
_shifts = shifts;
@@ -296,8 +305,8 @@ Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::set_requantize_data(co
return std::make_tuple(need_left, left_shifts.data(), right_shifts.data(), _multipliers.data());
}
-template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage>
-void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::prepare_indirect_buffer(ITensorPack &tensors)
+template <typename TypeInput, typename TypeOutput, class OutputStage>
+void Fallback<TypeInput, TypeOutput, OutputStage>::prepare_indirect_buffer(ITensorPack &tensors)
{
auto a = tensors.get_const_tensor(TensorType::ACL_SRC_0);
const TypeInput *A_ptr = reinterpret_cast<TypeInput *>(a->buffer());
@@ -334,12 +343,14 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::prepare_indirect_
if (input_x < 0 || input_x >= _cp.input_width || input_y < 0 || input_y >= _cp.input_height)
{
- _indirect_buf[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] =
+ _indirect_buf
+ .get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] =
_indirect_pad.data();
}
else
{
- _indirect_buf[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] =
+ _indirect_buf
+ .get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] =
A_ptr + (m * multi_stride_A + b * batch_stride_A + input_xy * stride_A);
}
}
@@ -350,11 +361,11 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::prepare_indirect_
}
}
-template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage>
-void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::configure_indirect(const ITensorInfo *a,
- const ITensorInfo *b,
- const ITensorInfo *d,
- const AsmGemmInfo &info)
+template <typename TypeInput, typename TypeOutput, class OutputStage>
+void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITensorInfo *a,
+ const ITensorInfo *b,
+ const ITensorInfo *d,
+ const AsmGemmInfo &info)
{
ARM_COMPUTE_ERROR_ON(!(info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect));
@@ -364,13 +375,13 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::configure_indirec
zeropad = a->quantization_info().uniform().offset;
}
- const auto input_width = static_cast<int64_t>(a->tensor_shape()[1]);
- const auto input_height = static_cast<int64_t>(a->tensor_shape()[2]);
- const auto input_channels = static_cast<int64_t>(a->tensor_shape()[0]);
- const auto kernel_width = static_cast<int64_t>(b->tensor_shape()[2]);
- const auto kernel_height = static_cast<int64_t>(b->tensor_shape()[3]);
- const auto output_width = static_cast<int64_t>(d->tensor_shape()[1]);
- const auto output_height = static_cast<int64_t>(d->tensor_shape()[2]);
+ const int64_t input_width = static_cast<int64_t>(a->tensor_shape()[1]);
+ const int64_t input_height = static_cast<int64_t>(a->tensor_shape()[2]);
+ const int64_t input_channels = static_cast<int64_t>(a->tensor_shape()[0]);
+ const int64_t kernel_width = static_cast<int64_t>(b->tensor_shape()[2]);
+ const int64_t kernel_height = static_cast<int64_t>(b->tensor_shape()[3]);
+ const int64_t output_width = static_cast<int64_t>(d->tensor_shape()[1]);
+ const int64_t output_height = static_cast<int64_t>(d->tensor_shape()[2]);
_cp = {input_width,
input_height,
@@ -381,8 +392,6 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::configure_indirec
output_height,
info.ps_info.stride().first,
info.ps_info.stride().second,
- 1,
- 1,
info.padding_top,
info.padding_left,
zeropad};
@@ -405,8 +414,10 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::configure_indirec
const int multi_size = batch_size * batches;
const size_t multi_stride = multi_size / sizeof(TypeInputPtr);
- _indirect_buf = std::vector<const TypeInput *>(multi_size * multis);
- _indirect_arg = std::vector<const TypeInput *const *>(sizeof(TypeInput **) * kernel_hw * multis * batches);
+ _indirect_buf = std::unique_ptr<const TypeInput *, free_delete>(
+ reinterpret_cast<const TypeInput **>(malloc(multi_size * multis)));
+ _indirect_arg = std::unique_ptr<const TypeInput *const *, free_delete>(
+ reinterpret_cast<const TypeInput *const **>(malloc(sizeof(TypeInput **) * kernel_hw * multis * batches)));
_indirect_pad = std::vector<TypeInput>(_cp.input_channels, TypeInput(zeropad));
// Set indirect argument
@@ -417,28 +428,29 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::configure_indirec
{
for (int64_t kernel_xy = 0; kernel_xy < kernel_hw; kernel_xy++)
{
- _indirect_arg[pos++] = &_indirect_buf[m * multi_stride + b * batch_stride + kernel_xy * output_hw];
+ (_indirect_arg.get())[pos++] =
+ _indirect_buf.get() + m * multi_stride + b * batch_stride + kernel_xy * output_hw;
}
}
}
- _gemm_kernel_asm->set_indirect_parameters(a->tensor_shape()[0], _indirect_arg.data());
+ _gemm_kernel_asm->set_indirect_parameters(a->tensor_shape()[0], _indirect_arg.get());
}
}
-template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage>
-void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::configure(const ITensorInfo *a,
- const ITensorInfo *b,
- const ITensorInfo *c,
- ITensorInfo *d,
- arm_gemm::GemmArgs args,
- const AsmGemmInfo &gemm_info,
- const OutputStage &os)
+template <typename TypeInput, typename TypeOutput, class OutputStage>
+void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *a,
+ const ITensorInfo *b,
+ const ITensorInfo *c,
+ ITensorInfo *d,
+ arm_gemm::GemmArgs args,
+ const AsmGemmInfo &gemm_info,
+ const OutputStage &os)
{
_is_b_constant = b->are_values_constant();
_is_c_constant = c ? c->are_values_constant() : true;
- _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeWeight, TypeOutput, OutputStage>(args, os);
+ _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput, OutputStage>(args, os);
if (_gemm_kernel_asm == nullptr)
{
//configuration not supported: Leave function unconfigured:
@@ -448,7 +460,7 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::configure(const I
arm_gemm::GemmConfig gemm_cfg = _gemm_kernel_asm->get_config();
// arm_compute wrapper for the Gemm object (see above)
- auto acl_gemm_wrapper = std::make_unique<kernel::CpuGemmAssemblyWrapperKernel<TypeInput, TypeWeight, TypeOutput>>();
+ auto acl_gemm_wrapper = std::make_unique<kernel::CpuGemmAssemblyWrapperKernel<TypeInput, TypeOutput>>();
ARM_COMPUTE_ERROR_ON(acl_gemm_wrapper == nullptr);
acl_gemm_wrapper->configure(_gemm_kernel_asm.get(), gemm_cfg.filter);
const size_t workspace_size = _gemm_kernel_asm->get_working_size();
@@ -537,8 +549,8 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::configure(const I
}
}
-template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage>
-void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
+template <typename TypeInput, typename TypeOutput, class OutputStage>
+void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
{
if (!_is_prepared)
{
@@ -576,17 +588,17 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::prepare(ITensorPa
// Fixed format kernels need no pretranspose.
ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(
assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format)));
- const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size();
- const auto in1_ptr = reinterpret_cast<const TypeWeight *>(
- b_to_use->buffer() + b_to_use->info()->offset_first_element_in_bytes());
- const int multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size();
+ const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size();
+ const auto in1_ptr = reinterpret_cast<const TypeInput *>(b_to_use->buffer() +
+ b_to_use->info()->offset_first_element_in_bytes());
+ const int multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size();
CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, false);
ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr);
const bool kernel_supports_transpose = _gemm_kernel_asm->B_pretranspose_supports_transpose();
- run_parallel_pretranspose_B_array<TypeInput, TypeWeight, TypeOutput>(
+ run_parallel_pretranspose_B_array<TypeInput, TypeOutput>(
_gemm_kernel_asm.get(), pretranspose.get(), in1_ptr, ldb, multi_stride_b,
NEScheduler::get().num_threads(), _B_pre_pretranspose_required && kernel_supports_transpose);
@@ -604,20 +616,20 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::prepare(ITensorPa
}
}
-template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage>
-bool Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::is_configured() const
+template <typename TypeInput, typename TypeOutput, class OutputStage>
+bool Fallback<TypeInput, TypeOutput, OutputStage>::is_configured() const
{
return _optimised_kernel != nullptr;
}
-template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage>
-experimental::MemoryRequirements Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::workspace() const
+template <typename TypeInput, typename TypeOutput, class OutputStage>
+experimental::MemoryRequirements Fallback<TypeInput, TypeOutput, OutputStage>::workspace() const
{
return _aux_mem;
}
-template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage>
-void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::run(ITensorPack &tensors)
+template <typename TypeInput, typename TypeOutput, class OutputStage>
+void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
{
auto a = tensors.get_const_tensor(TensorType::ACL_SRC_0);
auto b = tensors.get_const_tensor(TensorType::ACL_SRC_1);
@@ -651,8 +663,8 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::run(ITensorPack &
const int multi_stride_d = d->info()->strides_in_bytes()[d_multi_idx] / d->info()->element_size();
auto in0_ptr = reinterpret_cast<const TypeInput *>(a->buffer() + a->info()->offset_first_element_in_bytes());
- const TypeWeight *in1_ptr = nullptr;
- auto out_ptr = reinterpret_cast<TypeOutput *>(d->buffer() + d->info()->offset_first_element_in_bytes());
+ const TypeInput *in1_ptr = nullptr;
+ auto out_ptr = reinterpret_cast<TypeOutput *>(d->buffer() + d->info()->offset_first_element_in_bytes());
const ITensor *b_to_use = b;
@@ -674,8 +686,8 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::run(ITensorPack &
{
ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size();
multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size();
- in1_ptr = reinterpret_cast<const TypeWeight *>(b_to_use->buffer() +
- b_to_use->info()->offset_first_element_in_bytes());
+ in1_ptr =
+ reinterpret_cast<const TypeInput *>(b_to_use->buffer() + b_to_use->info()->offset_first_element_in_bytes());
}
// If necessary, run pretranspose every time if either weights or biases are non-constant
@@ -694,8 +706,8 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::run(ITensorPack &
ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(
assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format)));
const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size();
- const auto b_ptr = reinterpret_cast<const TypeWeight *>(b_to_use->buffer() +
- b_to_use->info()->offset_first_element_in_bytes());
+ const auto b_ptr = reinterpret_cast<const TypeInput *>(b_to_use->buffer() +
+ b_to_use->info()->offset_first_element_in_bytes());
const int multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size();
CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, true);
@@ -708,7 +720,7 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::run(ITensorPack &
else
{
const bool kernel_supports_transpose = _gemm_kernel_asm->B_pretranspose_supports_transpose();
- run_parallel_pretranspose_B_array<TypeInput, TypeWeight, TypeOutput>(
+ run_parallel_pretranspose_B_array<TypeInput, TypeOutput>(
_gemm_kernel_asm.get(), pretranspose.get(), b_ptr, ldb, multi_stride_b,
NEScheduler::get().num_threads(), _B_pre_pretranspose_required && kernel_supports_transpose);
}
@@ -732,7 +744,7 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::run(ITensorPack &
if (split_dim != IScheduler::split_dimensions_all)
{
// Make sure the kernel does not expect more threads than we can actually spawn
- const unsigned int num_iterations = _optimised_kernel->window().num_iterations(split_dim);
+ const unsigned int num_iterations = _optimised_kernel.get()->window().num_iterations(split_dim);
num_threads = std::min(num_iterations, num_threads);
}
_gemm_kernel_asm->set_nthreads(num_threads);
@@ -763,7 +775,7 @@ void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::run(ITensorPack &
NEScheduler::get().schedule(_optimised_kernel.get(), scheduling_hint);
}
-template <typename TypeInput, typename TypeWeight, typename TypeOutput>
+template <typename TypeInput, typename TypeOutput>
void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm,
const ITensorInfo *a,
const ITensorInfo *b,
@@ -782,12 +794,12 @@ void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_ge
info.fixed_format, info.fast_mode, info.accumulate, &cfg);
// Create arm_gemm fallback
- auto fallback = std::make_unique<Fallback<TypeInput, TypeWeight, TypeOutput>>();
+ auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput>>();
fallback->configure(a, b, c, d, args, info);
arm_gemm = std::move(fallback);
}
-template <typename TypeInput, typename TypeWeight, typename TypeOutput>
+template <typename TypeInput, typename TypeOutput>
void create_arm_gemm_dequant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm,
const ITensorInfo *a,
const ITensorInfo *b,
@@ -808,7 +820,7 @@ void create_arm_gemm_dequant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback>
info.fixed_format, info.fast_mode, info.accumulate, &cfg);
// Create arm_gemm fallback
- auto fallback = std::make_unique<Fallback<TypeInput, TypeWeight, TypeOutput, arm_gemm::DequantizeFloat>>();
+ auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::DequantizeFloat>>();
// Configure requantization info
const GEMMLowpOutputStageInfo os_info = info.output_stage;
@@ -820,7 +832,7 @@ void create_arm_gemm_dequant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback>
arm_gemm = std::move(fallback);
}
-template <typename TypeInput, typename TypeWeight, typename TypeOutput>
+template <typename TypeInput, typename TypeOutput>
void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm,
const ITensorInfo *a,
const ITensorInfo *b,
@@ -840,7 +852,7 @@ void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &
info.fixed_format, info.fast_mode, info.accumulate, &cfg);
// Create arm_gemm fallback
- auto fallback = std::make_unique<Fallback<TypeInput, TypeWeight, TypeOutput, arm_gemm::Requantize32>>();
+ auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::Requantize32>>();
// Configure requantization info
const int32_t negation = info.negated_offsets ? 1 : -1;
@@ -893,12 +905,12 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format);
arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads,
info.fixed_format, info.fast_mode, info.accumulate, &cfg);
- // TODO(COMPMID-6595): Incorporate info.transpose_b
+ // TODO: Incorporate info.transpose_b COMPMID-6595
switch (a->data_type())
{
case DataType::F32:
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<float, float, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
+ !(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for F32 input");
break;
#ifdef __aarch64__
@@ -907,22 +919,13 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
if (d->data_type() == DataType::S32)
{
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<uint8_t, uint8_t, uint32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args,
- {})),
+ !(arm_gemm::has_opt_gemm<uint8_t, uint32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for U8/QASYMM8 input and U32 output");
}
- else if (b->data_type() == DataType::QASYMM8_SIGNED)
- {
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<uint8_t, int8_t, uint8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf,
- args, {})),
- "We could not find an optimized kernel for U8 input with S8 weights and U8 output");
- }
else
{
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<uint8_t, uint8_t, uint8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf,
- args, {})),
+ !(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for U8 input and U8 output");
}
break;
@@ -931,15 +934,13 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
if (d->data_type() == DataType::S32)
{
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<int8_t, int8_t, int32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args,
- {})),
+ !(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for S8/QASYMM8_SIGNED input and S32 output");
}
else
{
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<int8_t, int8_t, int8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args,
- {})),
+ !(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for S8 input and S8 output");
}
break;
@@ -951,15 +952,13 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
if (d->data_type() == DataType::BFLOAT16)
{
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<bfloat16, bfloat16, bfloat16, arm_gemm::Nothing>(arm_gemm_expected_wf,
- args, {})),
+ !(arm_gemm::has_opt_gemm<bfloat16, bfloat16, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for BFLOAT16 input and BFLOAT16 output");
}
else
{
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<bfloat16, bfloat16, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args,
- {})),
+ !(arm_gemm::has_opt_gemm<bfloat16, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for BFLOAT16 input and F32 output");
}
break;
@@ -969,8 +968,7 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
#if defined(ENABLE_FP16_KERNELS)
case DataType::F16:
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<float16_t, float16_t, float16_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args,
- {})),
+ !(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for F16 input and F16 output");
break;
#endif /* ENABLE_FP16_KERNELS */
@@ -1011,7 +1009,7 @@ Status CpuGemmAssemblyDispatch::validate(
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(a, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(b, DataType::BFLOAT16);
}
- else if (!(a->data_type() == DataType::QASYMM8 && b->data_type() == DataType::QASYMM8_SIGNED))
+ else
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);
}
@@ -1026,13 +1024,12 @@ Status CpuGemmAssemblyDispatch::validate(
"Only U32 output supported for U8 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32,
"Only S32 output supported for S8 input");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- a->data_type() == DataType::QASYMM8 &&
- (d->data_type() != DataType::QASYMM8 && d->data_type() != DataType::S32 && d->data_type() != DataType::F32),
- "Only QASYMM8/S32/F32 output supported for QASYMM8 input");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 &&
+ (d->data_type() != DataType::QASYMM8 && d->data_type() != DataType::S32),
+ "Only QASYMM8/S32 output supported for QASYMM8 input");
arm_compute::WeightFormat expected_weight_format = arm_compute::WeightFormat::UNSPECIFIED;
const Status ret = CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, info);
- if (bool(ret) && expected_weight_format != arm_compute::WeightFormat::ANY)
+ if ((bool)ret && expected_weight_format != arm_compute::WeightFormat::ANY)
{
// Correctness check: if the format expected by the kernel is
// not "any", make sure that the one found matches the format
@@ -1065,44 +1062,33 @@ void CpuGemmAssemblyDispatch::configure(
switch (a->data_type())
{
case DataType::F32:
- create_arm_gemm<float, float, float>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm<float, float>(_arm_gemm, a, b, c, d, act, info);
break;
#ifdef __aarch64__
case DataType::U8:
case DataType::QASYMM8:
- if (b->data_type() == DataType::S8 || b->data_type() == DataType::QASYMM8_SIGNED)
- {
- if (d->data_type() == DataType::F32)
- {
- create_arm_gemm_dequant<uint8_t, int8_t, float>(_arm_gemm, a, b, c, d, act, info);
- }
- else
- {
- create_arm_gemm_quant<uint8_t, int8_t, uint8_t>(_arm_gemm, a, b, c, d, act, info);
- }
- }
- else if (d->data_type() == DataType::S32)
+ if (d->data_type() == DataType::S32)
{
- create_arm_gemm<uint8_t, uint8_t, uint32_t>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm<uint8_t, uint32_t>(_arm_gemm, a, b, c, d, act, info);
}
else
{
- create_arm_gemm_quant<uint8_t, uint8_t, uint8_t>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm_quant<uint8_t, uint8_t>(_arm_gemm, a, b, c, d, act, info);
}
break;
case DataType::S8:
case DataType::QASYMM8_SIGNED:
if (d->data_type() == DataType::S32)
{
- create_arm_gemm<int8_t, int8_t, int32_t>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm<int8_t, int32_t>(_arm_gemm, a, b, c, d, act, info);
}
else if (d->data_type() == DataType::F32)
{
- create_arm_gemm_dequant<int8_t, int8_t, float>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm_dequant<int8_t, float>(_arm_gemm, a, b, c, d, act, info);
}
else
{
- create_arm_gemm_quant<int8_t, int8_t, int8_t>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm_quant<int8_t, int8_t>(_arm_gemm, a, b, c, d, act, info);
}
break;
#endif /* __aarch64__ */
@@ -1110,17 +1096,17 @@ void CpuGemmAssemblyDispatch::configure(
case DataType::BFLOAT16:
if (d->data_type() == DataType::BFLOAT16)
{
- create_arm_gemm<bfloat16, bfloat16, bfloat16>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm<bfloat16, bfloat16>(_arm_gemm, a, b, c, d, act, info);
}
else
{
- create_arm_gemm<bfloat16, bfloat16, float>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm<bfloat16, float>(_arm_gemm, a, b, c, d, act, info);
}
break;
#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */
#ifdef ENABLE_FP16_KERNELS
case DataType::F16:
- create_arm_gemm<float16_t, float16_t, float16_t>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm<float16_t, float16_t>(_arm_gemm, a, b, c, d, act, info);
break;
#endif /* ENABLE_FP16_KERNELS */
default: