aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/CpuMatMul.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators/CpuMatMul.cpp')
-rw-r--r--src/cpu/operators/CpuMatMul.cpp14
1 files changed, 6 insertions, 8 deletions
diff --git a/src/cpu/operators/CpuMatMul.cpp b/src/cpu/operators/CpuMatMul.cpp
index 87cb6c6b54..515b511044 100644
--- a/src/cpu/operators/CpuMatMul.cpp
+++ b/src/cpu/operators/CpuMatMul.cpp
@@ -25,9 +25,9 @@
#include "src/cpu/operators/CpuMatMul.h"
#include "arm_compute/core/Types.h"
#include "arm_compute/core/Validate.h"
-#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
#include "arm_compute/core/experimental/Types.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
+#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
#include "arm_compute/runtime/NEON/NEScheduler.h"
#include "arm_compute/runtime/NEON/functions/NEMatMul.h"
#include "src/common/utils/Log.h"
@@ -45,7 +45,6 @@ namespace cpu
{
namespace
{
-
Status get_gemmlowp_output_stage_info(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *dst, const ActivationLayerInfo &act,
GEMMLowpOutputStageInfo &gemmlowp_output_stage_info)
{
@@ -74,15 +73,14 @@ Status get_gemmlowp_output_stage_info(const ITensorInfo *src, const ITensorInfo
return Status{};
}
-
-}
+} // namespace
CpuMatMul::CpuMatMul()
: _transpose_kernel_lhs(), _transpose_kernel_rhs(), _asm_glue(), _lhs_transposed(), _rhs_transposed(), _original_lhs_shape(), _original_rhs_shape(), _original_dst_shape()
{
}
-Status CpuMatMul::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, const MatMulInfo &info, const CpuMatMulSettings &settings)
+Status CpuMatMul::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, const MatMulInfo &info, const CpuMatMulSettings &settings, const ActivationLayerInfo &act_info)
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs, dst);
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::F32, DataType::F16, DataType::QASYMM8, DataType::QASYMM8_SIGNED);
@@ -100,7 +98,7 @@ Status CpuMatMul::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const
TensorInfo rhs_transposed{};
auto gemm_info = AsmGemmInfo();
- gemm_info.activation_info = info.fused_activation();
+ gemm_info.activation_info = act_info;
gemm_info.fast_mode = settings.fast_math();
// Validate and then permute a/b
@@ -139,7 +137,7 @@ Status CpuMatMul::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const
return Status{};
}
-void CpuMatMul::configure(ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, const MatMulInfo &info, const CpuMatMulSettings &settings)
+void CpuMatMul::configure(ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, const MatMulInfo &info, const CpuMatMulSettings &settings, const ActivationLayerInfo &act_info)
{
ARM_COMPUTE_ERROR_ON_NULLPTR(lhs, rhs, dst);
ARM_COMPUTE_LOG_PARAMS(lhs, rhs, dst, info, settings);
@@ -189,7 +187,7 @@ void CpuMatMul::configure(ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst,
// -----------------------------------------------------
// Use transposed tensors if the corresponding transpose flags are set
// Fill AsmGemmInfo class object before configuration
- _gemm_info.activation_info = info.fused_activation();
+ _gemm_info.activation_info = act_info;
_gemm_info.fast_mode = settings.fast_math();
_gemm_info.negated_offsets = false;