aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
diff options
context:
space:
mode:
authorSang-Hoon Park <sang-hoon.park@arm.com>2021-05-17 17:04:50 +0100
committerSang-Hoon Park <sang-hoon.park@arm.com>2021-05-26 10:16:05 +0000
commitd89e2faa60d148f3c04e57032a28f1065a1be0e8 (patch)
treec95eb97f9c79198cb5db1232b497491df10614f2 /src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
parent8b83d4684249bb96e27f95e11cf8f38a1c33b82b (diff)
downloadComputeLibrary-d89e2faa60d148f3c04e57032a28f1065a1be0e8.tar.gz
Create CpuGemmDirectConv2d
As the first phase of making NEGEMMConv2d stateless, CpuGemmDirectConv2d operator is created. Kernels and operators used by the operator use TensorInfo pointers instead of Tensor pointers. The CpuGemmDirectConv2d isn't completely stateless because it manages one intermediate tensor internally. This will be resolved by implementing memory injection mechanism with the following patches. Also, weight manager of CpuGemmAssemblyDispatch is disabled to enable this work. Implements: COMPMID-4506 Change-Id: Iec3ca6de29d98bef7ea95e8f4473d6dc0024a140 Signed-off-by: Sang-Hoon Park <sang-hoon.park@arm.com> Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/5672 Tested-by: Arm Jenkins <bsgcomp@arm.com> Reviewed-by: Michele Di Giorgio <michele.digiorgio@arm.com> Reviewed-by: Georgios Pinitas <georgios.pinitas@arm.com> Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp')
-rw-r--r--src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp190
1 files changed, 98 insertions, 92 deletions
diff --git a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
index 36c1bbb1b3..0c511ff548 100644
--- a/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
+++ b/src/runtime/cpu/operators/internal/CpuGemmAssemblyDispatch.cpp
@@ -56,14 +56,13 @@ struct Params
bool indirect;
};
-Params extract_parameters(const ITensor *a, const ITensor *b, const ITensor *d, const AsmGemmInfo &info)
+Params extract_parameters(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, const AsmGemmInfo &info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
-
Params p;
- p.M = d->info()->tensor_shape().y();
- p.K = a->info()->tensor_shape().x();
- p.N = d->info()->tensor_shape().x();
+ p.M = d->tensor_shape().y();
+ p.K = a->tensor_shape().x();
+ p.N = d->tensor_shape().x();
p.batches = 1;
p.multis = 1;
p.sections = 1;
@@ -72,19 +71,19 @@ Params extract_parameters(const ITensor *a, const ITensor *b, const ITensor *d,
if(info.method == AsmConvMethod::Conv || info.method == AsmConvMethod::Indirect)
{
p.indirect = true;
- p.sections = b->info()->tensor_shape()[2] * b->info()->tensor_shape()[3];
+ p.sections = b->tensor_shape()[2] * b->tensor_shape()[3];
}
else
{
- p.multis = b->info()->tensor_shape().z();
- p.batches = d->info()->tensor_shape().total_size_upper(2) / p.multis;
+ p.multis = b->tensor_shape().z();
+ p.batches = d->tensor_shape().total_size_upper(2) / p.multis;
}
// Update M in case of GEMM3D for output
if(info.depth_output_gemm3d != 0)
{
- p.M = d->info()->tensor_shape().y() * d->info()->tensor_shape().z();
- p.batches = d->info()->tensor_shape().total_size_upper(3) / p.multis;
+ p.M = d->tensor_shape().y() * d->tensor_shape().z();
+ p.batches = d->tensor_shape().total_size_upper(3) / p.multis;
}
return p;
@@ -205,11 +204,11 @@ public:
}
private:
- Tensor _output{};
- int _ldb{};
- const TypeInput *_in1_ptr{};
- int _multi_stride_b{};
- size_t _B_pretranspose_size{};
+ Tensor _output{};
+ int _ldb{};
+ const TypeInput *_in1_ptr{};
+ int _multi_stride_b{};
+ size_t _B_pretranspose_size{};
std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{ nullptr };
};
@@ -221,8 +220,7 @@ public:
/** Destructor */
~Fallback()
{
- // Release memory if we have allocated the memory ourselves
- if(_pretranspose && !(_weights_manager && _weights_manager->are_weights_managed(_b)))
+ if(_pretranspose && !(is_weight_managed()))
{
delete _pretranspose;
}
@@ -240,7 +238,7 @@ public:
* @param[in] weights_manager Weights manager to be used by the function.
* @param[in] os Output stage meta-data.
*/
- void configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d,
+ void configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d,
arm_gemm::GemmArgs args, const AsmGemmInfo &gemm_info,
MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os = {});
@@ -262,8 +260,8 @@ public:
const std::vector<int32_t> &multipliers);
// Inherited methods overridden:
- void run() override;
- void prepare() override;
+ void run(ITensorPack &tensors) override;
+ void prepare(ITensorPack &tensors) override;
bool is_configured() const override;
private:
@@ -283,28 +281,12 @@ private:
*/
void configure_indirect(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *d, const AsmGemmInfo &info);
/** Prepare the indirect buffer */
- void prepare_indirect_buffer();
+ void prepare_indirect_buffer(ITensorPack &tensors);
/** Assembly Gemm kernel */
std::shared_ptr<arm_gemm::GemmCommon<TypeInput, TypeOutput>> _gemm_kernel_asm{ nullptr };
/** Optimised Arm® Neon™ kernel */
std::unique_ptr<INEKernel> _optimised_kernel{ nullptr };
- /** Input A */
- const ITensor *_a
- {
- nullptr
- };
- /** Input B */
- const ITensor *_b
- {
- nullptr
- };
- const ITensor *_c
- {
- nullptr
- };
- /** Output */
- ITensor *_d{ nullptr };
/** GEMM workspace */
Tensor _workspace{};
/** Pre-transpose tensor */
@@ -328,8 +310,27 @@ private:
/** Indirect buffer */
std::unique_ptr<const TypeInput *const *, free_delete> _indirect_arg{};
std::unique_ptr<const TypeInput *, free_delete> _indirect_buf{};
- std::vector<TypeInput> _indirect_pad{};
- arm_gemm::ConvolutionParameters _cp{};
+ std::vector<TypeInput> _indirect_pad{};
+ arm_gemm::ConvolutionParameters _cp{};
+
+ bool is_weight_managed()
+ {
+ // TODO (COMPMID-4539): This function should do the following:
+ // _weights_manager && _weights_manager->are_weights_managed(_b)
+ // , where _b is the second Tensor that is used to be given to the configure().
+ // Currently, however, weight manager is disabled to make this class stateless.
+ // This should be revisited in the future.
+ return false;
+ }
+
+ void acquire_managed_weight()
+ {
+ // TODO (COMPMID-4539): This function should do the following:
+ // _pretranspose = _weights_manager->acquire(_b, &_weights_transform);
+ // , where _b is the second Tensor that is used to be given to the configure().
+ // Currently, however, weight manager is disabled to make this class stateless.
+ _pretranspose = nullptr;
+ }
};
template <typename TypeInput, typename TypeOutput, class OutputStage>
@@ -352,14 +353,15 @@ Fallback<TypeInput, TypeOutput, OutputStage>::set_requantize_data(const std::vec
}
template <typename TypeInput, typename TypeOutput, class OutputStage>
-void Fallback<TypeInput, TypeOutput, OutputStage>::prepare_indirect_buffer()
+void Fallback<TypeInput, TypeOutput, OutputStage>::prepare_indirect_buffer(ITensorPack &tensors)
{
- const TypeInput *A_ptr = reinterpret_cast<TypeInput *>(_a->buffer());
+ auto a = tensors.get_const_tensor(TensorType::ACL_SRC_0);
+ const TypeInput *A_ptr = reinterpret_cast<TypeInput *>(a->buffer());
const int multis = 1;
- const int batches = _a->info()->tensor_shape().total_size_upper(3);
- const size_t stride_A = _a->info()->strides_in_bytes().y() / sizeof(TypeInput);
- const size_t batch_stride_A = _a->info()->strides_in_bytes()[3] / sizeof(TypeInput);
- const size_t multi_stride_A = _a->info()->strides_in_bytes()[4] / sizeof(TypeInput);
+ const int batches = a->info()->tensor_shape().total_size_upper(3);
+ const size_t stride_A = a->info()->strides_in_bytes().y() / sizeof(TypeInput);
+ const size_t batch_stride_A = a->info()->strides_in_bytes()[3] / sizeof(TypeInput);
+ const size_t multi_stride_A = a->info()->strides_in_bytes()[4] / sizeof(TypeInput);
const size_t output_hw = _cp.output_height * _cp.output_width;
const int batch_size = _cp.kernel_height * _cp.kernel_width * output_hw * sizeof(TypeInput);
@@ -466,10 +468,11 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure_indirect(const ITen
}
template <typename TypeInput, typename TypeOutput, class OutputStage>
-void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d,
+void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d,
arm_gemm::GemmArgs args, const AsmGemmInfo &gemm_info,
MemoryGroup &memory_group, IWeightsManager *weights_manager, const OutputStage &os)
{
+ ARM_COMPUTE_UNUSED(c);
arm_gemm::GemmConfig gemm_cfg;
_kernel_info = arm_gemm::get_gemm_method<TypeInput, TypeOutput, OutputStage>(args, os);
_weights_manager = weights_manager;
@@ -508,10 +511,6 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensor *a, c
}
_optimised_kernel = std::move(acl_gemm_wrapper);
- _a = a;
- _b = b;
- _c = c;
- _d = d;
_gemm_info = gemm_info;
// Check for pre-transposed support
if(_gemm_kernel_asm->B_pretranspose_required())
@@ -519,10 +518,10 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensor *a, c
// Forcing 128-byte alignment (required by 32-bit kernels)
const unsigned int alignment = 128;
const size_t B_pretranspose_size = _gemm_kernel_asm->get_B_pretransposed_array_size();
- if(weights_manager && _weights_manager->are_weights_managed(b))
+ if(is_weight_managed())
{
_weights_transform.configure(B_pretranspose_size, alignment);
- _pretranspose = _weights_manager->acquire(b, &_weights_transform);
+ acquire_managed_weight();
}
else
{
@@ -534,32 +533,34 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::configure(const ITensor *a, c
// Handle indirect GEMM convolution
if(gemm_info.method == AsmConvMethod::Conv || gemm_info.method == AsmConvMethod::Indirect)
{
- configure_indirect(a->info(), b->info(), d->info(), gemm_info);
+ configure_indirect(a, b, d, gemm_info);
}
}
template <typename TypeInput, typename TypeOutput, class OutputStage>
-void Fallback<TypeInput, TypeOutput, OutputStage>::prepare()
+void Fallback<TypeInput, TypeOutput, OutputStage>::prepare(ITensorPack &tensors)
{
+ auto b = tensors.get_const_tensor(TensorType::ACL_SRC_1);
+ auto c = tensors.get_const_tensor(TensorType::ACL_SRC_2);
if(!_is_prepared)
{
// Setup up matrix bias in the assembly kernel, it's just a pointer to matrix C.
- if(_c && _c->info()->data_type() == DataType::S32)
+ if(c && c->info()->data_type() == DataType::S32)
{
- _gemm_kernel_asm->set_quantized_bias(reinterpret_cast<const int32_t *>(_c->buffer() + _c->info()->offset_first_element_in_bytes()), 0);
+ _gemm_kernel_asm->set_quantized_bias(reinterpret_cast<const int32_t *>(c->buffer() + c->info()->offset_first_element_in_bytes()), 0);
}
// Pretranspose B if required
if(_gemm_kernel_asm->B_pretranspose_required())
{
- const int ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput);
- const auto in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer() + _b->info()->offset_first_element_in_bytes());
- const int multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput);
+ const int ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
+ const auto in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
+ const int multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
- if(_weights_manager && _weights_manager->are_weights_managed(_b))
+ if(is_weight_managed())
{
_weights_transform.set_args(ldb, in1_ptr, multi_stride_b, _gemm_kernel_asm);
- _weights_manager->run(_b, &_weights_transform);
+ _weights_manager->run(b, &_weights_transform);
// If we didn't run the reshape function, set the pretransposed buffer
if(!_weights_transform.is_reshape_run())
@@ -572,13 +573,13 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::prepare()
static_cast<Tensor *>(_pretranspose)->allocator()->allocate();
ARM_COMPUTE_ERROR_ON(_pretranspose->buffer() == nullptr);
_gemm_kernel_asm->pretranspose_B_array(_pretranspose->buffer(), in1_ptr, ldb, multi_stride_b);
- _b->mark_as_unused();
+ b->mark_as_unused();
}
}
if(_gemm_info.method == AsmConvMethod::Indirect)
{
- prepare_indirect_buffer();
+ prepare_indirect_buffer(tensors);
}
_is_prepared = true;
@@ -601,37 +602,42 @@ bool Fallback<TypeInput, TypeOutput, OutputStage>::is_configured() const
}
template <typename TypeInput, typename TypeOutput, class OutputStage>
-void Fallback<TypeInput, TypeOutput, OutputStage>::run()
+void Fallback<TypeInput, TypeOutput, OutputStage>::run(ITensorPack &tensors)
{
- int lda = _a->info()->strides_in_bytes().y() / sizeof(TypeInput);
+ auto a = tensors.get_const_tensor(TensorType::ACL_SRC_0);
+ auto b = tensors.get_const_tensor(TensorType::ACL_SRC_1);
+ auto c = tensors.get_const_tensor(TensorType::ACL_SRC_2);
+ auto d = tensors.get_tensor(TensorType::ACL_DST);
+
+ int lda = a->info()->strides_in_bytes().y() / sizeof(TypeInput);
int ldb = 0;
- const int ldd = _d->info()->strides_in_bytes().y() / sizeof(TypeOutput);
+ const int ldd = d->info()->strides_in_bytes().y() / sizeof(TypeOutput);
const size_t a_batch_idx = _gemm_info.reinterpret_input_as_3d != 0 ? 3 : 2;
const size_t a_multi_idx = a_batch_idx + 1;
const size_t d_batch_idx = _gemm_info.depth_output_gemm3d != 0 ? 3 : 2;
const size_t d_multi_idx = d_batch_idx + 1;
- int batch_stride_a = _a->info()->strides_in_bytes()[a_batch_idx] / sizeof(TypeInput);
- const int batch_stride_d = _d->info()->strides_in_bytes()[d_batch_idx] / sizeof(TypeOutput);
+ int batch_stride_a = a->info()->strides_in_bytes()[a_batch_idx] / sizeof(TypeInput);
+ const int batch_stride_d = d->info()->strides_in_bytes()[d_batch_idx] / sizeof(TypeOutput);
- int multi_stride_a = _a->info()->strides_in_bytes()[a_multi_idx] / sizeof(TypeInput);
+ int multi_stride_a = a->info()->strides_in_bytes()[a_multi_idx] / sizeof(TypeInput);
int multi_stride_b = 0;
- const int multi_stride_d = _d->info()->strides_in_bytes()[d_multi_idx] / sizeof(TypeOutput);
+ const int multi_stride_d = d->info()->strides_in_bytes()[d_multi_idx] / sizeof(TypeOutput);
- auto in0_ptr = reinterpret_cast<const TypeInput *>(_a->buffer() + _a->info()->offset_first_element_in_bytes());
+ auto in0_ptr = reinterpret_cast<const TypeInput *>(a->buffer() + a->info()->offset_first_element_in_bytes());
const TypeInput *in1_ptr = nullptr;
- auto out_ptr = reinterpret_cast<TypeOutput *>(_d->buffer() + _d->info()->offset_first_element_in_bytes());
+ auto out_ptr = reinterpret_cast<TypeOutput *>(d->buffer() + d->info()->offset_first_element_in_bytes());
// Check if B is pre-tranposed and de-reference if not
if(!_gemm_kernel_asm->B_is_pretransposed())
{
- ldb = _b->info()->strides_in_bytes().y() / sizeof(TypeInput);
- multi_stride_b = _b->info()->strides_in_bytes().z() / sizeof(TypeInput);
- in1_ptr = reinterpret_cast<const TypeInput *>(_b->buffer() + _b->info()->offset_first_element_in_bytes());
+ ldb = b->info()->strides_in_bytes().y() / sizeof(TypeInput);
+ multi_stride_b = b->info()->strides_in_bytes().z() / sizeof(TypeInput);
+ in1_ptr = reinterpret_cast<const TypeInput *>(b->buffer() + b->info()->offset_first_element_in_bytes());
}
- const auto scheduling_hint = scheduling_hint_heuristic(_kernel_info.method, _d->info()->data_type());
+ const auto scheduling_hint = scheduling_hint_heuristic(_kernel_info.method, d->info()->data_type());
// Set workspace if needed and reset number of threads as buffer manager gets re-created with max_threads
if(_workspace.buffer() != nullptr)
@@ -654,13 +660,13 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run()
}
// Prepare assembly kernel
- prepare();
+ prepare(tensors);
// Setup up matrix bias in the assembly kernel, it's just a pointer to matrix C.
TypeOutput *bias = nullptr;
- if(_c && _c->info()->data_type() != DataType::S32)
+ if(c && c->info()->data_type() != DataType::S32)
{
- bias = reinterpret_cast<TypeOutput *>(_c->buffer() + _c->info()->offset_first_element_in_bytes());
+ bias = reinterpret_cast<TypeOutput *>(c->buffer() + c->info()->offset_first_element_in_bytes());
}
if(_gemm_info.method == AsmConvMethod::Indirect)
@@ -682,7 +688,7 @@ void Fallback<TypeInput, TypeOutput, OutputStage>::run()
template <typename TypeInput, typename TypeOutput>
void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group,
- const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, arm_gemm::Activation activation, const AsmGemmInfo &info,
+ const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, arm_gemm::Activation activation, const AsmGemmInfo &info,
IWeightsManager *weights_manager)
{
Params p = extract_parameters(a, b, d, info);
@@ -699,7 +705,7 @@ void create_arm_gemm(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_ge
template <typename TypeInput, typename TypeOutput>
void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &arm_gemm, MemoryGroup &memory_group,
- const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, arm_gemm::Activation activation, const AsmGemmInfo &info,
+ const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, arm_gemm::Activation activation, const AsmGemmInfo &info,
IWeightsManager *weights_manager)
{
ARM_COMPUTE_UNUSED(activation);
@@ -714,8 +720,8 @@ void create_arm_gemm_quant(std::unique_ptr<CpuGemmAssemblyDispatch::IFallback> &
// Configure requantization info
const int32_t negation = info.negated_offsets ? 1 : -1;
- const int32_t a_offset = -a->info()->quantization_info().uniform().offset * negation;
- const int32_t b_offset = -b->info()->quantization_info().uniform().offset * negation;
+ const int32_t a_offset = -a->quantization_info().uniform().offset * negation;
+ const int32_t b_offset = -b->quantization_info().uniform().offset * negation;
const GEMMLowpOutputStageInfo os_info = info.output_stage;
arm_gemm::Requantize32 gemm_requant_info{};
@@ -786,18 +792,18 @@ bool CpuGemmAssemblyDispatch::is_activation_supported(const ActivationLayerInfo
return act.type != arm_gemm::Activation::Type::None;
}
-void CpuGemmAssemblyDispatch::configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *d, const AsmGemmInfo &info)
+void CpuGemmAssemblyDispatch::configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, const AsmGemmInfo &info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
arm_gemm::Activation act = map_to_arm_gemm_activation(info.activation_info);
//If we don't support a combination of data types, silently return: it is the caller's responsibility to check if configure() was successful via is_configured()
- if(!CpuGemmAssemblyDispatch::validate(a->info(), b->info(), c != nullptr ? c->info() : nullptr, d->info(), info))
+ if(!CpuGemmAssemblyDispatch::validate(a, b, c, d, info))
{
return;
}
- switch(a->info()->data_type())
+ switch(a->data_type())
{
case DataType::F32:
create_arm_gemm<float, float>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
@@ -805,7 +811,7 @@ void CpuGemmAssemblyDispatch::configure(const ITensor *a, const ITensor *b, cons
#ifdef __aarch64__
case DataType::U8:
case DataType::QASYMM8:
- if(d->info()->data_type() == DataType::S32)
+ if(d->data_type() == DataType::S32)
{
create_arm_gemm<uint8_t, uint32_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
}
@@ -816,7 +822,7 @@ void CpuGemmAssemblyDispatch::configure(const ITensor *a, const ITensor *b, cons
break;
case DataType::S8:
case DataType::QASYMM8_SIGNED:
- if(d->info()->data_type() == DataType::S32)
+ if(d->data_type() == DataType::S32)
{
create_arm_gemm<int8_t, int32_t>(_arm_gemm, _memory_group, a, b, c, d, act, info, _weights_manager);
}
@@ -841,10 +847,10 @@ void CpuGemmAssemblyDispatch::configure(const ITensor *a, const ITensor *b, cons
}
}
-void CpuGemmAssemblyDispatch::prepare()
+void CpuGemmAssemblyDispatch::prepare(ITensorPack &tensors)
{
ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
- _arm_gemm->prepare();
+ _arm_gemm->prepare(tensors);
}
bool CpuGemmAssemblyDispatch::is_configured() const
@@ -852,12 +858,12 @@ bool CpuGemmAssemblyDispatch::is_configured() const
return _arm_gemm != nullptr && _arm_gemm->is_configured();
}
-void CpuGemmAssemblyDispatch::run()
+void CpuGemmAssemblyDispatch::run(ITensorPack &tensors)
{
MemoryGroupResourceScope scope_mg(_memory_group);
ARM_COMPUTE_ERROR_ON(_arm_gemm == nullptr);
- _arm_gemm->run();
+ _arm_gemm->run(tensors);
}
} // namespace cpu
} // namespace arm_compute