aboutsummaryrefslogtreecommitdiff
path: root/src/core/CL/kernels/CLGEMMLowpOffsetContributionOutputStageKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/CL/kernels/CLGEMMLowpOffsetContributionOutputStageKernel.cpp')
-rw-r--r--src/core/CL/kernels/CLGEMMLowpOffsetContributionOutputStageKernel.cpp79
1 files changed, 56 insertions, 23 deletions
diff --git a/src/core/CL/kernels/CLGEMMLowpOffsetContributionOutputStageKernel.cpp b/src/core/CL/kernels/CLGEMMLowpOffsetContributionOutputStageKernel.cpp
index 1852262337..2ebd76e1bf 100644
--- a/src/core/CL/kernels/CLGEMMLowpOffsetContributionOutputStageKernel.cpp
+++ b/src/core/CL/kernels/CLGEMMLowpOffsetContributionOutputStageKernel.cpp
@@ -37,17 +37,12 @@
#include <cstddef>
#include <cstdint>
-using namespace arm_compute;
-
namespace arm_compute
{
-class Coordinates;
-} // namespace arm_compute
-
namespace
{
Status validate_arguments(const ITensorInfo *mm_result, const ITensorInfo *vector_sum_col, const ITensorInfo *vector_sum_row, const ITensorInfo *bias, const ITensorInfo *output,
- int32_t a_offset, int32_t b_offset, const GEMMLowpOutputStageInfo &output_stage)
+ int32_t a_offset, int32_t b_offset, const GEMMLowpOutputStageInfo &output_stage, const ITensorInfo *output_multipliers, const ITensorInfo *output_shifts)
{
ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(mm_result, 1, DataType::S32);
ARM_COMPUTE_RETURN_ERROR_ON(output_stage.type == GEMMLowpOutputStageType::NONE);
@@ -61,6 +56,16 @@ Status validate_arguments(const ITensorInfo *mm_result, const ITensorInfo *vecto
ARM_COMPUTE_RETURN_ERROR_ON(mm_result->dimension(0) != bias->dimension(0));
}
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output_multipliers, 1, DataType::S32);
+ ARM_COMPUTE_RETURN_ERROR_ON(output_multipliers->num_dimensions() > 1);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output_shifts, 1, DataType::S32);
+ ARM_COMPUTE_RETURN_ERROR_ON(output_shifts->num_dimensions() > 1);
+ if(output_stage.is_quantized_per_channel)
+ {
+ ARM_COMPUTE_RETURN_ERROR_ON(mm_result->dimension(0) != output_shifts->dimension(0));
+ ARM_COMPUTE_RETURN_ERROR_ON(mm_result->dimension(0) != output_multipliers->dimension(0));
+ }
+
// If a_offset == 0, vector_sum_col can be a nullptr
if(a_offset != 0)
{
@@ -109,11 +114,14 @@ Status validate_arguments(const ITensorInfo *mm_result, const ITensorInfo *vecto
ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(mm_result, output);
}
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(output_stage.gemmlowp_multipliers.size() != output_stage.gemmlowp_shifts.size(),
+ "per channel quantization info is incorrect");
+
return Status{};
}
std::pair<Status, Window> validate_and_configure_window(ITensorInfo *mm_result, ITensorInfo *vector_sum_col, ITensorInfo *vector_sum_row, ITensorInfo *bias, ITensorInfo *output,
- int32_t a_offset, int32_t b_offset)
+ int32_t a_offset, int32_t b_offset, ITensorInfo *output_multipliers, ITensorInfo *output_shifts)
{
constexpr unsigned int num_elems_processed_per_iteration = 4;
bool window_changed = false;
@@ -147,36 +155,55 @@ std::pair<Status, Window> validate_and_configure_window(ITensorInfo *mm_result,
window_changed = window_changed || update_window_and_padding(win, bias_access);
}
+ if(output_multipliers->dimension(0) > 1)
+ {
+ AccessWindowHorizontal output_multipliers_access(output_multipliers, 0, num_elems_processed_per_iteration);
+ AccessWindowHorizontal output_shifts_access(output_shifts, 0, num_elems_processed_per_iteration);
+ window_changed = window_changed || update_window_and_padding(win, output_multipliers_access, output_shifts_access);
+ }
+
Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{};
return std::make_pair(err, win);
}
} // namespace
CLGEMMLowpOffsetContributionOutputStageKernel::CLGEMMLowpOffsetContributionOutputStageKernel()
- : _mm_result(nullptr), _vector_sum_col(nullptr), _vector_sum_row(nullptr), _bias(nullptr), _output(nullptr)
+ : _mm_result(nullptr),
+ _vector_sum_col(nullptr),
+ _vector_sum_row(nullptr),
+ _bias(nullptr),
+ _output(nullptr),
+ _output_multipliers(nullptr),
+ _output_shifts(nullptr),
+ _is_quantized_per_channel(false)
{
}
void CLGEMMLowpOffsetContributionOutputStageKernel::configure(const ICLTensor *mm_result, const ICLTensor *vector_sum_col, const ICLTensor *vector_sum_row, const ICLTensor *bias, ICLTensor *output,
- int32_t k, int32_t a_offset, int32_t b_offset, const GEMMLowpOutputStageInfo &output_stage)
+ int32_t k, int32_t a_offset, int32_t b_offset, const GEMMLowpOutputStageInfo &output_stage,
+ const ICLTensor *output_multipliers, const ICLTensor *output_shifts)
{
// Perform validate step
- ARM_COMPUTE_ERROR_ON_NULLPTR(mm_result, output);
+ ARM_COMPUTE_ERROR_ON_NULLPTR(mm_result, output, output_multipliers, output_shifts);
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(mm_result->info(),
vector_sum_col != nullptr ? vector_sum_col->info() : nullptr,
vector_sum_row != nullptr ? vector_sum_row->info() : nullptr,
bias != nullptr ? bias->info() : nullptr,
output->info(),
- a_offset, b_offset, output_stage)); // NOLINT
+ a_offset, b_offset, output_stage,
+ output_multipliers->info(), output_shifts->info())); // NOLINT
const int min = output_stage.gemmlowp_min_bound;
const int max = output_stage.gemmlowp_max_bound;
- _vector_sum_col = vector_sum_col;
- _vector_sum_row = vector_sum_row;
- _mm_result = mm_result;
- _bias = bias;
- _output = output;
+ _vector_sum_col = vector_sum_col;
+ _vector_sum_row = vector_sum_row;
+ _mm_result = mm_result;
+ _bias = bias;
+ _output = output;
+ _output_multipliers = output_multipliers;
+ _output_shifts = output_shifts;
+ _is_quantized_per_channel = output_stage.is_quantized_per_channel;
// Check if input is a 3D reinterpretation
const bool reinterpret_as_3d = vector_sum_row != nullptr
@@ -199,8 +226,9 @@ void CLGEMMLowpOffsetContributionOutputStageKernel::configure(const ICLTensor *m
build_opts.add_option_if(reinterpret_as_3d, "-DDEPTH_INPUT3D=" + support::cpp11::to_string(mm_result->info()->dimension(2)));
build_opts.add_option_if(bias != nullptr, "-DADD_BIAS");
build_opts.add_option("-DRESULT_OFFSET=" + support::cpp11::to_string(output_stage.gemmlowp_offset));
- build_opts.add_option("-DRESULT_MULTIPLIER=" + support::cpp11::to_string(output_stage.gemmlowp_multiplier));
- build_opts.add_option("-DRESULT_SHIFT=" + support::cpp11::to_string(output_stage.gemmlowp_shift));
+ build_opts.add_option("-DRESULT_MULTIPLIER=" + support::cpp11::to_string(output_stage.gemmlowp_multipliers[0]));
+ build_opts.add_option("-DRESULT_SHIFT=" + support::cpp11::to_string(output_stage.gemmlowp_shifts[0]));
+ build_opts.add_option_if(_is_quantized_per_channel, "-DPER_CHANNEL_QUANTIZATION");
build_opts.add_option_if((min != 0) && (min != max), "-DMIN_BOUND=" + support::cpp11::to_string(min));
build_opts.add_option_if((max != 255) && (min != max), "-DMAX_BOUND=" + support::cpp11::to_string(max));
@@ -225,7 +253,8 @@ void CLGEMMLowpOffsetContributionOutputStageKernel::configure(const ICLTensor *m
vector_sum_row != nullptr ? vector_sum_row->info() : nullptr,
bias != nullptr ? bias->info() : nullptr,
output->info(),
- a_offset, b_offset); // NOLINT
+ a_offset, b_offset,
+ output_multipliers->info(), output_shifts->info()); // NOLINT
ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
ICLKernel::configure_internal(win_config.second);
@@ -239,16 +268,17 @@ void CLGEMMLowpOffsetContributionOutputStageKernel::configure(const ICLTensor *m
}
Status CLGEMMLowpOffsetContributionOutputStageKernel::validate(const ITensorInfo *mm_result, const ITensorInfo *vector_sum_col, const ITensorInfo *vector_sum_row, const ITensorInfo *bias,
- const ITensorInfo *output,
- int32_t a_offset, int32_t b_offset, const GEMMLowpOutputStageInfo &output_stage)
+ const ITensorInfo *output, int32_t a_offset, int32_t b_offset, const GEMMLowpOutputStageInfo &output_stage,
+ const ITensorInfo *output_multipliers, const ITensorInfo *output_shifts)
{
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(mm_result, vector_sum_col, vector_sum_row, bias, output, a_offset, b_offset, output_stage));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(mm_result, vector_sum_col, vector_sum_row, bias, output, a_offset, b_offset, output_stage, output_multipliers, output_shifts));
ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(mm_result->clone().get(),
vector_sum_col != nullptr ? vector_sum_col->clone().get() : nullptr,
vector_sum_row != nullptr ? vector_sum_row->clone().get() : nullptr,
bias != nullptr ? bias->clone().get() : nullptr,
output->clone().get(),
- a_offset, b_offset)
+ a_offset, b_offset,
+ output_multipliers->clone().get(), output_shifts->clone().get())
.first); // NOLINT
return Status{};
@@ -285,7 +315,10 @@ void CLGEMMLowpOffsetContributionOutputStageKernel::run(const Window &window, cl
add_2D_tensor_argument_if((_vector_sum_row != nullptr), idx, _vector_sum_row, win_vector_sum_row);
add_1D_tensor_argument_if((_bias != nullptr), idx, _bias, biases_slice);
add_3D_tensor_argument(idx, _output, slice);
+ add_1D_tensor_argument_if(_is_quantized_per_channel, idx, _output_multipliers, biases_slice);
+ add_1D_tensor_argument_if(_is_quantized_per_channel, idx, _output_shifts, biases_slice);
enqueue(queue, *this, slice, lws_hint());
}
while(collapsed.slide_window_slice_3D(slice));
}
+} // namespace arm_compute