aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ConversionUtils.cpp4
-rw-r--r--ConversionUtils.hpp7
-rw-r--r--Utils.cpp24
-rw-r--r--Utils.hpp2
4 files changed, 11 insertions, 26 deletions
diff --git a/ConversionUtils.cpp b/ConversionUtils.cpp
index 9cc6e286..4cea7276 100644
--- a/ConversionUtils.cpp
+++ b/ConversionUtils.cpp
@@ -56,7 +56,7 @@ ConstTensorPin::ConstTensorPin(bool optional)
: m_Optional(optional)
{}
-ConstTensorPin::ConstTensorPin(const armnn::TensorInfo& tensorInfo,
+ConstTensorPin::ConstTensorPin(armnn::TensorInfo& tensorInfo,
const void* valueStart,
uint32_t numBytes,
const armnn::PermutationVector& mappings)
@@ -73,7 +73,7 @@ ConstTensorPin::ConstTensorPin(const armnn::TensorInfo& tensorInfo,
m_SwizzledTensorData.resize(tensorInfo.GetNumBytes());
SwizzleAndroidNn4dTensorToArmNn(tensorInfo, valueStart, m_SwizzledTensorData.data(), mappings);
- m_ConstTensor = armnn::ConstTensor(armnnUtils::Permuted(tensorInfo, mappings), m_SwizzledTensorData.data());
+ m_ConstTensor = armnn::ConstTensor(tensorInfo, m_SwizzledTensorData.data());
}
else
{
diff --git a/ConversionUtils.hpp b/ConversionUtils.hpp
index 439d4a4a..473f6d78 100644
--- a/ConversionUtils.hpp
+++ b/ConversionUtils.hpp
@@ -90,7 +90,7 @@ public:
// @param valueStart Start address of tensor data. Belongs to one of the memory pools associated with
// the model being converted.
// @param numBytes Number of bytes for the tensor data.
- ConstTensorPin(const armnn::TensorInfo& tensorInfo, const void* valueStart, uint32_t numBytes,
+ ConstTensorPin(armnn::TensorInfo& tensorInfo, const void* valueStart, uint32_t numBytes,
const armnn::PermutationVector& mappings);
ConstTensorPin(const ConstTensorPin& other) = delete;
@@ -843,11 +843,6 @@ ConstTensorPin ConvertOperandToConstTensorPin(const HalOperand& operand,
}
armnn::TensorInfo tensorInfo = GetTensorInfoForOperand(operand);
- // Android datalayout might be different than armnn datalayout, e.g. the kernel for the depthwise convolution.
- if (tensorInfo.HasPerAxisQuantization())
- {
- tensorInfo.SetQuantizationDim(dimensionMappings[tensorInfo.GetQuantizationDim().value()]);
- }
if (overrideTensorShape != nullptr)
{
diff --git a/Utils.cpp b/Utils.cpp
index 18842814..6fd1a785 100644
--- a/Utils.cpp
+++ b/Utils.cpp
@@ -31,25 +31,12 @@ namespace armnn_driver
{
const armnn::PermutationVector g_DontPermute{};
-namespace
-{
-
-void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorShape& inTensorShape, const void* input,
- void* output, size_t dataTypeSize, const armnn::PermutationVector& mappings)
-{
- assert(inTensorShape.GetNumDimensions() == 4U);
-
- armnnUtils::Permute(armnnUtils::Permuted(inTensorShape, mappings), mappings, input, output, dataTypeSize);
-}
-
-} // anonymous namespace
-
-void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorInfo& tensor, const void* input, void* output,
+void SwizzleAndroidNn4dTensorToArmNn(armnn::TensorInfo& tensorInfo, const void* input, void* output,
const armnn::PermutationVector& mappings)
{
- assert(tensor.GetNumDimensions() == 4U);
+ assert(tensorInfo.GetNumDimensions() == 4U);
- armnn::DataType dataType = tensor.GetDataType();
+ armnn::DataType dataType = tensorInfo.GetDataType();
switch (dataType)
{
case armnn::DataType::Float16:
@@ -57,7 +44,10 @@ void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorInfo& tensor, const void
case armnn::DataType::QAsymmU8:
case armnn::DataType::QSymmS8:
case armnn::DataType::QAsymmS8:
- SwizzleAndroidNn4dTensorToArmNn(tensor.GetShape(), input, output, armnn::GetDataTypeSize(dataType), mappings);
+ // First swizzle tensor info
+ tensorInfo = armnnUtils::Permuted(tensorInfo, mappings);
+ // Then swizzle tensor data
+ armnnUtils::Permute(tensorInfo.GetShape(), mappings, input, output, armnn::GetDataTypeSize(dataType));
break;
default:
ALOGW("Unknown armnn::DataType for swizzling");
diff --git a/Utils.hpp b/Utils.hpp
index f68747b0..893c4a08 100644
--- a/Utils.hpp
+++ b/Utils.hpp
@@ -60,7 +60,7 @@ public:
};
/// Swizzles tensor data in @a input according to the dimension mappings.
-void SwizzleAndroidNn4dTensorToArmNn(const armnn::TensorInfo& tensor, const void* input, void* output,
+void SwizzleAndroidNn4dTensorToArmNn(armnn::TensorInfo& tensor, const void* input, void* output,
const armnn::PermutationVector& mappings);
/// Returns a pointer to a specific location in a pool