aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp')
-rw-r--r--src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp328
1 files changed, 267 insertions, 61 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
index 5b0848398d..400fa64438 100644
--- a/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
+++ b/src/runtime/NEON/functions/NEGEMMAssemblyDispatch.cpp
@@ -25,18 +25,70 @@
#include "arm_compute/runtime/NEON/NEScheduler.h"
#include "src/core/CPP/Validate.h"
-#include "src/core/NEON/kernels/assembly/INEGEMMWrapperKernel.h"
#include "src/core/NEON/kernels/assembly/NEGEMMAssemblyWrapperKernel.h"
#include "src/core/NEON/kernels/assembly/arm_gemm.hpp"
#include "support/MemorySupport.h"
#include <arm_neon.h>
+#include <cstdlib>
namespace arm_compute
{
namespace
{
+struct free_delete
+{
+ void operator()(void *x)
+ {
+ free(x);
+ }
+};
+
+struct Params
+{
+ unsigned int M;
+ unsigned int N;
+ unsigned int K;
+ unsigned int batches;
+ unsigned int multis;
+ unsigned int sections;
+ bool indirect;
+};
+
+Params extract_parameters(const ITensor *a, const ITensor *b, const ITensor *d, const AsmGemmInfo &info)
+{
+ ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
+
+ Params p;
+ p.K = a->info()->tensor_shape().x();
+ p.N = d->info()->tensor_shape().x();
+ p.multis = 1;
+ p.indirect = false;
+ p.sections = 1;
+
+ if(info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect)
+ {
+ p.indirect = true;
+ p.sections = b->info()->tensor_shape()[2] * b->info()->tensor_shape()[3];
+ }
+ else
+ {
+ p.M = d->info()->tensor_shape().y();
+ p.multis = b->info()->tensor_shape().z();
+ p.batches = d->info()->tensor_shape().total_size_upper(2) / p.multis; //COMPMID-1423: Agree on and document the layout of gemm inputs/outputs
+ }
+
+ // Update M in case of GEMM3D for output
+ if(info.depth_output_gemm3d != 0)
+ {
+ p.M = d->info()->tensor_shape().y() * d->info()->tensor_shape().z();
+ p.batches = d->info()->tensor_shape().total_size_upper(3) / p.multis;
+ }
+
+ return p;
+}
+
arm_gemm::Activation map_to_arm_gemm_activation(const ActivationLayerInfo &act)
{
arm_gemm::Activation gemm_act;
@@ -69,6 +121,29 @@ arm_gemm::Activation map_to_arm_gemm_activation(const ActivationLayerInfo &act)
return gemm_act;
}
+IScheduler::Hints scheduling_hint_heuristic(arm_gemm::GemmMethod method, DataType data_type)
+{
+ // Schedule assembly kernel
+ const int granule_threshold = 200;
+ IScheduler::Hints scheduling_hint = IScheduler::Hints(Window::DimX);
+ if(method == arm_gemm::GemmMethod::GEMM_INTERLEAVED && data_type == DataType::F32)
+ {
+ scheduling_hint = IScheduler::Hints(Window::DimX, IScheduler::StrategyHint::DYNAMIC, granule_threshold);
+ }
+ else if(method == arm_gemm::GemmMethod::GEMM_INTERLEAVED_2D && (data_type == DataType::F32 || data_type == DataType::F16 || data_type == DataType::U8 || data_type == DataType::S8))
+ {
+ //GEMM_INTERLEAVED supports 2D parallelism, IScheduler::split_dimensions_all signals to parallelise over all window dimensions
+ scheduling_hint = IScheduler::Hints(IScheduler::split_dimensions_all, IScheduler::StrategyHint::STATIC, granule_threshold);
+ }
+ else if(method == arm_gemm::GemmMethod::QUANTIZE_WRAPPER_2D && (data_type == DataType::QASYMM8 || data_type == DataType::QASYMM8_SIGNED))
+ {
+ //special case for QASYMM8 to support 2D parallelism, scheduler here may be tweaked differently compared to FP32 case
+ scheduling_hint = IScheduler::Hints(IScheduler::split_dimensions_all, IScheduler::StrategyHint::STATIC, granule_threshold);
+ }
+
+ return scheduling_hint;
+}
+
template <typename TypeInput, typename TypeOutput>
class FallbackTransform : public ITransformWeights
{
@@ -165,7 +240,7 @@ public:
* @param[in] os Output stage meta-data.
*/
void configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d,
- arm_gemm::GemmArgs args, const GEMMInfo &gemm_info,
+ arm_gemm::GemmArgs args, const AsmGemmInfo &gemm_info,
MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os = {});
/** Set requantization shifts to be used
@@ -198,6 +273,16 @@ private:
* @param[in] alignment Workspace memory alignment.
*/
void allocate_workspace(size_t workspace_size, MemoryGroup &memory_group, size_t alignment);
+ /** Configure the indirect buffer
+ *
+ * @param[in] a Input tensor containing the Matrix A.
+ * @param[in] b Input tensor containing the Matrix B.
+ * @param[out] d Output tensor to store the result of matrix multiplication.
+ * @param[in] info GEMM meta-data
+ */
+ void configure_indirect(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, const AsmGemmInfo &info);
+ /** Prepare the indirect buffer */
+ void prepare_indirect_buffer();
/** Assembly Gemm kernel */
std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{ nullptr };
@@ -226,7 +311,7 @@ private:
/** Prepared flag */
bool _is_prepared{ false };
/** GEMM meta-data */
- GEMMInfo _gemm_info{};
+ AsmGemmInfo _gemm_info{};
/** Weights manager */
IWeightsManager *_weights_manager{ nullptr };
/** Weights transform object */
@@ -239,11 +324,16 @@ private:
std::vector<int32_t> left_shifts{};
/** Per channel quantization multipliers */
std::vector<int32_t> _multipliers{};
+ /** Indirect buffer */
+ std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{};
+ std::unique_ptr<const TypeInput *, free_delete> _indirect_buf{};
+ std::vector<TypeInput> _indirect_pad{};
+ arm_gemm::ConvolutionParameters _cp{};
};
template <typename TypeInput, typename TypeOutput, class OutputStage>
-std::tuple<bool, const int32_t *, const int32_t *, const int32_t *> Fallback<TypeInput, TypeOutput, OutputStage>::set_requantize_data(const std::vector<int32_t> &shifts,
- const std::vector<int32_t> &multipliers)
+std::tuple<bool, const int32_t *, const int32_t *, const int32_t *>
+Fallback<TypeInput, TypeOutput, OutputStage>::set_requantize_data(const std::vector<int32_t> &shifts, const std::vector<int32_t> &multipliers)
{
_multipliers = multipliers;
_shifts = shifts;
@@ -261,8 +351,122 @@ std::tuple<bool, const int32_t *, const int32_t *, const int32_t *> Fallback<Typ
}
template <typename TypeInput, typename TypeOutput, class OutputStage>
+void Fallback<TypeInput, TypeOutput, OutputStage>::prepare_indirect_buffer()
+{
+ const TypeInput *A_ptr = reinterpret_cast<TypeInput *>(_a->buffer());
+ const int multis = 1;
+ const int batches = _a->info()->tensor_shape().total_size_upper(3);
+ const size_t stride_A = _a->info()->strides_in_bytes().y() / sizeof(TypeInput);
+ const size_t batch_stride_A = _a->info()->strides_in_bytes()[3] / sizeof(TypeInput);
+ const size_t multi_stride_A = _a->info()->strides_in_bytes()[4] / sizeof(TypeInput);
+
+ const size_t output_hw = _cp.output_height * _cp.output_width;
+ const int batch_size = _cp.kernel_height * _cp.kernel_width * output_hw * sizeof(TypeInput);
+ const size_t batch_stride = batch_size / sizeof(TypeInput);
+ const int multi_size = batch_size * batches;
+ const size_t multi_stride = multi_size / sizeof(TypeInput);
+
+ for(int64_t m = 0; m < multis; m++)
+ {
+ for(int64_t b = 0; b < batches; b++)
+ {
+ for(int64_t output_y = 0; output_y < _cp.output_height; output_y++)
+ {
+ for(int64_t output_x = 0; output_x < _cp.output_width; output_x++)
+ {
+ int64_t output_xy = (output_y * _cp.output_width) + output_x;
+
+ for(int64_t kernel_y = 0; kernel_y < _cp.kernel_height; kernel_y++)
+ {
+ for(int64_t kernel_x = 0; kernel_x < _cp.kernel_width; kernel_x++)
+ {
+ int64_t input_x = (output_x * _cp.output_stride_w) + kernel_x - _cp.padding_left;
+ int64_t input_y = (output_y * _cp.output_stride_h) + kernel_y - _cp.padding_top;
+ int64_t kernel_xy = (kernel_y * _cp.kernel_width) + kernel_x;
+ int64_t input_xy = (input_y * _cp.input_width) + input_x;
+
+ if(input_x < 0 || input_x >= _cp.input_width || input_y < 0 || input_y >= _cp.input_height)
+ {
+ _indirect_buf.get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] = _indirect_pad.data();
+ }
+ else
+ {
+ _indirect_buf.get()[m * multi_stride + b * batch_stride + kernel_xy * output_hw + output_xy] =
+ A_ptr + (m * multi_stride_A + b * batch_stride_A + input_xy * stride_A);
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+}
+
+template <typename TypeInput, typename TypeOutput, class OutputStage>
+void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, const AsmGemmInfo &info)
+{
+ ARM_COMPUTE_ERROR_ON(!(info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect));
+
+ float zeropad = 0.f;
+ if(is_data_type_quantized(a->data_type()))
+ {
+ zeropad = a->quantization_info().uniform().offset;
+ }
+
+ const int64_t input_width = static_cast<int64_t>(a->tensor_shape()[1]);
+ const int64_t input_height = static_cast<int64_t>(a->tensor_shape()[2]);
+ const int64_t input_channels = static_cast<int64_t>(a->tensor_shape()[0]);
+ const int64_t kernel_width = static_cast<int64_t>(b->tensor_shape()[2]);
+ const int64_t kernel_height = static_cast<int64_t>(b->tensor_shape()[3]);
+ const int64_t output_width = static_cast<int64_t>(d->tensor_shape()[1]);
+ const int64_t output_height = static_cast<int64_t>(d->tensor_shape()[2]);
+
+ _cp = { input_width, input_height, input_channels, kernel_width, kernel_height, output_width, output_height,
+ info.ps_info.stride().first, info.ps_info.stride().second, info.padding_top, info.padding_left, zeropad
+ };
+
+ if(info.method == AsmConvMethod::Conv)
+ {
+ _gemm_kernel_asm->set_convolution_parameters(_cp);
+ }
+
+ if(info.method == AsmConvMethod::Indirect)
+ {
+ const unsigned int multis = 1;
+ const unsigned int batches = a->tensor_shape().total_size_upper(3);
+ const unsigned int kernel_hw = _cp.kernel_width * _cp.kernel_height;
+ const unsigned int output_hw = _cp.output_width * _cp.output_height;
+
+ using TypeInputPtr = TypeInput *;
+ const int batch_size = kernel_hw * output_hw * sizeof(TypeInputPtr);
+ const size_t batch_stride = batch_size / sizeof(TypeInputPtr);
+ const int multi_size = batch_size * batches;
+ const size_t multi_stride = multi_size / sizeof(TypeInputPtr);
+
+ _indirect_buf = std::unique_ptr<const TypeInput *, free_delete>(reinterpret_cast<const TypeInput **>(malloc(multi_size * multis)));
+ _indirect_arg = std::unique_ptr<const TypeInput *const *, free_delete>(reinterpret_cast<const TypeInput *const **>(malloc(sizeof(TypeInput **) * kernel_hw * multis * batches)));
+ _indirect_pad = std::vector<TypeInput>(_cp.input_channels, zeropad);
+
+ // Set indirect argument
+ int64_t pos = 0;
+ for(int64_t m = 0; m < multis; m++)
+ {
+ for(int64_t b = 0; b < batches; b++)
+ {
+ for(int64_t kernel_xy = 0; kernel_xy < kernel_hw; kernel_xy++)
+ {
+ (_indirect_arg.get())[pos++] = _indirect_buf.get() + m * multi_stride + b * batch_stride + kernel_xy * output_hw;
+ }
+ }
+ }
+
+ _gemm_kernel_asm->set_indirect_parameters(a->tensor_shape()[0], _indirect_arg.get());
+ }
+}
+
+template <typename TypeInput, typename TypeOutput, class OutputStage>
void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d,
- arm_gemm::GemmArgs args, const GEMMInfo &gemm_info,
+ arm_gemm::GemmArgs args, const AsmGemmInfo &gemm_info,
MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os)
{
arm_gemm::GemmConfig gemm_cfg;
@@ -325,6 +529,12 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensor *a, c
static_cast<Tensor *>(_pretranspose)->allocator()->init(TensorInfo(TensorShape{ (B_pretranspose_size + alignment /* FIXME: remove alignment after COMPMID-1088 */) }, 1, DataType::S8), alignment);
}
}
+
+ // Handle indirect GEMM convolution
+ if(gemm_info.method == AsmConvMethod::Conv || gemm_info.method == AsmConvMethod::Indirect)
+ {
+ configure_indirect(a->info(), b->info(), d->info(), gemm_info);
+ }
}
template <typename TypeInput, typename TypeOutput, class OutputStage>
@@ -365,6 +575,11 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare()
}
}
+ if(_gemm_info.method == AsmConvMethod::Indirect)
+ {
+ prepare_indirect_buffer();
+ }
+
_is_prepared = true;
}
}
@@ -387,23 +602,23 @@ bool Fallback<TypeInput, TypeOutput, OutputStage>::is_configured() const
template <typename TypeInput, typename TypeOutput, class OutputStage>
void Fallback<TypeInput, TypeOutput, OutputStage>::run()
{
- const int lda = _a->info()->strides_in_bytes().y() / sizeof(TypeInput);
+ int lda = _a->info()->strides_in_bytes().y() / sizeof(TypeInput);
int ldb = 0;
const int ldd = _d->info()->strides_in_bytes().y() / sizeof(TypeOutput);
- const size_t a_batch_idx = _gemm_info.reinterpret_input_as_3d() != 0 ? 3 : 2;
+ const size_t a_batch_idx = _gemm_info.reinterpret_input_as_3d != 0 ? 3 : 2;
const size_t a_multi_idx = a_batch_idx + 1;
- const size_t d_batch_idx = _gemm_info.depth_output_gemm3d() != 0 ? 3 : 2;
+ const size_t d_batch_idx = _gemm_info.depth_output_gemm3d != 0 ? 3 : 2;
const size_t d_multi_idx = d_batch_idx + 1;
- const int batch_stride_a = _a->info()->strides_in_bytes()[a_batch_idx] / sizeof(TypeInput);
+ int batch_stride_a = _a->info()->strides_in_bytes()[a_batch_idx] / sizeof(TypeInput);
const int batch_stride_d = _d->info()->strides_in_bytes()[d_batch_idx] / sizeof(TypeOutput);
- const int multi_stride_a = _a->info()->strides_in_bytes()[a_multi_idx] / sizeof(TypeInput);
+ int multi_stride_a = _a->info()->strides_in_bytes()[a_multi_idx] / sizeof(TypeInput);
int multi_stride_b = 0;
const int multi_stride_d = _d->info()->strides_in_bytes()[d_multi_idx] / sizeof(TypeOutput);
- const auto in0_ptr = reinterpret_cast<const TypeInput *>(_a->buffer() + _a->info()->offset_first_element_in_bytes());
+ auto in0_ptr = reinterpret_cast<const TypeInput *>(_a->buffer() + _a->info()->offset_first_element_in_bytes());
const TypeInput *in1_ptr = nullptr;
auto out_ptr = reinterpret_cast<TypeOutput *>(_d->buffer() + _d->info()->offset_first_element_in_bytes());
@@ -415,25 +630,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run()
in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer() + _b->info()->offset_first_element_in_bytes());
}
- IScheduler::Hints scheduling_hint = IScheduler::Hints(Window::DimX);
- if(_kernel_info.method == arm_gemm::GemmMethod::GEMM_INTERLEAVED && _d->info()->data_type() == DataType::F32)
- {
- const int granule_threshold = 200;
- scheduling_hint = IScheduler::Hints(Window::DimX, IScheduler::StrategyHint::DYNAMIC, granule_threshold);
- }
- else if(_kernel_info.method == arm_gemm::GemmMethod::GEMM_INTERLEAVED_2D && (_d->info()->data_type() == DataType::F32 || _d->info()->data_type() == DataType::F16
- || _d->info()->data_type() == DataType::U8 || _d->info()->data_type() == DataType::S8))
- {
- //GEMM_INTERLEAVED supports 2D parallelism, IScheduler::split_dimensions_all signals to parallelise over all window dimensions
- const int granule_threshold = 200;
- scheduling_hint = IScheduler::Hints(IScheduler::split_dimensions_all, IScheduler::StrategyHint::STATIC, granule_threshold);
- }
- else if(_kernel_info.method == arm_gemm::GemmMethod::QUANTIZE_WRAPPER_2D && (_d->info()->data_type() == DataType::QASYMM8 || _d->info()->data_type() == DataType::QASYMM8_SIGNED))
- {
- //special case for QASYMM8 to support 2D parallelism, scheduler here may be tweaked differently compared to FP32 case
- const int granule_threshold = 200;
- scheduling_hint = IScheduler::Hints(IScheduler::split_dimensions_all, IScheduler::StrategyHint::STATIC, granule_threshold);
- }
+ const auto scheduling_hint = scheduling_hint_heuristic(_kernel_info.method, _d->info()->data_type());
// Set workspace if needed and reset number of threads as buffer manager gets re-created with max_threads
if(_workspace.buffer() != nullptr)
@@ -458,57 +655,67 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run()
// Prepare assembly kernel
prepare();
- TypeOutput *bias = nullptr;
// Setup up matrix bias in the assembly kernel, it's just a pointer to matrix C.
+ TypeOutput *bias = nullptr;
if(_c && _c->info()->data_type() != DataType::S32)
{
bias = reinterpret_cast<TypeOutput *>(_c->buffer() + _c->info()->offset_first_element_in_bytes());
}
+
+ if(_gemm_info.method == AsmConvMethod::Indirect)
+ {
+ in0_ptr = nullptr;
+ lda = 0;
+ batch_stride_a = 0;
+ multi_stride_a = 0;
+ }
+
// Set gemm parameters
_gemm_kernel_asm->set_arrays(in0_ptr, lda, batch_stride_a, multi_stride_a,
in1_ptr, ldb, multi_stride_b,
out_ptr, ldd, batch_stride_d, multi_stride_d,
bias, 0);
- // Schedule assembly kernel
+ // Schedule
NEScheduler::get().schedule(_optimised_kernel.get(), scheduling_hint);
}
template <typename TypeInput, typename TypeOutput>
void create_arm_gemm(std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group,
- const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, arm_gemm::Activation activation, const GEMMInfo &gemm_info,
+ const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, arm_gemm::Activation activation, const AsmGemmInfo &info,
IWeightsManager *weights_manager)
{
- INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d, gemm_info);
- const CPUInfo &ci = NEScheduler::get().cpu_info();
- unsigned int num_threads = NEScheduler::get().num_threads();
+ Params p = extract_parameters(a, b, d, info);
+ const CPUInfo &ci = NEScheduler::get().cpu_info();
+ unsigned int num_threads = NEScheduler::get().num_threads();
- arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.batches, p.multis, activation, num_threads);
+ arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads);
// Create arm_gemm fallback
auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput>>();
- fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager);
+ fallback->configure(a, b, c, d, args, info, memory_group, weights_manager);
arm_gemm = std::move(fallback);
}
template <typename TypeInput, typename TypeOutput>
void create_arm_gemm_quant(std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group,
- const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, arm_gemm::Activation activation, const GEMMInfo &gemm_info,
+ const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, arm_gemm::Activation activation, const AsmGemmInfo &info,
IWeightsManager *weights_manager)
{
ARM_COMPUTE_UNUSED(activation);
- INEGEMMWrapperKernel::Params p = INEGEMMWrapperKernel::extract_parameters(a, b, d, gemm_info);
- const CPUInfo &ci = NEScheduler::get().cpu_info();
- unsigned int num_threads = NEScheduler::get().num_threads();
+ Params p = extract_parameters(a, b, d, info);
+ const CPUInfo &ci = NEScheduler::get().cpu_info();
+ unsigned int num_threads = NEScheduler::get().num_threads();
- arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.batches, p.multis, activation, num_threads);
+ arm_gemm::GemmArgs args(&ci, p.M, p.N, p.K, p.sections, p.batches, p.multis, p.indirect, activation, num_threads);
// Create arm_gemm fallback
auto fallback = support::cpp14::make_unique<Fallback<TypeInput, TypeOutput, arm_gemm::Requantize32>>();
// Configure requantization info
- const int32_t a_offset = -a->info()->quantization_info().uniform().offset;
- const int32_t b_offset = -b->info()->quantization_info().uniform().offset;
- const GEMMLowpOutputStageInfo os_info = gemm_info.gemmlowp_output_stage();
+ const int32_t negation = info.negated_offsets ? 1 : -1;
+ const int32_t a_offset = -a->info()->quantization_info().uniform().offset * negation;
+ const int32_t b_offset = -b->info()->quantization_info().uniform().offset * negation;
+ const GEMMLowpOutputStageInfo os_info = info.output_stage;
arm_gemm::Requantize32 gemm_requant_info{};
if(os_info.gemmlowp_shifts.size() > 1)
@@ -530,7 +737,7 @@ void create_arm_gemm_quant(std::unique_ptr<NEGEMMAssemblyDispatch::IFallback> &a
}
// Configure fallback
- fallback->configure(a, b, c, d, args, gemm_info, memory_group, weights_manager, gemm_requant_info);
+ fallback->configure(a, b, c, d, args, info, memory_group, weights_manager, gemm_requant_info);
arm_gemm = std::move(fallback);
}
@@ -541,14 +748,13 @@ NEGEMMAssemblyDispatch::NEGEMMAssemblyDispatch(std::shared_ptr<IMemoryManager> m
{
}
-Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const GEMMInfo &gemm_info)
+Status NEGEMMAssemblyDispatch::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, const AsmGemmInfo &info)
{
- ARM_COMPUTE_UNUSED(c);
+ ARM_COMPUTE_UNUSED(c, info);
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(a, b, d);
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a);
- ARM_COMPUTE_RETURN_ERROR_ON(!gemm_info.pretranpose_B());
#ifndef __aarch64__
ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->element_size() == 1, "8bit integer types only supported for aarch64");
#endif /* __aarch64__ */
@@ -579,13 +785,13 @@ bool NEGEMMAssemblyDispatch::is_activation_supported(const ActivationLayerInfo &
return act.type != arm_gemm::Activation::Type::None;
}
-void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, const GEMMInfo &gemm_info)
+void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, const AsmGemmInfo &info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
- arm_gemm::Activation act = map_to_arm_gemm_activation(gemm_info.activation_info());
+ arm_gemm::Activation act = map_to_arm_gemm_activation(info.activation_info);
//If we don't support a combination of data types, silently return: it is the caller's responsibility to check if configure() was successful via is_configured()
- if(!NEGEMMAssemblyDispatch::validate(a->info(), b->info(), c != nullptr ? c->info() : nullptr, d->info(), gemm_info))
+ if(!NEGEMMAssemblyDispatch::validate(a->info(), b->info(), c != nullptr ? c->info() : nullptr, d->info(), info))
{
return;
}
@@ -593,40 +799,40 @@ void NEGEMMAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const
switch(a->info()->data_type())
{
case DataType::F32:
- create_arm_gemm<float, float>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager);
+ create_arm_gemm<float, float>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
break;
#ifdef __aarch64__
case DataType::U8:
case DataType::QASYMM8:
if(d->info()->data_type() == DataType::S32)
{
- create_arm_gemm<uint8_t, uint32_t>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager);
+ create_arm_gemm<uint8_t, uint32_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
}
else
{
- create_arm_gemm_quant<uint8_t, uint8_t>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager);
+ create_arm_gemm_quant<uint8_t, uint8_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
}
break;
case DataType::S8:
case DataType::QASYMM8_SIGNED:
if(d->info()->data_type() == DataType::S32)
{
- create_arm_gemm<int8_t, int32_t>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager);
+ create_arm_gemm<int8_t, int32_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
}
else
{
- create_arm_gemm_quant<int8_t, int8_t>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager);
+ create_arm_gemm_quant<int8_t, int8_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
}
break;
#endif /* __aarch64__ */
#if defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16)
case DataType::BFLOAT16:
- create_arm_gemm<bfloat16, float>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager);
+ create_arm_gemm<bfloat16, float>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
break;
#endif /* defined(__ARM_FEATURE_BF16_VECTOR_ARITHMETIC) || defined(ARM_COMPUTE_FORCE_BF16) */
#ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
case DataType::F16:
- create_arm_gemm<float16_t, float16_t>(_arm_gemm, _memory_group, a, b, c, d, act, gemm_info, _weights_manager);
+ create_arm_gemm<float16_t, float16_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
break;
#endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
default: