aboutsummaryrefslogtreecommitdiff
path: root/shim
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2022-07-19 12:37:20 +0100
committerNikhil Raj <nikhil.raj@arm.com>2022-07-27 15:57:46 +0100
commit1e276f38e67af7505a25010eee579034ee83d12b (patch)
tree48607813d793d4142c0a2e4bc0b0b4cf15cf8285 /shim
parent07389192266eedac50a64c7d66ef62c1532e06f2 (diff)
downloadarmnn-1e276f38e67af7505a25010eee579034ee83d12b.tar.gz
IVGCVSW-6954 'Arm NN Support Library Implementation'
* Fixed model converting issue * Fixed import memory issue Signed-off-by: Sadik Armagan <sadik.armagan@arm.com> Change-Id: Ied61810b308e0c5d5754f122a6ea2bac1d0725f1
Diffstat (limited to 'shim')
-rw-r--r--shim/sl/canonical/ArmnnPreparedModel.cpp30
-rw-r--r--shim/sl/canonical/ConversionUtils.cpp39
-rw-r--r--shim/sl/canonical/ConversionUtils.hpp6
-rw-r--r--shim/sl/canonical/Converter.cpp10
4 files changed, 71 insertions, 14 deletions
diff --git a/shim/sl/canonical/ArmnnPreparedModel.cpp b/shim/sl/canonical/ArmnnPreparedModel.cpp
index 54a019004c..79cd241348 100644
--- a/shim/sl/canonical/ArmnnPreparedModel.cpp
+++ b/shim/sl/canonical/ArmnnPreparedModel.cpp
@@ -393,7 +393,37 @@ ErrorStatus ArmnnPreparedModel::ExecuteGraph(
armnn::Status status;
VLOG(DRIVER) << "ArmnnPreparedModel::ExecuteGraph m_AsyncModelExecutionEnabled false";
importedInputIds = m_Runtime->ImportInputs(m_NetworkId, inputTensors, armnn::MemorySource::Malloc);
+ if (!importedInputIds.empty())
+ {
+ // Some or all of the input tensors been imported. We need to remove the ones that could from
+ // inputTensors.
+ for (armnn::ImportedInputId& importedId : importedInputIds)
+ {
+ inputTensors.erase(
+ std::remove_if(
+ inputTensors.begin(), inputTensors.end(),
+ [&importedId](std::pair<armnn::LayerBindingId, class armnn::ConstTensor>& element) {
+ return (element.first == static_cast<int>(importedId));
+ }),
+ inputTensors.end());
+ }
+ }
importedOutputIds = m_Runtime->ImportOutputs(m_NetworkId, outputTensors, armnn::MemorySource::Malloc);
+ if (!importedOutputIds.empty())
+ {
+ // Some or all of the output tensors could not be imported. We need to remove the ones that could
+ // from outputTensors.
+ for (armnn::ImportedInputId& importedId : importedOutputIds)
+ {
+ outputTensors.erase(
+ std::remove_if(
+ outputTensors.begin(), outputTensors.end(),
+ [&importedId](std::pair<armnn::LayerBindingId, class armnn::Tensor>& element) {
+ return (element.first == static_cast<int>(importedId));
+ }),
+ outputTensors.end());
+ }
+ }
status = m_Runtime->EnqueueWorkload(m_NetworkId,
inputTensors,
outputTensors,
diff --git a/shim/sl/canonical/ConversionUtils.cpp b/shim/sl/canonical/ConversionUtils.cpp
index 020410d30e..96a8ddca6a 100644
--- a/shim/sl/canonical/ConversionUtils.cpp
+++ b/shim/sl/canonical/ConversionUtils.cpp
@@ -151,7 +151,8 @@ ConstTensorPin ConvertOperandToConstTensorPin(const Operand& operand,
const ConversionData& data,
const armnn::PermutationVector& dimensionMappings,
const armnn::TensorShape* overrideTensorShape,
- bool optional)
+ bool optional,
+ const armnn::DataType* overrideDataType)
{
if (!IsOperandTypeSupportedForTensors(operand.type))
{
@@ -180,13 +181,18 @@ ConstTensorPin ConvertOperandToConstTensorPin(const Operand& operand,
armnn::TensorInfo tensorInfo = GetTensorInfoForOperand(operand);
- // Make sure isConstant flag is set.
- tensorInfo.SetConstant();
-
- if (overrideTensorShape != nullptr)
+ if (overrideTensorShape)
{
tensorInfo.SetShape(*overrideTensorShape);
}
+
+ if (overrideDataType)
+ {
+ tensorInfo.SetDataType(*overrideDataType);
+ }
+
+ // Make sure isConstant flag is set.
+ tensorInfo.SetConstant();
return ConstTensorPin(tensorInfo, valueStart, operand.location.length, dimensionMappings);
}
@@ -194,7 +200,8 @@ LayerInputHandle ConvertToLayerInputHandle(const Operation& operation,
uint32_t inputIndex,
const Model& model,
ConversionData& data,
- const armnn::PermutationVector& dimensionMappings)
+ const armnn::PermutationVector& dimensionMappings,
+ const LayerInputHandle* inputHandle)
{
const Operand* operand = GetInputOperand(operation, inputIndex, model);
@@ -268,8 +275,26 @@ LayerInputHandle ConvertToLayerInputHandle(const Operation& operation,
case OperandLifeTime::POINTER:
case OperandLifeTime::CONSTANT_REFERENCE:
{
+ auto constantTensorDataType = operandTensorInfo.GetDataType();
+ if (inputHandle)
+ {
+ if ((inputHandle->GetTensorInfo().GetDataType() == armnn::DataType::Float32
+ || inputHandle->GetTensorInfo().GetDataType() == armnn::DataType::Float16)
+ && (operandTensorInfo.GetDataType() == armnn::DataType::QAsymmU8
+ || operandTensorInfo.GetDataType() == armnn::DataType::QAsymmS8))
+ {
+ constantTensorDataType = inputHandle->GetTensorInfo().GetDataType();
+ }
+ }
+
// The tensor has an already known constant value, and can be converted into an ArmNN Constant layer.
- ConstTensorPin tensorPin = ConvertOperandToConstTensorPin(*operand, model, data, dimensionMappings);
+ ConstTensorPin tensorPin = ConvertOperandToConstTensorPin(*operand,
+ model,
+ data,
+ dimensionMappings,
+ nullptr,
+ false,
+ &constantTensorDataType);
if (tensorPin.IsValid())
{
bool isSupported = false;
diff --git a/shim/sl/canonical/ConversionUtils.hpp b/shim/sl/canonical/ConversionUtils.hpp
index 5847d219d4..8058bcb379 100644
--- a/shim/sl/canonical/ConversionUtils.hpp
+++ b/shim/sl/canonical/ConversionUtils.hpp
@@ -700,7 +700,8 @@ ConstTensorPin ConvertOperandToConstTensorPin(const Operand& operand,
const ConversionData& data,
const armnn::PermutationVector& dimensionMappings = g_DontPermute,
const armnn::TensorShape* overrideTensorShape = nullptr,
- bool optional = false);
+ bool optional = false,
+ const armnn::DataType* overrideDataType = nullptr);
inline ConstTensorPin ConvertOperationInputToConstTensorPin(
const Operation& operation,
@@ -924,7 +925,8 @@ LayerInputHandle ConvertToLayerInputHandle(const Operation& operation,
uint32_t inputIndex,
const Model& model,
ConversionData& data,
- const armnn::PermutationVector& dimensionMappings = g_DontPermute);
+ const armnn::PermutationVector& dimensionMappings = g_DontPermute,
+ const LayerInputHandle* inputHandle = nullptr);
bool SetupAndTrackLayerOutputSlot(const Operation& operation,
uint32_t operationOutputIndex,
diff --git a/shim/sl/canonical/Converter.cpp b/shim/sl/canonical/Converter.cpp
index ade8b4fce6..fc983dc081 100644
--- a/shim/sl/canonical/Converter.cpp
+++ b/shim/sl/canonical/Converter.cpp
@@ -932,15 +932,15 @@ bool Converter::ConvertConv2d(const Operation& operation, const Model& model, Co
}
LayerInputHandle weightsInput = (desc.m_DataLayout == DataLayout::NCHW)
- ? ConvertToLayerInputHandle(operation, 1, model, data, OHWIToOIHW)
- : ConvertToLayerInputHandle(operation, 1, model, data);
+ ? ConvertToLayerInputHandle(operation, 1, model, data, OHWIToOIHW, &input)
+ : ConvertToLayerInputHandle(operation, 1, model, data, g_DontPermute, &input);
if (!weightsInput.IsValid())
{
return Fail("%s: Operation has invalid inputs", __func__);
}
- LayerInputHandle biasInput = ConvertToLayerInputHandle(operation, 2, model, data); // 1D
+ LayerInputHandle biasInput = ConvertToLayerInputHandle(operation, 2, model, data, g_DontPermute, &input); // 1D
if (!biasInput.IsValid())
{
return Fail("%s: Operation has invalid inputs", __func__);
@@ -1165,7 +1165,7 @@ bool Converter::ConvertDepthwiseConv2d(const Operation& operation, const Model&
unsigned int widthIndex = dataLayoutIndexed.GetWidthIndex();
unsigned int heightIndex = dataLayoutIndexed.GetHeightIndex();
- LayerInputHandle weightsInput = ConvertToLayerInputHandle(operation, 1, model, data);
+ LayerInputHandle weightsInput = ConvertToLayerInputHandle(operation, 1, model, data, g_DontPermute, &input);
if (!weightsInput.IsValid())
{
return Fail("%s: Operation has invalid inputs", __func__);
@@ -1177,7 +1177,7 @@ bool Converter::ConvertDepthwiseConv2d(const Operation& operation, const Model&
return Fail("%s: Could not read bias", __func__);
}
- LayerInputHandle biasInput = ConvertToLayerInputHandle(operation, 2, model, data); // 1D
+ LayerInputHandle biasInput = ConvertToLayerInputHandle(operation, 2, model, data, g_DontPermute, &input); // 1D
if (!biasInput.IsValid())
{
return Fail("%s: Operation has invalid inputs", __func__);