aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/LoadedNetwork.cpp
diff options
context:
space:
mode:
authorMike Kelly <mike.kelly@arm.com>2023-07-07 15:43:06 +0100
committerMike Kelly <mike.kelly@arm.com>2023-07-14 00:00:53 +0100
commit4cc341cf8b5a6e6bb0543504cbbfde6fa11a2cdb (patch)
tree7cac128e9ec6f2fd27f1afdb55f44b870f39e0b3 /src/armnn/LoadedNetwork.cpp
parent6963b33221c23af4a8eff19ff4a5773230b0befd (diff)
downloadarmnn-4cc341cf8b5a6e6bb0543504cbbfde6fa11a2cdb.tar.gz
IVGCVSW-7830 Add backend optimizations to remove Reshapes where possible
* Added optimization to remove reshapes for Neon and Ref Backends by using overridden TensorInfos * Added ability to delete Subgraphs during Optimization * Fixed naming error in NeonEndToEndTests and CLEndToEndTests * Added LayerNameAndTypeCheck for testing. * Fixed error where layers were not marked as altered when removed in CLBackend Signed-off-by: Mike Kelly <mike.kelly@arm.com> Change-Id: I1ac25cd4ec9821470d961831ae2c8d24882276cc
Diffstat (limited to 'src/armnn/LoadedNetwork.cpp')
-rw-r--r--src/armnn/LoadedNetwork.cpp28
1 files changed, 22 insertions, 6 deletions
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp
index 3f4aa34a5b..3d84054b69 100644
--- a/src/armnn/LoadedNetwork.cpp
+++ b/src/armnn/LoadedNetwork.cpp
@@ -955,10 +955,10 @@ Status LoadedNetwork::EnqueueWorkload(const InputTensors& inputTensors,
syncDesc.m_Inputs.push_back(inputTensorHandle);
WorkloadInfo info;
info.m_InputTensorInfos.push_back(
- outputLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo());
+ outputLayer->GetInputSlot(0).GetTensorInfo());
auto syncWorkload = std::make_unique<SyncMemGenericWorkload>(syncDesc, info);
ARMNN_ASSERT_MSG(syncWorkload, "No sync workload created");
- m_OutputQueue.push_back(move(syncWorkload));
+ m_OutputQueue.push_back(std::move(syncWorkload));
importedOutputIdIndex++;
}
else
@@ -1089,7 +1089,7 @@ void LoadedNetwork::EnqueueInput(const BindableLayer& layer, ITensorHandle* tens
timelineUtils->Commit();
}
- m_InputQueue.push_back(move(inputWorkload));
+ m_InputQueue.push_back(std::move(inputWorkload));
}
}
@@ -1149,7 +1149,7 @@ void LoadedNetwork::EnqueueOutput(const BindableLayer& layer, ITensorHandle* ten
info.m_InputTensorInfos.push_back(inputTensorInfo);
auto syncWorkload = std::make_unique<SyncMemGenericWorkload>(syncDesc, info);
ARMNN_ASSERT_MSG(syncWorkload, "No sync workload created");
- m_OutputQueue.push_back(move(syncWorkload));
+ m_OutputQueue.push_back(std::move(syncWorkload));
}
else
{
@@ -1177,7 +1177,7 @@ void LoadedNetwork::EnqueueOutput(const BindableLayer& layer, ITensorHandle* ten
timelineUtils->Commit();
}
- m_OutputQueue.push_back(move(outputWorkload));
+ m_OutputQueue.push_back(std::move(outputWorkload));
}
}
@@ -1650,7 +1650,7 @@ std::vector<ImportedOutputId> LoadedNetwork::ImportOutputs(const OutputTensors&
const InputSlot& inputSlot = layer->GetInputSlots()[0];
ITensorHandleFactory::FactoryId factoryId = inputSlot.GetConnectedOutputSlot()->GetTensorHandleFactoryId();
- const TensorInfo& tensorInfo = inputSlot.GetConnectedOutputSlot()->GetTensorInfo();
+ const TensorInfo& tensorInfo = inputSlot.GetTensorInfo();
ITensorHandleFactory* handleFactory = m_TensorHandleFactoryRegistry.GetFactory(factoryId);
ARMNN_ASSERT(handleFactory);
@@ -2093,6 +2093,14 @@ std::unique_ptr<IWorkingMemHandle> LoadedNetwork::CreateWorkingMemHandle(Network
if (found != m_ConstantTensorHandles.end())
{
ITensorHandle* tensorHandle = found->second;
+ if (slot.IsTensorInfoOverridden())
+ {
+ ITensorHandle* decorated = tensorHandle->DecorateTensorHandle(slot.GetTensorInfo()).get();
+ if (decorated)
+ {
+ tensorHandle = decorated;
+ }
+ }
workingMemDescriptor.m_Inputs.push_back(tensorHandle);
// Odd case where a constant layer is connected to an output layer
@@ -2113,6 +2121,14 @@ std::unique_ptr<IWorkingMemHandle> LoadedNetwork::CreateWorkingMemHandle(Network
HandleInfo& handleInfo = outputToHandleInfoMap.at(outputSlot);
ITensorHandle* inputTensorHandle = handleInfo.m_TensorHandle;
+ if (slot.IsTensorInfoOverridden())
+ {
+ ITensorHandle* decorated = inputTensorHandle->DecorateTensorHandle(slot.GetTensorInfo()).get();
+ if (decorated)
+ {
+ inputTensorHandle = decorated;
+ }
+ }
workingMemDescriptor.m_Inputs.push_back(inputTensorHandle);
// Store the LayerBindingId of the OutputLayer