aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJan Eilers <jan.eilers@arm.com>2021-04-12 13:12:19 +0100
committerJan Eilers <jan.eilers@arm.com>2021-04-19 09:31:02 +0100
commita71c0633eff4791eb98362f72f198a2d1ec3d8f9 (patch)
tree9976eb70e4c5f3a93b4db0b06e02f8d03057a235
parent32fe97ec627a70b6453375fcfc6665c0e1ad2024 (diff)
downloadandroid-nn-driver-a71c0633eff4791eb98362f72f198a2d1ec3d8f9.tar.gz
IVGCVSW-5826 Remove cross-wireing in depthwise
* The permutation of the tensor info is now completely handled in the armnnUtils::Permuted function. That includes quantization informations too !armnn:5411 Signed-off-by: Jan Eilers <jan.eilers@arm.com> Change-Id: I40410141303d950be7888f9e491133251b6f69d8
-rw-r--r--ConversionUtils.cpp4
-rw-r--r--ConversionUtils.hpp7
-rw-r--r--Utils.cpp24
-rw-r--r--Utils.hpp2
4 files changed, 11 insertions, 26 deletions
diff --git a/ConversionUtils.cpp b/ConversionUtils.cpp
index 9cc6e286..4cea7276 100644
--- a/ConversionUtils.cpp
+++ b/ConversionUtils.cpp
@@ -56,7 +56,7 @@ ConstTensorPin::ConstTensorPin(bool optional)
: m_Optional(optional)
{}
-ConstTensorPin::ConstTensorPin(const armnn::TensorInfo& tensorInfo,
+ConstTensorPin::ConstTensorPin(armnn::TensorInfo& tensorInfo,
const void* valueStart,
uint32_t numBytes,
const armnn::PermutationVector& mappings)
@@ -73,7 +73,7 @@ ConstTensorPin::ConstTensorPin(const armnn::TensorInfo& tensorInfo,
m_SwizzledTensorData.resize(tensorInfo.GetNumBytes());
SwizzleAndroidNn4dTensorToArmNn(tensorInfo, valueStart, m_SwizzledTensorData.data(), mappings);
- m_ConstTensor = armnn::ConstTensor(armnnUtils::Permuted(tensorInfo, mappings), m_SwizzledTensorData.data());
+ m_ConstTensor = armnn::ConstTensor(tensorInfo, m_SwizzledTensorData.data());
}
else
{
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp
index 439d4a4a..473f6d78 100644
--- a/ConversionUtils.hpp
+++ b/ConversionUtils.hpp
@@ -90,7 +90,7 @@ public:
// @param valueStart Start address of tensor data. Belongs to one of the memory pools associated with
// the model being converted.
// @param numBytes Number of bytes for the tensor data.
- ConstTensorPin(const armnn::TensorInfo& tensorInfo, const void* valueStart, uint32_t numBytes,
+ ConstTensorPin(armnn::TensorInfo& tensorInfo, const void* valueStart, uint32_t numBytes,
const armnn::PermutationVector& mappings);
ConstTensorPin(const ConstTensorPin& other) = delete;
@@ -843,11 +843,6 @@ ConstTensorPin ConvertOperandToConstTensorPin(const HalOperand& operand,
}
armnn::TensorInfo tensorInfo = GetTensorInfoForOperand(operand);
- // Android datalayout might be different than armnn datalayout, e.g. the kernel for the depthwise convolution.
- if (tensorInfo.HasPerAxisQuantization())
- {
- tensorInfo.SetQuantizationDim(dimensionMappings[tensorInfo.GetQuantizationDim().value()]);
- }
if (overrideTensorShape != nullptr)
{
diff --git a/Utils.cpp b/Utils.cpp
index 18842814..6fd1a785 100644
--- a/Utils.cpp
+++ b/Utils.cpp
@@ -31,25 +31,12 @@ namespace armnn_driver
{
const armnn::PermutationVector g_DontPermute{};
-namespace
-{
-
-void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorShape& inTensorShape, const void* input,
- void* output, size_t dataTypeSize, const armnn::PermutationVector& mappings)
-{
- assert(inTensorShape.GetNumDimensions() == 4U);
-
- armnnUtils::Permute(armnnUtils::Permuted(inTensorShape, mappings), mappings, input, output, dataTypeSize);
-}
-
-} // anonymous namespace
-
-void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorInfo& tensor, const void* input, void* output,
+void SwizzleAndroidNn4dTensorToArmNn(armnn::TensorInfo& tensorInfo, const void* input, void* output,
const armnn::PermutationVector& mappings)
{
- assert(tensor.GetNumDimensions() == 4U);
+ assert(tensorInfo.GetNumDimensions() == 4U);
- armnn::DataType dataType = tensor.GetDataType();
+ armnn::DataType dataType = tensorInfo.GetDataType();
switch (dataType)
{
case armnn::DataType::Float16:
@@ -57,7 +44,10 @@ void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorInfo& tensor, const void
case armnn::DataType::QAsymmU8:
case armnn::DataType::QSymmS8:
case armnn::DataType::QAsymmS8:
- SwizzleAndroidNn4dTensorToArmNn(tensor.GetShape(), input, output, armnn::GetDataTypeSize(dataType), mappings);
+ // First swizzle tensor info
+ tensorInfo = armnnUtils::Permuted(tensorInfo, mappings);
+ // Then swizzle tensor data
+ armnnUtils::Permute(tensorInfo.GetShape(), mappings, input, output, armnn::GetDataTypeSize(dataType));
break;
default:
ALOGW("Unknown armnn::DataType for swizzling");
diff --git a/Utils.hpp b/Utils.hpp
index f68747b0..893c4a08 100644
--- a/Utils.hpp
+++ b/Utils.hpp
@@ -60,7 +60,7 @@ public:
};
/// Swizzles tensor data in @a input according to the dimension mappings.
-void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorInfo& tensor, const void* input, void* output,
+void SwizzleAndroidNn4dTensorToArmNn(armnn::TensorInfo& tensor, const void* input, void* output,
const armnn::PermutationVector& mappings);
/// Returns a pointer to a specific location in a pool