aboutsummaryrefslogtreecommitdiff
path: root/src/core/cpu/kernels/CpuElementwiseKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/core/cpu/kernels/CpuElementwiseKernel.cpp')
-rw-r--r--src/core/cpu/kernels/CpuElementwiseKernel.cpp164
1 files changed, 81 insertions, 83 deletions
diff --git a/src/core/cpu/kernels/CpuElementwiseKernel.cpp b/src/core/cpu/kernels/CpuElementwiseKernel.cpp
index 1ac21acbc0..23e95f72d7 100644
--- a/src/core/cpu/kernels/CpuElementwiseKernel.cpp
+++ b/src/core/cpu/kernels/CpuElementwiseKernel.cpp
@@ -72,9 +72,9 @@ static ElementwiseKernel generate_kernel(UKernelType *ukernel)
template <ArithmeticOperation op>
std::function<void(const ITensor *, const ITensor *, ITensor *, const Window &)>
-configure_arithm_func(const ITensorInfo *input1, const ITensorInfo *input2, ITensorInfo *output)
+configure_arithm_func(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
{
- ARM_COMPUTE_UNUSED(input2, output);
+ ARM_COMPUTE_UNUSED(src1, dst);
static ElementwiseKernel kernels[] =
{
#if defined(__ARM_FEATURE_SVE)
@@ -103,7 +103,7 @@ configure_arithm_func(const ITensorInfo *input1, const ITensorInfo *input2, ITen
for(const auto &uk : kernels)
{
- if(uk.is_selected(input1->data_type()))
+ if(uk.is_selected(src0->data_type()))
{
return uk.ukernel;
}
@@ -113,10 +113,10 @@ configure_arithm_func(const ITensorInfo *input1, const ITensorInfo *input2, ITen
}
template <ComparisonOperation op>
-std::function<void(const ITensor *input1, const ITensor *input2, ITensor *output, const Window &window)>
-configure_comp_func(const ITensorInfo *input1, const ITensorInfo *input2, ITensorInfo *output)
+std::function<void(const ITensor *, const ITensor *, ITensor *, const Window &)>
+configure_comp_func(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
{
- ARM_COMPUTE_UNUSED(input2, output);
+ ARM_COMPUTE_UNUSED(src1, dst);
static ElementwiseKernel kernels[] =
{
#if defined(__ARM_FEATURE_SVE)
@@ -148,7 +148,7 @@ configure_comp_func(const ITensorInfo *input1, const ITensorInfo *input2, ITenso
for(const auto &uk : kernels)
{
- if(uk.is_selected(input1->data_type()))
+ if(uk.is_selected(src0->data_type()))
{
return uk.ukernel;
}
@@ -158,45 +158,43 @@ configure_comp_func(const ITensorInfo *input1, const ITensorInfo *input2, ITenso
}
} // namespace
-Status CpuElementwiseKernel::validate_arguments_common(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output)
+Status CpuElementwiseKernel::validate_arguments_common(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst)
{
- ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&input1);
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input1, &input2);
+ ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(&src0);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src0, &src1);
- const TensorShape out_shape = TensorShape::broadcast_shape(input1.tensor_shape(), input2.tensor_shape());
+ const TensorShape out_shape = TensorShape::broadcast_shape(src0.tensor_shape(), src1.tensor_shape());
ARM_COMPUTE_RETURN_ERROR_ON_MSG(out_shape.total_size() == 0, "Inputs are not broadcast compatible");
- // Validate in case of configured output
- if(output.total_size() > 0)
+ // Validate in case of configured dst
+ if(dst.total_size() > 0)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, output.tensor_shape(), 0),
+ ARM_COMPUTE_RETURN_ERROR_ON_MSG(detail::have_different_dimensions(out_shape, dst.tensor_shape(), 0),
"Wrong shape for output");
}
return Status{};
}
-void CpuElementwiseKernel::configure_common(const ITensorInfo *input1, const ITensorInfo *input2, ITensorInfo *output)
+void CpuElementwiseKernel::configure_common(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
{
- ARM_COMPUTE_ERROR_ON_NULLPTR(input1, input2, output);
-
- // Configure kernel window
- const TensorShape &out_shape = TensorShape::broadcast_shape(input1->tensor_shape(), input2->tensor_shape());
-
- // Auto initialize output if not initialized
- auto_init_if_empty(*output, out_shape, 1, input1->data_type());
+ ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst);
- Window win = calculate_max_window(out_shape);
+ // If any of shapes is dynamic, expect a configured window and dst at run-time.
+ if(src0->is_dynamic() || src1->is_dynamic())
+ {
+ return;
+ }
- ICpuKernel::configure(win);
+ auto shape_and_window = compute_output_shape_and_window(*src0, *src1);
+ auto_init_if_empty(*dst, shape_and_window.first, 1, src0->data_type());
+ ICpuKernel::configure(shape_and_window.second);
}
void CpuElementwiseKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
{
- ARM_COMPUTE_UNUSED(info, window);
- ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
- ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICpuKernel::window(), window);
+ ARM_COMPUTE_UNUSED(info);
auto src0 = tensors.get_const_tensor(TensorType::ACL_SRC_0);
auto src1 = tensors.get_const_tensor(TensorType::ACL_SRC_1);
@@ -208,49 +206,49 @@ void CpuElementwiseKernel::run_op(ITensorPack &tensors, const Window &window, co
}
/** Arithmetic operators (min, max, squared_diff) */
-void CpuArithmeticKernel::configure(ArithmeticOperation op, const ITensorInfo *input1, const ITensorInfo *input2, ITensorInfo *output)
+void CpuArithmeticKernel::configure(ArithmeticOperation op, const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
{
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1, *input2, *output));
- configure_common(input1, input2, output);
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*src0, *src1, *dst));
+ configure_common(src0, src1, dst);
_op = op;
}
-Status CpuArithmeticKernel::validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output)
+Status CpuArithmeticKernel::validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::F16, DataType::S32, DataType::F32);
- // Validate in case of configured output
- if(output.total_size() > 0)
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src0, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::F16, DataType::S32, DataType::F32);
+ // Validate in case of configured dst
+ if(dst.total_size() > 0)
{
- ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&input1, &output);
+ ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(&src0, &dst);
}
- return validate_arguments_common(input1, input2, output);
+ return validate_arguments_common(src0, src1, dst);
}
-Status CpuArithmeticKernel::validate(ArithmeticOperation op, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
+Status CpuArithmeticKernel::validate(ArithmeticOperation op, const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst)
{
ARM_COMPUTE_UNUSED(op);
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output);
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output));
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src0, src1, dst);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*src0, *src1, *dst));
return Status{};
}
std::function<CpuElementwiseKernel::ElementwiseFunction>
-CpuArithmeticKernel::get_implementation(const ITensorInfo *input1, const ITensorInfo *input2, ITensorInfo *output)
+CpuArithmeticKernel::get_implementation(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
{
switch(_op)
{
case ArithmeticOperation::MAX:
- return configure_arithm_func<ArithmeticOperation::MAX>(input1, input2, output);
+ return configure_arithm_func<ArithmeticOperation::MAX>(src0, src1, dst);
case ArithmeticOperation::MIN:
- return configure_arithm_func<ArithmeticOperation::MIN>(input1, input2, output);
+ return configure_arithm_func<ArithmeticOperation::MIN>(src0, src1, dst);
case ArithmeticOperation::SQUARED_DIFF:
- return configure_arithm_func<ArithmeticOperation::SQUARED_DIFF>(input1, input2, output);
+ return configure_arithm_func<ArithmeticOperation::SQUARED_DIFF>(src0, src1, dst);
case ArithmeticOperation::PRELU:
- return configure_arithm_func<ArithmeticOperation::PRELU>(input1, input2, output);
+ return configure_arithm_func<ArithmeticOperation::PRELU>(src0, src1, dst);
case ArithmeticOperation::DIV:
- return configure_arithm_func<ArithmeticOperation::DIV>(input1, input2, output);
+ return configure_arithm_func<ArithmeticOperation::DIV>(src0, src1, dst);
case ArithmeticOperation::POWER:
- return configure_arithm_func<ArithmeticOperation::POWER>(input1, input2, output);
+ return configure_arithm_func<ArithmeticOperation::POWER>(src0, src1, dst);
default:
ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
}
@@ -259,91 +257,91 @@ CpuArithmeticKernel::get_implementation(const ITensorInfo *input1, const ITensor
/** The division operator */
-void CpuDivisionKernel::configure(const ITensorInfo *input1, const ITensorInfo *input2, ITensorInfo *output)
+void CpuDivisionKernel::configure(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
{
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1, *input2, *output));
- configure_common(input1, input2, output);
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*src0, *src1, *dst));
+ configure_common(src0, src1, dst);
_op = ArithmeticOperation::DIV;
}
-Status CpuDivisionKernel::validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output)
+Status CpuDivisionKernel::validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::S32, DataType::F16, DataType::F32);
- return CpuArithmeticKernel::validate_arguments(input1, input2, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src0, 1, DataType::S32, DataType::F16, DataType::F32);
+ return CpuArithmeticKernel::validate_arguments(src0, src1, dst);
}
-Status CpuDivisionKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
+Status CpuDivisionKernel::validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst)
{
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output);
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output));
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src0, src1, dst);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*src0, *src1, *dst));
return Status{};
}
/** The power operator */
-void CpuPowerKernel::configure(const ITensorInfo *input1, const ITensorInfo *input2, ITensorInfo *output)
+void CpuPowerKernel::configure(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
{
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1, *input2, *output));
- configure_common(input1, input2, output);
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*src0, *src1, *dst));
+ configure_common(src0, src1, dst);
_op = ArithmeticOperation::POWER;
}
-Status CpuPowerKernel::validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output)
+Status CpuPowerKernel::validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::F16, DataType::F32);
- return CpuArithmeticKernel::validate_arguments(input1, input2, output);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src0, 1, DataType::F16, DataType::F32);
+ return CpuArithmeticKernel::validate_arguments(src0, src1, dst);
}
-Status CpuPowerKernel::validate(const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
+Status CpuPowerKernel::validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst)
{
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output);
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output));
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src0, src1, dst);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*src0, *src1, *dst));
return Status{};
}
/** Comparison operators (equal, not equal, less than, greater than, less than or equal, greater than or equal) */
-void CpuComparisonKernel::configure(ComparisonOperation op, const ITensorInfo *input1, const ITensorInfo *input2, ITensorInfo *output)
+void CpuComparisonKernel::configure(ComparisonOperation op, const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
{
- ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*input1, *input2, *output));
- configure_common(input1, input2, output);
+ ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*src0, *src1, *dst));
+ configure_common(src0, src1, dst);
_op = op;
}
-Status CpuComparisonKernel::validate_arguments(const ITensorInfo &input1, const ITensorInfo &input2, const ITensorInfo &output)
+Status CpuComparisonKernel::validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, const ITensorInfo &dst)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&input1, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::F16, DataType::S32, DataType::F32);
- // Validate in case of configured output
- if(output.total_size() > 0)
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&src0, 1, DataType::U8, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S16, DataType::F16, DataType::S32, DataType::F32);
+ // Validate in case of configured dst
+ if(dst.total_size() > 0)
{
- ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&output, 1, DataType::U8);
+ ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(&dst, 1, DataType::U8);
}
- return validate_arguments_common(input1, input2, output);
+ return validate_arguments_common(src0, src1, dst);
}
-Status CpuComparisonKernel::validate(ComparisonOperation op, const ITensorInfo *input1, const ITensorInfo *input2, const ITensorInfo *output)
+Status CpuComparisonKernel::validate(ComparisonOperation op, const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst)
{
ARM_COMPUTE_UNUSED(op);
- ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(input1, input2, output);
- ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*input1, *input2, *output));
+ ARM_COMPUTE_RETURN_ERROR_ON_NULLPTR(src0, src1, dst);
+ ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(*src0, *src1, *dst));
return Status{};
}
std::function<CpuElementwiseKernel::ElementwiseFunction>
-CpuComparisonKernel::get_implementation(const ITensorInfo *input1, const ITensorInfo *input2, ITensorInfo *output)
+CpuComparisonKernel::get_implementation(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
{
switch(_op)
{
case ComparisonOperation::Equal:
- return configure_comp_func<ComparisonOperation::Equal>(input1, input2, output);
+ return configure_comp_func<ComparisonOperation::Equal>(src0, src1, dst);
case ComparisonOperation::NotEqual:
- return configure_comp_func<ComparisonOperation::NotEqual>(input1, input2, output);
+ return configure_comp_func<ComparisonOperation::NotEqual>(src0, src1, dst);
case ComparisonOperation::Greater:
- return configure_comp_func<ComparisonOperation::Greater>(input1, input2, output);
+ return configure_comp_func<ComparisonOperation::Greater>(src0, src1, dst);
case ComparisonOperation::GreaterEqual:
- return configure_comp_func<ComparisonOperation::GreaterEqual>(input1, input2, output);
+ return configure_comp_func<ComparisonOperation::GreaterEqual>(src0, src1, dst);
case ComparisonOperation::Less:
- return configure_comp_func<ComparisonOperation::Less>(input1, input2, output);
+ return configure_comp_func<ComparisonOperation::Less>(src0, src1, dst);
case ComparisonOperation::LessEqual:
- return configure_comp_func<ComparisonOperation::LessEqual>(input1, input2, output);
+ return configure_comp_func<ComparisonOperation::LessEqual>(src0, src1, dst);
default:
ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
}