diff options
Diffstat (limited to 'src/cpu/operators')
-rw-r--r-- | src/cpu/operators/CpuMatMul.cpp | 14 | ||||
-rw-r--r-- | src/cpu/operators/CpuMatMul.h | 6 |
2 files changed, 10 insertions, 10 deletions
diff --git a/src/cpu/operators/CpuMatMul.cpp b/src/cpu/operators/CpuMatMul.cpp index 87cb6c6b54..515b511044 100644 --- a/src/cpu/operators/CpuMatMul.cpp +++ b/src/cpu/operators/CpuMatMul.cpp @@ -25,9 +25,9 @@ #include "src/cpu/operators/CpuMatMul.h" #include "arm_compute/core/Types.h" #include "arm_compute/core/Validate.h" -#include "arm_compute/core/utils/quantization/AsymmHelpers.h" #include "arm_compute/core/experimental/Types.h" #include "arm_compute/core/utils/misc/ShapeCalculator.h" +#include "arm_compute/core/utils/quantization/AsymmHelpers.h" #include "arm_compute/runtime/NEON/NEScheduler.h" #include "arm_compute/runtime/NEON/functions/NEMatMul.h" #include "src/common/utils/Log.h" @@ -45,7 +45,6 @@ namespace cpu { namespace { - Status get_gemmlowp_output_stage_info(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *dst, const ActivationLayerInfo &act, GEMMLowpOutputStageInfo &gemmlowp_output_stage_info) { @@ -74,15 +73,14 @@ Status get_gemmlowp_output_stage_info(const ITensorInfo *src, const ITensorInfo return Status{}; } - -} +} // namespace CpuMatMul::CpuMatMul() : _transpose_kernel_lhs(), _transpose_kernel_rhs(), _asm_glue(), _lhs_transposed(), _rhs_transposed(), _original_lhs_shape(), _original_rhs_shape(), _original_dst_shape() { } -Status CpuMatMul::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, const MatMulInfo &info, const CpuMatMulSettings &settings) +Status CpuMatMul::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, const MatMulInfo &info, const CpuMatMulSettings &settings, const ActivationLayerInfo &act_info) { ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs, dst); ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::F32, DataType::F16, DataType::QASYMM8, DataType::QASYMM8_SIGNED); @@ -100,7 +98,7 @@ Status CpuMatMul::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const TensorInfo rhs_transposed{}; auto gemm_info = AsmGemmInfo(); - gemm_info.activation_info = info.fused_activation(); + gemm_info.activation_info = act_info; gemm_info.fast_mode = settings.fast_math(); // Validate and then permute a/b @@ -139,7 +137,7 @@ Status CpuMatMul::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const return Status{}; } -void CpuMatMul::configure(ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, const MatMulInfo &info, const CpuMatMulSettings &settings) +void CpuMatMul::configure(ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, const MatMulInfo &info, const CpuMatMulSettings &settings, const ActivationLayerInfo &act_info) { ARM_COMPUTE_ERROR_ON_NULLPTR(lhs, rhs, dst); ARM_COMPUTE_LOG_PARAMS(lhs, rhs, dst, info, settings); @@ -189,7 +187,7 @@ void CpuMatMul::configure(ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, // ----------------------------------------------------- // Use transposed tensors if the corresponding transpose flags are set // Fill AsmGemmInfo class object before configuration - _gemm_info.activation_info = info.fused_activation(); + _gemm_info.activation_info = act_info; _gemm_info.fast_mode = settings.fast_math(); _gemm_info.negated_offsets = false; diff --git a/src/cpu/operators/CpuMatMul.h b/src/cpu/operators/CpuMatMul.h index 9f5833b24f..475c019fd0 100644 --- a/src/cpu/operators/CpuMatMul.h +++ b/src/cpu/operators/CpuMatMul.h @@ -64,15 +64,17 @@ public: * @param[out] dst Output tensor to store the result of the batched matrix multiplication. Data types supported: same as @p lhs / @p rhs. * @param[in] info Contains MatMul operation information described in @ref MatMulInfo. * @param[in] settings The settings for matmul operation (i.e fast math) + * @param[in] act_info Class containing information about fused activation function. */ - void configure(ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, const MatMulInfo &info, const CpuMatMulSettings &settings); + void configure(ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst, const MatMulInfo &info, const CpuMatMulSettings &settings, const ActivationLayerInfo &act_info = ActivationLayerInfo()); /** Static function to check if given info will lead to a valid configuration * * Similar to CpuMatMul::configure() * * @return a status */ - static Status validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, const MatMulInfo &info, const CpuMatMulSettings &settings); + static Status validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, const MatMulInfo &info, const CpuMatMulSettings &settings, + const ActivationLayerInfo &act_info = ActivationLayerInfo()); // Inherited methods overridden: void run(ITensorPack &tensors) override; |