aboutsummaryrefslogtreecommitdiff
path: root/src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp')
-rw-r--r--src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp147
1 files changed, 87 insertions, 60 deletions
diff --git a/src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp b/src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp
index 42d2ffce58..43ca7b3fbb 100644
--- a/src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp
+++ b/src/runtime/NEON/functions/NEGEMMLowpOutputStage.cpp
@@ -24,10 +24,10 @@
#include "arm_compute/runtime/NEON/functions/NEGEMMLowpOutputStage.h"
#include "arm_compute/core/ITensor.h"
+#include "arm_compute/core/NEON/kernels/NEGEMMLowpQuantizeDownInt32ScaleKernel.h"
#include "arm_compute/core/NEON/kernels/NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointKernel.h"
#include "arm_compute/core/NEON/kernels/NEGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel.h"
#include "arm_compute/core/NEON/kernels/NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel.h"
-#include "arm_compute/core/NEON/kernels/NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel.h"
#include "arm_compute/core/Validate.h"
#include "support/MemorySupport.h"
@@ -35,14 +35,25 @@ namespace arm_compute
{
void NEGEMMLowpQuantizeDownInt32ToUint8Scale::configure(const ITensor *input, const ITensor *bias, ITensor *output, int result_offset, int result_mult_int, int result_shift, int min, int max)
{
- auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel>();
- k->configure(input, bias, output, result_offset, result_mult_int, result_shift, min, max);
+ GEMMLowpOutputStageInfo info = GEMMLowpOutputStageInfo();
+ info.gemmlowp_offset = result_offset;
+ info.gemmlowp_multiplier = result_mult_int;
+ info.gemmlowp_shift = result_shift;
+ info.gemmlowp_min_bound = min;
+ info.gemmlowp_max_bound = max;
+
+ auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ScaleKernel>();
+ k->configure(input, bias, output, &info);
_kernel = std::move(k);
}
Status NEGEMMLowpQuantizeDownInt32ToUint8Scale::validate(const ITensorInfo *input, const ITensorInfo *bias, const ITensorInfo *output, int min, int max)
{
- return NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel::validate(input, bias, output, min, max);
+ GEMMLowpOutputStageInfo info = GEMMLowpOutputStageInfo();
+ info.gemmlowp_min_bound = min;
+ info.gemmlowp_max_bound = max;
+
+ return NEGEMMLowpQuantizeDownInt32ScaleKernel::validate(input, bias, output, &info);
}
void NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPoint::configure(const ITensor *input, const ITensor *bias, ITensor *output, int result_fixedpoint_multiplier, int result_shift,
@@ -89,53 +100,63 @@ void NEGEMMLowpOutputStage::configure(const ITensor *input, const ITensor *bias,
ARM_COMPUTE_ERROR_ON_NULLPTR(input, output);
ARM_COMPUTE_ERROR_THROW_ON(NEGEMMLowpOutputStage::validate(input->info(), bias != nullptr ? bias->info() : nullptr, output->info(), info));
- if(info.type == GEMMLowpOutputStageType::QUANTIZE_DOWN)
+ switch(info.type)
{
- switch(output->info()->data_type())
+ case GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT:
{
- case DataType::QASYMM8:
+ switch(info.output_data_type)
{
- auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel>();
- k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
- _kernel = std::move(k);
- break;
+ case DataType::QASYMM8:
+ {
+ auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel>();
+ k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
+ _kernel = std::move(k);
+ break;
+ }
+ case DataType::QASYMM8_SIGNED:
+ {
+ auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel>();
+ k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
+ _kernel = std::move(k);
+ break;
+ }
+ case DataType::QSYMM16:
+ {
+ auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointKernel>();
+ k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
+ _kernel = std::move(k);
+ break;
+ }
+ default:
+ {
+ ARM_COMPUTE_ERROR("Unsupported output data type.");
+ break;
+ }
}
- default:
- ARM_COMPUTE_ERROR("Unsupported output data type.");
+ break;
}
- }
- else if(info.type == GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT)
- {
- switch(output->info()->data_type())
+ case GEMMLowpOutputStageType::QUANTIZE_DOWN:
{
- case DataType::QASYMM8:
- {
- auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel>();
- k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
- _kernel = std::move(k);
- break;
- }
- case DataType::QASYMM8_SIGNED:
- {
- auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel>();
- k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_offset, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
- _kernel = std::move(k);
- break;
- }
- case DataType::QSYMM16:
+ switch(info.output_data_type)
{
- auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointKernel>();
- k->configure(input, bias, output, info.gemmlowp_multiplier, info.gemmlowp_shift, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
- _kernel = std::move(k);
- break;
+ case DataType::QASYMM8:
+ case DataType::QASYMM8_SIGNED:
+ {
+ auto k = arm_compute::support::cpp14::make_unique<NEGEMMLowpQuantizeDownInt32ScaleKernel>();
+ k->configure(input, bias, output, &info);
+ _kernel = std::move(k);
+ break;
+ }
+ default:
+ {
+ ARM_COMPUTE_ERROR("Unsupported output data type.");
+ break;
+ }
}
- default:
- ARM_COMPUTE_ERROR("Unsupported output data type.");
+ break;
}
- }
- else
- {
- ARM_COMPUTE_ERROR("Unsupported output stage quantization type.");
+ default:
+ ARM_COMPUTE_ERROR("Unsupported GEMMLowpOutputStage type.");
}
}
@@ -147,29 +168,35 @@ Status NEGEMMLowpOutputStage::validate(const ITensorInfo *input, const ITensorIn
ARM_COMPUTE_RETURN_ERROR_ON((info.type != GEMMLowpOutputStageType::QUANTIZE_DOWN) && (info.type != GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT));
- if(info.type == GEMMLowpOutputStageType::QUANTIZE_DOWN)
+ switch(info.type)
{
- switch(output->data_type())
+ case GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT:
{
- case DataType::QASYMM8:
- return NEGEMMLowpQuantizeDownInt32ToUint8ScaleKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
- default:
- return ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Unsupported output data type.");
+ switch(output->data_type())
+ {
+ case DataType::QASYMM8:
+ return NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
+ case DataType::QASYMM8_SIGNED:
+ return NEGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
+ case DataType::QSYMM16:
+ return NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
+ default:
+ return ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Unsupported output data type.");
+ }
}
- }
- else
- {
- switch(output->data_type())
+ case GEMMLowpOutputStageType::QUANTIZE_DOWN:
{
- case DataType::QASYMM8:
- return NEGEMMLowpQuantizeDownInt32ToUint8ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
- case DataType::QASYMM8_SIGNED:
- return NEGEMMLowpQuantizeDownInt32ToInt8ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
- case DataType::QSYMM16:
- return NEGEMMLowpQuantizeDownInt32ToInt16ScaleByFixedPointKernel::validate(input, bias, output, info.gemmlowp_min_bound, info.gemmlowp_max_bound);
- default:
- return ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Unsupported output data type.");
+ switch(output->data_type())
+ {
+ case DataType::QASYMM8:
+ case DataType::QASYMM8_SIGNED:
+ return NEGEMMLowpQuantizeDownInt32ScaleKernel::validate(input, bias, output, &info);
+ default:
+ return ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Unsupported output data type.");
+ }
}
+ default:
+ return ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Unsupported GEMMLowpOutputStage type.");
}
}
} // namespace arm_compute