aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--arm_compute/core/NEON/kernels/NEArithmeticAdditionKernel.h10
-rw-r--r--arm_compute/core/NEON/kernels/NEArithmeticSubtractionKernel.h10
-rw-r--r--arm_compute/runtime/NEON/functions/NEArithmeticAddition.h10
-rw-r--r--arm_compute/runtime/NEON/functions/NEArithmeticSubtraction.h10
-rw-r--r--src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp91
-rw-r--r--src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp93
-rw-r--r--src/runtime/NEON/functions/NEArithmeticAddition.cpp4
-rw-r--r--src/runtime/NEON/functions/NEArithmeticSubtraction.cpp4
-rw-r--r--tests/validation/NEON/ArithmeticAddition.cpp35
-rw-r--r--tests/validation/NEON/ArithmeticSubtraction.cpp35
10 files changed, 241 insertions, 61 deletions
diff --git a/arm_compute/core/NEON/kernels/NEArithmeticAdditionKernel.h b/arm_compute/core/NEON/kernels/NEArithmeticAdditionKernel.h
index edb7381635..044cf6846b 100644
--- a/arm_compute/core/NEON/kernels/NEArithmeticAdditionKernel.h
+++ b/arm_compute/core/NEON/kernels/NEArithmeticAdditionKernel.h
@@ -68,6 +68,16 @@ public:
* @param[in] policy Overflow policy.
*/
void configure(const ITensor *input1, const ITensor *input2, ITensor *output, ConvertPolicy policy);
+ /** Static function to check if given info will lead to a valid configuration of @ref NEArithmeticAdditionKernel
+ *
+ * @param[in] input1 An input tensor. Data types supported: U8/QS8/QS16/S16/F16/F32
+ * @param[in] input2 An input tensor. Data types supported: U8/QS8/QS16/S16/F16/F32
+ * @param[in] output The output tensor. Data types supported: U8/QS8/QS16/S16/F16/F32.
+ * @param[in] policy Overflow policy.
+ *
+ * @return an error status
+ */
+ static Error validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy);
// Inherited methods overridden:
void run(const Window &window, const ThreadInfo &info) override;
diff --git a/arm_compute/core/NEON/kernels/NEArithmeticSubtractionKernel.h b/arm_compute/core/NEON/kernels/NEArithmeticSubtractionKernel.h
index d6a219ffde..663f62864d 100644
--- a/arm_compute/core/NEON/kernels/NEArithmeticSubtractionKernel.h
+++ b/arm_compute/core/NEON/kernels/NEArithmeticSubtractionKernel.h
@@ -68,6 +68,16 @@ public:
* @param[in] policy Overflow policy.
*/
void configure(const ITensor *input1, const ITensor *input2, ITensor *output, ConvertPolicy policy);
+ /** Static function to check if given info will lead to a valid configuration of @ref NEArithmeticSubtractionKernel
+ *
+ * @param[in] input1 First tensor input. Data types supported: U8/QS8/QS16/S16/F16/F32
+ * @param[in] input2 Second tensor input. Data types supported: U8/QS8/QS16/S16/F16/F32
+ * @param[in] output Output tensor. Data types supported: U8/QS8/QS16/S16/F16/F32
+ * @param[in] policy Policy to use to handle overflow.
+ *
+ * @return an error status
+ */
+ static Error validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy);
// Inherited methods overridden:
void run(const Window &window, const ThreadInfo &info) override;
diff --git a/arm_compute/runtime/NEON/functions/NEArithmeticAddition.h b/arm_compute/runtime/NEON/functions/NEArithmeticAddition.h
index 3d1862389a..866cb5d2c7 100644
--- a/arm_compute/runtime/NEON/functions/NEArithmeticAddition.h
+++ b/arm_compute/runtime/NEON/functions/NEArithmeticAddition.h
@@ -43,6 +43,16 @@ public:
* @param[in] policy Policy to use to handle overflow.
*/
void configure(const ITensor *input1, const ITensor *input2, ITensor *output, ConvertPolicy policy);
+ /** Static function to check if given info will lead to a valid configuration of @ref NEArithmeticAddition
+ *
+ * @param[in] input1 First tensor input. Data types supported: U8/QS8/QS16/S16/F16/F32
+ * @param[in] input2 Second tensor input. Data types supported: U8/QS8/QS16/S16/F16/F32
+ * @param[in] output Output tensor. Data types supported: U8/QS8/QS16/S16/F16/F32
+ * @param[in] policy Policy to use to handle overflow.
+ *
+ * @return an error status
+ */
+ static Error validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy);
};
}
#endif /*__ARM_COMPUTE_NEARITHMETICADDITION_H__ */
diff --git a/arm_compute/runtime/NEON/functions/NEArithmeticSubtraction.h b/arm_compute/runtime/NEON/functions/NEArithmeticSubtraction.h
index b59cca98ab..2231e43bbf 100644
--- a/arm_compute/runtime/NEON/functions/NEArithmeticSubtraction.h
+++ b/arm_compute/runtime/NEON/functions/NEArithmeticSubtraction.h
@@ -43,6 +43,16 @@ public:
* @param[in] policy Policy to use to handle overflow.
*/
void configure(const ITensor *input1, const ITensor *input2, ITensor *output, ConvertPolicy policy);
+ /** Static function to check if given info will lead to a valid configuration of @ref NEArithmeticSubtraction
+ *
+ * @param[in] input1 First tensor input. Data types supported: U8/QS8/QS16/S16/F16/F32
+ * @param[in] input2 Second tensor input. Data types supported: U8/QS8/QS16/S16/F16/F32
+ * @param[in] output Output tensor. Data types supported: U8/QS8/QS16/S16/F16/F32
+ * @param[in] policy Policy to use to handle overflow.
+ *
+ * @return an error status
+ */
+ static Error validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy);
};
}
#endif /* __ARM_COMPUTE_NEARITHMETICSUBTRACTION_H__ */
diff --git a/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp b/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp
index 8e55994aaa..6452393ca0 100644
--- a/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp
+++ b/src/core/NEON/kernels/NEArithmeticAdditionKernel.cpp
@@ -355,6 +355,57 @@ void add_saturate_U8_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out
},
input1, input2, output);
}
+
+inline Error validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy)
+{
+ ARM_COMPUTE_UNUSED(policy);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+
+ if(is_data_type_fixed_point(input1->data_type()) || is_data_type_fixed_point(input2->data_type()) || is_data_type_fixed_point(output->data_type()))
+ {
+ // Check that all data types are the same and all fixed-point positions are the same
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2, output);
+ }
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ !(input1->data_type() == DataType::QS8 && input2->data_type() == DataType::QS8 && output->data_type() == DataType::QS8)
+ && !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::U8 && output->data_type() == DataType::U8)
+ && !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16)
+ && !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::S16 && output->data_type() == DataType::S16)
+ && !(input1->data_type() == DataType::S16 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16)
+ && !(input1->data_type() == DataType::QS16 && input2->data_type() == DataType::QS16 && output->data_type() == DataType::QS16)
+ && !(input1->data_type() == DataType::S16 && input2->data_type() == DataType::S16 && output->data_type() == DataType::S16)
+ && !(input1->data_type() == DataType::F32 && input2->data_type() == DataType::F32 && output->data_type() == DataType::F32)
+ && !(input1->data_type() == DataType::F16 && input2->data_type() == DataType::F16 && output->data_type() == DataType::F16),
+ "You called addition with the wrong image formats");
+
+ return Error{};
+}
+
+inline std::pair<Error, Window> validate_and_configure_window(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output)
+{
+ constexpr unsigned int num_elems_processed_per_iteration = 16;
+
+ // Configure kernel window
+ Window win = calculate_max_window(*input1, Steps(num_elems_processed_per_iteration));
+ AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
+
+ bool window_changed = update_window_and_padding(win,
+ AccessWindowHorizontal(input1, 0, num_elems_processed_per_iteration),
+ AccessWindowHorizontal(input2, 0, num_elems_processed_per_iteration),
+ output_access);
+
+ ValidRegion valid_region = intersect_valid_regions(input1->valid_region(),
+ input2->valid_region());
+
+ output_access.set_valid_region(win, valid_region);
+
+ Error err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Error{};
+ return std::make_pair(err, win);
+}
} // namespace
NEArithmeticAdditionKernel::NEArithmeticAdditionKernel()
@@ -384,17 +435,7 @@ void NEArithmeticAdditionKernel::configure(const ITensor *input1, const ITensor
}
}
- ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_MSG(output->info()->data_type() == DataType::U8 && (input1->info()->data_type() != DataType::U8 || input2->info()->data_type() != DataType::U8),
- "Output can only be U8 if both inputs are U8");
- if(is_data_type_fixed_point(input1->info()->data_type()) || is_data_type_fixed_point(input2->info()->data_type()) || is_data_type_fixed_point(output->info()->data_type()))
- {
- // Check that all data types are the same and all fixed-point positions are the same
- ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2, output);
- }
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input1->info(), input2->info(), output->info(), policy));
static std::map<std::string, AddFunction *> map_function =
{
@@ -416,7 +457,6 @@ void NEArithmeticAdditionKernel::configure(const ITensor *input1, const ITensor
{ "add_saturate_F32_F32_F32", &add_F32_F32_F32 },
{ "add_wrap_F16_F16_F16", &add_F16_F16_F16 },
{ "add_saturate_F16_F16_F16", &add_F16_F16_F16 },
-
};
_input1 = input1;
@@ -435,28 +475,19 @@ void NEArithmeticAdditionKernel::configure(const ITensor *input1, const ITensor
{
_func = it->second;
}
- else
- {
- ARM_COMPUTE_ERROR("You called arithmetic addition with the wrong tensor data type");
- }
-
- constexpr unsigned int num_elems_processed_per_iteration = 16;
// Configure kernel window
- Window win = calculate_max_window(*input1->info(), Steps(num_elems_processed_per_iteration));
- AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration);
-
- update_window_and_padding(win,
- AccessWindowHorizontal(input1->info(), 0, num_elems_processed_per_iteration),
- AccessWindowHorizontal(input2->info(), 0, num_elems_processed_per_iteration),
- output_access);
-
- ValidRegion valid_region = intersect_valid_regions(input1->info()->valid_region(),
- input2->info()->valid_region());
+ auto win_config = validate_and_configure_window(input1->info(), input2->info(), output->info());
+ ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
+ INEKernel::configure(win_config.second);
+}
- output_access.set_valid_region(win, valid_region);
+Error NEArithmeticAdditionKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy)
+{
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input1, input2, output, policy));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input1->clone().get(), input2->clone().get(), output->clone().get()).first);
- INEKernel::configure(win);
+ return Error{};
}
void NEArithmeticAdditionKernel::run(const Window &window, const ThreadInfo &info)
diff --git a/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp b/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp
index 1d86a35cc4..619669ae35 100644
--- a/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp
+++ b/src/core/NEON/kernels/NEArithmeticSubtractionKernel.cpp
@@ -348,6 +348,57 @@ void sub_saturate_U8_U8_S16(const ITensor *in1, const ITensor *in2, ITensor *out
},
input1, input2, output);
}
+
+inline Error validate_arguments(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy)
+{
+ ARM_COMPUTE_UNUSED(policy);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
+
+ if(is_data_type_fixed_point(input1->data_type()) || is_data_type_fixed_point(input2->data_type()) || is_data_type_fixed_point(output->data_type()))
+ {
+ // Check that all data types are the same and all fixed-point positions are the same
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2, output);
+ }
+
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(
+ !(input1->data_type() == DataType::QS8 && input2->data_type() == DataType::QS8 && output->data_type() == DataType::QS8)
+ && !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::U8 && output->data_type() == DataType::U8)
+ && !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16)
+ && !(input1->data_type() == DataType::U8 && input2->data_type() == DataType::S16 && output->data_type() == DataType::S16)
+ && !(input1->data_type() == DataType::S16 && input2->data_type() == DataType::U8 && output->data_type() == DataType::S16)
+ && !(input1->data_type() == DataType::QS16 && input2->data_type() == DataType::QS16 && output->data_type() == DataType::QS16)
+ && !(input1->data_type() == DataType::S16 && input2->data_type() == DataType::S16 && output->data_type() == DataType::S16)
+ && !(input1->data_type() == DataType::F32 && input2->data_type() == DataType::F32 && output->data_type() == DataType::F32)
+ && !(input1->data_type() == DataType::F16 && input2->data_type() == DataType::F16 && output->data_type() == DataType::F16),
+ "You called subtract with the wrong image formats");
+
+ return Error{};
+}
+
+inline std::pair<Error, Window> validate_and_configure_window(ITensorInfo *input1, ITensorInfo *input2, ITensorInfo *output)
+{
+ constexpr unsigned int num_elems_processed_per_iteration = 16;
+
+ // Configure kernel window
+ Window win = calculate_max_window(*input1, Steps(num_elems_processed_per_iteration));
+ AccessWindowHorizontal output_access(output, 0, num_elems_processed_per_iteration);
+
+ bool window_changed = update_window_and_padding(win,
+ AccessWindowHorizontal(input1, 0, num_elems_processed_per_iteration),
+ AccessWindowHorizontal(input2, 0, num_elems_processed_per_iteration),
+ output_access);
+
+ ValidRegion valid_region = intersect_valid_regions(input1->valid_region(),
+ input2->valid_region());
+
+ output_access.set_valid_region(win, valid_region);
+
+ Error err = (window_changed) ? ARM_COMPUTE_CREATE_ERROR(ErrorCode::RUNTIME_ERROR, "Insufficient Padding!") : Error{};
+ return std::make_pair(err, win);
+}
} // namespace
NEArithmeticSubtractionKernel::NEArithmeticSubtractionKernel()
@@ -377,19 +428,9 @@ void NEArithmeticSubtractionKernel::configure(const ITensor *input1, const ITens
}
}
- ARM_COMPUTE_ERROR_ON_MISMATCHING_SHAPES(input1, input2, output);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input1, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(input2, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(output, 1, DataType::QS8, DataType::U8, DataType::QS16, DataType::S16, DataType::F16, DataType::F32);
- ARM_COMPUTE_ERROR_ON_MSG(output->info()->data_type() == DataType::U8 && (input1->info()->data_type() != DataType::U8 || input2->info()->data_type() != DataType::U8),
- "Output can only be U8 if both inputs are U8");
- if(is_data_type_fixed_point(input1->info()->data_type()) || is_data_type_fixed_point(input2->info()->data_type()) || is_data_type_fixed_point(output->info()->data_type()))
- {
- // Check that all data types are the same and all fixed-point positions are the same
- ARM_COMPUTE_ERROR_ON_MISMATCHING_FIXED_POINT(input1, input2, output);
- }
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(input1->info(), input2->info(), output->info(), policy));
- static std::map<std::string, SubFunction *> map_function =
+ static std::map<std::string, NEArithmeticSubtractionKernel::SubFunction *> map_function =
{
{ "sub_wrap_QS8_QS8_QS8", &sub_wrap_QS8_QS8_QS8 },
{ "sub_saturate_QS8_QS8_QS8", &sub_saturate_QS8_QS8_QS8 },
@@ -409,7 +450,6 @@ void NEArithmeticSubtractionKernel::configure(const ITensor *input1, const ITens
{ "sub_saturate_F32_F32_F32", &sub_F32_F32_F32 },
{ "sub_wrap_F16_F16_F16", &sub_F16_F16_F16 },
{ "sub_saturate_F16_F16_F16", &sub_F16_F16_F16 },
-
};
_input1 = input1;
@@ -428,28 +468,19 @@ void NEArithmeticSubtractionKernel::configure(const ITensor *input1, const ITens
{
_func = it->second;
}
- else
- {
- ARM_COMPUTE_ERROR("You called subtract with the wrong image formats");
- }
-
- constexpr unsigned int num_elems_processed_per_iteration = 16;
// Configure kernel window
- Window win = calculate_max_window(*input1->info(), Steps(num_elems_processed_per_iteration));
- AccessWindowHorizontal output_access(output->info(), 0, num_elems_processed_per_iteration);
-
- update_window_and_padding(win,
- AccessWindowHorizontal(input1->info(), 0, num_elems_processed_per_iteration),
- AccessWindowHorizontal(input2->info(), 0, num_elems_processed_per_iteration),
- output_access);
-
- ValidRegion valid_region = intersect_valid_regions(input1->info()->valid_region(),
- input2->info()->valid_region());
+ auto win_config = validate_and_configure_window(input1->info(), input2->info(), output->info());
+ ARM_COMPUTE_ERROR_THROW_ON(win_config.first);
+ INEKernel::configure(win_config.second);
+}
- output_access.set_valid_region(win, valid_region);
+Error NEArithmeticSubtractionKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy)
+{
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(input1, input2, output, policy));
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_and_configure_window(input1->clone().get(), input2->clone().get(), output->clone().get()).first);
- INEKernel::configure(win);
+ return Error{};
}
void NEArithmeticSubtractionKernel::run(const Window &window, const ThreadInfo &info)
diff --git a/src/runtime/NEON/functions/NEArithmeticAddition.cpp b/src/runtime/NEON/functions/NEArithmeticAddition.cpp
index 11f5aa74e4..85119ea17d 100644
--- a/src/runtime/NEON/functions/NEArithmeticAddition.cpp
+++ b/src/runtime/NEON/functions/NEArithmeticAddition.cpp
@@ -36,3 +36,7 @@ void NEArithmeticAddition::configure(const ITensor *input1, const ITensor *input
k->configure(input1, input2, output, policy);
_kernel = std::move(k);
}
+Error NEArithmeticAddition::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy)
+{
+ return NEArithmeticAdditionKernel::validate(input1, input2, output, policy);
+}
diff --git a/src/runtime/NEON/functions/NEArithmeticSubtraction.cpp b/src/runtime/NEON/functions/NEArithmeticSubtraction.cpp
index 37586af751..be264d54b4 100644
--- a/src/runtime/NEON/functions/NEArithmeticSubtraction.cpp
+++ b/src/runtime/NEON/functions/NEArithmeticSubtraction.cpp
@@ -36,3 +36,7 @@ void NEArithmeticSubtraction::configure(const ITensor *input1, const ITensor *in
k->configure(input1, input2, output, policy);
_kernel = std::move(k);
}
+Error NEArithmeticSubtraction::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output, ConvertPolicy policy)
+{
+ return NEArithmeticSubtractionKernel::validate(input1, input2, output, policy);
+}
diff --git a/tests/validation/NEON/ArithmeticAddition.cpp b/tests/validation/NEON/ArithmeticAddition.cpp
index 4431371326..21a8c4b79f 100644
--- a/tests/validation/NEON/ArithmeticAddition.cpp
+++ b/tests/validation/NEON/ArithmeticAddition.cpp
@@ -66,6 +66,41 @@ TEST_SUITE(ArithmeticAddition)
template <typename T>
using NEArithmeticAdditionFixture = ArithmeticAdditionValidationFixture<Tensor, Accessor, NEArithmeticAddition, T>;
+// *INDENT-OFF*
+// clang-format off
+DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
+ framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
+ TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8), // Window shrink
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Invalid data type combination
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching shapes
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2), // Mismatching fixed point
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2),
+ }),
+ framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
+ TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
+ TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 3),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2),
+ })),
+ framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
+ TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
+ TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 3),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2),
+ })),
+ framework::dataset::make("Expected", { false, false, true, true, true, true, false })),
+ input1_info, input2_info, output_info, expected)
+{
+ ARM_COMPUTE_EXPECT(bool(NEArithmeticAddition::validate(&input1_info.clone()->set_is_resizable(false), &input2_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false), ConvertPolicy::WRAP)) == expected, framework::LogLevel::ERRORS);
+}
+// clang-format on
+// *INDENT-ON*
+
TEST_SUITE(U8)
DATA_TEST_CASE(Configuration, framework::DatasetMode::ALL, combine(framework::dataset::concat(datasets::SmallShapes(), datasets::LargeShapes()), framework::dataset::make("ConvertPolicy", { ConvertPolicy::SATURATE, ConvertPolicy::WRAP })),
shape, policy)
diff --git a/tests/validation/NEON/ArithmeticSubtraction.cpp b/tests/validation/NEON/ArithmeticSubtraction.cpp
index 0c2a7be60b..1a31defb46 100644
--- a/tests/validation/NEON/ArithmeticSubtraction.cpp
+++ b/tests/validation/NEON/ArithmeticSubtraction.cpp
@@ -70,6 +70,41 @@ const auto ArithmeticSubtractionFP32Dataset = combine(combine(framework::dataset
TEST_SUITE(NEON)
TEST_SUITE(ArithmeticSubtraction)
+// *INDENT-OFF*
+// clang-format off
+DATA_TEST_CASE(Validate, framework::DatasetMode::ALL, zip(zip(zip(
+ framework::dataset::make("Input1Info", { TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
+ TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8), // Window shrink
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8), // Invalid data type combination
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::F32), // Mismatching shapes
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2), // Mismatching fixed point
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2),
+ }),
+ framework::dataset::make("Input2Info",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
+ TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
+ TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 3),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2),
+ })),
+ framework::dataset::make("OutputInfo",{ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::S16),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
+ TensorInfo(TensorShape(27U, 13U, 2U), 1, DataType::U8),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::U8),
+ TensorInfo(TensorShape(48U, 11U, 2U), 1, DataType::F32),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 3),
+ TensorInfo(TensorShape(32U, 13U, 2U), 1, DataType::QS8, 2),
+ })),
+ framework::dataset::make("Expected", { false, false, true, true, true, true, false })),
+ input1_info, input2_info, output_info, expected)
+{
+ ARM_COMPUTE_EXPECT(bool(NEArithmeticSubtraction::validate(&input1_info.clone()->set_is_resizable(false), &input2_info.clone()->set_is_resizable(false), &output_info.clone()->set_is_resizable(false), ConvertPolicy::WRAP)) == expected, framework::LogLevel::ERRORS);
+}
+// clang-format on
+// *INDENT-ON*
+
template <typename T1, typename T2 = T1, typename T3 = T1>
using NEArithmeticSubtractionFixture = ArithmeticSubtractionValidationFixture<Tensor, Accessor, NEArithmeticSubtraction, T1, T2, T3>;