From e4563a032aaa71de5efdb83fc04ff2933338e02d Mon Sep 17 00:00:00 2001 From: Adnan AlSinan Date: Wed, 1 Sep 2021 15:32:03 +0100 Subject: Adds Conv3d reference implementation support. Expands the interface with the following items: - Size3D Class. - Conv3dInfo Struct. - Padding3D Struct. - Add 'NDHWC' to supported Tensor Data Layouts. - Add function to compute expected size of Conv3d. Resolves COMPMID-4658 & COMPMID-4657 Signed-off-by: Adnan AlSinan Change-Id: Ic7452c48461eedaa38eaf3ac458f54b031e7dfa8 Reviewed-on: https://review.mlplatform.org/c/ml/ComputeLibrary/+/6187 Reviewed-by: Giorgio Arena Reviewed-by: Gian Marco Iodice Tested-by: Arm Jenkins --- arm_compute/runtime/FunctionDescriptors.h | 28 ++++++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) (limited to 'arm_compute/runtime/FunctionDescriptors.h') diff --git a/arm_compute/runtime/FunctionDescriptors.h b/arm_compute/runtime/FunctionDescriptors.h index 1f4216eb21..07a8f6600e 100644 --- a/arm_compute/runtime/FunctionDescriptors.h +++ b/arm_compute/runtime/FunctionDescriptors.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2020 Arm Limited. + * Copyright (c) 2019-2021 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -52,7 +52,7 @@ struct FFT2DInfo FFTDirection direction{ FFTDirection::Forward }; /**< Direction of the FFT. */ }; -/** Descriptor used by the Convolution function */ +/** Descriptor used by the 2d Convolution function */ struct Conv2dInfo { Conv2dInfo() = default; @@ -72,5 +72,29 @@ struct Conv2dInfo bool enable_fast_math{ false }; unsigned int num_groups{ 1 }; }; + +/** Descriptor used by the 3d Convolution function */ +struct Conv3dInfo +{ + Conv3dInfo() = default; + + Conv3dInfo(const Size3D &stride, + const Padding3D &padding, + const ActivationLayerInfo &act_info, + const Size3D &dilation, + const DimensionRoundingType &round_type, + bool enable_fast_math) + : stride(stride), padding(padding), act_info(act_info), dilation(dilation), round_type(round_type), enable_fast_math(enable_fast_math) + { + } + + Size3D stride{ 1U, 1U, 1U }; + Padding3D padding{}; + ActivationLayerInfo act_info{}; + Size3D dilation{ 1U, 1U, 1U }; + DimensionRoundingType round_type{}; + bool enable_fast_math{ false }; +}; + } // namespace arm_compute #endif /* ARM_COMPUTE_RUNTIME_FUNCTION_DESCRIPTORS_H */ -- cgit v1.2.1