aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp')
-rw-r--r--src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp21
1 files changed, 11 insertions, 10 deletions
diff --git a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
index 4c0a521de8..cdb78c291d 100644
--- a/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
+++ b/src/runtime/CL/functions/CLGEMMLowpMatrixMultiplyCore.cpp
@@ -1,5 +1,5 @@
/*
- * Copyright (c) 2017-2019 ARM Limited.
+ * Copyright (c) 2017-2020 ARM Limited.
*
* SPDX-License-Identifier: MIT
*
@@ -206,8 +206,10 @@ void CLGEMMLowpMatrixMultiplyCore::configure(const ICLTensor *a, const ICLTensor
_gemm_output_stage_multipliers.allocator()->init(TensorInfo(TensorShape(num_filters), 1, DataType::S32));
_gemm_output_stage_shifts.allocator()->init(TensorInfo(TensorShape(num_filters), 1, DataType::S32));
+ GEMMLowpOutputStageInfo gemmlowp_output_stage = gemm_info.gemmlowp_output_stage();
+ gemmlowp_output_stage.output_data_type = _matrix_a->info()->data_type();
_offset_contribution_output_stage_kernel.configure(&_mm_result_s32, _a_offset == 0 ? nullptr : &_vector_sum_col, _b_offset == 0 ? nullptr : &_vector_sum_row, c, output, a->info()->dimension(0),
- _a_offset, _b_offset, gemm_info.gemmlowp_output_stage(), &_gemm_output_stage_multipliers, &_gemm_output_stage_shifts);
+ _a_offset, _b_offset, gemmlowp_output_stage, &_gemm_output_stage_multipliers, &_gemm_output_stage_shifts);
_gemm_output_stage_multipliers.allocator()->allocate();
_gemm_output_stage_shifts.allocator()->allocate();
@@ -271,13 +273,10 @@ void CLGEMMLowpMatrixMultiplyCore::configure(const ICLTensor *a, const ICLTensor
Status CLGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, const GEMMInfo &gemm_info)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::QASYMM8);
- if(b->data_type() == DataType::QSYMM8_PER_CHANNEL)
- {
- //DataType::QSYMM8_PER_CHANNEL supported only for weights
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->data_type() != DataType::QASYMM8, "Matrix A is not quantized while Matrix B is");
- }
- else
+ ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED);
+ //DataType::QSYMM8_PER_CHANNEL supported only for weights
+ if(b->data_type() != DataType::QSYMM8_PER_CHANNEL)
{
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);
}
@@ -388,13 +387,15 @@ Status CLGEMMLowpMatrixMultiplyCore::validate(const ITensorInfo *a, const ITenso
const TensorInfo gemm_output_stage_multipliers_shifts_info(TensorInfo(TensorShape(num_filters), 1, DataType::S32));
+ GEMMLowpOutputStageInfo gemmlowp_output_stage = gemm_info.gemmlowp_output_stage();
+ gemmlowp_output_stage.output_data_type = a->data_type();
ARM_COMPUTE_RETURN_ON_ERROR(CLGEMMLowpOffsetContributionOutputStageKernel::validate(&mm_result_s32_info,
a_offset == 0 ? nullptr : &info_vector_sum_col,
b_offset == 0 ? nullptr : &info_vector_sum_row,
c,
output,
a_offset, b_offset,
- gemm_info.gemmlowp_output_stage(),
+ gemmlowp_output_stage,
&gemm_output_stage_multipliers_shifts_info,
&gemm_output_stage_multipliers_shifts_info));
}