aboutsummaryrefslogtreecommitdiff
path: root/src/cpu/kernels/CpuAddKernel.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/cpu/kernels/CpuAddKernel.cpp')
-rw-r--r--src/cpu/kernels/CpuAddKernel.cpp123
1 files changed, 90 insertions, 33 deletions
diff --git a/src/cpu/kernels/CpuAddKernel.cpp b/src/cpu/kernels/CpuAddKernel.cpp
index e756effea9..85ae410a94 100644
--- a/src/cpu/kernels/CpuAddKernel.cpp
+++ b/src/cpu/kernels/CpuAddKernel.cpp
@@ -39,82 +39,127 @@ namespace cpu
{
namespace kernels
{
+bool can_interpret_inputs_as_1d_array(const ITensorInfo &src0, const ITensorInfo &src1)
+{
+ return !src0.has_padding() && !src1.has_padding() && src0.tensor_shape() == src1.tensor_shape();
+}
+
namespace
{
static const std::vector<CpuAddKernel::AddKernel> available_kernels =
{
{
+ "neon_fp32_add_as_1d_array",
+ [](const CpuAddKernelDataTypeISASelectorData & data)
+ {
+ return (data.dt == DataType::F32) && data.can_interpret_inputs_as_1d_array == true;
+ },
+ REGISTER_FP32_NEON(arm_compute::cpu::add_fp32_neon_as_1d_array)
+ },
+ {
+ "neon_fp16_add_as_1d_array",
+ [](const CpuAddKernelDataTypeISASelectorData & data)
+ {
+ return (data.dt == DataType::F16) && data.can_interpret_inputs_as_1d_array == true;
+ },
+ REGISTER_FP16_NEON(arm_compute::cpu::add_fp16_neon_as_1d_array)
+ },
+ {
+ "neon_u8_add_as_1d_array",
+ [](const CpuAddKernelDataTypeISASelectorData & data)
+ {
+ return (data.dt == DataType::U8) && data.can_interpret_inputs_as_1d_array == true;
+ },
+ REGISTER_INTEGER_NEON(arm_compute::cpu::add_u8_neon_as_1d_array)
+ },
+ {
+ "neon_s16_add_as_1d_array",
+ [](const CpuAddKernelDataTypeISASelectorData & data)
+ {
+ return (data.dt == DataType::S16) && data.can_interpret_inputs_as_1d_array == true;
+ },
+ REGISTER_INTEGER_NEON(arm_compute::cpu::add_s16_neon_as_1d_array)
+ },
+ {
+ "neon_s32_add_as_1d_array",
+ [](const CpuAddKernelDataTypeISASelectorData & data)
+ {
+ return (data.dt == DataType::S32) && data.can_interpret_inputs_as_1d_array == true;
+ },
+ REGISTER_INTEGER_NEON(arm_compute::cpu::add_s32_neon_as_1d_array)
+ },
+ {
"sve2_qu8_add",
- [](const DataTypeISASelectorData & data)
+ [](const CpuAddKernelDataTypeISASelectorData & data)
{
- return (data.dt == DataType::QASYMM8) && data.isa.sve2;
+ return (data.dt == DataType::QASYMM8) && data.isa.sve2 && data.can_interpret_inputs_as_1d_array == false;
},
REGISTER_QASYMM8_SVE2(arm_compute::cpu::add_qasymm8_sve2)
},
{
"sve2_qs8_add",
- [](const DataTypeISASelectorData & data)
+ [](const CpuAddKernelDataTypeISASelectorData & data)
{
- return (data.dt == DataType::QASYMM8_SIGNED) && data.isa.sve2;
+ return (data.dt == DataType::QASYMM8_SIGNED) && data.isa.sve2 && data.can_interpret_inputs_as_1d_array == false;
},
REGISTER_QASYMM8_SIGNED_SVE2(arm_compute::cpu::add_qasymm8_signed_sve2)
},
{
"sve2_qs16_add",
- [](const DataTypeISASelectorData & data)
+ [](const CpuAddKernelDataTypeISASelectorData & data)
{
- return (data.dt == DataType::QSYMM16) && data.isa.sve2;
+ return (data.dt == DataType::QSYMM16) && data.isa.sve2 && data.can_interpret_inputs_as_1d_array == false;
},
REGISTER_QSYMM16_SVE2(arm_compute::cpu::add_qsymm16_sve2)
},
{
"sve_fp32_add",
- [](const DataTypeISASelectorData & data)
+ [](const CpuAddKernelDataTypeISASelectorData & data)
{
- return (data.dt == DataType::F32) && data.isa.sve;
+ return (data.dt == DataType::F32) && data.isa.sve && data.can_interpret_inputs_as_1d_array == false;
},
REGISTER_FP32_SVE(arm_compute::cpu::add_fp32_sve)
},
{
"sve_fp16_add",
- [](const DataTypeISASelectorData & data)
+ [](const CpuAddKernelDataTypeISASelectorData & data)
{
- return (data.dt == DataType::F16) && data.isa.sve && data.isa.fp16;
+ return (data.dt == DataType::F16) && data.isa.sve && data.isa.fp16 && data.can_interpret_inputs_as_1d_array == false;
},
REGISTER_FP16_SVE(arm_compute::cpu::add_fp16_sve)
},
{
"sve_u8_add",
- [](const DataTypeISASelectorData & data)
+ [](const CpuAddKernelDataTypeISASelectorData & data)
{
- return (data.dt == DataType::U8) && data.isa.sve;
+ return (data.dt == DataType::U8) && data.isa.sve && data.can_interpret_inputs_as_1d_array == false;
},
REGISTER_INTEGER_SVE(arm_compute::cpu::add_u8_sve)
},
{
"sve_s16_add",
- [](const DataTypeISASelectorData & data)
+ [](const CpuAddKernelDataTypeISASelectorData & data)
{
- return (data.dt == DataType::S16) && data.isa.sve;
+ return (data.dt == DataType::S16) && data.isa.sve && data.can_interpret_inputs_as_1d_array == false;
},
REGISTER_INTEGER_SVE(arm_compute::cpu::add_s16_sve)
},
{
"sve_s32_add",
- [](const DataTypeISASelectorData & data)
+ [](const CpuAddKernelDataTypeISASelectorData & data)
{
- return (data.dt == DataType::S32) && data.isa.sve;
+ return (data.dt == DataType::S32) && data.isa.sve && data.can_interpret_inputs_as_1d_array == false;
},
REGISTER_INTEGER_SVE(arm_compute::cpu::add_s32_sve)
},
{
"neon_fp32_add",
- [](const DataTypeISASelectorData & data) { return (data.dt == DataType::F32); },
+ [](const CpuAddKernelDataTypeISASelectorData & data) { return (data.dt == DataType::F32); },
REGISTER_FP32_NEON(arm_compute::cpu::add_fp32_neon)
},
{
"neon_fp16_add",
- [](const DataTypeISASelectorData & data)
+ [](const CpuAddKernelDataTypeISASelectorData & data)
{
return (data.dt == DataType::F16) && data.isa.fp16;
},
@@ -122,32 +167,32 @@ static const std::vector<CpuAddKernel::AddKernel> available_kernels =
},
{
"neon_u8_add",
- [](const DataTypeISASelectorData & data) { return (data.dt == DataType::U8); },
+ [](const CpuAddKernelDataTypeISASelectorData & data) { return (data.dt == DataType::U8); },
REGISTER_INTEGER_NEON(arm_compute::cpu::add_u8_neon)
},
{
"neon_s16_add",
- [](const DataTypeISASelectorData & data) { return (data.dt == DataType::S16); },
+ [](const CpuAddKernelDataTypeISASelectorData & data) { return (data.dt == DataType::S16); },
REGISTER_INTEGER_NEON(arm_compute::cpu::add_s16_neon)
},
{
"neon_s32_add",
- [](const DataTypeISASelectorData & data) { return (data.dt == DataType::S32); },
+ [](const CpuAddKernelDataTypeISASelectorData & data) { return (data.dt == DataType::S32); },
REGISTER_INTEGER_NEON(arm_compute::cpu::add_s32_neon)
},
{
"neon_qu8_add",
- [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8); },
+ [](const CpuAddKernelDataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8); },
REGISTER_QASYMM8_NEON(arm_compute::cpu::add_qasymm8_neon)
},
{
"neon_qs8_add",
- [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); },
+ [](const CpuAddKernelDataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); },
REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::add_qasymm8_signed_neon)
},
{
"neon_qs16_add",
- [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QSYMM16); },
+ [](const CpuAddKernelDataTypeISASelectorData & data) { return (data.dt == DataType::QSYMM16); },
REGISTER_QSYMM16_NEON(arm_compute::cpu::add_qsymm16_neon)
}
};
@@ -177,7 +222,8 @@ Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, cons
"Wrong shape for dst");
}
- const auto *uk = CpuAddKernel::get_implementation(DataTypeISASelectorData{ src0.data_type(), CPUInfo::get().get_isa() });
+ const auto uk = CpuAddKernel::get_implementation<CpuAddKernelDataTypeISASelectorData>(CpuAddKernelDataTypeISASelectorData{ src0.data_type(),
+ CPUInfo::get().get_isa(), can_interpret_inputs_as_1d_array(src0, src1) });
ARM_COMPUTE_RETURN_ERROR_ON(uk == nullptr || uk->ukernel == nullptr);
return Status{};
@@ -185,16 +231,25 @@ Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, cons
std::pair<Status, Window> validate_and_configure_window(const ITensorInfo &src0, const ITensorInfo &src1, ITensorInfo &dst)
{
- const TensorShape &out_shape = TensorShape::broadcast_shape(src0.tensor_shape(), src1.tensor_shape());
+ if(can_interpret_inputs_as_1d_array(src0, src1))
+ {
+ Window window;
+ window.set(0, Window::Dimension(0, src0.tensor_shape().total_size()));
+ return std::make_pair(Status{}, window);
+ }
+ else
+ {
+ const TensorShape &out_shape = TensorShape::broadcast_shape(src0.tensor_shape(), src1.tensor_shape());
- // Auto initialize dst if not initialized
- set_shape_if_empty(dst, out_shape);
- set_data_type_if_unknown(dst, src0.data_type());
+ // Auto initialize dst if not initialized
+ set_shape_if_empty(dst, out_shape);
+ set_data_type_if_unknown(dst, src0.data_type());
- Window win = calculate_max_window(out_shape, Steps());
+ Window win = calculate_max_window(out_shape, Steps());
- // CpuAddKernel doesn't need padding so update_window_and_padding() can be skipped
- return std::make_pair(Status{}, win);
+ // CpuAddKernel doesn't need padding so update_window_and_padding() can be skipped
+ return std::make_pair(Status{}, win);
+ }
}
} // namespace
@@ -203,7 +258,9 @@ void CpuAddKernel::configure(const ITensorInfo *src0, const ITensorInfo *src1, I
ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst);
ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*src0, *src1, *dst, policy));
- const auto uk = CpuAddKernel::get_implementation(DataTypeISASelectorData{ src0->data_type(), CPUInfo::get().get_isa() });
+ _can_interpret_inputs_as_1d_array = can_interpret_inputs_as_1d_array(*src0, *src1);
+ const auto uk = CpuAddKernel::get_implementation<CpuAddKernelDataTypeISASelectorData>(CpuAddKernelDataTypeISASelectorData{ src0->data_type(),
+ CPUInfo::get().get_isa(), _can_interpret_inputs_as_1d_array });
ARM_COMPUTE_ERROR_ON_NULLPTR(uk);