aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/CpuMatMul.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators/CpuMatMul.cpp')
-rw-r--r--src/cpu/operators/CpuMatMul.cpp52
1 files changed, 50 insertions, 2 deletions
diff --git a/src/cpu/operators/CpuMatMul.cpp b/src/cpu/operators/CpuMatMul.cpp
index b5359e51af..369466b669 100644
--- a/src/cpu/operators/CpuMatMul.cpp
+++ b/src/cpu/operators/CpuMatMul.cpp
@@ -23,7 +23,9 @@
*/
#include "src/cpu/operators/CpuMatMul.h"
+#include "arm_compute/core/Types.h"
#include "arm_compute/core/Validate.h"
+#include "arm_compute/core/utils/quantization/AsymmHelpers.h"
#include "arm_compute/core/experimental/Types.h"
#include "arm_compute/core/utils/misc/ShapeCalculator.h"
#include "arm_compute/runtime/NEON/NEScheduler.h"
@@ -40,6 +42,40 @@ namespace arm_compute
{
namespace cpu
{
+namespace
+{
+
+Status get_gemmlowp_output_stage_info(const ITensorInfo *src, const ITensorInfo *weights, const ITensorInfo *dst, const ActivationLayerInfo &act,
+ GEMMLowpOutputStageInfo &gemmlowp_output_stage_info)
+{
+ const auto data_type = src->data_type();
+ const QuantizationInfo oq_info = dst->quantization_info();
+ const UniformQuantizationInfo iq_unif = src->quantization_info().uniform();
+ const UniformQuantizationInfo wq_unif = weights->quantization_info().uniform();
+ const UniformQuantizationInfo oq_unif = oq_info.uniform();
+
+ float multiplier = (iq_unif.scale * wq_unif.scale) / oq_unif.scale;
+ int32_t output_multiplier;
+ int32_t output_shift;
+
+ ARM_COMPUTE_RETURN_ON_ERROR(quantization::calculate_quantized_multiplier(multiplier, &output_multiplier, &output_shift));
+
+ PixelValue type_min{};
+ PixelValue type_max{};
+ std::tie(type_min, type_max) = quantization::get_quantized_asymmetric_output_min_max(oq_info, act, data_type);
+
+ gemmlowp_output_stage_info.gemmlowp_multiplier = output_multiplier;
+ gemmlowp_output_stage_info.gemmlowp_shift = output_shift;
+ gemmlowp_output_stage_info.gemmlowp_offset = oq_unif.offset;
+ gemmlowp_output_stage_info.type = GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
+ gemmlowp_output_stage_info.gemmlowp_min_bound = type_min.get<int32_t>();
+ gemmlowp_output_stage_info.gemmlowp_max_bound = type_max.get<int32_t>();
+
+ return Status{};
+}
+
+}
+
CpuMatMul::CpuMatMul()
: _transpose_kernel_lhs(), _transpose_kernel_rhs(), _asm_glue(), _lhs_transposed(), _rhs_transposed(), _original_lhs_shape(), _original_rhs_shape(), _original_dst_shape()
{
@@ -47,8 +83,8 @@ CpuMatMul::CpuMatMul()
Status CpuMatMul::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, const MatMulInfo &info, const CpuMatMulSettings &settings)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::F32, DataType::F16);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs, dst);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::F32, DataType::F16, DataType::QASYMM8, DataType::QASYMM8_SIGNED);
ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs->are_values_constant(), "LHS Tensor must be dynamic.");
ARM_COMPUTE_RETURN_ERROR_ON_MSG(rhs->are_values_constant(), "RHS Tensor must be dynamic.");
ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(lhs);
@@ -96,6 +132,12 @@ Status CpuMatMul::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const
ARM_COMPUTE_RETURN_ERROR_ON_MSG(lhs_to_use->dimension(i) != rhs_to_use->dimension(i), "Broadcasting in Batch dimension is unsupported by this operator.");
}
+ // Quantized-specific configuration
+ if(is_data_type_quantized(lhs->data_type()))
+ {
+ ARM_COMPUTE_RETURN_ON_ERROR(get_gemmlowp_output_stage_info(lhs_to_use, rhs_to_use, dst, gemm_info.activation_info, gemm_info.output_stage));
+ }
+
cpu::CpuGemmAssemblyDispatch::validate(lhs_to_use, rhs_to_use, nullptr, dst, gemm_info);
return Status{};
@@ -157,6 +199,12 @@ void CpuMatMul::configure(ITensorInfo *lhs, ITensorInfo *rhs, ITensorInfo *dst,
lhs_to_use = (_adj_lhs) ? _lhs_transposed : lhs_to_use;
rhs_to_use = (_adj_rhs) ? _rhs_transposed : rhs_to_use;
+ // Quantized-specific configuration
+ if(is_data_type_quantized(lhs->data_type()))
+ {
+ get_gemmlowp_output_stage_info(&lhs_to_use, &rhs_to_use, &dst_to_use, _gemm_info.activation_info, _gemm_info.output_stage);
+ }
+
// Configure Asm Kernel
_asm_glue = std::make_unique<cpu::CpuGemmAssemblyDispatch>();
_asm_glue->configure(&lhs_to_use, &rhs_to_use, nullptr, &dst_to_use, _gemm_info); // c is nullptr as bias not supported in MatMul