aboutsummaryrefslogtreecommitdiff
path: root/src/armnnOnnxParser/OnnxParser.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnOnnxParser/OnnxParser.cpp')
-rw-r--r--src/armnnOnnxParser/OnnxParser.cpp67
1 files changed, 46 insertions, 21 deletions
diff --git a/src/armnnOnnxParser/OnnxParser.cpp b/src/armnnOnnxParser/OnnxParser.cpp
index 81d9e3d240..1fb5b96b8f 100644
--- a/src/armnnOnnxParser/OnnxParser.cpp
+++ b/src/armnnOnnxParser/OnnxParser.cpp
@@ -18,6 +18,7 @@
#include <iostream>
#include <numeric>
+#include <armnnUtils/Permute.hpp>
using namespace armnn;
@@ -500,14 +501,46 @@ void OnnxParserImpl::Cleanup()
m_OutputsFusedAndUsed.clear();
}
-std::pair<ConstTensor, std::unique_ptr<float[]>> OnnxParserImpl::CreateConstTensor(const std::string name)
+template<typename T>
+std::pair<armnn::ConstTensor, std::unique_ptr<T[]>>
+CreateConstTensorImpl(const T* bufferPtr,
+ armnn::TensorInfo& tensorInfo,
+ const armnn::Optional<armnn::PermutationVector&> permutationVector)
{
- const TensorInfo tensorInfo = *m_TensorsInfo[name].m_info;
+ ARMNN_ASSERT_MSG(bufferPtr != nullptr, fmt::format("Buffer for permutation is null").c_str());
+
+ std::unique_ptr<T[]> data(new T[tensorInfo.GetNumElements()]);
+
+ if (permutationVector.has_value() && permutationVector.value().GetSize() > 0)
+ {
+ tensorInfo = armnnUtils::Permuted(tensorInfo, permutationVector.value());
+ armnnUtils::Permute(tensorInfo.GetShape(), permutationVector.value(),
+ reinterpret_cast<const T*>(bufferPtr), data.get(), sizeof(T));
+ }
+ else
+ {
+ ::memcpy(data.get(), bufferPtr, tensorInfo.GetNumBytes());
+ }
+
+ return std::make_pair(ConstTensor(tensorInfo, data.get()), std::move(data));
+}
+
+std::pair<ConstTensor, std::unique_ptr<float[]>>
+OnnxParserImpl::CreateConstTensor(const std::string name,
+ armnn::Optional<armnn::PermutationVector&> permutationVector)
+{
+ TensorInfo tensorInfo = *m_TensorsInfo[name].m_info;
onnx::TensorProto onnxTensor = *m_TensorsInfo[name].m_tensor;
+ // Const tensors requires at least a list of values
+ if (tensorInfo.GetNumElements() == 0)
+ {
+ throw ParseException(fmt::format("No tensor data found for Const tensor '{}' {}",
+ name,
+ CHECK_LOCATION().AsString()));
+ }
+
auto srcData = onnxTensor.float_data().data();
- std::unique_ptr<float[]> tensorData(new float[tensorInfo.GetNumElements()]);
- const size_t tensorSizeInBytes = tensorInfo.GetNumBytes();
// Copy the value list entries into the destination
if (!onnxTensor.has_raw_data())
{
@@ -521,21 +554,14 @@ std::pair<ConstTensor, std::unique_ptr<float[]>> OnnxParserImpl::CreateConstTens
tensorInfo.GetNumElements(),
CHECK_LOCATION().AsString()));
}
- ::memcpy(tensorData.get(), srcData, tensorSizeInBytes);
+ return CreateConstTensorImpl<float>(srcData, tensorInfo, permutationVector);
}
else
{
- ::memcpy(tensorData.get(), onnxTensor.raw_data().c_str(), tensorSizeInBytes);
+ return CreateConstTensorImpl<float>(reinterpret_cast<const float*>(onnxTensor.raw_data().c_str()),
+ tensorInfo,
+ permutationVector);
}
-
- // Const tensors requires at least a list of values
- if (tensorInfo.GetNumElements() == 0)
- {
- throw ParseException(fmt::format("No tensor data found for Const tensor '{}' {}",
- name,
- CHECK_LOCATION().AsString()));
- }
- return std::make_pair(ConstTensor(tensorInfo, tensorData.get()), std::move(tensorData));
}
ModelPtr OnnxParserImpl::LoadModelFromTextFile(const char* graphFile)
@@ -858,11 +884,10 @@ void OnnxParserImpl::AddConvLayerWithDepthwiseConv(const onnx::NodeProto& node,
desc.m_BiasEnabled = convDesc.m_BiasEnabled;
armnn::IConnectableLayer* layer;
- auto weightTensor = CreateConstTensor(node.input(1));
- TensorShape& weightShape = weightTensor.first.GetShape();
- weightShape[1] = weightShape[0];
- weightShape[0] = 1;
- m_TensorsInfo[node.input(1)].m_info->SetShape(weightShape);
+
+ // weights come in as [O,1,H,W] from ONNX and need to be converted to ArmNNs dephtwise weights layout [1,H,W,O]
+ armnn::PermutationVector perVec {3,0,1,2};
+ auto weightTensor = CreateConstTensor(node.input(1), perVec);
if (node.input_size() == 3)
{
@@ -891,7 +916,7 @@ void OnnxParserImpl::AddConvLayerWithDepthwiseConv(const onnx::NodeProto& node,
auto outputInfo = ComputeOutputInfo({ node.output(0) }, layer,
{ m_TensorsInfo[node.input(0)].m_info->GetShape(),
- m_TensorsInfo[node.input(1)].m_info->GetShape() });
+ weightTensor.first.GetInfo().GetShape() });
layer->GetOutputSlot(0).SetTensorInfo(outputInfo[0]);