From 397d58aa40b02a26923c34d8cd4ba274eac45963 Mon Sep 17 00:00:00 2001 From: Ioan-Cristian Szabo Date: Thu, 30 Nov 2017 15:19:11 +0000 Subject: COMPMID-617: Add validate support for NEON ArithmeticLayer Change-Id: I8b58359487194f4cbf7452df4aea92523b5745bf Reviewed-on: https://eu-gerrit-1.euhpc.arm.com/111351 Tested-by: BSG Visual Compute Jenkins server to access repositories on http://mpd-gerrit.cambridge.arm.com Reviewed-by: Michalis Spyrou Reviewed-by: Anthony Barbier --- .../NEON/kernels/NEArithmeticAdditionKernel.cpp | 91 ++++++++++++++------- .../NEON/kernels/NEArithmeticSubtractionKernel.cpp | 93 ++++++++++++++-------- .../NEON/functions/NEArithmeticAddition.cpp | 4 + .../NEON/functions/NEArithmeticSubtraction.cpp | 4 + 4 files changed, 131 insertions(+), 61 deletions(-) (limited to 'src') diff --git a/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp b/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp index 8e55994aaa..6452393ca0 100644 --- a/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp +++ b/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp @@ -355,6 +355,57 @@ void add_saturate_U8_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out }, input1, input2, output); } + +inline Error validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy) +{ + ARM_COMPUTE_UNUSED(policy); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + + if(is_data_type_fixed_point(input1->data_type()) || is_data_type_fixed_point(input2->data_type()) || is_data_type_fixed_point(output->data_type())) + { + // Check that all data types are the same and all fixed-point positions are the same + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2, output); + } + + ARM_COMPUTE_RETURN_ERROR_ON_MSG( + !(input1->data_type() == DataType::QS8 && input2->data_type() == DataType::QS8 && output->data_type() == DataType::QS8) + && !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::U8 && output->data_type() == DataType::U8) + && !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16) + && !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::S16 && output->data_type() == DataType::S16) + && !(input1->data_type() == DataType::S16 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16) + && !(input1->data_type() == DataType::QS16 && input2->data_type() == DataType::QS16 && output->data_type() == DataType::QS16) + && !(input1->data_type() == DataType::S16 && input2->data_type() == DataType::S16 && output->data_type() == DataType::S16) + && !(input1->data_type() == DataType::F32 && input2->data_type() == DataType::F32 && output->data_type() == DataType::F32) + && !(input1->data_type() == DataType::F16 && input2->data_type() == DataType::F16 && output->data_type() == DataType::F16), + "You called addition with the wrong image formats"); + + return Error{}; +} + +inline std::pair validate_and_configure_window(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output) +{ + constexpr unsigned int num_elems_processed_per_iteration = 16; + + // Configure kernel window + Window win = calculate_max_window(*input1, Steps(num_elems_processed_per_iteration)); + AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration); + + bool window_changed = update_window_and_padding(win, + AccessWindowHorizontal(input1, 0, num_elems_processed_per_iteration), + AccessWindowHorizontal(input2, 0, num_elems_processed_per_iteration), + output_access); + + ValidRegion valid_region = intersect_valid_regions(input1->valid_region(), + input2->valid_region()); + + output_access.set_valid_region(win, valid_region); + + Error err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Error{}; + return std::make_pair(err, win); +} } // namespace NEArithmeticAdditionKernel::NEArithmeticAdditionKernel() @@ -384,17 +435,7 @@ void NEArithmeticAdditionKernel::configure(const ITensor *input1, const ITensor } } - ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_ERROR_ON_MSG(output->info()->data_type() == DataType::U8 && (input1->info()->data_type() != DataType::U8 || input2->info()->data_type() != DataType::U8), - "Output can only be U8 if both inputs are U8"); - if(is_data_type_fixed_point(input1->info()->data_type()) || is_data_type_fixed_point(input2->info()->data_type()) || is_data_type_fixed_point(output->info()->data_type())) - { - // Check that all data types are the same and all fixed-point positions are the same - ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2, output); - } + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input1->info(), input2->info(), output->info(), policy)); static std::map map_function = { @@ -416,7 +457,6 @@ void NEArithmeticAdditionKernel::configure(const ITensor *input1, const ITensor { "add_saturate_F32_F32_F32", &add_F32_F32_F32 }, { "add_wrap_F16_F16_F16", &add_F16_F16_F16 }, { "add_saturate_F16_F16_F16", &add_F16_F16_F16 }, - }; _input1 = input1; @@ -435,28 +475,19 @@ void NEArithmeticAdditionKernel::configure(const ITensor *input1, const ITensor { _func = it->second; } - else - { - ARM_COMPUTE_ERROR("You called arithmetic addition with the wrong tensor data type"); - } - - constexpr unsigned int num_elems_processed_per_iteration = 16; // Configure kernel window - Window win = calculate_max_window(*input1->info(), Steps(num_elems_processed_per_iteration)); - AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration); - - update_window_and_padding(win, - AccessWindowHorizontal(input1->info(), 0, num_elems_processed_per_iteration), - AccessWindowHorizontal(input2->info(), 0, num_elems_processed_per_iteration), - output_access); - - ValidRegion valid_region = intersect_valid_regions(input1->info()->valid_region(), - input2->info()->valid_region()); + auto win_config = validate_and_configure_window(input1->info(), input2->info(), output->info()); + ARM_COMPUTE_ERROR_THROW_ON(win_config.first); + INEKernel::configure(win_config.second); +} - output_access.set_valid_region(win, valid_region); +Error NEArithmeticAdditionKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy) +{ + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input1, input2, output, policy)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input1->clone().get(), input2->clone().get(), output->clone().get()).first); - INEKernel::configure(win); + return Error{}; } void NEArithmeticAdditionKernel::run(const Window &window, const ThreadInfo &info) diff --git a/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp b/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp index 1d86a35cc4..619669ae35 100644 --- a/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp +++ b/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp @@ -348,6 +348,57 @@ void sub_saturate_U8_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out }, input1, input2, output); } + +inline Error validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy) +{ + ARM_COMPUTE_UNUSED(policy); + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); + + if(is_data_type_fixed_point(input1->data_type()) || is_data_type_fixed_point(input2->data_type()) || is_data_type_fixed_point(output->data_type())) + { + // Check that all data types are the same and all fixed-point positions are the same + ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2, output); + } + + ARM_COMPUTE_RETURN_ERROR_ON_MSG( + !(input1->data_type() == DataType::QS8 && input2->data_type() == DataType::QS8 && output->data_type() == DataType::QS8) + && !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::U8 && output->data_type() == DataType::U8) + && !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16) + && !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::S16 && output->data_type() == DataType::S16) + && !(input1->data_type() == DataType::S16 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16) + && !(input1->data_type() == DataType::QS16 && input2->data_type() == DataType::QS16 && output->data_type() == DataType::QS16) + && !(input1->data_type() == DataType::S16 && input2->data_type() == DataType::S16 && output->data_type() == DataType::S16) + && !(input1->data_type() == DataType::F32 && input2->data_type() == DataType::F32 && output->data_type() == DataType::F32) + && !(input1->data_type() == DataType::F16 && input2->data_type() == DataType::F16 && output->data_type() == DataType::F16), + "You called subtract with the wrong image formats"); + + return Error{}; +} + +inline std::pair validate_and_configure_window(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output) +{ + constexpr unsigned int num_elems_processed_per_iteration = 16; + + // Configure kernel window + Window win = calculate_max_window(*input1, Steps(num_elems_processed_per_iteration)); + AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration); + + bool window_changed = update_window_and_padding(win, + AccessWindowHorizontal(input1, 0, num_elems_processed_per_iteration), + AccessWindowHorizontal(input2, 0, num_elems_processed_per_iteration), + output_access); + + ValidRegion valid_region = intersect_valid_regions(input1->valid_region(), + input2->valid_region()); + + output_access.set_valid_region(win, valid_region); + + Error err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Error{}; + return std::make_pair(err, win); +} } // namespace NEArithmeticSubtractionKernel::NEArithmeticSubtractionKernel() @@ -377,19 +428,9 @@ void NEArithmeticSubtractionKernel::configure(const ITensor *input1, const ITens } } - ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32); - ARM_COMPUTE_ERROR_ON_MSG(output->info()->data_type() == DataType::U8 && (input1->info()->data_type() != DataType::U8 || input2->info()->data_type() != DataType::U8), - "Output can only be U8 if both inputs are U8"); - if(is_data_type_fixed_point(input1->info()->data_type()) || is_data_type_fixed_point(input2->info()->data_type()) || is_data_type_fixed_point(output->info()->data_type())) - { - // Check that all data types are the same and all fixed-point positions are the same - ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2, output); - } + ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input1->info(), input2->info(), output->info(), policy)); - static std::map map_function = + static std::map map_function = { { "sub_wrap_QS8_QS8_QS8", &sub_wrap_QS8_QS8_QS8 }, { "sub_saturate_QS8_QS8_QS8", &sub_saturate_QS8_QS8_QS8 }, @@ -409,7 +450,6 @@ void NEArithmeticSubtractionKernel::configure(const ITensor *input1, const ITens { "sub_saturate_F32_F32_F32", &sub_F32_F32_F32 }, { "sub_wrap_F16_F16_F16", &sub_F16_F16_F16 }, { "sub_saturate_F16_F16_F16", &sub_F16_F16_F16 }, - }; _input1 = input1; @@ -428,28 +468,19 @@ void NEArithmeticSubtractionKernel::configure(const ITensor *input1, const ITens { _func = it->second; } - else - { - ARM_COMPUTE_ERROR("You called subtract with the wrong image formats"); - } - - constexpr unsigned int num_elems_processed_per_iteration = 16; // Configure kernel window - Window win = calculate_max_window(*input1->info(), Steps(num_elems_processed_per_iteration)); - AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration); - - update_window_and_padding(win, - AccessWindowHorizontal(input1->info(), 0, num_elems_processed_per_iteration), - AccessWindowHorizontal(input2->info(), 0, num_elems_processed_per_iteration), - output_access); - - ValidRegion valid_region = intersect_valid_regions(input1->info()->valid_region(), - input2->info()->valid_region()); + auto win_config = validate_and_configure_window(input1->info(), input2->info(), output->info()); + ARM_COMPUTE_ERROR_THROW_ON(win_config.first); + INEKernel::configure(win_config.second); +} - output_access.set_valid_region(win, valid_region); +Error NEArithmeticSubtractionKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy) +{ + ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input1, input2, output, policy)); + ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input1->clone().get(), input2->clone().get(), output->clone().get()).first); - INEKernel::configure(win); + return Error{}; } void NEArithmeticSubtractionKernel::run(const Window &window, const ThreadInfo &info) diff --git a/src/runtime/NEON/functions/NEArithmeticAddition.cpp b/src/runtime/NEON/functions/NEArithmeticAddition.cpp index 11f5aa74e4..85119ea17d 100644 --- a/src/runtime/NEON/functions/NEArithmeticAddition.cpp +++ b/src/runtime/NEON/functions/NEArithmeticAddition.cpp @@ -36,3 +36,7 @@ void NEArithmeticAddition::configure(const ITensor *input1, const ITensor *input k->configure(input1, input2, output, policy); _kernel = std::move(k); } +Error NEArithmeticAddition::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy) +{ + return NEArithmeticAdditionKernel::validate(input1, input2, output, policy); +} diff --git a/src/runtime/NEON/functions/NEArithmeticSubtraction.cpp b/src/runtime/NEON/functions/NEArithmeticSubtraction.cpp index 37586af751..be264d54b4 100644 --- a/src/runtime/NEON/functions/NEArithmeticSubtraction.cpp +++ b/src/runtime/NEON/functions/NEArithmeticSubtraction.cpp @@ -36,3 +36,7 @@ void NEArithmeticSubtraction::configure(const ITensor *input1, const ITensor *in k->configure(input1, input2, output, policy); _kernel = std::move(k); } +Error NEArithmeticSubtraction::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy) +{ + return NEArithmeticSubtractionKernel::validate(input1, input2, output, policy); +} -- cgit v1.2.1