aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/operators/CpuGemmLowpOutputStage.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/operators/CpuGemmLowpOutputStage.cpp')
-rw-r--r--src/cpu/operators/CpuGemmLowpOutputStage.cpp52
1 files changed, 34 insertions, 18 deletions
diff --git a/src/cpu/operators/CpuGemmLowpOutputStage.cpp b/src/cpu/operators/CpuGemmLowpOutputStage.cpp
index 58f98acff0..4215eed199 100644
--- a/src/cpu/operators/CpuGemmLowpOutputStage.cpp
+++ b/src/cpu/operators/CpuGemmLowpOutputStage.cpp
@@ -26,6 +26,7 @@
#include "arm_compute/core/ITensor.h"
#include "arm_compute/core/Validate.h"
#include "arm_compute/runtime/NEON/NEScheduler.h"
+
#include "src/common/utils/Log.h"
#include "src/cpu/kernels/CpuGemmLowpQuantizeDownInt32ScaleKernel.h"
#include "src/cpu/kernels/CpuGemmLowpQuantizeDownInt32ToInt16ScaleByFixedPointKernel.h"
@@ -36,36 +37,42 @@ namespace arm_compute
{
namespace cpu
{
-void CpuGemmLowpOutputStage::configure(ITensorInfo *src, ITensorInfo *bias, ITensorInfo *dst, const GEMMLowpOutputStageInfo &info)
+void CpuGemmLowpOutputStage::configure(ITensorInfo *src,
+ ITensorInfo *bias,
+ ITensorInfo *dst,
+ const GEMMLowpOutputStageInfo &info)
{
// Perform validate step
ARM_COMPUTE_ERROR_THROW_ON(CpuGemmLowpOutputStage::validate(src, bias, dst, info));
ARM_COMPUTE_LOG_PARAMS(src, bias, dst, info);
- switch(info.type)
+ switch (info.type)
{
case GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT:
{
- switch(info.output_data_type)
+ switch (info.output_data_type)
{
case DataType::QASYMM8:
{
auto k = std::make_unique<kernels::CpuGemmLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel>();
- k->configure(src, bias, dst, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
+ k->configure(src, bias, dst, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset,
+ info.gemmlowp_min_bound, info.gemmlowp_max_bound);
_kernel = std::move(k);
break;
}
case DataType::QASYMM8_SIGNED:
{
auto k = std::make_unique<kernels::CpuGemmLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel>();
- k->configure(src, bias, dst, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
+ k->configure(src, bias, dst, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset,
+ info.gemmlowp_min_bound, info.gemmlowp_max_bound);
_kernel = std::move(k);
break;
}
case DataType::QSYMM16:
{
auto k = std::make_unique<kernels::CpuGemmLowpQuantizeDownInt32ToInt16ScaleByFixedPointKernel>();
- k->configure(src, bias, dst, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
+ k->configure(src, bias, dst, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_min_bound,
+ info.gemmlowp_max_bound);
_kernel = std::move(k);
break;
}
@@ -79,7 +86,7 @@ void CpuGemmLowpOutputStage::configure(ITensorInfo *src, ITensorInfo *bias, ITen
}
case GEMMLowpOutputStageType::QUANTIZE_DOWN:
{
- switch(info.output_data_type)
+ switch (info.output_data_type)
{
case DataType::QASYMM8:
case DataType::QASYMM8_SIGNED:
@@ -102,32 +109,41 @@ void CpuGemmLowpOutputStage::configure(ITensorInfo *src, ITensorInfo *bias, ITen
}
}
-Status CpuGemmLowpOutputStage::validate(const ITensorInfo *src, const ITensorInfo *bias, const ITensorInfo *dst, const GEMMLowpOutputStageInfo &info)
+Status CpuGemmLowpOutputStage::validate(const ITensorInfo *src,
+ const ITensorInfo *bias,
+ const ITensorInfo *dst,
+ const GEMMLowpOutputStageInfo &info)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src, dst);
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(dst->data_type() == DataType::UNKNOWN, "CpuGemmLowpOutputStage cannot be used with UNKNOWN output data type.");
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(dst, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM16);
- ARM_COMPUTE_RETURN_ERROR_ON((info.type != GEMMLowpOutputStageType::QUANTIZE_DOWN) && (info.type != GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT));
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(dst->data_type() == DataType::UNKNOWN,
+ "CpuGemmLowpOutputStage cannot be used with UNKNOWN output data type.");
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(dst, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED,
+ DataType::QSYMM16);
+ ARM_COMPUTE_RETURN_ERROR_ON((info.type != GEMMLowpOutputStageType::QUANTIZE_DOWN) &&
+ (info.type != GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT));
- switch(info.type)
+ switch (info.type)
{
case GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT:
{
- switch(dst->data_type())
+ switch (dst->data_type())
{
case DataType::QASYMM8:
- return kernels::CpuGemmLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel::validate(src, bias, dst, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
+ return kernels::CpuGemmLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel::validate(
+ src, bias, dst, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
case DataType::QASYMM8_SIGNED:
- return kernels::CpuGemmLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel::validate(src, bias, dst, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
+ return kernels::CpuGemmLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel::validate(
+ src, bias, dst, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
case DataType::QSYMM16:
- return kernels::CpuGemmLowpQuantizeDownInt32ToInt16ScaleByFixedPointKernel::validate(src, bias, dst, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
+ return kernels::CpuGemmLowpQuantizeDownInt32ToInt16ScaleByFixedPointKernel::validate(
+ src, bias, dst, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
default:
return ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Unsupported output data type.");
}
}
case GEMMLowpOutputStageType::QUANTIZE_DOWN:
{
- switch(dst->data_type())
+ switch (dst->data_type())
{
case DataType::QASYMM8:
case DataType::QASYMM8_SIGNED:
@@ -146,4 +162,4 @@ void CpuGemmLowpOutputStage::run(ITensorPack &tensors)
NEScheduler::get().schedule_op(_kernel.get(), Window::DimY, _kernel->window(), tensors);
}
} // namespace cpu
-} // namespace arm_compute \ No newline at end of file
+} // namespace arm_compute