diff options
Diffstat (limited to 'src/backends/ArmComputeTensorUtils.cpp')
-rw-r--r-- | src/backends/ArmComputeTensorUtils.cpp | 28 |
1 files changed, 28 insertions, 0 deletions
diff --git a/src/backends/ArmComputeTensorUtils.cpp b/src/backends/ArmComputeTensorUtils.cpp index ba9fb40cfc..e65c4ad35f 100644 --- a/src/backends/ArmComputeTensorUtils.cpp +++ b/src/backends/ArmComputeTensorUtils.cpp @@ -5,6 +5,7 @@ #include "ArmComputeTensorUtils.hpp" #include "ArmComputeUtils.hpp" +#include "armnn/Exceptions.hpp" #include <armnn/Descriptors.hpp> namespace armnn @@ -66,6 +67,33 @@ arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tenso return arm_compute::TensorInfo(aclTensorShape, 1, aclDataType, aclQuantizationInfo); } +arm_compute::DataLayout ConvertDataLayout(armnn::DataLayout dataLayout) +{ + switch(dataLayout) + { + case armnn::DataLayout::NHWC : return arm_compute::DataLayout::NHWC; + + case armnn::DataLayout::NCHW : return arm_compute::DataLayout::NCHW; + + default: throw InvalidArgumentException("Unknown armnn::DataLayout: [" + + std::to_string(static_cast<int>(dataLayout)) + "]"); + } +} + +arm_compute::TensorInfo BuildArmComputeTensorInfo(const armnn::TensorInfo& tensorInfo, + armnn::DataLayout dataLayout) +{ + const arm_compute::TensorShape aclTensorShape = BuildArmComputeTensorShape(tensorInfo.GetShape()); + const arm_compute::DataType aclDataType = GetArmComputeDataType(tensorInfo.GetDataType()); + const arm_compute::QuantizationInfo aclQuantizationInfo(tensorInfo.GetQuantizationScale(), + tensorInfo.GetQuantizationOffset()); + + arm_compute::TensorInfo clTensorInfo(aclTensorShape, 1, aclDataType, aclQuantizationInfo); + clTensorInfo.set_data_layout(ConvertDataLayout(dataLayout)); + + return clTensorInfo; +} + arm_compute::PoolingLayerInfo BuildArmComputePoolingLayerInfo(const Pooling2dDescriptor& descriptor) { using arm_compute::PoolingType; |