aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp')
-rw-r--r--src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp393
1 files changed, 236 insertions, 157 deletions
diff --git a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 3069d6b541..343ef21c0b 100644
--- a/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -24,12 +24,13 @@
#include "src/cpu/operators/internal/CpuGemmAssemblyDispatch.h"
#include "arm_compute/runtime/NEON/NEScheduler.h"
+
#include "src/core/CPP/Validate.h"
-#include "src/core/NEON/kernels/arm_gemm/utils.hpp"
#include "src/core/helpers/MemoryHelpers.h"
+#include "src/core/NEON/kernels/arm_gemm/utils.hpp"
#include "src/core/utils/AssemblyUtils.h"
-#include "src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h"
#include "src/cpu/kernels/assembly/arm_gemm.hpp"
+#include "src/cpu/kernels/assembly/CpuGemmAssemblyWrapperKernel.h"
#include "src/cpu/utils/CpuAuxTensorHandler.h"
#include <arm_neon.h>
@@ -53,7 +54,12 @@ namespace
* @param[in] num_threads Number of threads to run this method. Must be >= 1
*/
template <typename TypeInput, typename TypeOutput>
-void run_parallel_pretranspose_B_array(arm_gemm::GemmCommon<TypeInput, TypeOutput> *gemm_asm, ITensor *dst, const TypeInput *src, int src_ld, int src_multi_stride, unsigned int num_threads)
+void run_parallel_pretranspose_B_array(arm_gemm::GemmCommon<TypeInput, TypeOutput> *gemm_asm,
+ ITensor *dst,
+ const TypeInput *src,
+ int src_ld,
+ int src_multi_stride,
+ unsigned int num_threads)
{
ARM_COMPUTE_ERROR_ON(gemm_asm == nullptr);
ARM_COMPUTE_ERROR_ON(num_threads == 0);
@@ -61,14 +67,14 @@ void run_parallel_pretranspose_B_array(arm_gemm::GemmCommon<TypeInput, TypeOutpu
const unsigned int wsize = gemm_asm->get_B_pretranspose_window_size();
std::vector<IScheduler::Workload> workloads(num_threads);
- for(unsigned int t = 0; t < num_threads; ++t)
+ for (unsigned int t = 0; t < num_threads; ++t)
{
- workloads[t] = [ = ](const ThreadInfo & info)
+ workloads[t] = [=](const ThreadInfo &info)
{
const unsigned int start = (info.thread_id * wsize) / num_threads;
const unsigned int end = ((info.thread_id + 1) * wsize) / num_threads;
- if(start < end)
+ if (start < end)
{
gemm_asm->pretranspose_B_array_part(dst->buffer(), src, src_ld, src_multi_stride, start, end);
}
@@ -113,7 +119,7 @@ Params extract_parameters(const ITensorInfo *a, const ITensorInfo *b, const ITen
p.sections = 1;
p.indirect = false;
- if(info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect)
+ if (info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect)
{
p.indirect = true;
p.sections = b->tensor_shape()[2] * b->tensor_shape()[3];
@@ -125,7 +131,7 @@ Params extract_parameters(const ITensorInfo *a, const ITensorInfo *b, const ITen
}
// Update M in case of GEMM3D for output
- if(info.depth_output_gemm3d != 0)
+ if (info.depth_output_gemm3d != 0)
{
p.M = d->tensor_shape().y() * d->tensor_shape().z();
p.batches = d->tensor_shape().total_size_upper(3) / p.multis;
@@ -139,19 +145,24 @@ IScheduler::Hints scheduling_hint_heuristic(arm_gemm::GemmMethod method, DataTyp
// Schedule assembly kernel
const int granule_threshold = 200;
IScheduler::Hints scheduling_hint = IScheduler::Hints(Window::DimX);
- if(method == arm_gemm::GemmMethod::GEMM_INTERLEAVED && data_type == DataType::F32)
+ if (method == arm_gemm::GemmMethod::GEMM_INTERLEAVED && data_type == DataType::F32)
{
scheduling_hint = IScheduler::Hints(Window::DimX, IScheduler::StrategyHint::DYNAMIC, granule_threshold);
}
- else if(method == arm_gemm::GemmMethod::GEMM_INTERLEAVED_2D && (data_type == DataType::F32 || data_type == DataType::F16 || data_type == DataType::U8 || data_type == DataType::S8))
+ else if (method == arm_gemm::GemmMethod::GEMM_INTERLEAVED_2D &&
+ (data_type == DataType::F32 || data_type == DataType::F16 || data_type == DataType::U8 ||
+ data_type == DataType::S8))
{
//GEMM_INTERLEAVED supports 2D parallelism, IScheduler::split_dimensions_all signals to parallelise over all window dimensions
- scheduling_hint = IScheduler::Hints(IScheduler::split_dimensions_all, IScheduler::StrategyHint::STATIC, granule_threshold);
+ scheduling_hint =
+ IScheduler::Hints(IScheduler::split_dimensions_all, IScheduler::StrategyHint::STATIC, granule_threshold);
}
- else if(method == arm_gemm::GemmMethod::QUANTIZE_WRAPPER_2D && (data_type == DataType::QASYMM8 || data_type == DataType::QASYMM8_SIGNED))
+ else if (method == arm_gemm::GemmMethod::QUANTIZE_WRAPPER_2D &&
+ (data_type == DataType::QASYMM8 || data_type == DataType::QASYMM8_SIGNED))
{
//special case for QASYMM8 to support 2D parallelism, scheduler here may be tweaked differently compared to FP32 case
- scheduling_hint = IScheduler::Hints(IScheduler::split_dimensions_all, IScheduler::StrategyHint::STATIC, granule_threshold);
+ scheduling_hint =
+ IScheduler::Hints(IScheduler::split_dimensions_all, IScheduler::StrategyHint::STATIC, granule_threshold);
}
return scheduling_hint;
@@ -175,8 +186,12 @@ public:
* @param[in] gemm_info GEMM meta-data
* @param[in] os Output stage meta-data.
*/
- void configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d,
- arm_gemm::GemmArgs args, const AsmGemmInfo &gemm_info,
+ void configure(const ITensorInfo *a,
+ const ITensorInfo *b,
+ const ITensorInfo *c,
+ ITensorInfo *d,
+ arm_gemm::GemmArgs args,
+ const AsmGemmInfo &gemm_info,
const OutputStage &os = {});
/** Set requantization shifts to be used
@@ -193,19 +208,20 @@ public:
*
* @return A tuple with the pointers to the shift and multiplier data respectively
*/
- std::tuple<bool, const int32_t *, const int32_t *, const int32_t *> set_requantize_data(const std::vector<int32_t> &shifts,
- const std::vector<int32_t> &multipliers);
+ std::tuple<bool, const int32_t *, const int32_t *, const int32_t *>
+ set_requantize_data(const std::vector<int32_t> &shifts, const std::vector<int32_t> &multipliers);
// Inherited methods overridden:
- void run(ITensorPack &tensors) override;
- void prepare(ITensorPack &tensors) override;
+ void run(ITensorPack &tensors) override;
+ void prepare(ITensorPack &tensors) override;
bool is_configured() const override;
experimental::MemoryRequirements workspace() const override;
bool isVarWeightsKernel() const override
{
- if(!_gemm_kernel_asm)
+ if (!_gemm_kernel_asm)
return false;
- const arm_compute::WeightFormat wf = assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format);
+ const arm_compute::WeightFormat wf =
+ assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format);
return wf != arm_compute::WeightFormat::UNSPECIFIED && wf != arm_compute::WeightFormat::ANY;
}
@@ -229,15 +245,15 @@ private:
void prepare_indirect_buffer(ITensorPack &tensors);
/** Assembly Gemm kernel */
- std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{ nullptr };
+ std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{nullptr};
/** Optimised Arm® Neon™ kernel */
- std::unique_ptr<INEKernel> _optimised_kernel{ nullptr };
+ std::unique_ptr<INEKernel> _optimised_kernel{nullptr};
/** Assembly GEMM workspace tensor info */
TensorInfo _workspace_info{};
/** Pre-transpose tensor info */
TensorInfo _pretranspose_info{};
/** Prepared flag */
- bool _is_prepared{ false };
+ bool _is_prepared{false};
/** GEMM meta-data */
AsmGemmInfo _gemm_info{};
/** GEMM kernel description */
@@ -251,26 +267,27 @@ private:
/** Indirect buffer */
std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{};
std::unique_ptr<const TypeInput *, free_delete> _indirect_buf{};
- std::vector<TypeInput> _indirect_pad{};
- arm_gemm::ConvolutionParameters _cp{};
- experimental::MemoryRequirements _aux_mem{ Count };
- bool _B_pretranspose_required{ false };
- bool _is_b_constant{ true };
- bool _is_c_constant{ true };
+ std::vector<TypeInput> _indirect_pad{};
+ arm_gemm::ConvolutionParameters _cp{};
+ experimental::MemoryRequirements _aux_mem{Count};
+ bool _B_pretranspose_required{false};
+ bool _is_b_constant{true};
+ bool _is_c_constant{true};
};
template <typename TypeInput, typename TypeOutput, class OutputStage>
std::tuple<bool, const int32_t *, const int32_t *, const int32_t *>
-Fallback<TypeInput, TypeOutput, OutputStage>::set_requantize_data(const std::vector<int32_t> &shifts, const std::vector<int32_t> &multipliers)
+Fallback<TypeInput, TypeOutput, OutputStage>::set_requantize_data(const std::vector<int32_t> &shifts,
+ const std::vector<int32_t> &multipliers)
{
_multipliers = multipliers;
_shifts = shifts;
bool need_left = false;
- for(const auto s : _shifts)
+ for (const auto s : _shifts)
{
left_shifts.push_back(std::max(-s, int32_t(0)));
right_shifts.push_back(std::min(-s, int32_t(0)));
- if(s < 0 && !need_left)
+ if (s < 0 && !need_left)
{
need_left = true;
}
@@ -295,32 +312,35 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare_indirect_buffer(ITens
const int multi_size = batch_size * batches;
const size_t multi_stride = multi_size / sizeof(TypeInput);
- for(int64_t m = 0; m < multis; m++)
+ for (int64_t m = 0; m < multis; m++)
{
- for(int64_t b = 0; b < batches; b++)
+ for (int64_t b = 0; b < batches; b++)
{
- for(int64_t output_y = 0; output_y < _cp.output_height; output_y++)
+ for (int64_t output_y = 0; output_y < _cp.output_height; output_y++)
{
- for(int64_t output_x = 0; output_x < _cp.output_width; output_x++)
+ for (int64_t output_x = 0; output_x < _cp.output_width; output_x++)
{
int64_t output_xy = (output_y * _cp.output_width) + output_x;
- for(int64_t kernel_y = 0; kernel_y < _cp.kernel_height; kernel_y++)
+ for (int64_t kernel_y = 0; kernel_y < _cp.kernel_height; kernel_y++)
{
- for(int64_t kernel_x = 0; kernel_x < _cp.kernel_width; kernel_x++)
+ for (int64_t kernel_x = 0; kernel_x < _cp.kernel_width; kernel_x++)
{
int64_t input_x = (output_x * _cp.output_stride_w) + kernel_x - _cp.padding_left;
int64_t input_y = (output_y * _cp.output_stride_h) + kernel_y - _cp.padding_top;
int64_t kernel_xy = (kernel_y * _cp.kernel_width) + kernel_x;
int64_t input_xy = (input_y * _cp.input_width) + input_x;
- if(input_x < 0 || input_x >= _cp.input_width || input_y < 0 || input_y >= _cp.input_height)
+ if (input_x < 0 || input_x >= _cp.input_width || input_y < 0 || input_y >= _cp.input_height)
{
- _indirect_buf.get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] = _indirect_pad.data();
+ _indirect_buf
+ .get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] =
+ _indirect_pad.data();
}
else
{
- _indirect_buf.get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] =
+ _indirect_buf
+ .get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] =
A_ptr + (m * multi_stride_A + b * batch_stride_A + input_xy * stride_A);
}
}
@@ -332,12 +352,15 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare_indirect_buffer(ITens
}
template <typename TypeInput, typename TypeOutput, class OutputStage>
-void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, const AsmGemmInfo &info)
+void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITensorInfo *a,
+ const ITensorInfo *b,
+ const ITensorInfo *d,
+ const AsmGemmInfo &info)
{
ARM_COMPUTE_ERROR_ON(!(info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect));
float zeropad = 0.f;
- if(is_data_type_quantized(a->data_type()))
+ if (is_data_type_quantized(a->data_type()))
{
zeropad = a->quantization_info().uniform().offset;
}
@@ -350,16 +373,25 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITen
const int64_t output_width = static_cast<int64_t>(d->tensor_shape()[1]);
const int64_t output_height = static_cast<int64_t>(d->tensor_shape()[2]);
- _cp = { input_width, input_height, input_channels, kernel_width, kernel_height, output_width, output_height,
- info.ps_info.stride().first, info.ps_info.stride().second, info.padding_top, info.padding_left, zeropad
- };
-
- if(info.method == AsmConvMethod::Conv)
+ _cp = {input_width,
+ input_height,
+ input_channels,
+ kernel_width,
+ kernel_height,
+ output_width,
+ output_height,
+ info.ps_info.stride().first,
+ info.ps_info.stride().second,
+ info.padding_top,
+ info.padding_left,
+ zeropad};
+
+ if (info.method == AsmConvMethod::Conv)
{
_gemm_kernel_asm->set_convolution_parameters(_cp);
}
- if(info.method == AsmConvMethod::Indirect)
+ if (info.method == AsmConvMethod::Indirect)
{
const unsigned int multis = 1;
const unsigned int batches = a->tensor_shape().total_size_upper(3);
@@ -372,19 +404,22 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITen
const int multi_size = batch_size * batches;
const size_t multi_stride = multi_size / sizeof(TypeInputPtr);
- _indirect_buf = std::unique_ptr<const TypeInput *, free_delete>(reinterpret_cast<const TypeInput **>(malloc(multi_size * multis)));
- _indirect_arg = std::unique_ptr<const TypeInput *const *, free_delete>(reinterpret_cast<const TypeInput *const **>(malloc(sizeof(TypeInput **) * kernel_hw * multis * batches)));
+ _indirect_buf = std::unique_ptr<const TypeInput *, free_delete>(
+ reinterpret_cast<const TypeInput **>(malloc(multi_size * multis)));
+ _indirect_arg = std::unique_ptr<const TypeInput *const *, free_delete>(
+ reinterpret_cast<const TypeInput *const **>(malloc(sizeof(TypeInput **) * kernel_hw * multis * batches)));
_indirect_pad = std::vector<TypeInput>(_cp.input_channels, TypeInput(zeropad));
// Set indirect argument
int64_t pos = 0;
- for(int64_t m = 0; m < multis; m++)
+ for (int64_t m = 0; m < multis; m++)
{
- for(int64_t b = 0; b < batches; b++)
+ for (int64_t b = 0; b < batches; b++)
{
- for(int64_t kernel_xy = 0; kernel_xy < kernel_hw; kernel_xy++)
+ for (int64_t kernel_xy = 0; kernel_xy < kernel_hw; kernel_xy++)
{
- (_indirect_arg.get())[pos++] = _indirect_buf.get() + m * multi_stride + b * batch_stride + kernel_xy * output_hw;
+ (_indirect_arg.get())[pos++] =
+ _indirect_buf.get() + m * multi_stride + b * batch_stride + kernel_xy * output_hw;
}
}
}
@@ -394,8 +429,12 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITen
}
template <typename TypeInput, typename TypeOutput, class OutputStage>
-void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d,
- arm_gemm::GemmArgs args, const AsmGemmInfo &gemm_info,
+void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *a,
+ const ITensorInfo *b,
+ const ITensorInfo *c,
+ ITensorInfo *d,
+ arm_gemm::GemmArgs args,
+ const AsmGemmInfo &gemm_info,
const OutputStage &os)
{
ARM_COMPUTE_UNUSED(c);
@@ -404,7 +443,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *
_is_c_constant = c ? c->are_values_constant() : true;
_gemm_kernel_asm = arm_gemm::gemm<TypeInput, TypeOutput, OutputStage>(args, os);
- if(_gemm_kernel_asm == nullptr)
+ if (_gemm_kernel_asm == nullptr)
{
//configuration not supported: Leave function unconfigured:
return;
@@ -419,13 +458,14 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *
const size_t workspace_size = _gemm_kernel_asm->get_working_size();
const unsigned int alignment = 4096;
_workspace_info = TensorInfo(TensorShape(workspace_size), 1, DataType::U8);
- _aux_mem[AsmGemmWorkspace] = MemoryInfo(offset_int_vec(AsmGemmWorkspace), MemoryLifetime::Temporary, workspace_size, alignment);
+ _aux_mem[AsmGemmWorkspace] =
+ MemoryInfo(offset_int_vec(AsmGemmWorkspace), MemoryLifetime::Temporary, workspace_size, alignment);
//if we disable this code below in brackets then ConvLayer deadlocks when threads > 1 and
//the shapes are In=1x1x1024 Weights=1x1x1024x1001 Biases=1001 Out=1x1x1001
{
const unsigned int window_size = _gemm_kernel_asm->get_window_size().total_size();
- if(window_size < static_cast<unsigned int>(args._maxthreads))
+ if (window_size < static_cast<unsigned int>(args._maxthreads))
{
_gemm_kernel_asm->set_nthreads(window_size);
}
@@ -434,18 +474,19 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *
_optimised_kernel = std::move(acl_gemm_wrapper);
_gemm_info = gemm_info;
// Check for pre-transposed support
- if(_gemm_kernel_asm->B_pretranspose_required())
+ if (_gemm_kernel_asm->B_pretranspose_required())
{
// Forcing 128-byte alignment (required by 32-bit kernels)
const unsigned int alignment = 128;
const size_t B_pretranspose_size = _gemm_kernel_asm->get_B_pretransposed_array_size();
_pretranspose_info = TensorInfo(TensorShape(B_pretranspose_size), 1, DataType::U8);
- _aux_mem[Pretranspose] = MemoryInfo(offset_int_vec(Pretranspose), MemoryLifetime::Persistent, B_pretranspose_size, alignment);
- _B_pretranspose_required = true;
+ _aux_mem[Pretranspose] =
+ MemoryInfo(offset_int_vec(Pretranspose), MemoryLifetime::Persistent, B_pretranspose_size, alignment);
+ _B_pretranspose_required = true;
}
// Handle indirect GEMM convolution
- if(gemm_info.method == AsmConvMethod::Conv || gemm_info.method == AsmConvMethod::Indirect)
+ if (gemm_info.method == AsmConvMethod::Conv || gemm_info.method == AsmConvMethod::Indirect)
{
configure_indirect(a, b, d, gemm_info);
}
@@ -454,34 +495,39 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *
template <typename TypeInput, typename TypeOutput, class OutputStage>
void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
{
- if(!_is_prepared)
+ if (!_is_prepared)
{
auto b = tensors.get_const_tensor(TensorType::ACL_SRC_1);
auto c = tensors.get_const_tensor(TensorType::ACL_SRC_2);
// Setup up matrix bias in the assembly kernel, it's just a pointer to matrix C.
- if(c && c->info()->data_type() == DataType::S32)
+ if (c && c->info()->data_type() == DataType::S32)
{
- _gemm_kernel_asm->set_quantized_bias(reinterpret_cast<const int32_t *>(c->buffer() + c->info()->offset_first_element_in_bytes()), 0);
+ _gemm_kernel_asm->set_quantized_bias(
+ reinterpret_cast<const int32_t *>(c->buffer() + c->info()->offset_first_element_in_bytes()), 0);
}
// Pretranspose B if required
- if(_gemm_kernel_asm->B_pretranspose_required())
+ if (_gemm_kernel_asm->B_pretranspose_required())
{
// Fixed format kernels need no pretranspose.
- ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format)));
- const int ldb = b->info()->strides_in_bytes().y() / b->info()->element_size();
- const auto in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
- const int multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size();
+ ARM_COMPUTE_ERROR_ON(arm_compute::is_fixed_format(
+ assembly_utils::map_to_arm_compute_weight_format(_gemm_kernel_asm->get_config().weight_format)));
+ const int ldb = b->info()->strides_in_bytes().y() / b->info()->element_size();
+ const auto in1_ptr =
+ reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
+ const int multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size();
CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, false);
ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr);
- run_parallel_pretranspose_B_array<TypeInput, TypeOutput>(_gemm_kernel_asm.get(), pretranspose.get(), in1_ptr, ldb, multi_stride_b, NEScheduler::get().num_threads());
+ run_parallel_pretranspose_B_array<TypeInput, TypeOutput>(_gemm_kernel_asm.get(), pretranspose.get(),
+ in1_ptr, ldb, multi_stride_b,
+ NEScheduler::get().num_threads());
b->mark_as_unused();
}
- if(_gemm_info.method == AsmConvMethod::Indirect)
+ if (_gemm_info.method == AsmConvMethod::Indirect)
{
prepare_indirect_buffer(tensors);
}
@@ -526,12 +572,12 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
int multi_stride_b = 0;
const int multi_stride_d = d->info()->strides_in_bytes()[d_multi_idx] / d->info()->element_size();
- auto in0_ptr = reinterpret_cast<const TypeInput *>(a->buffer() + a->info()->offset_first_element_in_bytes());
+ auto in0_ptr = reinterpret_cast<const TypeInput *>(a->buffer() + a->info()->offset_first_element_in_bytes());
const TypeInput *in1_ptr = nullptr;
auto out_ptr = reinterpret_cast<TypeOutput *>(d->buffer() + d->info()->offset_first_element_in_bytes());
// Check if B is pre-tranposed and de-reference if not
- if(!_gemm_kernel_asm->B_is_pretransposed())
+ if (!_gemm_kernel_asm->B_is_pretransposed())
{
ldb = b->info()->strides_in_bytes().y() / b->info()->element_size();
multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size();
@@ -539,30 +585,34 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
}
// If necessary, run pretranspose every time if either weights or biases are non-constant
- if((b && !_is_b_constant) || (c && !_is_c_constant && c->info()->data_type() == DataType::S32))
+ if ((b && !_is_b_constant) || (c && !_is_c_constant && c->info()->data_type() == DataType::S32))
{
- if(c && c->info()->data_type() == DataType::S32)
+ if (c && c->info()->data_type() == DataType::S32)
{
- _gemm_kernel_asm->set_quantized_bias(reinterpret_cast<const int32_t *>(c->buffer() + c->info()->offset_first_element_in_bytes()), 0);
+ _gemm_kernel_asm->set_quantized_bias(
+ reinterpret_cast<const int32_t *>(c->buffer() + c->info()->offset_first_element_in_bytes()), 0);
}
// Pretranspose B if required
- if(_B_pretranspose_required)
+ if (_B_pretranspose_required)
{
- const int ldb = b->info()->strides_in_bytes().y() / b->info()->element_size();
- const auto b_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
- const int multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size();
+ const int ldb = b->info()->strides_in_bytes().y() / b->info()->element_size();
+ const auto b_ptr =
+ reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
+ const int multi_stride_b = b->info()->strides_in_bytes().z() / b->info()->element_size();
CpuAuxTensorHandler pretranspose(offset_int_vec(Pretranspose), _pretranspose_info, tensors, true);
ARM_COMPUTE_ERROR_ON(pretranspose.get()->buffer() == nullptr);
- if(_is_b_constant)
+ if (_is_b_constant)
{
_gemm_kernel_asm->requantize_bias(pretranspose.get()->buffer(), b_ptr, ldb, multi_stride_b);
}
else
{
- run_parallel_pretranspose_B_array<TypeInput, TypeOutput>(_gemm_kernel_asm.get(), pretranspose.get(), b_ptr, ldb, multi_stride_b, NEScheduler::get().num_threads());
+ run_parallel_pretranspose_B_array<TypeInput, TypeOutput>(_gemm_kernel_asm.get(), pretranspose.get(),
+ b_ptr, ldb, multi_stride_b,
+ NEScheduler::get().num_threads());
}
}
}
@@ -571,17 +621,17 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
// Set workspace if needed and reset number of threads as buffer manager gets re-created with max_threads
CpuAuxTensorHandler workspace(offset_int_vec(AsmGemmWorkspace), _workspace_info, tensors, false);
- if(workspace.get()->buffer() != nullptr)
+ if (workspace.get()->buffer() != nullptr)
{
_gemm_kernel_asm->set_working_space(reinterpret_cast<void *>(workspace.get()->buffer()));
const unsigned int split_dim = scheduling_hint.split_dimension();
const unsigned int window_size = _gemm_kernel_asm->get_window_size().total_size();
unsigned int num_threads = NEScheduler::get().num_threads();
- if(window_size < num_threads)
+ if (window_size < num_threads)
{
num_threads = window_size;
}
- if(split_dim != IScheduler::split_dimensions_all)
+ if (split_dim != IScheduler::split_dimensions_all)
{
// Make sure the kernel does not expect more threads than we can actually spawn
const unsigned int num_iterations = _optimised_kernel.get()->window().num_iterations(split_dim);
@@ -595,12 +645,12 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
// Setup up matrix bias in the assembly kernel, it's just a pointer to matrix C.
TypeOutput *bias = nullptr;
- if(c && c->info()->data_type() != DataType::S32)
+ if (c && c->info()->data_type() != DataType::S32)
{
bias = reinterpret_cast<TypeOutput *>(c->buffer() + c->info()->offset_first_element_in_bytes());
}
- if(_gemm_info.method == AsmConvMethod::Indirect)
+ if (_gemm_info.method == AsmConvMethod::Indirect)
{
in0_ptr = nullptr;
lda = 0;
@@ -609,18 +659,20 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
}
// Set gemm parameters
- _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a,
- in1_ptr, ldb, multi_stride_b,
- out_ptr, ldd, batch_stride_d, multi_stride_d,
- bias, 0);
+ _gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a, in1_ptr, ldb, multi_stride_b, out_ptr,
+ ldd, batch_stride_d, multi_stride_d, bias, 0);
// Schedule
NEScheduler::get().schedule(_optimised_kernel.get(), scheduling_hint);
}
template <typename TypeInput, typename TypeOutput>
void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm,
- const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d,
- arm_gemm::Activation activation, const AsmGemmInfo &info)
+ const ITensorInfo *a,
+ const ITensorInfo *b,
+ const ITensorInfo *c,
+ ITensorInfo *d,
+ arm_gemm::Activation activation,
+ const AsmGemmInfo &info)
{
Params p = extract_parameters(a, b, d, info);
const CPUInfo &ci = NEScheduler::get().cpu_info();
@@ -628,7 +680,8 @@ void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_ge
arm_gemm::GemmConfig cfg;
cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
- arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode, &cfg);
+ arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads,
+ info.fixed_format, info.fast_mode, &cfg);
// Create arm_gemm fallback
auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput>>();
@@ -638,8 +691,12 @@ void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_ge
template <typename TypeInput, typename TypeOutput>
void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm,
- const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d,
- arm_gemm::Activation activation, const AsmGemmInfo &info)
+ const ITensorInfo *a,
+ const ITensorInfo *b,
+ const ITensorInfo *c,
+ ITensorInfo *d,
+ arm_gemm::Activation activation,
+ const AsmGemmInfo &info)
{
ARM_COMPUTE_UNUSED(activation);
Params p = extract_parameters(a, b, d, info);
@@ -648,7 +705,8 @@ void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &
arm_gemm::GemmConfig cfg;
cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
- arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads, info.fixed_format, info.fast_mode, &cfg);
+ arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads,
+ info.fixed_format, info.fast_mode, &cfg);
// Create arm_gemm fallback
auto fallback = std::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::Requantize32>>();
@@ -660,22 +718,20 @@ void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &
const GEMMLowpOutputStageInfo os_info = info.output_stage;
arm_gemm::Requantize32 gemm_requant_info{};
- if(os_info.gemmlowp_shifts.size() > 1)
+ if (os_info.gemmlowp_shifts.size() > 1)
{
- const auto requantize_data = fallback->set_requantize_data(os_info.gemmlowp_shifts, os_info.gemmlowp_multipliers);
- gemm_requant_info = arm_gemm::Requantize32(nullptr, 0,
- a_offset, b_offset, os_info.gemmlowp_offset,
- (std::get<0>(requantize_data)) ? std::get<1>(requantize_data) : nullptr,
- std::get<2>(requantize_data),
- std::get<3>(requantize_data),
- os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound);
+ const auto requantize_data =
+ fallback->set_requantize_data(os_info.gemmlowp_shifts, os_info.gemmlowp_multipliers);
+ gemm_requant_info = arm_gemm::Requantize32(
+ nullptr, 0, a_offset, b_offset, os_info.gemmlowp_offset,
+ (std::get<0>(requantize_data)) ? std::get<1>(requantize_data) : nullptr, std::get<2>(requantize_data),
+ std::get<3>(requantize_data), os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound);
}
else
{
- gemm_requant_info = arm_gemm::Requantize32(nullptr, 0,
- a_offset, b_offset, os_info.gemmlowp_offset,
- -os_info.gemmlowp_shift, os_info.gemmlowp_multiplier,
- os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound);
+ gemm_requant_info =
+ arm_gemm::Requantize32(nullptr, 0, a_offset, b_offset, os_info.gemmlowp_offset, -os_info.gemmlowp_shift,
+ os_info.gemmlowp_multiplier, os_info.gemmlowp_min_bound, os_info.gemmlowp_max_bound);
}
// Configure fallback
@@ -684,13 +740,16 @@ void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &
}
} //namespace
-CpuGemmAssemblyDispatch::CpuGemmAssemblyDispatch()
- : _arm_gemm(nullptr)
+CpuGemmAssemblyDispatch::CpuGemmAssemblyDispatch() : _arm_gemm(nullptr)
{
}
-Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
- const AsmGemmInfo &info)
+Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected_weight_format,
+ const ITensorInfo *a,
+ const ITensorInfo *b,
+ const ITensorInfo *c,
+ const ITensorInfo *d,
+ const AsmGemmInfo &info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
ARM_COMPUTE_UNUSED(c);
@@ -701,53 +760,61 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
arm_gemm::GemmConfig cfg;
cfg.weight_format = assembly_utils::map_to_arm_gemm_weight_format(info.weight_format);
arm_gemm::WeightFormat arm_gemm_expected_wf = assembly_utils::map_to_arm_gemm_weight_format(expected_weight_format);
- arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads, info.fixed_format, info.fast_mode, &cfg);
- switch(a->data_type())
+ arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, act, num_threads,
+ info.fixed_format, info.fast_mode, &cfg);
+ switch (a->data_type())
{
case DataType::F32:
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
- "We could not find an optimized kernel for F32 input");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ !(arm_gemm::has_opt_gemm<float, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
+ "We could not find an optimized kernel for F32 input");
break;
#ifdef __aarch64__
case DataType::U8:
case DataType::QASYMM8:
- if(d->data_type() == DataType::S32)
+ if (d->data_type() == DataType::S32)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
- "We could not find an optimized kernel for U8/QASYMM8 input and U32 output");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ !(arm_gemm::has_opt_gemm<uint8_t, uint32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
+ "We could not find an optimized kernel for U8/QASYMM8 input and U32 output");
}
else
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})),
- "We could not find an optimized kernel for U8 input and U8 output");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ !(arm_gemm::has_opt_gemm<uint8_t, uint8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})),
+ "We could not find an optimized kernel for U8 input and U8 output");
}
break;
case DataType::S8:
case DataType::QASYMM8_SIGNED:
- if(d->data_type() == DataType::S32)
+ if (d->data_type() == DataType::S32)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
- "We could not find an optimized kernel for S8/QASYMM8_SIGNED input and S32 output");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ !(arm_gemm::has_opt_gemm<int8_t, int32_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
+ "We could not find an optimized kernel for S8/QASYMM8_SIGNED input and S32 output");
}
else
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})),
- "We could not find an optimized kernel for S8 input and S8 output");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ !(arm_gemm::has_opt_gemm<int8_t, int8_t, arm_gemm::Requantize32>(arm_gemm_expected_wf, args, {})),
+ "We could not find an optimized kernel for S8 input and S8 output");
}
break;
#endif /* __aarch64__ */
#if defined(ARM_COMPUTE_ENABLE_BF16)
case DataType::BFLOAT16:
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<bfloat16, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
- "We could not find an optimized kernel for BFLOAT16 input and F32 output");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ !(arm_gemm::has_opt_gemm<bfloat16, float, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
+ "We could not find an optimized kernel for BFLOAT16 input and F32 output");
break;
}
#endif /* defined(ARM_COMPUTE_ENABLE_BF16) */
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
- "We could not find an optimized kernel for F16 input and F16 output");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ !(arm_gemm::has_opt_gemm<float16_t, float16_t, arm_gemm::Nothing>(arm_gemm_expected_wf, args, {})),
+ "We could not find an optimized kernel for F16 input and F16 output");
break;
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
default:
@@ -759,26 +826,30 @@ Status CpuGemmAssemblyDispatch::has_opt_impl(arm_compute::WeightFormat &expected
return Status{};
}
-Status CpuGemmAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info)
+Status CpuGemmAssemblyDispatch::validate(
+ const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info)
{
ARM_COMPUTE_UNUSED(c, info);
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(a, b, d);
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(info.reshape_b_only_on_first_run), "Assembly kernel will not be executed when reshape_b_only_on_first_run is false");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(!(info.reshape_b_only_on_first_run),
+ "Assembly kernel will not be executed when reshape_b_only_on_first_run is false");
#ifndef __aarch64__
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->element_size() == 1, "8bit integer types only supported for aarch64");
#endif /* __aarch64__ */
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S8,
- DataType::BFLOAT16, DataType::F16, DataType::F32);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(b, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8_PER_CHANNEL, DataType::S8,
- DataType::BFLOAT16, DataType::F16, DataType::F32);
- if(is_data_type_quantized_per_channel(b->data_type()))
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::U8, DataType::QASYMM8,
+ DataType::QASYMM8_SIGNED, DataType::S8, DataType::BFLOAT16,
+ DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(
+ b, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8_PER_CHANNEL, DataType::S8,
+ DataType::BFLOAT16, DataType::F16, DataType::F32);
+ if (is_data_type_quantized_per_channel(b->data_type()))
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::QASYMM8_SIGNED, DataType::S8);
}
- else if(is_fixed_format_fast_math(info.weight_format))
+ else if (is_fixed_format_fast_math(info.weight_format))
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(a, DataType::F32);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(b, DataType::BFLOAT16);
@@ -787,22 +858,29 @@ Status CpuGemmAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);
}
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F32 && d->data_type() != DataType::F32, "Only F32 output supported for F32 input");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F16 && d->data_type() != DataType::F16, "Only F16 output supported for F16 input");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::BFLOAT16 && d->data_type() != DataType::F32, "Only F32 output supported for BFLOAT16 input");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32, "Only U32 output supported for U8 input");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32, "Only S32 output supported for S8 input");
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 && (d->data_type() != DataType::QASYMM8 && d->data_type() != DataType::S32),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F32 && d->data_type() != DataType::F32,
+ "Only F32 output supported for F32 input");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::F16 && d->data_type() != DataType::F16,
+ "Only F16 output supported for F16 input");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::BFLOAT16 && d->data_type() != DataType::F32,
+ "Only F32 output supported for BFLOAT16 input");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::U8 && d->data_type() != DataType::U32,
+ "Only U32 output supported for U8 input");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::S8 && d->data_type() != DataType::S32,
+ "Only S32 output supported for S8 input");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() == DataType::QASYMM8 &&
+ (d->data_type() != DataType::QASYMM8 && d->data_type() != DataType::S32),
"Only QASYMM8/S32 output supported for QASYMM8 input");
arm_compute::WeightFormat expected_weight_format = arm_compute::WeightFormat::UNSPECIFIED;
const Status ret = CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, info);
- if((bool)ret && expected_weight_format != arm_compute::WeightFormat::ANY)
+ if ((bool)ret && expected_weight_format != arm_compute::WeightFormat::ANY)
{
// Correctness check: if the format expected by the kernel is
// not "any", make sure that the one found matches the format
// intended by the caller.
- ARM_COMPUTE_RETURN_ERROR_ON_MSG((expected_weight_format != info.weight_format),
- "The format expected by the kernel does not correspond with the one requested by the user.");
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ (expected_weight_format != info.weight_format),
+ "The format expected by the kernel does not correspond with the one requested by the user.");
}
return ret;
}
@@ -813,18 +891,19 @@ bool CpuGemmAssemblyDispatch::is_activation_supported(const ActivationLayerInfo
return act.type != arm_gemm::Activation::Type::None;
}
-void CpuGemmAssemblyDispatch::configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, const AsmGemmInfo &info)
+void CpuGemmAssemblyDispatch::configure(
+ const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, const AsmGemmInfo &info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
arm_gemm::Activation act = assembly_utils::map_to_arm_gemm_activation(info.activation_info);
//If we don't support a combination of data types, silently return: it is the caller's responsibility to check if configure() was successful via is_configured()
- if(!CpuGemmAssemblyDispatch::validate(a, b, c, d, info))
+ if (!CpuGemmAssemblyDispatch::validate(a, b, c, d, info))
{
return;
}
- switch(a->data_type())
+ switch (a->data_type())
{
case DataType::F32:
create_arm_gemm<float, float>(_arm_gemm, a, b, c, d, act, info);
@@ -832,7 +911,7 @@ void CpuGemmAssemblyDispatch::configure(const ITensorInfo *a, const ITensorInfo
#ifdef __aarch64__
case DataType::U8:
case DataType::QASYMM8:
- if(d->data_type() == DataType::S32)
+ if (d->data_type() == DataType::S32)
{
create_arm_gemm<uint8_t, uint32_t>(_arm_gemm, a, b, c, d, act, info);
}
@@ -843,7 +922,7 @@ void CpuGemmAssemblyDispatch::configure(const ITensorInfo *a, const ITensorInfo
break;
case DataType::S8:
case DataType::QASYMM8_SIGNED:
- if(d->data_type() == DataType::S32)
+ if (d->data_type() == DataType::S32)
{
create_arm_gemm<int8_t, int32_t>(_arm_gemm, a, b, c, d, act, info);
}