diff options
Diffstat (limited to 'src/backends/aclCommon')
-rw-r--r-- | src/backends/aclCommon/ArmComputeTensorUtils.cpp | 23 | ||||
-rw-r--r-- | src/backends/aclCommon/ArmComputeTensorUtils.hpp | 5 |
2 files changed, 28 insertions, 0 deletions
diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.cpp b/src/backends/aclCommon/ArmComputeTensorUtils.cpp index e476eb38a1..1960332ccf 100644 --- a/src/backends/aclCommon/ArmComputeTensorUtils.cpp +++ b/src/backends/aclCommon/ArmComputeTensorUtils.cpp @@ -8,6 +8,8 @@ #include "armnn/Exceptions.hpp" #include <armnn/Descriptors.hpp> +#include <fmt/format.h> + namespace armnn { namespace armcomputetensorutils @@ -342,5 +344,26 @@ arm_compute::PixelValue GetPixelValue(const arm_compute::ITensorInfo* tensorInfo } } +unsigned int ComputeDepthwiseConv2dDepthMultiplier(armnn::DataLayout layout, + const arm_compute::TensorShape& weightsShape, + const arm_compute::TensorShape& inputShape) +{ + unsigned int depthMultiplier; + if (layout == armnn::DataLayout::NHWC) + { + depthMultiplier = static_cast<uint32_t>(weightsShape[0]) / static_cast<uint32_t>(inputShape[0]); + } + else if (layout == armnn::DataLayout::NCHW) + { + depthMultiplier = static_cast<uint32_t>(weightsShape[2]) / static_cast<uint32_t>(inputShape[2]); + } + else + { + throw InvalidArgumentException(fmt::format("Unknown data layout for tensor conversion: {}", + GetDataLayoutName(layout))); + } + return depthMultiplier; +} + } // namespace armcomputetensorutils } // namespace armnn diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.hpp b/src/backends/aclCommon/ArmComputeTensorUtils.hpp index 31992b93c6..ee8240f3b8 100644 --- a/src/backends/aclCommon/ArmComputeTensorUtils.hpp +++ b/src/backends/aclCommon/ArmComputeTensorUtils.hpp @@ -77,6 +77,11 @@ arm_compute::Size2D BuildArmComputeSize2D(const unsigned int width, const unsign /// Gets the appropriate PixelValue for the TensorInfo DataType arm_compute::PixelValue GetPixelValue(const arm_compute::ITensorInfo* tensorInfo, float pixelValue); +/// Computes the depth multiplier parameter for the Depthwise Conv2d ACL workload. +unsigned int ComputeDepthwiseConv2dDepthMultiplier(armnn::DataLayout layout, + const arm_compute::TensorShape& weightsShape, + const arm_compute::TensorShape& inputShape); + /// Utility function used to setup an arm_compute::PadStrideInfo object from an armnn layer descriptor. template <typename Descriptor> arm_compute::PadStrideInfo BuildArmComputePadStrideInfo(const Descriptor &descriptor) |