diff options
Diffstat (limited to 'src/backends/aclCommon/ArmComputeTensorUtils.cpp')
-rw-r--r-- | src/backends/aclCommon/ArmComputeTensorUtils.cpp | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/src/backends/aclCommon/ArmComputeTensorUtils.cpp b/src/backends/aclCommon/ArmComputeTensorUtils.cpp index e476eb38a1..1960332ccf 100644 --- a/src/backends/aclCommon/ArmComputeTensorUtils.cpp +++ b/src/backends/aclCommon/ArmComputeTensorUtils.cpp @@ -8,6 +8,8 @@ #include "armnn/Exceptions.hpp" #include <armnn/Descriptors.hpp> +#include <fmt/format.h> + namespace armnn { namespace armcomputetensorutils @@ -342,5 +344,26 @@ arm_compute::PixelValue GetPixelValue(const arm_compute::ITensorInfo* tensorInfo } } +unsigned int ComputeDepthwiseConv2dDepthMultiplier(armnn::DataLayout layout, + const arm_compute::TensorShape& weightsShape, + const arm_compute::TensorShape& inputShape) +{ + unsigned int depthMultiplier; + if (layout == armnn::DataLayout::NHWC) + { + depthMultiplier = static_cast<uint32_t>(weightsShape[0]) / static_cast<uint32_t>(inputShape[0]); + } + else if (layout == armnn::DataLayout::NCHW) + { + depthMultiplier = static_cast<uint32_t>(weightsShape[2]) / static_cast<uint32_t>(inputShape[2]); + } + else + { + throw InvalidArgumentException(fmt::format("Unknown data layout for tensor conversion: {}", + GetDataLayoutName(layout))); + } + return depthMultiplier; +} + } // namespace armcomputetensorutils } // namespace armnn |