aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLGEMMLowpReductionKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLGEMMLowpReductionKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLGEMMLowpReductionKernel.cpp10
1 files changed, 8 insertions, 2 deletions
diff --git a/src/core/CL/kernels/CLGEMMLowpReductionKernel.cpp b/src/core/CL/kernels/CLGEMMLowpReductionKernel.cpp
index e81ab2ffba..9fa253a55a 100644
--- a/src/core/CL/kernels/CLGEMMLowpReductionKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMLowpReductionKernel.cpp
@@ -36,7 +36,7 @@ namespace
Status validate_arguments_matrix_a_reduction(const ITensorInfo *input, const ITensorInfo *output)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8);
if(output->total_size() > 0)
{
@@ -49,7 +49,7 @@ Status validate_arguments_matrix_a_reduction(const ITensorInfo *input, const ITe
Status validate_arguments_matrix_b_reduction(const ITensorInfo *input, const ITensorInfo *output)
{
ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output);
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8);
if(output->total_size() > 0)
{
@@ -63,6 +63,9 @@ std::pair<Status, Window> validate_and_configure_window_matrix_b_reduction(ITens
{
constexpr unsigned int num_elems_processed_per_iteration = 16;
+ // Output auto initialization if not yet initialized
+ auto_init_if_empty(*output, TensorShape(input->dimension(0)), 1, DataType::S32);
+
// Configure kernel window
Window win = calculate_max_window(*output, Steps(num_elems_processed_per_iteration));
@@ -94,6 +97,9 @@ void CLGEMMLowpMatrixAReductionKernel::configure(CLCompileContext &compile_conte
ARM_COMPUTE_ERROR_ON_NULLPTR(mtx_a, vector_sum_row);
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments_matrix_a_reduction(mtx_a->info(), vector_sum_row->info()));
+ // Output auto initialization if not yet initialized
+ auto_init_if_empty(*vector_sum_row->info(), TensorShape(mtx_a->info()->dimension(1)), 1, DataType::S32);
+
_input = mtx_a;
_output = vector_sum_row;