aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators')
-rw-r--r--src/cpu/operators/CpuConv2d.h13
-rw-r--r--src/cpu/operators/CpuGemm.cpp4
-rw-r--r--src/cpu/operators/CpuGemmConv2d.cpp10
-rw-r--r--src/cpu/operators/CpuGemmConv2d.h3
-rw-r--r--src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp75
-rw-r--r--src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h2
-rw-r--r--src/cpu/operators/CpuWinogradConv2d.cpp4
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp274
8 files changed, 217 insertions, 168 deletions
diff --git a/src/cpu/operators/CpuConv2d.h b/src/cpu/operators/CpuConv2d.h
index 71b9e15dc1..0012ff6609 100644
--- a/src/cpu/operators/CpuConv2d.h
+++ b/src/cpu/operators/CpuConv2d.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2021, 2023 Arm Limited.
+ * Copyright (c) 2017-2021, 2023-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -21,6 +21,10 @@
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
* SOFTWARE.
*/
+
+#ifndef ACL_SRC_CPU_OPERATORS_CPUCONV2D_H
+#define ACL_SRC_CPU_OPERATORS_CPUCONV2D_H
+
#include "arm_compute/function_info/ActivationLayerInfo.h"
#include "src/core/common/Macros.h"
@@ -81,6 +85,7 @@ public:
* |F16 |F16 |F16 |F16 |
* |F32 |F32 |F32 |F32 |
* |QASYMM8 |QASYMM8 |S32 |QASYMM8 |
+ * |QASYMM8 |QASYMM8_SIGNED |S32 |QASYMM8 |
* |QASYMM8 |QSYMM8_PER_CHANNEL |S32 |QASYMM8 |
* |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |QASYMM8_SIGNED |
* |QASYMM8_SIGNED |QSYMM8_PER_CHANNEL |S32 |QASYMM8_SIGNED |
@@ -89,7 +94,7 @@ public:
* while every optional dimension from 4 and above represent a batch of inputs.
* Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32.
* @param[in] weights Weights tensor info. Weights are 4D tensor with dimensions [kernel_x, kernel_y, IFM, OFM].
- * Data type supported: Same as @p src, also could be QSYMM8_PER_CHANNEL if input is QASYMM8/QASYMM8_SIGNED.
+ * Data type supported: Same as @p src, also could be QSYMM8_PER_CHANNEL or QASYMM8_SIGNED if input is QASYMM8/QASYMM8_SIGNED.
* @param[in] biases Biases tensor info. Shared biases supported. Biases are 1D tensor with dimensions [OFM].
* Data type supported: Same as @p src, except for input of QASYMM8/QASYMM8_SIGNED type where biases should be of S32 type.
* @param[out] dst Destination tensor info. 3 lower dimensions represent a single output [width, height, OFM], while the rest represent batch of outputs.
@@ -135,7 +140,7 @@ public:
* while every optional dimension from 4 and above represent a batch of inputs.
* Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32.
* @param[in] weights Weights tensor info. Weights are 4D tensor with dimensions [kernel_x, kernel_y, IFM, OFM].
- * Data type supported:Same as @p src, also could be QSYMM8_PER_CHANNEL if input is QASYMM8/QASYMM8_SIGNED.
+ * Data type supported:Same as @p src, also could be QSYMM8_PER_CHANNEL or QASYMM8_SIGNED if input is QASYMM8/QASYMM8_SIGNED.
* @param[in] dst Destination tensor info. 3 lower dimensions represent a single output [width, height, OFM], while the rest represent batch of outputs.
* Data types supported: Same as @p src.
* @param[in] conv_info Contains padding and stride information described in @ref PadStrideInfo.
@@ -167,3 +172,5 @@ private:
};
} // namespace cpu
} // namespace arm_compute
+
+#endif // ACL_SRC_CPU_OPERATORS_CPUCONV2D_H
diff --git a/src/cpu/operators/CpuGemm.cpp b/src/cpu/operators/CpuGemm.cpp
index 905e86c185..c489b256b8 100644
--- a/src/cpu/operators/CpuGemm.cpp
+++ b/src/cpu/operators/CpuGemm.cpp
@@ -174,8 +174,8 @@ void CpuGemm::configure(const ITensorInfo *a,
// Configure rhs transpose1xw kernel
_transpose1xW_b_kernel = std::make_unique<cpu::kernels::CpuGemmTranspose1xWKernel>();
_transpose1xW_b_kernel->configure(b_to_use, &_tmp_b);
- _aux_mem[Transposed1xWRHS] =
- MemoryInfo(offset_int_vec(Transposed1xWRHS), MemoryLifetime::Persistent, _tmp_b.total_size());
+ const auto lifetime = _reshape_b_only_on_first_run ? MemoryLifetime::Persistent : MemoryLifetime::Temporary;
+ _aux_mem[Transposed1xWRHS] = MemoryInfo(offset_int_vec(Transposed1xWRHS), lifetime, _tmp_b.total_size());
// Use a and b here instead of _tmp_a and _tmp_b because CpuGemmMatrixMultiplyKernel requires the original m,n,k in case of interleaved a and transposed1xw b
const int m = a->dimension(1);
diff --git a/src/cpu/operators/CpuGemmConv2d.cpp b/src/cpu/operators/CpuGemmConv2d.cpp
index 55d950ff4a..f3b78f8885 100644
--- a/src/cpu/operators/CpuGemmConv2d.cpp
+++ b/src/cpu/operators/CpuGemmConv2d.cpp
@@ -589,8 +589,14 @@ void CpuGemmConv2d::configure(const ITensorInfo *src,
// WeightsReshaped in prepare
// Otherwise WeightsReshaped is the final transformation of weights and needs to persist
bool gemm_trans_wei = _aux_mem[GemmAsmPretransposedRHS].size > 0;
- gemm_trans_wei = _mm_gemm != nullptr ? _aux_mem[GemmTransposed1xWRHS].size > 0 : gemm_trans_wei;
- gemm_trans_wei = _mm_gemmlowp != nullptr ? _aux_mem[GemmLowpTransposed1xWRHS].size > 0 : gemm_trans_wei;
+ if (_mm_gemm != nullptr)
+ {
+ gemm_trans_wei |= _aux_mem[GemmTransposed1xWRHS].size > 0;
+ }
+ if (_mm_gemmlowp != nullptr)
+ {
+ gemm_trans_wei |= _aux_mem[GemmLowpTransposed1xWRHS].size > 0;
+ }
_aux_mem[WeightsReshaped] = MemoryInfo(offset_int_vec(WeightsReshaped),
gemm_trans_wei ? MemoryLifetime::Prepare : MemoryLifetime::Persistent,
diff --git a/src/cpu/operators/CpuGemmConv2d.h b/src/cpu/operators/CpuGemmConv2d.h
index 48a0d11107..ae5023a71a 100644
--- a/src/cpu/operators/CpuGemmConv2d.h
+++ b/src/cpu/operators/CpuGemmConv2d.h
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -76,6 +76,7 @@ public:
* |F32 |F32 |F32 |F32 |
* |BFLOAT16 |BFLOAT16 |BFLOAT16 |BFLOAT16 |
* |QASYMM8 |QASYMM8 |S32 |QASYMM8 |
+ * |QASYMM8 |QASYMM8_SIGNED |S32 |QASYMM8 |
* |QASYMM8 |QSYMM8_PER_CHANNEL |S32 |QASYMM8 |
* |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |QASYMM8_SIGNED |
* |QASYMM8_SIGNED |QSYMM8_PER_CHANNEL |S32 |QASYMM8_SIGNED |
diff --git a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
index f3396fbb5c..1dbe3d8a31 100644
--- a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
+++ b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.cpp
@@ -128,24 +128,31 @@ void CpuGemmLowpMatrixMultiplyCore::configure(
_reshape_b_only_on_first_run;
_gemm_info = gemm_info;
- // Offset kernel is need if offset is non-zero or it may change (i.e. dynamic).
- // It is not needed if the datatype is symmetric, because there is no offset
- bool a_offset_kernel_needed = _a_offset != 0 || a->quantization_info().is_dynamic();
- bool b_offset_kernel_needed = _b_offset != 0 || b->quantization_info().is_dynamic();
+ const ITensorInfo *a_to_use = a;
- _asm_glue = std::make_unique<cpu::CpuGemmAssemblyDispatch>();
+ // Initialize assembly kernel meta-data
+ const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
- const ITensorInfo *a_to_use = a;
+ const int32_t offset_correction = 128;
+ const DataType dt = DataType::QASYMM8_SIGNED;
+ const UniformQuantizationInfo iqinfo = a_to_use->quantization_info().uniform();
+
+ _signed_a = a_to_use->clone()->set_data_type(dt).set_quantization_info(
+ QuantizationInfo(iqinfo.scale, iqinfo.offset + offset_correction));
+
+ // If inputs are mixed-sign but this machine does not support mixed sign kernels,
+ // flip the sign so matched-sign kernels can be used.
+ if (!_flip_signedness && a->data_type() == DataType::QASYMM8 && b->data_type() == DataType::QASYMM8_SIGNED &&
+ !bool(CpuGemmAssemblyDispatch::validate(a_to_use, b, c, dst, asm_info)))
+ {
+ _flip_signedness = true;
+ }
+
+ _asm_glue = std::make_unique<cpu::CpuGemmAssemblyDispatch>();
// Convert to QASYMM8 -> QASYMM8_SIGNED and back
if (_flip_signedness)
{
- const int32_t offset_correction = 128;
- const DataType dt = DataType::QASYMM8_SIGNED;
- const UniformQuantizationInfo iqinfo = a_to_use->quantization_info().uniform();
-
- _signed_a = a_to_use->clone()->set_data_type(dt).set_quantization_info(
- QuantizationInfo(iqinfo.scale, iqinfo.offset + offset_correction));
_convert_to_signed_asymm = std::make_unique<kernels::CpuConvertQuantizedSignednessKernel>();
_convert_to_signed_asymm->configure(a_to_use, &_signed_a);
a_to_use = &_signed_a;
@@ -166,6 +173,11 @@ void CpuGemmLowpMatrixMultiplyCore::configure(
matrix_a = &_signed_a;
}
+ // Offset kernel is need if offset is non-zero or it may change (i.e. dynamic).
+ // It is not needed if the datatype is symmetric, because there is no offset
+ bool a_offset_kernel_needed = _a_offset != 0 || a->quantization_info().is_dynamic();
+ bool b_offset_kernel_needed = _b_offset != 0 || b->quantization_info().is_dynamic();
+
// If GEMMLowpOutputStage != NONE, fuse the offset contribution with the output stage
if (info.gemmlowp_output_stage().type != GEMMLowpOutputStageType::NONE)
{
@@ -173,8 +185,6 @@ void CpuGemmLowpMatrixMultiplyCore::configure(
_mm_result_s32 = TensorInfo(dst->tensor_shape(), 1, DataType::S32);
}
- // Initialize assembly kernel meta-data
- const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
#ifdef __aarch64__
if (!(!b->are_values_constant() &&
b->tensor_shape().z() > 1)) // Disable batch matmul as optimized GeMM handles batching differently.
@@ -375,10 +385,6 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a,
int32_t a_offset = a->quantization_info().uniform().offset;
int32_t b_offset = b->quantization_info().uniform().offset;
- // Offset kernel is need if offset is non-zero or it may change (i.e. dynamic).
- bool a_offset_kernel_needed = a_offset != 0 || a->quantization_info().is_dynamic();
- bool b_offset_kernel_needed = b_offset != 0 || b->quantization_info().is_dynamic();
-
bool fuse_output_stage = info.gemmlowp_output_stage().type != GEMMLowpOutputStageType::NONE;
if (fuse_output_stage)
{
@@ -386,19 +392,31 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a,
a->clone()->set_tensor_shape(output->tensor_shape()).set_data_type(DataType::S32));
}
+ // Initialize assembly kernel meta-data
+ const AsmGemmInfo asm_info = init_assembly_metadata(info);
+
// Convert QASYMM8->QASYMM8_SIGNED
- TensorInfo signed_a{};
+ const int32_t offset_correction = 128;
+ const DataType dt = DataType::QASYMM8_SIGNED;
+ const UniformQuantizationInfo iqinfo = a_to_use->quantization_info().uniform();
+
+ TensorInfo signed_a = a_to_use->clone()->set_data_type(dt).set_quantization_info(
+ QuantizationInfo(iqinfo.scale, iqinfo.offset + offset_correction));
TensorInfo signed_output{};
- bool flip_signedness = is_data_type_quantized_per_channel(b->data_type()) &&
+
+ bool flip_signedness = is_data_type_quantized_per_channel(b->data_type()) &&
(a->data_type() == DataType::QASYMM8) && info.reshape_b_only_on_first_run();
- if (flip_signedness)
+
+ // If inputs are mixed-sign but this machine does not support mixed sign kernels,
+ // flip the sign so matched-sign kernels can be used.
+ if (!flip_signedness && a->data_type() == DataType::QASYMM8 && b->data_type() == DataType::QASYMM8_SIGNED &&
+ !bool(CpuGemmAssemblyDispatch::validate(a_to_use, b, c, output, asm_info)))
{
- const int32_t offset_correction = 128;
- const DataType dt = DataType::QASYMM8_SIGNED;
- const UniformQuantizationInfo iqinfo = a_to_use->quantization_info().uniform();
+ flip_signedness = true;
+ }
- signed_a = a_to_use->clone()->set_data_type(dt).set_quantization_info(
- QuantizationInfo(iqinfo.scale, iqinfo.offset + offset_correction));
+ if (flip_signedness)
+ {
ARM_COMPUTE_RETURN_ON_ERROR(kernels::CpuConvertQuantizedSignednessKernel::validate(a_to_use, &signed_a));
a_to_use = &signed_a;
a_offset = signed_a.quantization_info().uniform().offset;
@@ -418,8 +436,9 @@ Status CpuGemmLowpMatrixMultiplyCore::validate(const ITensorInfo *a,
matrix_a_info = &signed_a;
}
- // Initialize assembly kernel meta-data
- const AsmGemmInfo asm_info = init_assembly_metadata(info);
+ // Offset kernel is need if offset is non-zero or it may change (i.e. dynamic).
+ bool a_offset_kernel_needed = a_offset != 0 || a->quantization_info().is_dynamic();
+ bool b_offset_kernel_needed = b_offset != 0 || b->quantization_info().is_dynamic();
// Check if we need to run the optimized assembly kernel
bool run_optimised = false;
diff --git a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h
index 38121c9bb4..11fe6f9ef0 100644
--- a/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h
+++ b/src/cpu/operators/CpuGemmLowpMatrixMultiplyCore.h
@@ -81,11 +81,13 @@ public:
* |src0 |src1 |src2 |dst |
* |:--------------|:------------------|:--------|:--------------|
* |QASYMM8 |QASYMM8 |S32 |QASYMM8 |
+ * |QASYMM8 |QASYMM8_SIGNED |S32 |QASYMM8 |
* |QASYMM8 |QSYMM8_PER_CHANNEL |S32 |QASYMM8 |
* |QASYMM8 |QSYMM8 |S32 |QASYMM8 |
* |QASYMM8 |QASYMM8 |S32 |S32 |
* |QASYMM8 |QSYMM8_PER_CHANNEL |S32 |S32 |
* |QASYMM8 |QSYMM8 |S32 |S32 |
+ * |QASYMM8 |QASYMM8_SIGNED |F32 |F32 |
* |QASYMM8_SIGNED |QASYMM8_SIGNED |S32 |QASYMM8_SIGNED |
* |QASYMM8_SIGNED |QSYMM8_PER_CHANNEL |S32 |QASYMM8_SIGNED |
* |QASYMM8_SIGNED |QSYMM8 |S32 |QASYMM8_SIGNED |
diff --git a/src/cpu/operators/CpuWinogradConv2d.cpp b/src/cpu/operators/CpuWinogradConv2d.cpp
index 7d81aee0e9..7ed2f14ac5 100644
--- a/src/cpu/operators/CpuWinogradConv2d.cpp
+++ b/src/cpu/operators/CpuWinogradConv2d.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2021-2023 Arm Limited.
+ * Copyright (c) 2021-2024 Arm Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -309,7 +309,7 @@ void CpuWinogradConv2d::configure(const ITensorInfo *src,
std::max(input_workspace_size, output_workspace_size));
_aux_mem[PermutedWeights] =
MemoryInfo(offset_int_vec(PermutedWeights), MemoryLifetime::Prepare, _weights_hwio.total_size());
- _aux_mem[TransformedWeights] = MemoryInfo(offset_int_vec(TransformedWeights), MemoryLifetime::Persistent,
+ _aux_mem[TransformedWeights] = MemoryInfo(offset_int_vec(TransformedWeights), MemoryLifetime::Prepare,
wds.weight_matrix_size_bytes, storage_alignment);
if (_data_layout == DataLayout::NCHW)
{
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index a4c856bb8f..785837dbc6 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -45,6 +45,7 @@ namespace
/** Run pretranspose_B_array in parallel (1D static scheduling)
*
* @tparam TypeInput
+ * @tparam TypeWeight
* @tparam TypeOutput
*
* @param[in] gemm_asm GemmCommon kernel to run
@@ -54,14 +55,14 @@ namespace
* @param[in] src_multi_stride Stride in z ("multi")
* @param[in] num_threads Number of threads to run this method. Must be >= 1
*/
-template <typename TypeInput, typename TypeOutput>
-void run_parallel_pretranspose_B_array(arm_gemm::GemmCommon<TypeInput, TypeOutput> *gemm_asm,
- ITensor *dst,
- const TypeInput *src,
- int src_ld,
- int src_multi_stride,
- unsigned int num_threads,
- bool transpose)
+template <typename TypeInput, typename TypeWeight, typename TypeOutput>
+void run_parallel_pretranspose_B_array(arm_gemm::GemmCommon<TypeInput, TypeWeight, TypeOutput> *gemm_asm,
+ ITensor *dst,
+ const TypeWeight *src,
+ int src_ld,
+ int src_multi_stride,
+ unsigned int num_threads,
+ bool transpose)
{
ARM_COMPUTE_ERROR_ON(gemm_asm == nullptr);
ARM_COMPUTE_ERROR_ON(num_threads == 0);
@@ -91,14 +92,6 @@ using namespace arm_compute::experimental;
namespace
{
-struct free_delete
-{
- void operator()(void *x)
- {
- free(x);
- }
-};
-
struct Params
{
unsigned int M;
@@ -113,14 +106,13 @@ struct Params
Params extract_parameters(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, const AsmGemmInfo &info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
- Params p;
- p.M = d->tensor_shape().y();
- p.K = a->tensor_shape().x();
- p.N = d->tensor_shape().x();
- p.batches = 1;
- p.multis = 1;
- p.sections = 1;
- p.indirect = false;
+ Params p{/* M */ static_cast<unsigned int>(d->tensor_shape().y()),
+ /* N */ static_cast<unsigned int>(d->tensor_shape().x()),
+ /* K */ static_cast<unsigned int>(a->tensor_shape().x()),
+ /* batches */ 1,
+ /* multis */ 1,
+ /* sections */ 1,
+ /* indirect */ false};
if (info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect)
{
@@ -172,13 +164,10 @@ IScheduler::Hints scheduling_hint_heuristic(arm_gemm::GemmMethod method, DataTyp
}
/** Fallback in case ACL doesn't have a function */
-template <typename TypeInput, typename TypeOutput, class OutputStage = arm_gemm::Nothing>
+template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage = arm_gemm::Nothing>
class Fallback : public CpuGemmAssemblyDispatch::IFallback
{
public:
- /** Destructor */
- ~Fallback() = default;
-
/** Initialise the functions's input and output.
*
* @param[in] a Input tensor containing the Matrix A.
@@ -222,7 +211,9 @@ public:
bool isVarWeightsKernel() const override
{
if (!_gemm_kernel_asm)
+ {
return false;
+ }
const arm_compute::WeightFormat wf =
assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format);
return wf != arm_compute::WeightFormat::UNSPECIFIED && wf != arm_compute::WeightFormat::ANY;
@@ -251,7 +242,7 @@ private:
/** Operator to transpose B before gemm or pretranspose_B_array*/
std::unique_ptr<CpuTranspose> _pre_pretranspose_b{nullptr};
/** Assembly Gemm kernel */
- std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{nullptr};
+ std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeWeight, TypeOutput>> _gemm_kernel_asm{nullptr};
/** Optimised Arm® Neon™ kernel */
std::unique_ptr<INEKernel> _optimised_kernel{nullptr};
/** Assembly GEMM workspace tensor info */
@@ -273,22 +264,22 @@ private:
/** Per channel quantization multipliers */
std::vector<int32_t> _multipliers{};
/** Indirect buffer */
- std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{};
- std::unique_ptr<const TypeInput *, free_delete> _indirect_buf{};
- std::vector<TypeInput> _indirect_pad{};
- arm_gemm::ConvolutionParameters _cp{};
- experimental::MemoryRequirements _aux_mem{Count};
- bool _B_pretranspose_required{false};
- bool _is_b_constant{true};
- bool _is_c_constant{true};
- bool _run_pre_pretranspose_b{false};
- bool _B_pre_pretranspose_required{false};
+ std::vector<const TypeInput *const *> _indirect_arg{};
+ std::vector<const TypeInput *> _indirect_buf{};
+ std::vector<TypeInput> _indirect_pad{};
+ arm_gemm::ConvolutionParameters _cp{};
+ experimental::MemoryRequirements _aux_mem{Count};
+ bool _B_pretranspose_required{false};
+ bool _is_b_constant{true};
+ bool _is_c_constant{true};
+ bool _run_pre_pretranspose_b{false};
+ bool _B_pre_pretranspose_required{false};
};
-template <typename TypeInput, typename TypeOutput, class OutputStage>
+template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage>
std::tuple<bool, const int32_t *, const int32_t *, const int32_t *>
-Fallback<TypeInput, TypeOutput, OutputStage>::set_requantize_data(const std::vector<int32_t> &shifts,
- const std::vector<int32_t> &multipliers)
+Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::set_requantize_data(const std::vector<int32_t> &shifts,
+ const std::vector<int32_t> &multipliers)
{
_multipliers = multipliers;
_shifts = shifts;
@@ -305,8 +296,8 @@ Fallback<TypeInput, TypeOutput, OutputStage>::set_requantize_data(const std::vec
return std::make_tuple(need_left, left_shifts.data(), right_shifts.data(), _multipliers.data());
}
-template <typename TypeInput, typename TypeOutput, class OutputStage>
-void Fallback<TypeInput, TypeOutput, OutputStage>::prepare_indirect_buffer(ITensorPack &tensors)
+template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage>
+void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::prepare_indirect_buffer(ITensorPack &tensors)
{
auto a = tensors.get_const_tensor(TensorType::ACL_SRC_0);
const TypeInput *A_ptr = reinterpret_cast<TypeInput *>(a->buffer());
@@ -343,14 +334,12 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare_indirect_buffer(ITens
if (input_x < 0 || input_x >= _cp.input_width || input_y < 0 || input_y >= _cp.input_height)
{
- _indirect_buf
- .get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] =
+ _indirect_buf[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] =
_indirect_pad.data();
}
else
{
- _indirect_buf
- .get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] =
+ _indirect_buf[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] =
A_ptr + (m * multi_stride_A + b * batch_stride_A + input_xy * stride_A);
}
}
@@ -361,11 +350,11 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare_indirect_buffer(ITens
}
}
-template <typename TypeInput, typename TypeOutput, class OutputStage>
-void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITensorInfo *a,
- const ITensorInfo *b,
- const ITensorInfo *d,
- const AsmGemmInfo &info)
+template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage>
+void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::configure_indirect(const ITensorInfo *a,
+ const ITensorInfo *b,
+ const ITensorInfo *d,
+ const AsmGemmInfo &info)
{
ARM_COMPUTE_ERROR_ON(!(info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect));
@@ -375,13 +364,13 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITen
zeropad = a->quantization_info().uniform().offset;
}
- const int64_t input_width = static_cast<int64_t>(a->tensor_shape()[1]);
- const int64_t input_height = static_cast<int64_t>(a->tensor_shape()[2]);
- const int64_t input_channels = static_cast<int64_t>(a->tensor_shape()[0]);
- const int64_t kernel_width = static_cast<int64_t>(b->tensor_shape()[2]);
- const int64_t kernel_height = static_cast<int64_t>(b->tensor_shape()[3]);
- const int64_t output_width = static_cast<int64_t>(d->tensor_shape()[1]);
- const int64_t output_height = static_cast<int64_t>(d->tensor_shape()[2]);
+ const auto input_width = static_cast<int64_t>(a->tensor_shape()[1]);
+ const auto input_height = static_cast<int64_t>(a->tensor_shape()[2]);
+ const auto input_channels = static_cast<int64_t>(a->tensor_shape()[0]);
+ const auto kernel_width = static_cast<int64_t>(b->tensor_shape()[2]);
+ const auto kernel_height = static_cast<int64_t>(b->tensor_shape()[3]);
+ const auto output_width = static_cast<int64_t>(d->tensor_shape()[1]);
+ const auto output_height = static_cast<int64_t>(d->tensor_shape()[2]);
_cp = {input_width,
input_height,
@@ -392,6 +381,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITen
output_height,
info.ps_info.stride().first,
info.ps_info.stride().second,
+ 1,
+ 1,
info.padding_top,
info.padding_left,
zeropad};
@@ -414,10 +405,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITen
const int multi_size = batch_size * batches;
const size_t multi_stride = multi_size / sizeof(TypeInputPtr);
- _indirect_buf = std::unique_ptr<const TypeInput *, free_delete>(
- reinterpret_cast<const TypeInput **>(malloc(multi_size * multis)));
- _indirect_arg = std::unique_ptr<const TypeInput *const *, free_delete>(
- reinterpret_cast<const TypeInput *const **>(malloc(sizeof(TypeInput **) * kernel_hw * multis * batches)));
+ _indirect_buf = std::vector<const TypeInput *>(multi_size * multis);
+ _indirect_arg = std::vector<const TypeInput *const *>(sizeof(TypeInput **) * kernel_hw * multis * batches);
_indirect_pad = std::vector<TypeInput>(_cp.input_channels, TypeInput(zeropad));
// Set indirect argument
@@ -428,29 +417,28 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITen
{
for (int64_t kernel_xy = 0; kernel_xy < kernel_hw; kernel_xy++)
{
- (_indirect_arg.get())[pos++] =
- _indirect_buf.get() + m * multi_stride + b * batch_stride + kernel_xy * output_hw;
+ _indirect_arg[pos++] = &_indirect_buf[m * multi_stride + b * batch_stride + kernel_xy * output_hw];
}
}
}
- _gemm_kernel_asm->set_indirect_parameters(a->tensor_shape()[0], _indirect_arg.get());
+ _gemm_kernel_asm->set_indirect_parameters(a->tensor_shape()[0], _indirect_arg.data());
}
}
-template <typename TypeInput, typename TypeOutput, class OutputStage>
-void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *a,
- const ITensorInfo *b,
- const ITensorInfo *c,
- ITensorInfo *d,
- arm_gemm::GemmArgs args,
- const AsmGemmInfo &gemm_info,
- const OutputStage &os)
+template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage>
+void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::configure(const ITensorInfo *a,
+ const ITensorInfo *b,
+ const ITensorInfo *c,
+ ITensorInfo *d,
+ arm_gemm::GemmArgs args,
+ const AsmGemmInfo &gemm_info,
+ const OutputStage &os)
{
_is_b_constant = b->are_values_constant();
_is_c_constant = c ? c->are_values_constant() : true;
- _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput, OutputStage>(args, os);
+ _gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeWeight, TypeOutput, OutputStage>(args, os);
if (_gemm_kernel_asm == nullptr)
{
//configuration not supported: Leave function unconfigured:
@@ -460,7 +448,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *
arm_gemm::GemmConfig gemm_cfg = _gemm_kernel_asm->get_config();
// arm_compute wrapper for the Gemm object (see above)
- auto acl_gemm_wrapper = std::make_unique<kernel::CpuGemmAssemblyWrapperKernel<TypeInput, TypeOutput>>();
+ auto acl_gemm_wrapper = std::make_unique<kernel::CpuGemmAssemblyWrapperKernel<TypeInput, TypeWeight, TypeOutput>>();
ARM_COMPUTE_ERROR_ON(acl_gemm_wrapper == nullptr);
acl_gemm_wrapper->configure(_gemm_kernel_asm.get(), gemm_cfg.filter);
const size_t workspace_size = _gemm_kernel_asm->get_working_size();
@@ -531,8 +519,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *
const unsigned int alignment = 128;
const size_t B_pretranspose_size = _gemm_kernel_asm->get_B_pretransposed_array_size();
_pretranspose_info = TensorInfo(TensorShape(B_pretranspose_size), 1, DataType::U8);
- _aux_mem[Pretranspose] =
- MemoryInfo(offset_int_vec(Pretranspose), MemoryLifetime::Persistent, B_pretranspose_size, alignment);
+ MemoryLifetime lifetime = _is_b_constant ? MemoryLifetime::Persistent : MemoryLifetime::Temporary;
+ _aux_mem[Pretranspose] = MemoryInfo(offset_int_vec(Pretranspose), lifetime, B_pretranspose_size, alignment);
}
// Handle indirect GEMM convolution
@@ -549,8 +537,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *
}
}
-template <typename TypeInput, typename TypeOutput, class OutputStage>
-void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
+template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage>
+void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
{
if (!_is_prepared)
{
@@ -588,17 +576,17 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
// Fixed format kernels need no pretranspose.
ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(
assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format)));
- const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size();
- const auto in1_ptr = reinterpret_cast<const TypeInput *>(b_to_use->buffer() +
- b_to_use->info()->offset_first_element_in_bytes());
- const int multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size();
+ const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size();
+ const auto in1_ptr = reinterpret_cast<const TypeWeight *>(
+ b_to_use->buffer() + b_to_use->info()->offset_first_element_in_bytes());
+ const int multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size();
CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, false);
ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr);
const bool kernel_supports_transpose = _gemm_kernel_asm->B_pretranspose_supports_transpose();
- run_parallel_pretranspose_B_array<TypeInput, TypeOutput>(
+ run_parallel_pretranspose_B_array<TypeInput, TypeWeight, TypeOutput>(
_gemm_kernel_asm.get(), pretranspose.get(), in1_ptr, ldb, multi_stride_b,
NEScheduler::get().num_threads(), _B_pre_pretranspose_required && kernel_supports_transpose);
@@ -616,20 +604,20 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
}
}
-template <typename TypeInput, typename TypeOutput, class OutputStage>
-bool Fallback<TypeInput, TypeOutput, OutputStage>::is_configured() const
+template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage>
+bool Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::is_configured() const
{
return _optimised_kernel != nullptr;
}
-template <typename TypeInput, typename TypeOutput, class OutputStage>
-experimental::MemoryRequirements Fallback<TypeInput, TypeOutput, OutputStage>::workspace() const
+template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage>
+experimental::MemoryRequirements Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::workspace() const
{
return _aux_mem;
}
-template <typename TypeInput, typename TypeOutput, class OutputStage>
-void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
+template <typename TypeInput, typename TypeWeight, typename TypeOutput, class OutputStage>
+void Fallback<TypeInput, TypeWeight, TypeOutput, OutputStage>::run(ITensorPack &tensors)
{
auto a = tensors.get_const_tensor(TensorType::ACL_SRC_0);
auto b = tensors.get_const_tensor(TensorType::ACL_SRC_1);
@@ -663,8 +651,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
const int multi_stride_d = d->info()->strides_in_bytes()[d_multi_idx] / d->info()->element_size();
auto in0_ptr = reinterpret_cast<const TypeInput *>(a->buffer() + a->info()->offset_first_element_in_bytes());
- const TypeInput *in1_ptr = nullptr;
- auto out_ptr = reinterpret_cast<TypeOutput *>(d->buffer() + d->info()->offset_first_element_in_bytes());
+ const TypeWeight *in1_ptr = nullptr;
+ auto out_ptr = reinterpret_cast<TypeOutput *>(d->buffer() + d->info()->offset_first_element_in_bytes());
const ITensor *b_to_use = b;
@@ -686,8 +674,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
{
ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size();
multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size();
- in1_ptr =
- reinterpret_cast<const TypeInput *>(b_to_use->buffer() + b_to_use->info()->offset_first_element_in_bytes());
+ in1_ptr = reinterpret_cast<const TypeWeight *>(b_to_use->buffer() +
+ b_to_use->info()->offset_first_element_in_bytes());
}
// If necessary, run pretranspose every time if either weights or biases are non-constant
@@ -706,8 +694,8 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(
assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format)));
const int ldb = b_to_use->info()->strides_in_bytes().y() / b_to_use->info()->element_size();
- const auto b_ptr = reinterpret_cast<const TypeInput *>(b_to_use->buffer() +
- b_to_use->info()->offset_first_element_in_bytes());
+ const auto b_ptr = reinterpret_cast<const TypeWeight *>(b_to_use->buffer() +
+ b_to_use->info()->offset_first_element_in_bytes());
const int multi_stride_b = b_to_use->info()->strides_in_bytes().z() / b_to_use->info()->element_size();
CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, true);
@@ -720,7 +708,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
else
{
const bool kernel_supports_transpose = _gemm_kernel_asm->B_pretranspose_supports_transpose();
- run_parallel_pretranspose_B_array<TypeInput, TypeOutput>(
+ run_parallel_pretranspose_B_array<TypeInput, TypeWeight, TypeOutput>(
_gemm_kernel_asm.get(), pretranspose.get(), b_ptr, ldb, multi_stride_b,
NEScheduler::get().num_threads(), _B_pre_pretranspose_required && kernel_supports_transpose);
}
@@ -744,7 +732,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
if (split_dim != IScheduler::split_dimensions_all)
{
// Make sure the kernel does not expect more threads than we can actually spawn
- const unsigned int num_iterations = _optimised_kernel.get()->window().num_iterations(split_dim);
+ const unsigned int num_iterations = _optimised_kernel->window().num_iterations(split_dim);
num_threads = std::min(num_iterations, num_threads);
}
_gemm_kernel_asm->set_nthreads(num_threads);
@@ -775,7 +763,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
NEScheduler::get().schedule(_optimised_kernel.get(), scheduling_hint);
}
-template <typename TypeInput, typename TypeOutput>
+template <typename TypeInput, typename TypeWeight, typename TypeOutput>
void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm,
const ITensorInfo *a,
const ITensorInfo *b,
@@ -794,12 +782,12 @@ void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_ge
info.fixed_format, info.fast_mode, info.accumulate, &cfg);
// Create arm_gemm fallback
- auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput>>();
+ auto fallback = std::make_unique<Fallback<TypeInput, TypeWeight, TypeOutput>>();
fallback->configure(a, b, c, d, args, info);
arm_gemm = std::move(fallback);
}
-template <typename TypeInput, typename TypeOutput>
+template <typename TypeInput, typename TypeWeight, typename TypeOutput>
void create_arm_gemm_dequant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm,
const ITensorInfo *a,
const ITensorInfo *b,
@@ -820,7 +808,7 @@ void create_arm_gemm_dequant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback>
info.fixed_format, info.fast_mode, info.accumulate, &cfg);
// Create arm_gemm fallback
- auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::DequantizeFloat>>();
+ auto fallback = std::make_unique<Fallback<TypeInput, TypeWeight, TypeOutput, arm_gemm::DequantizeFloat>>();
// Configure requantization info
const GEMMLowpOutputStageInfo os_info = info.output_stage;
@@ -832,7 +820,7 @@ void create_arm_gemm_dequant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback>
arm_gemm = std::move(fallback);
}
-template <typename TypeInput, typename TypeOutput>
+template <typename TypeInput, typename TypeWeight, typename TypeOutput>
void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm,
const ITensorInfo *a,
const ITensorInfo *b,
@@ -852,7 +840,7 @@ void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &
info.fixed_format, info.fast_mode, info.accumulate, &cfg);
// Create arm_gemm fallback
- auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::Requantize32>>();
+ auto fallback = std::make_unique<Fallback<TypeInput, TypeWeight, TypeOutput, arm_gemm::Requantize32>>();
// Configure requantization info
const int32_t negation = info.negated_offsets ? 1 : -1;
@@ -905,12 +893,12 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format);
arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads,
info.fixed_format, info.fast_mode, info.accumulate, &cfg);
- // TODO: Incorporate info.transpose_b COMPMID-6595
+ // TODO(COMPMID-6595): Incorporate info.transpose_b
switch (a->data_type())
{
case DataType::F32:
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
+ !(arm_gemm::has_opt_gemm<float, float, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
"We could not find an optimized kernel for F32 input");
break;
#ifdef __aarch64__
@@ -919,13 +907,22 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
if (d->data_type() == DataType::S32)
{
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<uint8_t, uint32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
+ !(arm_gemm::has_opt_gemm<uint8_t, uint8_t, uint32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args,
+ {})),
"We could not find an optimized kernel for U8/QASYMM8 input and U32 output");
}
+ else if (b->data_type() == DataType::QASYMM8_SIGNED)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ !(arm_gemm::has_opt_gemm<uint8_t, int8_t, uint8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf,
+ args, {})),
+ "We could not find an optimized kernel for U8 input with S8 weights and U8 output");
+ }
else
{
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})),
+ !(arm_gemm::has_opt_gemm<uint8_t, uint8_t, uint8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf,
+ args, {})),
"We could not find an optimized kernel for U8 input and U8 output");
}
break;
@@ -934,13 +931,15 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
if (d->data_type() == DataType::S32)
{
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
+ !(arm_gemm::has_opt_gemm<int8_t, int8_t, int32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args,
+ {})),
"We could not find an optimized kernel for S8/QASYMM8_SIGNED input and S32 output");
}
else
{
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})),
+ !(arm_gemm::has_opt_gemm<int8_t, int8_t, int8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args,
+ {})),
"We could not find an optimized kernel for S8 input and S8 output");
}
break;
@@ -952,13 +951,15 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
if (d->data_type() == DataType::BFLOAT16)
{
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<bfloat16, bfloat16, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
+ !(arm_gemm::has_opt_gemm<bfloat16, bfloat16, bfloat16, arm_gemm::Nothing>(arm_gemm_expected_wf,
+ args, {})),
"We could not find an optimized kernel for BFLOAT16 input and BFLOAT16 output");
}
else
{
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<bfloat16, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
+ !(arm_gemm::has_opt_gemm<bfloat16, bfloat16, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args,
+ {})),
"We could not find an optimized kernel for BFLOAT16 input and F32 output");
}
break;
@@ -968,7 +969,8 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
#if defined(ENABLE_FP16_KERNELS)
case DataType::F16:
ARM_COMPUTE_RETURN_ERROR_ON_MSG(
- !(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
+ !(arm_gemm::has_opt_gemm<float16_t, float16_t, float16_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args,
+ {})),
"We could not find an optimized kernel for F16 input and F16 output");
break;
#endif /* ENABLE_FP16_KERNELS */
@@ -1009,7 +1011,7 @@ Status CpuGemmAssemblyDispatch::validate(
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(a, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(b, DataType::BFLOAT16);
}
- else
+ else if (!(a->data_type() == DataType::QASYMM8 && b->data_type() == DataType::QASYMM8_SIGNED))
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);
}
@@ -1024,12 +1026,13 @@ Status CpuGemmAssemblyDispatch::validate(
"Only U32 output supported for U8 input");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32,
"Only S32 output supported for S8 input");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 &&
- (d->data_type() != DataType::QASYMM8 && d->data_type() != DataType::S32),
- "Only QASYMM8/S32 output supported for QASYMM8 input");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ a->data_type() == DataType::QASYMM8 &&
+ (d->data_type() != DataType::QASYMM8 && d->data_type() != DataType::S32 && d->data_type() != DataType::F32),
+ "Only QASYMM8/S32/F32 output supported for QASYMM8 input");
arm_compute::WeightFormat expected_weight_format = arm_compute::WeightFormat::UNSPECIFIED;
const Status ret = CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, info);
- if ((bool)ret && expected_weight_format != arm_compute::WeightFormat::ANY)
+ if (bool(ret) && expected_weight_format != arm_compute::WeightFormat::ANY)
{
// Correctness check: if the format expected by the kernel is
// not "any", make sure that the one found matches the format
@@ -1062,33 +1065,44 @@ void CpuGemmAssemblyDispatch::configure(
switch (a->data_type())
{
case DataType::F32:
- create_arm_gemm<float, float>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm<float, float, float>(_arm_gemm, a, b, c, d, act, info);
break;
#ifdef __aarch64__
case DataType::U8:
case DataType::QASYMM8:
- if (d->data_type() == DataType::S32)
+ if (b->data_type() == DataType::S8 || b->data_type() == DataType::QASYMM8_SIGNED)
+ {
+ if (d->data_type() == DataType::F32)
+ {
+ create_arm_gemm_dequant<uint8_t, int8_t, float>(_arm_gemm, a, b, c, d, act, info);
+ }
+ else
+ {
+ create_arm_gemm_quant<uint8_t, int8_t, uint8_t>(_arm_gemm, a, b, c, d, act, info);
+ }
+ }
+ else if (d->data_type() == DataType::S32)
{
- create_arm_gemm<uint8_t, uint32_t>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm<uint8_t, uint8_t, uint32_t>(_arm_gemm, a, b, c, d, act, info);
}
else
{
- create_arm_gemm_quant<uint8_t, uint8_t>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm_quant<uint8_t, uint8_t, uint8_t>(_arm_gemm, a, b, c, d, act, info);
}
break;
case DataType::S8:
case DataType::QASYMM8_SIGNED:
if (d->data_type() == DataType::S32)
{
- create_arm_gemm<int8_t, int32_t>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm<int8_t, int8_t, int32_t>(_arm_gemm, a, b, c, d, act, info);
}
else if (d->data_type() == DataType::F32)
{
- create_arm_gemm_dequant<int8_t, float>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm_dequant<int8_t, int8_t, float>(_arm_gemm, a, b, c, d, act, info);
}
else
{
- create_arm_gemm_quant<int8_t, int8_t>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm_quant<int8_t, int8_t, int8_t>(_arm_gemm, a, b, c, d, act, info);
}
break;
#endif /* __aarch64__ */
@@ -1096,17 +1110,17 @@ void CpuGemmAssemblyDispatch::configure(
case DataType::BFLOAT16:
if (d->data_type() == DataType::BFLOAT16)
{
- create_arm_gemm<bfloat16, bfloat16>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm<bfloat16, bfloat16, bfloat16>(_arm_gemm, a, b, c, d, act, info);
}
else
{
- create_arm_gemm<bfloat16, float>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm<bfloat16, bfloat16, float>(_arm_gemm, a, b, c, d, act, info);
}
break;
#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */
#ifdef ENABLE_FP16_KERNELS
case DataType::F16:
- create_arm_gemm<float16_t, float16_t>(_arm_gemm, a, b, c, d, act, info);
+ create_arm_gemm<float16_t, float16_t, float16_t>(_arm_gemm, a, b, c, d, act, info);
break;
#endif /* ENABLE_FP16_KERNELS */
default: