diff options
author | Gunes Bayir <gunes.bayir@arm.com> | 2022-07-28 17:44:00 +0100 |
---|---|---|
committer | Gunes Bayir <gunes.bayir@arm.com> | 2022-08-01 20:13:56 +0000 |
commit | 9b921be1ff7283050eb39d9ce1b10b5c8bfc1300 (patch) | |
tree | 0cb274a6c529717b8ef987aa3e270647927e9d89 /src/cpu/kernels/CpuAddKernel.cpp | |
parent | 385dad2bffecbf395aa9aad257809de81c727ac7 (diff) | |
download | ComputeLibrary-9b921be1ff7283050eb39d9ce1b10b5c8bfc1300.tar.gz |
Optimize add layer by considering the input tensors as 1D array
Resolves: COMPMID-5108
Change-Id: I544f8160fbe5b4ffbef348d1fbd3dd626a6e1bdb
Signed-off-by: Gunes Bayir <gunes.bayir@arm.com>
Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/8002
Reviewed-by: Gian Marco Iodice <gianmarco.iodice@arm.com>
Comments-Addressed: Arm Jenkins <bsgcomp@arm.com>
Benchmark: Arm Jenkins <bsgcomp@arm.com>
Tested-by: Arm Jenkins <bsgcomp@arm.com>
Diffstat (limited to 'src/cpu/kernels/CpuAddKernel.cpp')
-rw-r--r-- | src/cpu/kernels/CpuAddKernel.cpp | 123 |
1 files changed, 90 insertions, 33 deletions
diff --git a/src/cpu/kernels/CpuAddKernel.cpp b/src/cpu/kernels/CpuAddKernel.cpp index e756effea9..85ae410a94 100644 --- a/src/cpu/kernels/CpuAddKernel.cpp +++ b/src/cpu/kernels/CpuAddKernel.cpp @@ -39,82 +39,127 @@ namespace cpu { namespace kernels { +bool can_interpret_inputs_as_1d_array(const ITensorInfo &src0, const ITensorInfo &src1) +{ + return !src0.has_padding() && !src1.has_padding() && src0.tensor_shape() == src1.tensor_shape(); +} + namespace { static const std::vector<CpuAddKernel::AddKernel> available_kernels = { { + "neon_fp32_add_as_1d_array", + [](const CpuAddKernelDataTypeISASelectorData & data) + { + return (data.dt == DataType::F32) && data.can_interpret_inputs_as_1d_array == true; + }, + REGISTER_FP32_NEON(arm_compute::cpu::add_fp32_neon_as_1d_array) + }, + { + "neon_fp16_add_as_1d_array", + [](const CpuAddKernelDataTypeISASelectorData & data) + { + return (data.dt == DataType::F16) && data.can_interpret_inputs_as_1d_array == true; + }, + REGISTER_FP16_NEON(arm_compute::cpu::add_fp16_neon_as_1d_array) + }, + { + "neon_u8_add_as_1d_array", + [](const CpuAddKernelDataTypeISASelectorData & data) + { + return (data.dt == DataType::U8) && data.can_interpret_inputs_as_1d_array == true; + }, + REGISTER_INTEGER_NEON(arm_compute::cpu::add_u8_neon_as_1d_array) + }, + { + "neon_s16_add_as_1d_array", + [](const CpuAddKernelDataTypeISASelectorData & data) + { + return (data.dt == DataType::S16) && data.can_interpret_inputs_as_1d_array == true; + }, + REGISTER_INTEGER_NEON(arm_compute::cpu::add_s16_neon_as_1d_array) + }, + { + "neon_s32_add_as_1d_array", + [](const CpuAddKernelDataTypeISASelectorData & data) + { + return (data.dt == DataType::S32) && data.can_interpret_inputs_as_1d_array == true; + }, + REGISTER_INTEGER_NEON(arm_compute::cpu::add_s32_neon_as_1d_array) + }, + { "sve2_qu8_add", - [](const DataTypeISASelectorData & data) + [](const CpuAddKernelDataTypeISASelectorData & data) { - return (data.dt == DataType::QASYMM8) && data.isa.sve2; + return (data.dt == DataType::QASYMM8) && data.isa.sve2 && data.can_interpret_inputs_as_1d_array == false; }, REGISTER_QASYMM8_SVE2(arm_compute::cpu::add_qasymm8_sve2) }, { "sve2_qs8_add", - [](const DataTypeISASelectorData & data) + [](const CpuAddKernelDataTypeISASelectorData & data) { - return (data.dt == DataType::QASYMM8_SIGNED) && data.isa.sve2; + return (data.dt == DataType::QASYMM8_SIGNED) && data.isa.sve2 && data.can_interpret_inputs_as_1d_array == false; }, REGISTER_QASYMM8_SIGNED_SVE2(arm_compute::cpu::add_qasymm8_signed_sve2) }, { "sve2_qs16_add", - [](const DataTypeISASelectorData & data) + [](const CpuAddKernelDataTypeISASelectorData & data) { - return (data.dt == DataType::QSYMM16) && data.isa.sve2; + return (data.dt == DataType::QSYMM16) && data.isa.sve2 && data.can_interpret_inputs_as_1d_array == false; }, REGISTER_QSYMM16_SVE2(arm_compute::cpu::add_qsymm16_sve2) }, { "sve_fp32_add", - [](const DataTypeISASelectorData & data) + [](const CpuAddKernelDataTypeISASelectorData & data) { - return (data.dt == DataType::F32) && data.isa.sve; + return (data.dt == DataType::F32) && data.isa.sve && data.can_interpret_inputs_as_1d_array == false; }, REGISTER_FP32_SVE(arm_compute::cpu::add_fp32_sve) }, { "sve_fp16_add", - [](const DataTypeISASelectorData & data) + [](const CpuAddKernelDataTypeISASelectorData & data) { - return (data.dt == DataType::F16) && data.isa.sve && data.isa.fp16; + return (data.dt == DataType::F16) && data.isa.sve && data.isa.fp16 && data.can_interpret_inputs_as_1d_array == false; }, REGISTER_FP16_SVE(arm_compute::cpu::add_fp16_sve) }, { "sve_u8_add", - [](const DataTypeISASelectorData & data) + [](const CpuAddKernelDataTypeISASelectorData & data) { - return (data.dt == DataType::U8) && data.isa.sve; + return (data.dt == DataType::U8) && data.isa.sve && data.can_interpret_inputs_as_1d_array == false; }, REGISTER_INTEGER_SVE(arm_compute::cpu::add_u8_sve) }, { "sve_s16_add", - [](const DataTypeISASelectorData & data) + [](const CpuAddKernelDataTypeISASelectorData & data) { - return (data.dt == DataType::S16) && data.isa.sve; + return (data.dt == DataType::S16) && data.isa.sve && data.can_interpret_inputs_as_1d_array == false; }, REGISTER_INTEGER_SVE(arm_compute::cpu::add_s16_sve) }, { "sve_s32_add", - [](const DataTypeISASelectorData & data) + [](const CpuAddKernelDataTypeISASelectorData & data) { - return (data.dt == DataType::S32) && data.isa.sve; + return (data.dt == DataType::S32) && data.isa.sve && data.can_interpret_inputs_as_1d_array == false; }, REGISTER_INTEGER_SVE(arm_compute::cpu::add_s32_sve) }, { "neon_fp32_add", - [](const DataTypeISASelectorData & data) { return (data.dt == DataType::F32); }, + [](const CpuAddKernelDataTypeISASelectorData & data) { return (data.dt == DataType::F32); }, REGISTER_FP32_NEON(arm_compute::cpu::add_fp32_neon) }, { "neon_fp16_add", - [](const DataTypeISASelectorData & data) + [](const CpuAddKernelDataTypeISASelectorData & data) { return (data.dt == DataType::F16) && data.isa.fp16; }, @@ -122,32 +167,32 @@ static const std::vector<CpuAddKernel::AddKernel> available_kernels = }, { "neon_u8_add", - [](const DataTypeISASelectorData & data) { return (data.dt == DataType::U8); }, + [](const CpuAddKernelDataTypeISASelectorData & data) { return (data.dt == DataType::U8); }, REGISTER_INTEGER_NEON(arm_compute::cpu::add_u8_neon) }, { "neon_s16_add", - [](const DataTypeISASelectorData & data) { return (data.dt == DataType::S16); }, + [](const CpuAddKernelDataTypeISASelectorData & data) { return (data.dt == DataType::S16); }, REGISTER_INTEGER_NEON(arm_compute::cpu::add_s16_neon) }, { "neon_s32_add", - [](const DataTypeISASelectorData & data) { return (data.dt == DataType::S32); }, + [](const CpuAddKernelDataTypeISASelectorData & data) { return (data.dt == DataType::S32); }, REGISTER_INTEGER_NEON(arm_compute::cpu::add_s32_neon) }, { "neon_qu8_add", - [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8); }, + [](const CpuAddKernelDataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8); }, REGISTER_QASYMM8_NEON(arm_compute::cpu::add_qasymm8_neon) }, { "neon_qs8_add", - [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); }, + [](const CpuAddKernelDataTypeISASelectorData & data) { return (data.dt == DataType::QASYMM8_SIGNED); }, REGISTER_QASYMM8_SIGNED_NEON(arm_compute::cpu::add_qasymm8_signed_neon) }, { "neon_qs16_add", - [](const DataTypeISASelectorData & data) { return (data.dt == DataType::QSYMM16); }, + [](const CpuAddKernelDataTypeISASelectorData & data) { return (data.dt == DataType::QSYMM16); }, REGISTER_QSYMM16_NEON(arm_compute::cpu::add_qsymm16_neon) } }; @@ -177,7 +222,8 @@ Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, cons "Wrong shape for dst"); } - const auto *uk = CpuAddKernel::get_implementation(DataTypeISASelectorData{ src0.data_type(), CPUInfo::get().get_isa() }); + const auto uk = CpuAddKernel::get_implementation<CpuAddKernelDataTypeISASelectorData>(CpuAddKernelDataTypeISASelectorData{ src0.data_type(), + CPUInfo::get().get_isa(), can_interpret_inputs_as_1d_array(src0, src1) }); ARM_COMPUTE_RETURN_ERROR_ON(uk == nullptr || uk->ukernel == nullptr); return Status{}; @@ -185,16 +231,25 @@ Status validate_arguments(const ITensorInfo &src0, const ITensorInfo &src1, cons std::pair<Status, Window> validate_and_configure_window(const ITensorInfo &src0, const ITensorInfo &src1, ITensorInfo &dst) { - const TensorShape &out_shape = TensorShape::broadcast_shape(src0.tensor_shape(), src1.tensor_shape()); + if(can_interpret_inputs_as_1d_array(src0, src1)) + { + Window window; + window.set(0, Window::Dimension(0, src0.tensor_shape().total_size())); + return std::make_pair(Status{}, window); + } + else + { + const TensorShape &out_shape = TensorShape::broadcast_shape(src0.tensor_shape(), src1.tensor_shape()); - // Auto initialize dst if not initialized - set_shape_if_empty(dst, out_shape); - set_data_type_if_unknown(dst, src0.data_type()); + // Auto initialize dst if not initialized + set_shape_if_empty(dst, out_shape); + set_data_type_if_unknown(dst, src0.data_type()); - Window win = calculate_max_window(out_shape, Steps()); + Window win = calculate_max_window(out_shape, Steps()); - // CpuAddKernel doesn't need padding so update_window_and_padding() can be skipped - return std::make_pair(Status{}, win); + // CpuAddKernel doesn't need padding so update_window_and_padding() can be skipped + return std::make_pair(Status{}, win); + } } } // namespace @@ -203,7 +258,9 @@ void CpuAddKernel::configure(const ITensorInfo *src0, const ITensorInfo *src1, I ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst); ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(*src0, *src1, *dst, policy)); - const auto uk = CpuAddKernel::get_implementation(DataTypeISASelectorData{ src0->data_type(), CPUInfo::get().get_isa() }); + _can_interpret_inputs_as_1d_array = can_interpret_inputs_as_1d_array(*src0, *src1); + const auto uk = CpuAddKernel::get_implementation<CpuAddKernelDataTypeISASelectorData>(CpuAddKernelDataTypeISASelectorData{ src0->data_type(), + CPUInfo::get().get_isa(), _can_interpret_inputs_as_1d_array }); ARM_COMPUTE_ERROR_ON_NULLPTR(uk); |