diff options
-rw-r--r-- | ConversionUtils.cpp | 4 | ||||
-rw-r--r-- | ConversionUtils.hpp | 7 | ||||
-rw-r--r-- | Utils.cpp | 24 | ||||
-rw-r--r-- | Utils.hpp | 2 |
4 files changed, 11 insertions, 26 deletions
diff --git a/ConversionUtils.cpp b/ConversionUtils.cpp index 9cc6e286..4cea7276 100644 --- a/ConversionUtils.cpp +++ b/ConversionUtils.cpp @@ -56,7 +56,7 @@ ConstTensorPin::ConstTensorPin(bool optional) : m_Optional(optional) {} -ConstTensorPin::ConstTensorPin(const armnn::TensorInfo& tensorInfo, +ConstTensorPin::ConstTensorPin(armnn::TensorInfo& tensorInfo, const void* valueStart, uint32_t numBytes, const armnn::PermutationVector& mappings) @@ -73,7 +73,7 @@ ConstTensorPin::ConstTensorPin(const armnn::TensorInfo& tensorInfo, m_SwizzledTensorData.resize(tensorInfo.GetNumBytes()); SwizzleAndroidNn4dTensorToArmNn(tensorInfo, valueStart, m_SwizzledTensorData.data(), mappings); - m_ConstTensor = armnn::ConstTensor(armnnUtils::Permuted(tensorInfo, mappings), m_SwizzledTensorData.data()); + m_ConstTensor = armnn::ConstTensor(tensorInfo, m_SwizzledTensorData.data()); } else { diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp index 439d4a4a..473f6d78 100644 --- a/ConversionUtils.hpp +++ b/ConversionUtils.hpp @@ -90,7 +90,7 @@ public: // @param valueStart Start address of tensor data. Belongs to one of the memory pools associated with // the model being converted. // @param numBytes Number of bytes for the tensor data. - ConstTensorPin(const armnn::TensorInfo& tensorInfo, const void* valueStart, uint32_t numBytes, + ConstTensorPin(armnn::TensorInfo& tensorInfo, const void* valueStart, uint32_t numBytes, const armnn::PermutationVector& mappings); ConstTensorPin(const ConstTensorPin& other) = delete; @@ -843,11 +843,6 @@ ConstTensorPin ConvertOperandToConstTensorPin(const HalOperand& operand, } armnn::TensorInfo tensorInfo = GetTensorInfoForOperand(operand); - // Android datalayout might be different than armnn datalayout, e.g. the kernel for the depthwise convolution. - if (tensorInfo.HasPerAxisQuantization()) - { - tensorInfo.SetQuantizationDim(dimensionMappings[tensorInfo.GetQuantizationDim().value()]); - } if (overrideTensorShape != nullptr) { @@ -31,25 +31,12 @@ namespace armnn_driver { const armnn::PermutationVector g_DontPermute{}; -namespace -{ - -void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorShape& inTensorShape, const void* input, - void* output, size_t dataTypeSize, const armnn::PermutationVector& mappings) -{ - assert(inTensorShape.GetNumDimensions() == 4U); - - armnnUtils::Permute(armnnUtils::Permuted(inTensorShape, mappings), mappings, input, output, dataTypeSize); -} - -} // anonymous namespace - -void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorInfo& tensor, const void* input, void* output, +void SwizzleAndroidNn4dTensorToArmNn(armnn::TensorInfo& tensorInfo, const void* input, void* output, const armnn::PermutationVector& mappings) { - assert(tensor.GetNumDimensions() == 4U); + assert(tensorInfo.GetNumDimensions() == 4U); - armnn::DataType dataType = tensor.GetDataType(); + armnn::DataType dataType = tensorInfo.GetDataType(); switch (dataType) { case armnn::DataType::Float16: @@ -57,7 +44,10 @@ void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorInfo& tensor, const void case armnn::DataType::QAsymmU8: case armnn::DataType::QSymmS8: case armnn::DataType::QAsymmS8: - SwizzleAndroidNn4dTensorToArmNn(tensor.GetShape(), input, output, armnn::GetDataTypeSize(dataType), mappings); + // First swizzle tensor info + tensorInfo = armnnUtils::Permuted(tensorInfo, mappings); + // Then swizzle tensor data + armnnUtils::Permute(tensorInfo.GetShape(), mappings, input, output, armnn::GetDataTypeSize(dataType)); break; default: ALOGW("Unknown armnn::DataType for swizzling"); @@ -60,7 +60,7 @@ public: }; /// Swizzles tensor data in @a input according to the dimension mappings. -void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorInfo& tensor, const void* input, void* output, +void SwizzleAndroidNn4dTensorToArmNn(armnn::TensorInfo& tensor, const void* input, void* output, const armnn::PermutationVector& mappings); /// Returns a pointer to a specific location in a pool |