diff options
Diffstat (limited to 'src/core/CL/kernels/CLGEMMMatrixAdditionKernel.cpp')
-rw-r--r-- | src/core/CL/kernels/CLGEMMMatrixAdditionKernel.cpp | 68 |
1 files changed, 51 insertions, 17 deletions
diff --git a/src/core/CL/kernels/CLGEMMMatrixAdditionKernel.cpp b/src/core/CL/kernels/CLGEMMMatrixAdditionKernel.cpp index 1499df0bec..3fe956d759 100644 --- a/src/core/CL/kernels/CLGEMMMatrixAdditionKernel.cpp +++ b/src/core/CL/kernels/CLGEMMMatrixAdditionKernel.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2017 ARM Limited. + * Copyright (c) 2017-2018 ARM Limited. * * SPDX-License-Identifier: MIT * @@ -36,6 +36,42 @@ using namespace arm_compute; +namespace +{ +std::pair<Status, Window> validate_and_configure_window(ITensorInfo *input, ITensorInfo *output) +{ + const unsigned int num_elems_processed_per_iteration = max_cl_vector_width / data_size_from_type(input->data_type()); + // Configure kernel window + Window win = calculate_max_window(*input, Steps(num_elems_processed_per_iteration)); + + AccessWindowHorizontal input_access(input, 0, num_elems_processed_per_iteration); + AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration); + + bool window_changed = update_window_and_padding(win, input_access, output_access); + + output_access.set_valid_region(win, input->valid_region()); + + Status err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Status{}; + return std::make_pair(err, win); +} +} // namespace + +namespace +{ +Status validate_arguments(const ITensorInfo *input, const ITensorInfo *output, const float beta) +{ + ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input, output); + + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); + ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(0) != output->dimension(0)); + ARM_COMPUTE_RETURN_ERROR_ON(input->dimension(1) != output->dimension(1)); + + ARM_COMPUTE_UNUSED(beta); + return Status{}; +} +} // namespace + CLGEMMMatrixAdditionKernel::CLGEMMMatrixAdditionKernel() : _input(nullptr), _output(nullptr) { @@ -43,14 +79,13 @@ CLGEMMMatrixAdditionKernel::CLGEMMMatrixAdditionKernel() void CLGEMMMatrixAdditionKernel::configure(const ICLTensor *input, ICLTensor *output, float beta) { - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input, 1, DataType::QS8, DataType::QS16, DataType::F16, DataType::F32); - ARM_COMPUTE_ERROR_ON_MISMATCHING_DATA_TYPES(input, output); - ARM_COMPUTE_ERROR_ON(input->info()->dimension(0) != output->info()->dimension(0)); - ARM_COMPUTE_ERROR_ON(input->info()->dimension(1) != output->info()->dimension(1)); + ARM_COMPUTE_ERROR_ON_NULLPTR(input, output); + + // Perform validation step + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input->info(), output->info(), beta)); - _input = input; - _output = output; - const unsigned int num_elems_processed_per_iteration = max_cl_vector_width / data_size_from_type(input->info()->data_type()); + _input = input; + _output = output; std::ostringstream ma_arguments; if(is_data_type_fixed_point(input->info()->data_type())) @@ -74,16 +109,15 @@ void CLGEMMMatrixAdditionKernel::configure(const ICLTensor *input, ICLTensor *ou _kernel = static_cast<cl::Kernel>(CLKernelLibrary::get().create_kernel(("gemm_ma_" + data_type_name), build_opts)); // Configure kernel window - Window win = calculate_max_window(*_input->info(), Steps(num_elems_processed_per_iteration)); - - AccessWindowHorizontal input_access(input->info(), 0, num_elems_processed_per_iteration); - AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration); - - update_window_and_padding(win, input_access, output_access); - - output_access.set_valid_region(win, input->info()->valid_region()); + auto win_config = validate_and_configure_window(input->info(), output->info()); + ARM_COMPUTE_ERROR_THROW_ON(win_config.first); + ICLKernel::configure(win_config.second); +} - ICLKernel::configure(win); +Status CLGEMMMatrixAdditionKernel::validate(const ITensorInfo *input, const ITensorInfo *output, const float beta) +{ + ARM_COMPUTE_RETURN_ERROR_ON(validate_arguments(input, output, beta)); + return Status{}; } void CLGEMMMatrixAdditionKernel::run(const Window &window, cl::CommandQueue &queue) |