aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFinn Williams <finn.williams@arm.com>2022-02-15 20:47:34 +0000
committerColm Donelan <colm.donelan@arm.com>2022-02-16 21:36:45 +0000
commit840c45d19bff23d64f78a7e466886fb970b4fcc9 (patch)
treef3941b3deb75c00ab68cb535f96172a28c2d4b41
parenta00bad1140223453e72a134388b209f9fa296d00 (diff)
downloadarmnn-840c45d19bff23d64f78a7e466886fb970b4fcc9.tar.gz
Refactor Forced Import
* Find and replace all workloads associated with imported IO * Only attempt tensorhandle replacement if supported by all workloads * Add new RefBaseWorkload to enable forced input for ref backend * Store imported tensorhandles in preImportedTensorhandles instead of outputHandles * Create pre-imported tensorhandles at network load-time * Front load import workload validation to load network time * Only call ReplaceTensorHandle when needed Change-Id: I3816a71b7f57ae90388bb16462a75d4ef3544fa7 Signed-off-by: Finn Williams <finn.williams@arm.com>
-rw-r--r--include/armnn/backends/IWorkload.hpp4
-rw-r--r--include/armnn/backends/Workload.hpp11
-rw-r--r--src/armnn/LoadedNetwork.cpp447
-rw-r--r--src/armnn/LoadedNetwork.hpp17
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp1
-rw-r--r--src/backends/cl/workloads/ClConvolution2dWorkload.hpp2
-rw-r--r--src/backends/reference/workloads/CMakeLists.txt1
-rw-r--r--src/backends/reference/workloads/RefActivationWorkload.hpp6
-rw-r--r--src/backends/reference/workloads/RefArgMinMaxWorkload.cpp2
-rw-r--r--src/backends/reference/workloads/RefArgMinMaxWorkload.hpp4
-rw-r--r--src/backends/reference/workloads/RefBaseWorkload.hpp36
-rw-r--r--src/backends/reference/workloads/RefBatchNormalizationWorkload.cpp2
-rw-r--r--src/backends/reference/workloads/RefBatchNormalizationWorkload.hpp4
-rw-r--r--src/backends/reference/workloads/RefBatchToSpaceNdWorkload.hpp6
-rw-r--r--src/backends/reference/workloads/RefCastWorkload.hpp6
-rw-r--r--src/backends/reference/workloads/RefChannelShuffleWorkload.hpp6
-rw-r--r--src/backends/reference/workloads/RefComparisonWorkload.cpp2
-rw-r--r--src/backends/reference/workloads/RefComparisonWorkload.hpp6
-rw-r--r--src/backends/reference/workloads/RefConcatWorkload.hpp6
-rw-r--r--src/backends/reference/workloads/RefConstantWorkload.cpp2
-rw-r--r--src/backends/reference/workloads/RefConstantWorkload.hpp4
-rw-r--r--src/backends/reference/workloads/RefConvertBf16ToFp32Workload.hpp2
-rw-r--r--src/backends/reference/workloads/RefConvertFp16ToFp32Workload.hpp2
-rw-r--r--src/backends/reference/workloads/RefConvertFp32ToBf16Workload.hpp2
-rw-r--r--src/backends/reference/workloads/RefConvertFp32ToFp16Workload.hpp2
-rw-r--r--src/backends/reference/workloads/RefConvolution2dWorkload.cpp2
-rw-r--r--src/backends/reference/workloads/RefConvolution2dWorkload.hpp4
-rw-r--r--src/backends/reference/workloads/RefConvolution3dWorkload.cpp2
-rw-r--r--src/backends/reference/workloads/RefConvolution3dWorkload.hpp4
-rw-r--r--src/backends/reference/workloads/RefDebugWorkload.hpp2
-rw-r--r--src/backends/reference/workloads/RefDepthToSpaceWorkload.hpp6
-rw-r--r--src/backends/reference/workloads/RefDepthwiseConvolution2dWorkload.cpp2
-rw-r--r--src/backends/reference/workloads/RefDepthwiseConvolution2dWorkload.hpp4
-rw-r--r--src/backends/reference/workloads/RefDequantizeWorkload.hpp8
-rw-r--r--src/backends/reference/workloads/RefDetectionPostProcessWorkload.cpp2
-rw-r--r--src/backends/reference/workloads/RefDetectionPostProcessWorkload.hpp4
-rw-r--r--src/backends/reference/workloads/RefElementwiseUnaryWorkload.cpp2
-rw-r--r--src/backends/reference/workloads/RefElementwiseUnaryWorkload.hpp6
-rw-r--r--src/backends/reference/workloads/RefElementwiseWorkload.cpp2
-rw-r--r--src/backends/reference/workloads/RefElementwiseWorkload.hpp6
-rw-r--r--src/backends/reference/workloads/RefFakeQuantizationFloat32Workload.hpp2
-rw-r--r--src/backends/reference/workloads/RefFillWorkload.hpp6
-rw-r--r--src/backends/reference/workloads/RefFloorWorkload.hpp6
-rw-r--r--src/backends/reference/workloads/RefFullyConnectedWorkload.cpp2
-rw-r--r--src/backends/reference/workloads/RefFullyConnectedWorkload.hpp4
-rw-r--r--src/backends/reference/workloads/RefGatherWorkload.hpp6
-rw-r--r--src/backends/reference/workloads/RefInstanceNormalizationWorkload.cpp2
-rw-r--r--src/backends/reference/workloads/RefInstanceNormalizationWorkload.hpp4
-rw-r--r--src/backends/reference/workloads/RefL2NormalizationWorkload.cpp2
-rw-r--r--src/backends/reference/workloads/RefL2NormalizationWorkload.hpp4
-rw-r--r--src/backends/reference/workloads/RefLogSoftmaxWorkload.hpp6
-rw-r--r--src/backends/reference/workloads/RefLogicalBinaryWorkload.cpp2
-rw-r--r--src/backends/reference/workloads/RefLogicalBinaryWorkload.hpp6
-rw-r--r--src/backends/reference/workloads/RefLogicalUnaryWorkload.cpp2
-rw-r--r--src/backends/reference/workloads/RefLogicalUnaryWorkload.hpp6
-rw-r--r--src/backends/reference/workloads/RefLstmWorkload.cpp2
-rw-r--r--src/backends/reference/workloads/RefLstmWorkload.hpp4
-rw-r--r--src/backends/reference/workloads/RefMeanWorkload.cpp2
-rw-r--r--src/backends/reference/workloads/RefMeanWorkload.hpp4
-rw-r--r--src/backends/reference/workloads/RefNormalizationWorkload.cpp2
-rw-r--r--src/backends/reference/workloads/RefNormalizationWorkload.hpp4
-rw-r--r--src/backends/reference/workloads/RefPadWorkload.hpp6
-rw-r--r--src/backends/reference/workloads/RefPermuteWorkload.hpp2
-rw-r--r--src/backends/reference/workloads/RefPooling2dWorkload.hpp6
-rw-r--r--src/backends/reference/workloads/RefPooling3dWorkload.hpp6
-rw-r--r--src/backends/reference/workloads/RefPreluWorkload.cpp2
-rw-r--r--src/backends/reference/workloads/RefPreluWorkload.hpp4
-rw-r--r--src/backends/reference/workloads/RefQLstmWorkload.cpp2
-rw-r--r--src/backends/reference/workloads/RefQLstmWorkload.hpp4
-rw-r--r--src/backends/reference/workloads/RefQuantizeWorkload.cpp2
-rw-r--r--src/backends/reference/workloads/RefQuantizeWorkload.hpp4
-rw-r--r--src/backends/reference/workloads/RefRankWorkload.hpp6
-rw-r--r--src/backends/reference/workloads/RefReduceWorkload.cpp2
-rw-r--r--src/backends/reference/workloads/RefReduceWorkload.hpp4
-rw-r--r--src/backends/reference/workloads/RefReshapeWorkload.hpp6
-rw-r--r--src/backends/reference/workloads/RefResizeWorkload.hpp6
-rw-r--r--src/backends/reference/workloads/RefShapeWorkload.hpp6
-rw-r--r--src/backends/reference/workloads/RefSliceWorkload.hpp6
-rw-r--r--src/backends/reference/workloads/RefSoftmaxWorkload.hpp6
-rw-r--r--src/backends/reference/workloads/RefSpaceToBatchNdWorkload.hpp6
-rw-r--r--src/backends/reference/workloads/RefSpaceToDepthWorkload.hpp6
-rw-r--r--src/backends/reference/workloads/RefSplitterWorkload.hpp6
-rw-r--r--src/backends/reference/workloads/RefStackWorkload.cpp2
-rw-r--r--src/backends/reference/workloads/RefStackWorkload.hpp4
-rw-r--r--src/backends/reference/workloads/RefStridedSliceWorkload.cpp2
-rw-r--r--src/backends/reference/workloads/RefStridedSliceWorkload.hpp4
-rw-r--r--src/backends/reference/workloads/RefTransposeConvolution2dWorkload.cpp2
-rw-r--r--src/backends/reference/workloads/RefTransposeConvolution2dWorkload.hpp4
-rw-r--r--src/backends/reference/workloads/RefTransposeWorkload.hpp2
-rw-r--r--src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp2
-rw-r--r--src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.hpp4
91 files changed, 466 insertions, 375 deletions
diff --git a/include/armnn/backends/IWorkload.hpp b/include/armnn/backends/IWorkload.hpp
index d63e0acc72..ce3914bc5a 100644
--- a/include/armnn/backends/IWorkload.hpp
+++ b/include/armnn/backends/IWorkload.hpp
@@ -31,6 +31,10 @@ public:
virtual profiling::ProfilingGuid GetGuid() const = 0;
+ // SupportsTensorHandleReplacement signals that a given workload is capable of
+ // replacing any of its I/O tensors via ReplaceInput/OutputTensorHandle
+ virtual bool SupportsTensorHandleReplacement() const = 0;
+
// Replace input tensor handle with the given TensorHandle
virtual void ReplaceInputTensorHandle(ITensorHandle* /*input*/, unsigned int /*slot*/) = 0;
diff --git a/include/armnn/backends/Workload.hpp b/include/armnn/backends/Workload.hpp
index 07e1abb392..21109480dc 100644
--- a/include/armnn/backends/Workload.hpp
+++ b/include/armnn/backends/Workload.hpp
@@ -54,16 +54,23 @@ public:
profiling::ProfilingGuid GetGuid() const final { return m_Guid; }
+ virtual bool SupportsTensorHandleReplacement() const override
+ {
+ return false;
+ }
+
// Replace input tensor handle with the given TensorHandle
void ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override
{
- m_Data.m_Inputs[slot] = tensorHandle;
+ armnn::IgnoreUnused(tensorHandle, slot);
+ throw armnn::UnimplementedException("ReplaceInputTensorHandle not implemented for this workload");
}
// Replace output tensor handle with the given TensorHandle
void ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override
{
- m_Data.m_Outputs[slot] = tensorHandle;
+ armnn::IgnoreUnused(tensorHandle, slot);
+ throw armnn::UnimplementedException("ReplaceOutputTensorHandle not implemented for this workload");
}
protected:
diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp
index fd7279a294..bcceaf4a99 100644
--- a/src/armnn/LoadedNetwork.cpp
+++ b/src/armnn/LoadedNetwork.cpp
@@ -139,6 +139,13 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
bool useExternalMemoryManager = false;
bool useInternalMemoryManager = false;
Graph& order = m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().TopologicalSort();
+
+ if (!networkProperties.m_AsyncEnabled)
+ {
+ m_IsInputImported = std::vector<bool>(order.GetNumInputs(), false);
+ m_IsOutputImported = std::vector<bool>(order.GetNumOutputs(), false);
+ }
+
for (auto&& layer : order)
{
auto const& backendId = layer->GetBackendId();
@@ -312,44 +319,6 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
}
else
{
- if (layer->GetNumInputSlots() >= 1)
- {
- unsigned int inputSlotIndex = 0;
- for (auto& inputSlot : layer->GetInputSlots())
- {
- if (inputSlot.GetConnectedOutputSlot()->GetOwningLayer().GetType() == LayerType::Input)
- {
- auto inputLayer =
- PolymorphicDowncast<InputLayer*>(
- &inputSlot.GetConnectedOutputSlot()->GetOwningLayer());
- m_InputWorkloadSlotPairs[inputLayer->GetBindingId()] =
- std::make_pair(m_WorkloadQueue.size(), inputSlotIndex);
- }
- ++inputSlotIndex;
- }
- }
-
- if (layer->GetNumOutputSlots() >= 1)
- {
- unsigned int outputSlotIndex = 0;
- for (auto& outputSlot : layer->GetOutputSlots())
- {
- for (unsigned int i = 0; i < outputSlot.GetNumConnections(); i++)
- {
- // If any of the connections on this outputSlot are connected to an Output then
- // Add its index within layer->GetOutputSlots() to m_OutputWorkloadSlotPairs
- if (outputSlot.GetConnection(i)->GetOwningLayer().GetType() == LayerType::Output)
- {
- auto outputLayer = PolymorphicDowncast<OutputLayer*>(
- &outputSlot.GetConnection(i)->GetOwningLayer());
- m_OutputWorkloadSlotPairs[outputLayer->GetBindingId()] =
- std::make_pair(m_WorkloadQueue.size(), outputSlotIndex);
- continue;
- }
- }
- ++outputSlotIndex;
- }
- }
m_WorkloadQueue.push_back(std::move(workload));
}
@@ -361,6 +330,100 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr<IOptimizedNetwork> net,
}
}
+ // Gather information about workloads for inputs & outputs
+ if (!networkProperties.m_AsyncEnabled && m_WorkloadQueue.size() != 0)
+ {
+ const int noOfInputs = armnn::numeric_cast<int>(order.GetNumInputs());
+
+ // Get indices of all workloads connected to each input and
+ // check if they support tensor handle replacement
+ for (const BindableLayer* layer: order.GetInputLayers())
+ {
+ const auto bindingId = layer->GetBindingId();
+
+ bool supportsReplacement = true;
+
+ for (const auto inputSlot: layer->GetOutputSlot(0).GetConnections())
+ {
+ auto workloadIndex = std::distance(order.begin(), order.GetPosInGraph(inputSlot->GetOwningLayer()));
+ workloadIndex -= noOfInputs;
+
+ m_InputWorkloadSlotPairs[bindingId].emplace_back(WorkloadIndices{
+ armnn::numeric_cast<unsigned int>(workloadIndex), inputSlot->GetSlotIndex()});
+
+ auto workload = m_WorkloadQueue[m_InputWorkloadSlotPairs[bindingId].back().m_WorkloadIndex].get();
+ supportsReplacement &= workload->SupportsTensorHandleReplacement();
+ }
+
+ ITensorHandleFactory::FactoryId factoryId = layer->GetOutputSlot(0).GetTensorHandleFactoryId();
+ // Get matching import factory Id
+ ITensorHandleFactory::FactoryId importFactoryId =
+ m_TensorHandleFactoryRegistry.GetMatchingImportFactoryId(factoryId);
+
+ ITensorHandleFactory *importFactory = m_TensorHandleFactoryRegistry.GetFactory(importFactoryId);
+
+ if (supportsReplacement && importFactory)
+ {
+ m_PreImportedInputHandles.emplace_back(
+ bindingId, importFactory->CreateTensorHandle(layer->GetOutputSlot(0).GetTensorInfo(), false));
+ }
+ else
+ {
+ m_PreImportedInputHandles.emplace_back(bindingId, nullptr);
+ }
+ }
+
+ // Get indices of all workloads connected to each output and
+ // check if they support tensor handle replacement
+ for (const BindableLayer* layer: order.GetOutputLayers())
+ {
+ const auto bindingId = layer->GetBindingId();
+
+ const auto outputSlot = layer->GetInputSlot(0).GetConnectedOutputSlot();
+ auto& indices = m_OutputWorkloadSlotPairs[bindingId];
+
+ auto workloadIndex = std::distance(order.begin(), order.GetPosInGraph(outputSlot->GetOwningLayer()));
+ workloadIndex -= noOfInputs;
+
+ indices.m_OutputSlotIndices = WorkloadIndices{numeric_cast<unsigned int>(workloadIndex),
+ outputSlot->CalculateIndexOnOwner()};
+
+ bool supportsReplacement = true;
+ auto outputWorkload = m_WorkloadQueue[indices.m_OutputSlotIndices.m_WorkloadIndex].get();
+ supportsReplacement &= outputWorkload->SupportsTensorHandleReplacement();
+
+ for (auto &inputSlot: outputSlot->GetConnections())
+ {
+ if(inputSlot->GetOwningLayer().GetType() != LayerType::Output)
+ {
+ auto inWorkloadIndex = std::distance(order.begin(),
+ order.GetPosInGraph(inputSlot->GetOwningLayer()));
+ inWorkloadIndex -= noOfInputs;
+ indices.m_InputSlotIndices.emplace_back(WorkloadIndices{numeric_cast<unsigned int>(inWorkloadIndex),
+ inputSlot->GetSlotIndex()});
+ auto inputWorkload = m_WorkloadQueue[indices.m_InputSlotIndices.back().m_WorkloadIndex].get();
+ supportsReplacement &= inputWorkload->SupportsTensorHandleReplacement();
+ }
+ }
+
+ ITensorHandleFactory::FactoryId factoryId = outputSlot->GetTensorHandleFactoryId();
+ // Get matching import factory Id
+ ITensorHandleFactory::FactoryId importFactoryId =
+ m_TensorHandleFactoryRegistry.GetMatchingImportFactoryId(factoryId);
+ ITensorHandleFactory *importFactory = m_TensorHandleFactoryRegistry.GetFactory(importFactoryId);
+
+ if (supportsReplacement && importFactory)
+ {
+ m_PreImportedOutputHandles.emplace_back(
+ bindingId, importFactory->CreateTensorHandle(outputSlot->GetTensorInfo(), false));
+ }
+ else
+ {
+ m_PreImportedOutputHandles.emplace_back(bindingId, nullptr);
+ }
+ }
+ }
+
for (auto&& workloadFactory : m_WorkloadFactories)
{
workloadFactory.second->AfterWorkloadsCreated();
@@ -699,77 +762,133 @@ Status LoadedNetwork::EnqueueWorkload(const InputTensors& inputTensors,
m_InputQueue.clear();
m_InputQueue.reserve(graph.GetNumInputs());
+ if (preImportedInputIds.size() > graph.GetNumInputs())
+ {
+ throw InvalidArgumentException("Invalid number of preImportedInputIds");
+ }
+
+ unsigned int inputIndex = 0;
+ unsigned int importedInputIdIndex = 0;
+ std::sort(preImportedInputIds.begin(), preImportedInputIds.end());
for (const BindableLayer* inputLayer : graph.GetInputLayers())
{
- if (preImportedInputIds.size() > graph.GetNumInputs())
+ if (importedInputIdIndex < preImportedInputIds.size() &&
+ inputIndex == preImportedInputIds[importedInputIdIndex])
{
- throw InvalidArgumentException("Invalid number of preImportedInputIds");
+ // Only replace tensorhandles if they have not already been replaced
+ if (!m_IsInputImported[inputIndex])
+ {
+ auto outputTensorHandle = m_PreImportedInputHandles[inputIndex].m_TensorHandle.get();
+
+ for (const auto& workloadInfo: m_InputWorkloadSlotPairs[inputLayer->GetBindingId()])
+ {
+ auto workload = m_WorkloadQueue[workloadInfo.m_WorkloadIndex].get();
+ workload->ReplaceInputTensorHandle(outputTensorHandle, workloadInfo.m_SlotIndex);
+ }
+ m_IsInputImported[inputIndex] = true;
+ }
+ importedInputIdIndex++;
}
- auto layerBindingId = inputLayer->GetBindingId();
- auto it = std::find_if(preImportedInputIds.begin(), preImportedInputIds.end(),
- [=](auto preImportedInputId)
+ else
{
- return m_PreImportedInputHandles[preImportedInputId].m_LayerBindingId == layerBindingId;
- });
+ if (m_IsInputImported[inputIndex])
+ {
+ OutputHandler& handler = const_cast<OutputHandler&>(inputLayer->GetOutputHandler(0));
+
+ for (const auto& workloadInfo: m_InputWorkloadSlotPairs[inputLayer->GetBindingId()])
+ {
+ auto workload = m_WorkloadQueue[workloadInfo.m_WorkloadIndex].get();
+ workload->ReplaceInputTensorHandle(handler.GetData(), workloadInfo.m_SlotIndex);
+ }
+
+ m_IsInputImported[inputIndex] = false;
+ }
- if (it == preImportedInputIds.end())
- {
// InputTensorHandle is not imported yet, process to enqueue input
const TensorPin& pin = workloadData.GetInputTensorPin(inputLayer->GetBindingId());
EnqueueInput(*inputLayer, pin.GetTensorHandle(), pin.GetTensorInfo());
}
+ inputIndex++;
}
}
-
// For each output to the network, call EnqueueOutput with the data passed by the user.
{
ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "PrepareOutputs");
m_OutputQueue.clear();
m_OutputQueue.reserve(graph.GetNumOutputs());
+ if (preImportedOutputIds.size() > graph.GetNumOutputs())
+ {
+ throw InvalidArgumentException("Invalid number of preImportedOutputIds");
+ }
+
+ unsigned int outputIndex = 0;
+ unsigned int importedOutputIdIndex = 0;
+ std::sort(preImportedOutputIds.begin(), preImportedOutputIds.end());
for (const BindableLayer* outputLayer : graph.GetOutputLayers())
{
- if (preImportedOutputIds.size() > graph.GetNumOutputs())
- {
- throw InvalidArgumentException("Invalid number of preImportedOutputIds");
- }
- auto layerBindingId = outputLayer->GetBindingId();
- auto it = std::find_if(preImportedOutputIds.begin(), preImportedOutputIds.end(),
- [=](auto preImportedOutputId)
+ if (importedOutputIdIndex < preImportedOutputIds.size() &&
+ outputIndex == preImportedOutputIds[importedOutputIdIndex])
{
- return m_PreImportedOutputHandles[preImportedOutputId].m_LayerBindingId == layerBindingId;
- });
+ // Only replace tensorhandles if they have not already been replaced
+ ITensorHandle* inputTensorHandle = m_PreImportedOutputHandles[outputIndex].m_TensorHandle.get();
- const TensorPin& pin = workloadData.GetOutputTensorPin(outputLayer->GetBindingId());
+ if (!m_IsOutputImported[outputIndex])
+ {
+ const auto bindingId = outputLayer->GetBindingId();
+ const auto& indices = m_OutputWorkloadSlotPairs[bindingId];
- if (it == preImportedOutputIds.end())
- {
- // OutputTensorHandle is not imported yet, process to enqueue Output
- EnqueueOutput(*outputLayer, pin.GetTensorHandle(), pin.GetTensorInfo());
- }
- else
- {
- // Insert synchronization workload for the imported output
- OutputQueueDescriptor outputQueueDescriptor;
- WorkloadInfo info;
+ auto outputWorkload = m_WorkloadQueue[indices.m_OutputSlotIndices.m_WorkloadIndex].get();
- outputQueueDescriptor.m_Outputs.push_back(pin.GetTensorHandle());
- info.m_OutputTensorInfos.push_back(pin.GetTensorInfo());
+ outputWorkload->ReplaceOutputTensorHandle(inputTensorHandle,
+ indices.m_OutputSlotIndices.m_SlotIndex);
- // Gets the output handler from the previous node.
- const OutputHandler& outputHandler =
- outputLayer->GetInputSlots()[0].GetConnectedOutputSlot()->GetOutputHandler();
+ for (const auto& workloadInfo: indices.m_InputSlotIndices)
+ {
+ auto inputWorkload = m_WorkloadQueue[workloadInfo.m_WorkloadIndex].get();
+ inputWorkload->ReplaceInputTensorHandle(inputTensorHandle, workloadInfo.m_SlotIndex);
+ }
+ m_IsOutputImported[outputIndex] = true;
+ }
- const TensorInfo& inputTensorInfo = outputHandler.GetTensorInfo();
- ITensorHandle* inputTensorHandle = outputHandler.GetData();
ARMNN_ASSERT_MSG(inputTensorHandle != nullptr, "Data should have been allocated.");
MemSyncQueueDescriptor syncDesc;
syncDesc.m_Inputs.push_back(inputTensorHandle);
- info.m_InputTensorInfos.push_back(inputTensorInfo);
+ WorkloadInfo info;
+ info.m_InputTensorInfos.push_back(
+ outputLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo());
auto syncWorkload = std::make_unique<SyncMemGenericWorkload>(syncDesc, info);
ARMNN_ASSERT_MSG(syncWorkload, "No sync workload created");
m_OutputQueue.push_back(move(syncWorkload));
+ importedOutputIdIndex++;
}
+ else
+ {
+ if (m_IsOutputImported[outputIndex])
+ {
+ const auto bindingId = outputLayer->GetBindingId();
+ const auto& indices = m_OutputWorkloadSlotPairs[bindingId];
+
+ auto outputWorkload = m_WorkloadQueue[indices.m_OutputSlotIndices.m_WorkloadIndex].get();
+ const OutputHandler& outputHandler =
+ outputLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetOutputHandler();
+
+ outputWorkload->ReplaceOutputTensorHandle(
+ outputHandler.GetData(), indices.m_OutputSlotIndices.m_SlotIndex);
+
+ for (const auto& workloadInfo: indices.m_InputSlotIndices)
+ {
+ auto inputWorkload = m_WorkloadQueue[workloadInfo.m_WorkloadIndex].get();
+ inputWorkload->ReplaceInputTensorHandle(outputHandler.GetData(), workloadInfo.m_SlotIndex);
+ }
+ m_IsOutputImported[outputIndex] = false;
+ }
+
+ const TensorPin& pin = workloadData.GetOutputTensorPin(outputLayer->GetBindingId());
+ // OutputTensorHandle is not imported yet, process to enqueue Output
+ EnqueueOutput(*outputLayer, pin.GetTensorHandle(), pin.GetTensorInfo());
+ }
+ outputIndex++;
}
}
@@ -806,6 +925,7 @@ Status LoadedNetwork::EnqueueWorkload(const InputTensors& inputTensors,
timelineUtils->RecordEvent(inferenceGuid, LabelsAndEventClasses::ARMNN_PROFILING_EOL_EVENT_CLASS);
timelineUtils->Commit();
}
+
return executionSucceeded ? Status::Success : Status::Failure;
}
@@ -1186,14 +1306,13 @@ const armnn::Tensor GetOutputTensor(const LayerBindingId layerId, const OutputTe
std::vector<ImportedInputId> LoadedNetwork::ImportInputs(const InputTensors& inputTensors,
MemorySource forceImportMemorySource)
{
- if (!m_NetworkProperties.m_ImportEnabled)
+ if (!m_NetworkProperties.m_AsyncEnabled)
{
// Cannot import if import is not enabled and forceImportMemorySource is undefined
if (forceImportMemorySource == MemorySource::Undefined)
{
throw MemoryImportException("ImportInputs: Memory Import failed, NetworkProperties.m_ImportEnabled");
}
- // If forceImportMemorySource is defined, try import if memory is aligned
if (inputTensors.size() != m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().GetNumInputs())
{
throw MemoryImportException("ImportInputs: Force Import failed, incorrect number of tensors");
@@ -1201,85 +1320,42 @@ std::vector<ImportedInputId> LoadedNetwork::ImportInputs(const InputTensors& inp
std::vector<ImportedInputId> importedInputs;
Graph& graph = m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().TopologicalSort();
- for (auto inputTensor : inputTensors)
+ unsigned int inputIndex = 0;
+ for (const BindableLayer* inputLayer : graph.GetInputLayers())
{
- auto layerBindingId = inputTensor.first;
- auto it = std::find_if(graph.GetInputLayers().begin(), graph.GetInputLayers().end(), [=](auto* layer)
- {
- return layer->GetBindingId() == layerBindingId;
- });
+ auto outputTensorHandle = m_PreImportedInputHandles[inputIndex].m_TensorHandle.get();
- if (it == graph.GetInputLayers().end())
+ if (!outputTensorHandle)
{
- throw MemoryImportException(fmt::format(
- "ImportInputs: Memory Import failed, unknown LayerBindingId: {}", layerBindingId));
+ inputIndex++;
+ continue;
}
- const Layer* layer = *it;
- if (layer->GetType() != LayerType::Input)
+ auto layerBindingId = inputLayer->GetBindingId();
+ auto it = std::find_if(inputTensors.begin(), inputTensors.end(), [=](const auto& inputTensor)
{
- throw InvalidArgumentException("ImportInputs: given layer not an InputLayer");
- }
- const OutputSlot& outputSlot = layer->GetOutputSlots()[0];
- ITensorHandleFactory::FactoryId factoryId = outputSlot.GetTensorHandleFactoryId();
- // Get matching import factory Id
- ITensorHandleFactory::FactoryId importFactoryId =
- m_TensorHandleFactoryRegistry.GetMatchingImportFactoryId(factoryId);
- ITensorHandleFactory* importFactory =
- m_TensorHandleFactoryRegistry.GetFactory(importFactoryId, forceImportMemorySource);
- if (!importFactory)
+ return inputTensor.first == layerBindingId;
+ });
+
+ if (it == inputTensors.end())
{
- throw MemoryImportException("ImportInputs: Force Import failed, cannot find matching Import Factory");
+ inputIndex++;
+ continue;
}
- OutputHandler& handler = const_cast<OutputHandler&>(layer->GetOutputHandler(0));
- handler.SetAllocatedData();
- handler.CreateTensorHandles(*importFactory, false);
- ITensorHandle* outputTensorHandle = handler.GetData();
+ const auto& inputTensor = *it;
std::unique_ptr<ITensorHandle> passThroughTensorHandle =
std::make_unique<ConstPassthroughTensorHandle>(inputTensor.second.GetInfo(),
inputTensor.second.GetMemoryArea());
- // Check if the input memory can be imported
- if (outputTensorHandle->CanBeImported(passThroughTensorHandle->Map(), forceImportMemorySource))
- {
- passThroughTensorHandle->Unmap();
- if (outputTensorHandle->Import(passThroughTensorHandle->Map(), forceImportMemorySource))
- {
- passThroughTensorHandle->Unmap();
- try
- {
- m_WorkloadQueue[m_InputWorkloadSlotPairs[layerBindingId].first].get()->ReplaceInputTensorHandle(
- outputTensorHandle, m_InputWorkloadSlotPairs[layerBindingId].second);
- importedInputs.push_back(m_CurImportedInputId++);
- // For force import, we want OutputHandler to own the TensorHandle,
- // so we do not move the TensorHandle to m_PreImportedInputHandles as in AsyncEnabled networks
- ImportedTensorHandlePin importedTensorHandlePin{layerBindingId, nullptr};
- m_PreImportedInputHandles.push_back(std::move(importedTensorHandlePin));
- }
- catch(armnn::UnimplementedException& e)
- {
- IgnoreUnused(e);
- // Method not implement, cannot use import tensor and have to use allocated data instead
- handler.UseAllocatedData();
- }
- }
- }
- else
+
+ if (outputTensorHandle->CanBeImported(passThroughTensorHandle->Map(), forceImportMemorySource)
+ && (outputTensorHandle->Import(passThroughTensorHandle->Map(), forceImportMemorySource)))
{
- // Cannot import, use allocated data
- handler.UseAllocatedData();
- // Ensure that the workload get correct tensor
- try
- {
- m_WorkloadQueue[m_InputWorkloadSlotPairs[layerBindingId].first].get()->ReplaceInputTensorHandle(
- handler.GetData(), m_InputWorkloadSlotPairs[layerBindingId].second);
- }
- catch(armnn::UnimplementedException& e)
- {
- IgnoreUnused(e);
- }
+ importedInputs.push_back(inputIndex);
}
+ passThroughTensorHandle->Unmap();
+ inputIndex++;
}
return importedInputs;
@@ -1363,7 +1439,7 @@ std::vector<ImportedInputId> LoadedNetwork::ImportInputs(const InputTensors& inp
std::vector<ImportedOutputId> LoadedNetwork::ImportOutputs(const OutputTensors& outputTensors,
MemorySource forceImportMemorySource)
{
- if (!m_NetworkProperties.m_ExportEnabled)
+ if (!m_NetworkProperties.m_AsyncEnabled)
{
// Cannot import if import is not enabled and forceImportMemorySource is undefined
if (forceImportMemorySource == MemorySource::Undefined)
@@ -1377,85 +1453,38 @@ std::vector<ImportedOutputId> LoadedNetwork::ImportOutputs(const OutputTensors&
}
std::vector<ImportedInputId> importedOutputs;
Graph& graph = m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().TopologicalSort();
- for (auto outputTensor : outputTensors)
+
+ unsigned int outputIndex = 0;
+ for (const BindableLayer* const outputLayer : graph.GetOutputLayers())
{
- auto layerBindingId = outputTensor.first;
- auto it = std::find_if(graph.GetOutputLayers().begin(), graph.GetOutputLayers().end(), [=](auto* layer)
- {
- return layer->GetBindingId() == layerBindingId;
- });
+ auto inputTensorHandle = m_PreImportedOutputHandles[outputIndex].m_TensorHandle.get();
- if (it == graph.GetOutputLayers().end())
+ if (!inputTensorHandle)
{
- throw MemoryImportException(fmt::format("ImportOutputs: Memory Import failed, "
- "unknown LayerBindingId: {}",
- layerBindingId));
+ outputIndex++;
+ continue;
}
- const Layer* layer = *it;
- if (layer->GetType() != LayerType::Output)
+ auto layerBindingId = outputLayer->GetBindingId();
+ auto it = std::find_if(outputTensors.begin(), outputTensors.end(), [=] (const auto& outputTensor)
{
- throw InvalidArgumentException("ImportOutputs: given layer not an OutputLayer");
- }
+ return outputTensor.first == layerBindingId;
+ });
- const OutputSlot* outputSlot = layer->GetInputSlots()[0].GetConnectedOutputSlot();
- ITensorHandleFactory::FactoryId factoryId = outputSlot->GetTensorHandleFactoryId();
- ITensorHandleFactory::FactoryId importFactoryId =
- m_TensorHandleFactoryRegistry.GetMatchingImportFactoryId(factoryId);
- ITensorHandleFactory* importFactory =
- m_TensorHandleFactoryRegistry.GetFactory(importFactoryId, forceImportMemorySource);
- if (!importFactory)
+ if (it == outputTensors.end())
{
- throw MemoryImportException("ImportOutputs: Force Import failed, cannot find matching Import Factory");
+ outputIndex++;
+ continue;
}
- OutputHandler& outputHandler =
- const_cast<OutputHandler&>(layer->GetInputSlots()[0].GetConnectedOutputSlot()->GetOutputHandler());
- outputHandler.SetAllocatedData();
- ITensorHandle* inputTensorHandle = outputHandler.GetData();
- outputHandler.CreateTensorHandles(*importFactory, false);
- inputTensorHandle = outputHandler.GetData();
-
+ const auto outputTensor = *it;
// Check if the output memory can be imported
- if (inputTensorHandle->CanBeImported(outputTensor.second.GetMemoryArea(), forceImportMemorySource))
- {
- if (inputTensorHandle->Import(outputTensor.second.GetMemoryArea(), forceImportMemorySource))
- {
- try
- {
- m_WorkloadQueue[m_OutputWorkloadSlotPairs[layerBindingId].first].get()->
- ReplaceOutputTensorHandle(inputTensorHandle,
- m_OutputWorkloadSlotPairs[layerBindingId].second);
- importedOutputs.push_back(m_CurImportedOutputId++);
- // For force import, we want OutputHandler to own the TensorHandle,
- // so we do not move the TensorHandle to m_PreImportedOutputHandles as in AsyncEnabled networks
- ImportedTensorHandlePin importedTensorHandlePin{layerBindingId, nullptr};
- m_PreImportedOutputHandles.push_back(std::move(importedTensorHandlePin));
- }
- catch(armnn::UnimplementedException& e)
- {
- IgnoreUnused(e);
- // Method not implement, cannot use import tensor and have to use allocated data instead
- outputHandler.UseAllocatedData();
- }
- }
- }
- else
+ if (inputTensorHandle->CanBeImported(outputTensor.second.GetMemoryArea(), forceImportMemorySource)
+ && inputTensorHandle->Import(outputTensor.second.GetMemoryArea(), forceImportMemorySource))
{
- // Cannot import, use allocated memory
- outputHandler.UseAllocatedData();
- // Ensure that the workload get correct tensor
- try
- {
- m_WorkloadQueue[m_OutputWorkloadSlotPairs[layerBindingId].first].get()->
- ReplaceOutputTensorHandle(outputHandler.GetData(),
- m_OutputWorkloadSlotPairs[layerBindingId].second);
- }
- catch(armnn::UnimplementedException& e)
- {
- IgnoreUnused(e);
- }
+ importedOutputs.push_back(outputIndex);
}
+ outputIndex++;
}
return importedOutputs;
}
diff --git a/src/armnn/LoadedNetwork.hpp b/src/armnn/LoadedNetwork.hpp
index f637dec8eb..dc2f4dc10f 100644
--- a/src/armnn/LoadedNetwork.hpp
+++ b/src/armnn/LoadedNetwork.hpp
@@ -204,8 +204,21 @@ private:
// A set of vectors to record the workload queue indexes and their corresponding Input/Output Slot indexes
// which are connected to Inputs and Outputs for the network.
- std::unordered_map<LayerBindingId, std::pair<unsigned int, unsigned int>> m_InputWorkloadSlotPairs;
- std::unordered_map<LayerBindingId, std::pair<unsigned int, unsigned int>> m_OutputWorkloadSlotPairs;
+ struct WorkloadIndices
+ {
+ unsigned int m_WorkloadIndex;
+ unsigned int m_SlotIndex;
+ };
+
+ struct OutputWorkloadIndices
+ {
+ WorkloadIndices m_OutputSlotIndices;
+ std::vector<WorkloadIndices> m_InputSlotIndices;
+ };
+ std::unordered_map<LayerBindingId, std::vector<WorkloadIndices>> m_InputWorkloadSlotPairs;
+ std::unordered_map<LayerBindingId, OutputWorkloadIndices> m_OutputWorkloadSlotPairs;
+ std::vector<bool> m_IsInputImported;
+ std::vector<bool> m_IsOutputImported;
};
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 385affa5fa..fc48ffce28 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -583,7 +583,6 @@ void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
{
ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1);
- ValidateNumOutputs(workloadInfo, "MemSyncQueueDescriptor" , 1);
if (m_Inputs.size() != 1)
{
diff --git a/src/backends/cl/workloads/ClConvolution2dWorkload.hpp b/src/backends/cl/workloads/ClConvolution2dWorkload.hpp
index 891d5096cd..e4177e4327 100644
--- a/src/backends/cl/workloads/ClConvolution2dWorkload.hpp
+++ b/src/backends/cl/workloads/ClConvolution2dWorkload.hpp
@@ -40,6 +40,8 @@ public:
arm_compute::ConvolutionMethod GetConvolutionMethod() const;
+ bool SupportsTensorHandleReplacement() const override { return true;};
+
protected:
void Reconfigure() override;
diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt
index 60d8255454..46c2706742 100644
--- a/src/backends/reference/workloads/CMakeLists.txt
+++ b/src/backends/reference/workloads/CMakeLists.txt
@@ -68,6 +68,7 @@ list(APPEND armnnRefBackendWorkloads_sources
RefActivationWorkload.hpp
RefArgMinMaxWorkload.cpp
RefArgMinMaxWorkload.hpp
+ RefBaseWorkload.hpp
RefBatchNormalizationWorkload.cpp
RefBatchNormalizationWorkload.hpp
RefBatchToSpaceNdWorkload.cpp
diff --git a/src/backends/reference/workloads/RefActivationWorkload.hpp b/src/backends/reference/workloads/RefActivationWorkload.hpp
index 9814ac172b..8dc2d52d9b 100644
--- a/src/backends/reference/workloads/RefActivationWorkload.hpp
+++ b/src/backends/reference/workloads/RefActivationWorkload.hpp
@@ -5,16 +5,16 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
namespace armnn
{
-class RefActivationWorkload : public BaseWorkload<ActivationQueueDescriptor>
+class RefActivationWorkload : public RefBaseWorkload<ActivationQueueDescriptor>
{
public:
- using BaseWorkload<ActivationQueueDescriptor>::BaseWorkload;
+ using RefBaseWorkload<ActivationQueueDescriptor>::RefBaseWorkload;
void Execute() const override;
void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override;
diff --git a/src/backends/reference/workloads/RefArgMinMaxWorkload.cpp b/src/backends/reference/workloads/RefArgMinMaxWorkload.cpp
index 2d635bf6c2..d724273287 100644
--- a/src/backends/reference/workloads/RefArgMinMaxWorkload.cpp
+++ b/src/backends/reference/workloads/RefArgMinMaxWorkload.cpp
@@ -16,7 +16,7 @@ namespace armnn
RefArgMinMaxWorkload::RefArgMinMaxWorkload(
const ArgMinMaxQueueDescriptor& descriptor,
const WorkloadInfo& info)
- : BaseWorkload<ArgMinMaxQueueDescriptor>(descriptor, info) {}
+ : RefBaseWorkload<ArgMinMaxQueueDescriptor>(descriptor, info) {}
void RefArgMinMaxWorkload::Execute() const
diff --git a/src/backends/reference/workloads/RefArgMinMaxWorkload.hpp b/src/backends/reference/workloads/RefArgMinMaxWorkload.hpp
index f3c264469b..97c4b45d60 100644
--- a/src/backends/reference/workloads/RefArgMinMaxWorkload.hpp
+++ b/src/backends/reference/workloads/RefArgMinMaxWorkload.hpp
@@ -5,12 +5,12 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
namespace armnn
{
-class RefArgMinMaxWorkload : public BaseWorkload<ArgMinMaxQueueDescriptor>
+class RefArgMinMaxWorkload : public RefBaseWorkload<ArgMinMaxQueueDescriptor>
{
public:
explicit RefArgMinMaxWorkload(const ArgMinMaxQueueDescriptor& descriptor,
diff --git a/src/backends/reference/workloads/RefBaseWorkload.hpp b/src/backends/reference/workloads/RefBaseWorkload.hpp
new file mode 100644
index 0000000000..824b4ccc67
--- /dev/null
+++ b/src/backends/reference/workloads/RefBaseWorkload.hpp
@@ -0,0 +1,36 @@
+//
+// Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#pragma once
+
+#include <armnn/backends/Workload.hpp>
+
+namespace armnn
+{
+ template <typename QueueDescriptor>
+ class RefBaseWorkload : public BaseWorkload<QueueDescriptor>
+ {
+ public:
+ RefBaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info)
+ : BaseWorkload<QueueDescriptor>(descriptor, info)
+ {}
+
+ virtual bool SupportsTensorHandleReplacement() const override
+ {
+ return true;
+ }
+ // Replace input tensor handle with the given TensorHandle
+ void ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override
+ {
+ this->m_Data.m_Inputs[slot] = tensorHandle;
+ }
+
+ // Replace output tensor handle with the given TensorHandle
+ void ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override
+ {
+ this->m_Data.m_Outputs[slot] = tensorHandle;
+ }
+ };
+} //namespace armnn \ No newline at end of file
diff --git a/src/backends/reference/workloads/RefBatchNormalizationWorkload.cpp b/src/backends/reference/workloads/RefBatchNormalizationWorkload.cpp
index 282374d89b..a6bd986f1d 100644
--- a/src/backends/reference/workloads/RefBatchNormalizationWorkload.cpp
+++ b/src/backends/reference/workloads/RefBatchNormalizationWorkload.cpp
@@ -15,7 +15,7 @@ namespace armnn
RefBatchNormalizationWorkload::RefBatchNormalizationWorkload(const BatchNormalizationQueueDescriptor& descriptor,
const WorkloadInfo& info)
- : BaseWorkload(descriptor, info)
+ : RefBaseWorkload(descriptor, info)
, m_Mean (std::make_unique<ScopedTensorHandle>(*(descriptor.m_Mean)))
, m_Variance(std::make_unique<ScopedTensorHandle>(*(descriptor.m_Variance)))
, m_Beta (std::make_unique<ScopedTensorHandle>(*(descriptor.m_Beta)))
diff --git a/src/backends/reference/workloads/RefBatchNormalizationWorkload.hpp b/src/backends/reference/workloads/RefBatchNormalizationWorkload.hpp
index 305c0ce573..60dd2a927c 100644
--- a/src/backends/reference/workloads/RefBatchNormalizationWorkload.hpp
+++ b/src/backends/reference/workloads/RefBatchNormalizationWorkload.hpp
@@ -5,13 +5,13 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
namespace armnn
{
-class RefBatchNormalizationWorkload : public BaseWorkload<BatchNormalizationQueueDescriptor>
+class RefBatchNormalizationWorkload : public RefBaseWorkload<BatchNormalizationQueueDescriptor>
{
public:
explicit RefBatchNormalizationWorkload(const BatchNormalizationQueueDescriptor& descriptor,
diff --git a/src/backends/reference/workloads/RefBatchToSpaceNdWorkload.hpp b/src/backends/reference/workloads/RefBatchToSpaceNdWorkload.hpp
index 7d18c12476..d7ee6fc81c 100644
--- a/src/backends/reference/workloads/RefBatchToSpaceNdWorkload.hpp
+++ b/src/backends/reference/workloads/RefBatchToSpaceNdWorkload.hpp
@@ -5,16 +5,16 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
namespace armnn {
-class RefBatchToSpaceNdWorkload : public BaseWorkload<BatchToSpaceNdQueueDescriptor>
+class RefBatchToSpaceNdWorkload : public RefBaseWorkload<BatchToSpaceNdQueueDescriptor>
{
public:
- using BaseWorkload<BatchToSpaceNdQueueDescriptor>::BaseWorkload;
+ using RefBaseWorkload<BatchToSpaceNdQueueDescriptor>::RefBaseWorkload;
void Execute() const override;
void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override;
diff --git a/src/backends/reference/workloads/RefCastWorkload.hpp b/src/backends/reference/workloads/RefCastWorkload.hpp
index ccafaafac9..6f7e56a6b6 100644
--- a/src/backends/reference/workloads/RefCastWorkload.hpp
+++ b/src/backends/reference/workloads/RefCastWorkload.hpp
@@ -5,7 +5,7 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
#include "RefWorkloadUtils.hpp"
@@ -13,10 +13,10 @@ namespace armnn
{
-class RefCastWorkload : public BaseWorkload<CastQueueDescriptor>
+class RefCastWorkload : public RefBaseWorkload<CastQueueDescriptor>
{
public:
- using BaseWorkload<CastQueueDescriptor>::BaseWorkload;
+ using RefBaseWorkload<CastQueueDescriptor>::RefBaseWorkload;
void Execute() const override;
void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override;
private:
diff --git a/src/backends/reference/workloads/RefChannelShuffleWorkload.hpp b/src/backends/reference/workloads/RefChannelShuffleWorkload.hpp
index 0c8037823a..b459b87592 100644
--- a/src/backends/reference/workloads/RefChannelShuffleWorkload.hpp
+++ b/src/backends/reference/workloads/RefChannelShuffleWorkload.hpp
@@ -5,16 +5,16 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
namespace armnn
{
-class RefChannelShuffleWorkload : public BaseWorkload<ChannelShuffleQueueDescriptor>
+class RefChannelShuffleWorkload : public RefBaseWorkload<ChannelShuffleQueueDescriptor>
{
public:
- using BaseWorkload<ChannelShuffleQueueDescriptor>::BaseWorkload;
+ using RefBaseWorkload<ChannelShuffleQueueDescriptor>::RefBaseWorkload;
void Execute() const override;
void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override;
diff --git a/src/backends/reference/workloads/RefComparisonWorkload.cpp b/src/backends/reference/workloads/RefComparisonWorkload.cpp
index 03df7a4c4a..433e3e8ad8 100644
--- a/src/backends/reference/workloads/RefComparisonWorkload.cpp
+++ b/src/backends/reference/workloads/RefComparisonWorkload.cpp
@@ -21,7 +21,7 @@ namespace armnn
RefComparisonWorkload::RefComparisonWorkload(const ComparisonQueueDescriptor& desc,
const WorkloadInfo& info)
- : BaseWorkload<ComparisonQueueDescriptor>(desc, info)
+ : RefBaseWorkload<ComparisonQueueDescriptor>(desc, info)
{}
void RefComparisonWorkload::PostAllocationConfigure()
diff --git a/src/backends/reference/workloads/RefComparisonWorkload.hpp b/src/backends/reference/workloads/RefComparisonWorkload.hpp
index f2780c7ae5..93cfd1f2b1 100644
--- a/src/backends/reference/workloads/RefComparisonWorkload.hpp
+++ b/src/backends/reference/workloads/RefComparisonWorkload.hpp
@@ -7,16 +7,16 @@
#include "BaseIterator.hpp"
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
namespace armnn
{
-class RefComparisonWorkload : public BaseWorkload<ComparisonQueueDescriptor>
+class RefComparisonWorkload : public RefBaseWorkload<ComparisonQueueDescriptor>
{
public:
- using BaseWorkload<ComparisonQueueDescriptor>::m_Data;
+ using RefBaseWorkload<ComparisonQueueDescriptor>::m_Data;
RefComparisonWorkload(const ComparisonQueueDescriptor& descriptor, const WorkloadInfo& info);
void PostAllocationConfigure() override;
diff --git a/src/backends/reference/workloads/RefConcatWorkload.hpp b/src/backends/reference/workloads/RefConcatWorkload.hpp
index cb1ecf06a7..11d6d016ed 100644
--- a/src/backends/reference/workloads/RefConcatWorkload.hpp
+++ b/src/backends/reference/workloads/RefConcatWorkload.hpp
@@ -5,16 +5,16 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
namespace armnn
{
-class RefConcatWorkload : public BaseWorkload<ConcatQueueDescriptor>
+class RefConcatWorkload : public RefBaseWorkload<ConcatQueueDescriptor>
{
public:
- using BaseWorkload<ConcatQueueDescriptor>::BaseWorkload;
+ using RefBaseWorkload<ConcatQueueDescriptor>::RefBaseWorkload;
void Execute() const override;
void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override;
private:
diff --git a/src/backends/reference/workloads/RefConstantWorkload.cpp b/src/backends/reference/workloads/RefConstantWorkload.cpp
index 6290237d69..571dbb219a 100644
--- a/src/backends/reference/workloads/RefConstantWorkload.cpp
+++ b/src/backends/reference/workloads/RefConstantWorkload.cpp
@@ -18,7 +18,7 @@ namespace armnn
RefConstantWorkload::RefConstantWorkload(
const ConstantQueueDescriptor& descriptor, const WorkloadInfo& info)
- : BaseWorkload<ConstantQueueDescriptor>(descriptor, info) {}
+ : RefBaseWorkload<ConstantQueueDescriptor>(descriptor, info) {}
void RefConstantWorkload::Execute() const
{
diff --git a/src/backends/reference/workloads/RefConstantWorkload.hpp b/src/backends/reference/workloads/RefConstantWorkload.hpp
index c158983d7a..181d79d320 100644
--- a/src/backends/reference/workloads/RefConstantWorkload.hpp
+++ b/src/backends/reference/workloads/RefConstantWorkload.hpp
@@ -5,7 +5,7 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
#include <armnn/Types.hpp>
@@ -14,7 +14,7 @@ namespace armnn
{
// Base class template providing an implementation of the Constant layer common to all data types.
-class RefConstantWorkload : public BaseWorkload<ConstantQueueDescriptor>
+class RefConstantWorkload : public RefBaseWorkload<ConstantQueueDescriptor>
{
public:
RefConstantWorkload(const ConstantQueueDescriptor& descriptor, const WorkloadInfo& info);
diff --git a/src/backends/reference/workloads/RefConvertBf16ToFp32Workload.hpp b/src/backends/reference/workloads/RefConvertBf16ToFp32Workload.hpp
index b3af111fa3..8b5c6d56c2 100644
--- a/src/backends/reference/workloads/RefConvertBf16ToFp32Workload.hpp
+++ b/src/backends/reference/workloads/RefConvertBf16ToFp32Workload.hpp
@@ -5,7 +5,7 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
namespace armnn
diff --git a/src/backends/reference/workloads/RefConvertFp16ToFp32Workload.hpp b/src/backends/reference/workloads/RefConvertFp16ToFp32Workload.hpp
index acb1995b9f..feb442ef5a 100644
--- a/src/backends/reference/workloads/RefConvertFp16ToFp32Workload.hpp
+++ b/src/backends/reference/workloads/RefConvertFp16ToFp32Workload.hpp
@@ -5,7 +5,7 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
namespace armnn
diff --git a/src/backends/reference/workloads/RefConvertFp32ToBf16Workload.hpp b/src/backends/reference/workloads/RefConvertFp32ToBf16Workload.hpp
index 97a138f49c..cd3cfa4cf3 100644
--- a/src/backends/reference/workloads/RefConvertFp32ToBf16Workload.hpp
+++ b/src/backends/reference/workloads/RefConvertFp32ToBf16Workload.hpp
@@ -5,7 +5,7 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
namespace armnn
diff --git a/src/backends/reference/workloads/RefConvertFp32ToFp16Workload.hpp b/src/backends/reference/workloads/RefConvertFp32ToFp16Workload.hpp
index 8cc822e7d8..fe137ed62f 100644
--- a/src/backends/reference/workloads/RefConvertFp32ToFp16Workload.hpp
+++ b/src/backends/reference/workloads/RefConvertFp32ToFp16Workload.hpp
@@ -5,7 +5,7 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
namespace armnn
diff --git a/src/backends/reference/workloads/RefConvolution2dWorkload.cpp b/src/backends/reference/workloads/RefConvolution2dWorkload.cpp
index 20c5c08b17..d57040eaec 100644
--- a/src/backends/reference/workloads/RefConvolution2dWorkload.cpp
+++ b/src/backends/reference/workloads/RefConvolution2dWorkload.cpp
@@ -14,7 +14,7 @@ namespace armnn
{
RefConvolution2dWorkload::RefConvolution2dWorkload(
const Convolution2dQueueDescriptor& descriptor, const WorkloadInfo& info)
- : BaseWorkload<Convolution2dQueueDescriptor>(descriptor, info)
+ : RefBaseWorkload<Convolution2dQueueDescriptor>(descriptor, info)
{
WorkloadInfo detailsInfo;
detailsInfo.m_InputTensorInfos = info.m_InputTensorInfos;
diff --git a/src/backends/reference/workloads/RefConvolution2dWorkload.hpp b/src/backends/reference/workloads/RefConvolution2dWorkload.hpp
index 880547dc33..3335782f78 100644
--- a/src/backends/reference/workloads/RefConvolution2dWorkload.hpp
+++ b/src/backends/reference/workloads/RefConvolution2dWorkload.hpp
@@ -5,7 +5,7 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
#include "Decoders.hpp"
#include "Encoders.hpp"
@@ -13,7 +13,7 @@
namespace armnn
{
-class RefConvolution2dWorkload : public BaseWorkload<Convolution2dQueueDescriptor>
+class RefConvolution2dWorkload : public RefBaseWorkload<Convolution2dQueueDescriptor>
{
public:
explicit RefConvolution2dWorkload(const Convolution2dQueueDescriptor& descriptor,
diff --git a/src/backends/reference/workloads/RefConvolution3dWorkload.cpp b/src/backends/reference/workloads/RefConvolution3dWorkload.cpp
index afab88f0a8..5f542807ed 100644
--- a/src/backends/reference/workloads/RefConvolution3dWorkload.cpp
+++ b/src/backends/reference/workloads/RefConvolution3dWorkload.cpp
@@ -14,7 +14,7 @@ namespace armnn
{
RefConvolution3dWorkload::RefConvolution3dWorkload(
const Convolution3dQueueDescriptor& descriptor, const WorkloadInfo& info)
- : BaseWorkload<Convolution3dQueueDescriptor>(descriptor, info)
+ : RefBaseWorkload<Convolution3dQueueDescriptor>(descriptor, info)
{
WorkloadInfo detailsInfo;
detailsInfo.m_InputTensorInfos = info.m_InputTensorInfos;
diff --git a/src/backends/reference/workloads/RefConvolution3dWorkload.hpp b/src/backends/reference/workloads/RefConvolution3dWorkload.hpp
index 53ce309eb8..6c74675eec 100644
--- a/src/backends/reference/workloads/RefConvolution3dWorkload.hpp
+++ b/src/backends/reference/workloads/RefConvolution3dWorkload.hpp
@@ -5,7 +5,7 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
#include "Decoders.hpp"
#include "Encoders.hpp"
@@ -13,7 +13,7 @@
namespace armnn
{
-class RefConvolution3dWorkload : public BaseWorkload<Convolution3dQueueDescriptor>
+class RefConvolution3dWorkload : public RefBaseWorkload<Convolution3dQueueDescriptor>
{
public:
explicit RefConvolution3dWorkload(const Convolution3dQueueDescriptor& descriptor,
diff --git a/src/backends/reference/workloads/RefDebugWorkload.hpp b/src/backends/reference/workloads/RefDebugWorkload.hpp
index 66af9a0b0f..a1579599f4 100644
--- a/src/backends/reference/workloads/RefDebugWorkload.hpp
+++ b/src/backends/reference/workloads/RefDebugWorkload.hpp
@@ -7,7 +7,7 @@
#include <armnn/TypesUtils.hpp>
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
namespace armnn
{
diff --git a/src/backends/reference/workloads/RefDepthToSpaceWorkload.hpp b/src/backends/reference/workloads/RefDepthToSpaceWorkload.hpp
index 854a564062..bd179d3b9c 100644
--- a/src/backends/reference/workloads/RefDepthToSpaceWorkload.hpp
+++ b/src/backends/reference/workloads/RefDepthToSpaceWorkload.hpp
@@ -5,15 +5,15 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
namespace armnn
{
-class RefDepthToSpaceWorkload : public BaseWorkload<DepthToSpaceQueueDescriptor>
+class RefDepthToSpaceWorkload : public RefBaseWorkload<DepthToSpaceQueueDescriptor>
{
public:
- using BaseWorkload<DepthToSpaceQueueDescriptor>::BaseWorkload;
+ using RefBaseWorkload<DepthToSpaceQueueDescriptor>::RefBaseWorkload;
void Execute() const override;
void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override;
private:
diff --git a/src/backends/reference/workloads/RefDepthwiseConvolution2dWorkload.cpp b/src/backends/reference/workloads/RefDepthwiseConvolution2dWorkload.cpp
index b447d1a441..ad5edde7e6 100644
--- a/src/backends/reference/workloads/RefDepthwiseConvolution2dWorkload.cpp
+++ b/src/backends/reference/workloads/RefDepthwiseConvolution2dWorkload.cpp
@@ -17,7 +17,7 @@ namespace armnn
RefDepthwiseConvolution2dWorkload::RefDepthwiseConvolution2dWorkload(
const DepthwiseConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info)
- : BaseWorkload<DepthwiseConvolution2dQueueDescriptor>(descriptor, info)
+ : RefBaseWorkload<DepthwiseConvolution2dQueueDescriptor>(descriptor, info)
{
m_Weight = std::make_unique<ScopedTensorHandle>(*(descriptor.m_Weight));
const TensorInfo& rFilterInfo = m_Weight->GetTensorInfo();
diff --git a/src/backends/reference/workloads/RefDepthwiseConvolution2dWorkload.hpp b/src/backends/reference/workloads/RefDepthwiseConvolution2dWorkload.hpp
index ae93d03656..5d4b483fa7 100644
--- a/src/backends/reference/workloads/RefDepthwiseConvolution2dWorkload.hpp
+++ b/src/backends/reference/workloads/RefDepthwiseConvolution2dWorkload.hpp
@@ -2,7 +2,7 @@
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
#include "Decoders.hpp"
#include "Encoders.hpp"
@@ -12,7 +12,7 @@
namespace armnn
{
-class RefDepthwiseConvolution2dWorkload : public BaseWorkload<DepthwiseConvolution2dQueueDescriptor> {
+class RefDepthwiseConvolution2dWorkload : public RefBaseWorkload<DepthwiseConvolution2dQueueDescriptor> {
public:
explicit RefDepthwiseConvolution2dWorkload(const DepthwiseConvolution2dQueueDescriptor &descriptor,
const WorkloadInfo &info);
diff --git a/src/backends/reference/workloads/RefDequantizeWorkload.hpp b/src/backends/reference/workloads/RefDequantizeWorkload.hpp
index 285c6496bb..8fa8951677 100644
--- a/src/backends/reference/workloads/RefDequantizeWorkload.hpp
+++ b/src/backends/reference/workloads/RefDequantizeWorkload.hpp
@@ -5,16 +5,16 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
namespace armnn
{
-class RefDequantizeWorkload : public BaseWorkload<DequantizeQueueDescriptor>
+class RefDequantizeWorkload : public RefBaseWorkload<DequantizeQueueDescriptor>
{
public:
- using BaseWorkload<DequantizeQueueDescriptor>::m_Data;
- using BaseWorkload<DequantizeQueueDescriptor>::BaseWorkload;
+ using RefBaseWorkload<DequantizeQueueDescriptor>::m_Data;
+ using RefBaseWorkload<DequantizeQueueDescriptor>::RefBaseWorkload;
void Execute() const override;
void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override;
diff --git a/src/backends/reference/workloads/RefDetectionPostProcessWorkload.cpp b/src/backends/reference/workloads/RefDetectionPostProcessWorkload.cpp
index 4bc9eb1704..5f01db3280 100644
--- a/src/backends/reference/workloads/RefDetectionPostProcessWorkload.cpp
+++ b/src/backends/reference/workloads/RefDetectionPostProcessWorkload.cpp
@@ -15,7 +15,7 @@ namespace armnn
RefDetectionPostProcessWorkload::RefDetectionPostProcessWorkload(
const DetectionPostProcessQueueDescriptor& descriptor, const WorkloadInfo& info)
- : BaseWorkload<DetectionPostProcessQueueDescriptor>(descriptor, info),
+ : RefBaseWorkload<DetectionPostProcessQueueDescriptor>(descriptor, info),
m_Anchors(std::make_unique<ScopedTensorHandle>(*(descriptor.m_Anchors))) {}
void RefDetectionPostProcessWorkload::Execute() const
diff --git a/src/backends/reference/workloads/RefDetectionPostProcessWorkload.hpp b/src/backends/reference/workloads/RefDetectionPostProcessWorkload.hpp
index 4c3ad42b0f..53b2971063 100644
--- a/src/backends/reference/workloads/RefDetectionPostProcessWorkload.hpp
+++ b/src/backends/reference/workloads/RefDetectionPostProcessWorkload.hpp
@@ -5,13 +5,13 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
namespace armnn
{
-class RefDetectionPostProcessWorkload : public BaseWorkload<DetectionPostProcessQueueDescriptor>
+class RefDetectionPostProcessWorkload : public RefBaseWorkload<DetectionPostProcessQueueDescriptor>
{
public:
explicit RefDetectionPostProcessWorkload(const DetectionPostProcessQueueDescriptor& descriptor,
diff --git a/src/backends/reference/workloads/RefElementwiseUnaryWorkload.cpp b/src/backends/reference/workloads/RefElementwiseUnaryWorkload.cpp
index be153636f9..3ea51b9f69 100644
--- a/src/backends/reference/workloads/RefElementwiseUnaryWorkload.cpp
+++ b/src/backends/reference/workloads/RefElementwiseUnaryWorkload.cpp
@@ -27,7 +27,7 @@ namespace armnn
RefElementwiseUnaryWorkload::RefElementwiseUnaryWorkload(const ElementwiseUnaryQueueDescriptor& desc,
const WorkloadInfo& info)
- : BaseWorkload<ElementwiseUnaryQueueDescriptor>(desc, info)
+ : RefBaseWorkload<ElementwiseUnaryQueueDescriptor>(desc, info)
{}
void RefElementwiseUnaryWorkload::Execute() const
diff --git a/src/backends/reference/workloads/RefElementwiseUnaryWorkload.hpp b/src/backends/reference/workloads/RefElementwiseUnaryWorkload.hpp
index e055fd012c..91229b3c58 100644
--- a/src/backends/reference/workloads/RefElementwiseUnaryWorkload.hpp
+++ b/src/backends/reference/workloads/RefElementwiseUnaryWorkload.hpp
@@ -7,16 +7,16 @@
#include "BaseIterator.hpp"
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
namespace armnn
{
-class RefElementwiseUnaryWorkload : public BaseWorkload<ElementwiseUnaryQueueDescriptor>
+class RefElementwiseUnaryWorkload : public RefBaseWorkload<ElementwiseUnaryQueueDescriptor>
{
public:
- using BaseWorkload<ElementwiseUnaryQueueDescriptor>::m_Data;
+ using RefBaseWorkload<ElementwiseUnaryQueueDescriptor>::m_Data;
RefElementwiseUnaryWorkload(const ElementwiseUnaryQueueDescriptor& descriptor, const WorkloadInfo& info);
void Execute() const override;
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
index dd7d325ca5..d14ce075b0 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp
@@ -21,7 +21,7 @@ template <typename Functor, typename ParentDescriptor, typename armnn::StringMap
RefElementwiseWorkload<Functor, ParentDescriptor, DebugString>::RefElementwiseWorkload(
const ParentDescriptor& desc,
const WorkloadInfo& info)
- : BaseWorkload<ParentDescriptor>(desc, info)
+ : RefBaseWorkload<ParentDescriptor>(desc, info)
{
}
diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.hpp b/src/backends/reference/workloads/RefElementwiseWorkload.hpp
index 4b108e4363..065a7833d7 100644
--- a/src/backends/reference/workloads/RefElementwiseWorkload.hpp
+++ b/src/backends/reference/workloads/RefElementwiseWorkload.hpp
@@ -6,7 +6,7 @@
#pragma once
#include <armnn/Types.hpp>
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
#include "BaseIterator.hpp"
#include "ElementwiseFunction.hpp"
@@ -18,12 +18,12 @@ namespace armnn
{
template <typename Functor, typename ParentDescriptor, typename armnn::StringMapping::Id DebugString>
-class RefElementwiseWorkload : public BaseWorkload<ParentDescriptor>
+class RefElementwiseWorkload : public RefBaseWorkload<ParentDescriptor>
{
public:
using InType = typename ElementwiseBinaryFunction<Functor>::InType;
using OutType = typename ElementwiseBinaryFunction<Functor>::OutType;
- using BaseWorkload<ParentDescriptor>::m_Data;
+ using RefBaseWorkload<ParentDescriptor>::m_Data;
RefElementwiseWorkload(const ParentDescriptor& descriptor, const WorkloadInfo& info);
void Execute() const override;
diff --git a/src/backends/reference/workloads/RefFakeQuantizationFloat32Workload.hpp b/src/backends/reference/workloads/RefFakeQuantizationFloat32Workload.hpp
index 53b3375a50..85dc6af326 100644
--- a/src/backends/reference/workloads/RefFakeQuantizationFloat32Workload.hpp
+++ b/src/backends/reference/workloads/RefFakeQuantizationFloat32Workload.hpp
@@ -5,7 +5,7 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
namespace armnn
diff --git a/src/backends/reference/workloads/RefFillWorkload.hpp b/src/backends/reference/workloads/RefFillWorkload.hpp
index 56d44b85f7..d1e00581cd 100644
--- a/src/backends/reference/workloads/RefFillWorkload.hpp
+++ b/src/backends/reference/workloads/RefFillWorkload.hpp
@@ -5,16 +5,16 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
namespace armnn
{
-class RefFillWorkload : public BaseWorkload<FillQueueDescriptor>
+class RefFillWorkload : public RefBaseWorkload<FillQueueDescriptor>
{
public:
- using BaseWorkload<FillQueueDescriptor>::BaseWorkload;
+ using RefBaseWorkload<FillQueueDescriptor>::RefBaseWorkload;
void Execute() const override;
void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override;
private:
diff --git a/src/backends/reference/workloads/RefFloorWorkload.hpp b/src/backends/reference/workloads/RefFloorWorkload.hpp
index 1a532f7a49..6237ff0c61 100644
--- a/src/backends/reference/workloads/RefFloorWorkload.hpp
+++ b/src/backends/reference/workloads/RefFloorWorkload.hpp
@@ -5,16 +5,16 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
namespace armnn
{
-class RefFloorWorkload : public BaseWorkload<FloorQueueDescriptor>
+class RefFloorWorkload : public RefBaseWorkload<FloorQueueDescriptor>
{
public:
- using BaseWorkload<FloorQueueDescriptor>::BaseWorkload;
+ using RefBaseWorkload<FloorQueueDescriptor>::RefBaseWorkload;
void Execute() const override;
void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override;
private:
diff --git a/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp b/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp
index 5a7951ec48..c6ea147043 100644
--- a/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp
+++ b/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp
@@ -14,7 +14,7 @@ namespace armnn
{
RefFullyConnectedWorkload::RefFullyConnectedWorkload(
const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info)
- : BaseWorkload<FullyConnectedQueueDescriptor>(descriptor, info)
+ : RefBaseWorkload<FullyConnectedQueueDescriptor>(descriptor, info)
{
}
diff --git a/src/backends/reference/workloads/RefFullyConnectedWorkload.hpp b/src/backends/reference/workloads/RefFullyConnectedWorkload.hpp
index 3ee4a4a83c..432a8879a0 100644
--- a/src/backends/reference/workloads/RefFullyConnectedWorkload.hpp
+++ b/src/backends/reference/workloads/RefFullyConnectedWorkload.hpp
@@ -5,7 +5,7 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
#include "BaseIterator.hpp"
#include "Decoders.hpp"
@@ -15,7 +15,7 @@
namespace armnn
{
-class RefFullyConnectedWorkload : public BaseWorkload<FullyConnectedQueueDescriptor>
+class RefFullyConnectedWorkload : public RefBaseWorkload<FullyConnectedQueueDescriptor>
{
public:
explicit RefFullyConnectedWorkload(const FullyConnectedQueueDescriptor& descriptor,
diff --git a/src/backends/reference/workloads/RefGatherWorkload.hpp b/src/backends/reference/workloads/RefGatherWorkload.hpp
index a2698e3a25..ec880a5109 100644
--- a/src/backends/reference/workloads/RefGatherWorkload.hpp
+++ b/src/backends/reference/workloads/RefGatherWorkload.hpp
@@ -5,7 +5,7 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
#include <armnn/TypesUtils.hpp>
@@ -16,10 +16,10 @@
namespace armnn
{
-class RefGatherWorkload : public BaseWorkload<GatherQueueDescriptor>
+class RefGatherWorkload : public RefBaseWorkload<GatherQueueDescriptor>
{
public:
- using BaseWorkload<GatherQueueDescriptor>::BaseWorkload;
+ using RefBaseWorkload<GatherQueueDescriptor>::RefBaseWorkload;
void Execute() const override;
void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override;
private:
diff --git a/src/backends/reference/workloads/RefInstanceNormalizationWorkload.cpp b/src/backends/reference/workloads/RefInstanceNormalizationWorkload.cpp
index e642dc9b9a..c103a6b9d3 100644
--- a/src/backends/reference/workloads/RefInstanceNormalizationWorkload.cpp
+++ b/src/backends/reference/workloads/RefInstanceNormalizationWorkload.cpp
@@ -16,7 +16,7 @@ namespace armnn
RefInstanceNormalizationWorkload::RefInstanceNormalizationWorkload(
const InstanceNormalizationQueueDescriptor& descriptor,
const WorkloadInfo& info)
- : BaseWorkload<InstanceNormalizationQueueDescriptor>(descriptor, info) {}
+ : RefBaseWorkload<InstanceNormalizationQueueDescriptor>(descriptor, info) {}
void RefInstanceNormalizationWorkload::Execute() const
{
diff --git a/src/backends/reference/workloads/RefInstanceNormalizationWorkload.hpp b/src/backends/reference/workloads/RefInstanceNormalizationWorkload.hpp
index 3283c444d2..a4b2dd39cb 100644
--- a/src/backends/reference/workloads/RefInstanceNormalizationWorkload.hpp
+++ b/src/backends/reference/workloads/RefInstanceNormalizationWorkload.hpp
@@ -5,13 +5,13 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
namespace armnn
{
-class RefInstanceNormalizationWorkload : public BaseWorkload<InstanceNormalizationQueueDescriptor>
+class RefInstanceNormalizationWorkload : public RefBaseWorkload<InstanceNormalizationQueueDescriptor>
{
public:
explicit RefInstanceNormalizationWorkload(const InstanceNormalizationQueueDescriptor& descriptor,
diff --git a/src/backends/reference/workloads/RefL2NormalizationWorkload.cpp b/src/backends/reference/workloads/RefL2NormalizationWorkload.cpp
index ca31503620..f6fcff3cc5 100644
--- a/src/backends/reference/workloads/RefL2NormalizationWorkload.cpp
+++ b/src/backends/reference/workloads/RefL2NormalizationWorkload.cpp
@@ -22,7 +22,7 @@ namespace armnn
RefL2NormalizationWorkload::RefL2NormalizationWorkload(
const L2NormalizationQueueDescriptor& descriptor,
const WorkloadInfo& info)
- : BaseWorkload<L2NormalizationQueueDescriptor>(descriptor, info) {}
+ : RefBaseWorkload<L2NormalizationQueueDescriptor>(descriptor, info) {}
void RefL2NormalizationWorkload::Execute() const
{
diff --git a/src/backends/reference/workloads/RefL2NormalizationWorkload.hpp b/src/backends/reference/workloads/RefL2NormalizationWorkload.hpp
index dd129c663e..c64e2ea0fd 100644
--- a/src/backends/reference/workloads/RefL2NormalizationWorkload.hpp
+++ b/src/backends/reference/workloads/RefL2NormalizationWorkload.hpp
@@ -5,13 +5,13 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
namespace armnn
{
-class RefL2NormalizationWorkload : public BaseWorkload<L2NormalizationQueueDescriptor>
+class RefL2NormalizationWorkload : public RefBaseWorkload<L2NormalizationQueueDescriptor>
{
public:
explicit RefL2NormalizationWorkload(const L2NormalizationQueueDescriptor& descriptor,
diff --git a/src/backends/reference/workloads/RefLogSoftmaxWorkload.hpp b/src/backends/reference/workloads/RefLogSoftmaxWorkload.hpp
index 9f87def1bd..91ad5f6c36 100644
--- a/src/backends/reference/workloads/RefLogSoftmaxWorkload.hpp
+++ b/src/backends/reference/workloads/RefLogSoftmaxWorkload.hpp
@@ -5,16 +5,16 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
namespace armnn
{
-class RefLogSoftmaxWorkload : public BaseWorkload<LogSoftmaxQueueDescriptor>
+class RefLogSoftmaxWorkload : public RefBaseWorkload<LogSoftmaxQueueDescriptor>
{
public:
- using BaseWorkload<LogSoftmaxQueueDescriptor>::BaseWorkload;
+ using RefBaseWorkload<LogSoftmaxQueueDescriptor>::RefBaseWorkload;
void Execute() const override;
void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override;
private:
diff --git a/src/backends/reference/workloads/RefLogicalBinaryWorkload.cpp b/src/backends/reference/workloads/RefLogicalBinaryWorkload.cpp
index f187e0ca31..f0cb846acf 100644
--- a/src/backends/reference/workloads/RefLogicalBinaryWorkload.cpp
+++ b/src/backends/reference/workloads/RefLogicalBinaryWorkload.cpp
@@ -19,7 +19,7 @@ namespace armnn
RefLogicalBinaryWorkload::RefLogicalBinaryWorkload(const LogicalBinaryQueueDescriptor& desc,
const WorkloadInfo& info)
- : BaseWorkload<LogicalBinaryQueueDescriptor>(desc, info)
+ : RefBaseWorkload<LogicalBinaryQueueDescriptor>(desc, info)
{}
void RefLogicalBinaryWorkload::Execute() const
diff --git a/src/backends/reference/workloads/RefLogicalBinaryWorkload.hpp b/src/backends/reference/workloads/RefLogicalBinaryWorkload.hpp
index 053de7daf9..797d937d80 100644
--- a/src/backends/reference/workloads/RefLogicalBinaryWorkload.hpp
+++ b/src/backends/reference/workloads/RefLogicalBinaryWorkload.hpp
@@ -7,16 +7,16 @@
#include "BaseIterator.hpp"
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
namespace armnn
{
-class RefLogicalBinaryWorkload : public BaseWorkload<LogicalBinaryQueueDescriptor>
+class RefLogicalBinaryWorkload : public RefBaseWorkload<LogicalBinaryQueueDescriptor>
{
public:
- using BaseWorkload<LogicalBinaryQueueDescriptor>::m_Data;
+ using RefBaseWorkload<LogicalBinaryQueueDescriptor>::m_Data;
RefLogicalBinaryWorkload(const LogicalBinaryQueueDescriptor& descriptor, const WorkloadInfo& info);
void Execute() const override;
diff --git a/src/backends/reference/workloads/RefLogicalUnaryWorkload.cpp b/src/backends/reference/workloads/RefLogicalUnaryWorkload.cpp
index bef2bdc668..ec0aa0e454 100644
--- a/src/backends/reference/workloads/RefLogicalUnaryWorkload.cpp
+++ b/src/backends/reference/workloads/RefLogicalUnaryWorkload.cpp
@@ -19,7 +19,7 @@ namespace armnn
RefLogicalUnaryWorkload::RefLogicalUnaryWorkload(const ElementwiseUnaryQueueDescriptor& desc,
const WorkloadInfo& info)
- : BaseWorkload<ElementwiseUnaryQueueDescriptor>(desc, info)
+ : RefBaseWorkload<ElementwiseUnaryQueueDescriptor>(desc, info)
{}
void RefLogicalUnaryWorkload::Execute() const
diff --git a/src/backends/reference/workloads/RefLogicalUnaryWorkload.hpp b/src/backends/reference/workloads/RefLogicalUnaryWorkload.hpp
index 008d24fef8..ebd5826cc5 100644
--- a/src/backends/reference/workloads/RefLogicalUnaryWorkload.hpp
+++ b/src/backends/reference/workloads/RefLogicalUnaryWorkload.hpp
@@ -7,16 +7,16 @@
#include "BaseIterator.hpp"
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
namespace armnn
{
-class RefLogicalUnaryWorkload : public BaseWorkload<ElementwiseUnaryQueueDescriptor>
+class RefLogicalUnaryWorkload : public RefBaseWorkload<ElementwiseUnaryQueueDescriptor>
{
public:
- using BaseWorkload<ElementwiseUnaryQueueDescriptor>::m_Data;
+ using RefBaseWorkload<ElementwiseUnaryQueueDescriptor>::m_Data;
RefLogicalUnaryWorkload(const ElementwiseUnaryQueueDescriptor& descriptor, const WorkloadInfo& info);
void Execute() const override;
diff --git a/src/backends/reference/workloads/RefLstmWorkload.cpp b/src/backends/reference/workloads/RefLstmWorkload.cpp
index 1ff6f50ed5..8609811253 100644
--- a/src/backends/reference/workloads/RefLstmWorkload.cpp
+++ b/src/backends/reference/workloads/RefLstmWorkload.cpp
@@ -15,7 +15,7 @@ namespace armnn
{
RefLstmWorkload::RefLstmWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info)
- : BaseWorkload<LstmQueueDescriptor>(descriptor, info)
+ : RefBaseWorkload<LstmQueueDescriptor>(descriptor, info)
, m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights))
, m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights))
, m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights))
diff --git a/src/backends/reference/workloads/RefLstmWorkload.hpp b/src/backends/reference/workloads/RefLstmWorkload.hpp
index 72f6360281..57526c9ba2 100644
--- a/src/backends/reference/workloads/RefLstmWorkload.hpp
+++ b/src/backends/reference/workloads/RefLstmWorkload.hpp
@@ -7,13 +7,13 @@
#include <armnn/TypesUtils.hpp>
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
namespace armnn
{
-class RefLstmWorkload : public BaseWorkload<LstmQueueDescriptor>
+class RefLstmWorkload : public RefBaseWorkload<LstmQueueDescriptor>
{
public:
explicit RefLstmWorkload(const LstmQueueDescriptor& descriptor, const WorkloadInfo& info);
diff --git a/src/backends/reference/workloads/RefMeanWorkload.cpp b/src/backends/reference/workloads/RefMeanWorkload.cpp
index 7941ce2c36..23abaf8ff4 100644
--- a/src/backends/reference/workloads/RefMeanWorkload.cpp
+++ b/src/backends/reference/workloads/RefMeanWorkload.cpp
@@ -16,7 +16,7 @@ namespace armnn
{
RefMeanWorkload::RefMeanWorkload(const MeanQueueDescriptor& descriptor, const WorkloadInfo& info)
- :BaseWorkload<MeanQueueDescriptor>(descriptor, info) {}
+ :RefBaseWorkload<MeanQueueDescriptor>(descriptor, info) {}
void RefMeanWorkload::Execute() const
{
diff --git a/src/backends/reference/workloads/RefMeanWorkload.hpp b/src/backends/reference/workloads/RefMeanWorkload.hpp
index 2825d669c4..c4c6a1261c 100644
--- a/src/backends/reference/workloads/RefMeanWorkload.hpp
+++ b/src/backends/reference/workloads/RefMeanWorkload.hpp
@@ -5,7 +5,7 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
#include "Decoders.hpp"
@@ -14,7 +14,7 @@
namespace armnn
{
-class RefMeanWorkload : public BaseWorkload<MeanQueueDescriptor>
+class RefMeanWorkload : public RefBaseWorkload<MeanQueueDescriptor>
{
public:
explicit RefMeanWorkload (const MeanQueueDescriptor& descriptor, const WorkloadInfo& info);
diff --git a/src/backends/reference/workloads/RefNormalizationWorkload.cpp b/src/backends/reference/workloads/RefNormalizationWorkload.cpp
index 36828acfb3..613868de57 100644
--- a/src/backends/reference/workloads/RefNormalizationWorkload.cpp
+++ b/src/backends/reference/workloads/RefNormalizationWorkload.cpp
@@ -158,7 +158,7 @@ namespace armnn
RefNormalizationWorkload::RefNormalizationWorkload(const NormalizationQueueDescriptor& descriptor,
const WorkloadInfo& info)
- : BaseWorkload(descriptor, info)
+ : RefBaseWorkload(descriptor, info)
{}
void RefNormalizationWorkload::Execute() const
diff --git a/src/backends/reference/workloads/RefNormalizationWorkload.hpp b/src/backends/reference/workloads/RefNormalizationWorkload.hpp
index b152072496..5218e1e43a 100644
--- a/src/backends/reference/workloads/RefNormalizationWorkload.hpp
+++ b/src/backends/reference/workloads/RefNormalizationWorkload.hpp
@@ -5,13 +5,13 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
namespace armnn
{
-class RefNormalizationWorkload : public BaseWorkload<NormalizationQueueDescriptor>
+class RefNormalizationWorkload : public RefBaseWorkload<NormalizationQueueDescriptor>
{
public:
explicit RefNormalizationWorkload(const NormalizationQueueDescriptor& descriptor,
diff --git a/src/backends/reference/workloads/RefPadWorkload.hpp b/src/backends/reference/workloads/RefPadWorkload.hpp
index 18c406a4de..c5871059cc 100644
--- a/src/backends/reference/workloads/RefPadWorkload.hpp
+++ b/src/backends/reference/workloads/RefPadWorkload.hpp
@@ -5,16 +5,16 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
namespace armnn
{
-class RefPadWorkload : public BaseWorkload<PadQueueDescriptor>
+class RefPadWorkload : public RefBaseWorkload<PadQueueDescriptor>
{
public:
- using BaseWorkload<PadQueueDescriptor>::BaseWorkload;
+ using RefBaseWorkload<PadQueueDescriptor>::RefBaseWorkload;
void Execute() const override;
void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override;
private:
diff --git a/src/backends/reference/workloads/RefPermuteWorkload.hpp b/src/backends/reference/workloads/RefPermuteWorkload.hpp
index 9424441c37..d1e44520a1 100644
--- a/src/backends/reference/workloads/RefPermuteWorkload.hpp
+++ b/src/backends/reference/workloads/RefPermuteWorkload.hpp
@@ -5,7 +5,7 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/TypesUtils.hpp>
diff --git a/src/backends/reference/workloads/RefPooling2dWorkload.hpp b/src/backends/reference/workloads/RefPooling2dWorkload.hpp
index 125fea8d4e..a073e3921b 100644
--- a/src/backends/reference/workloads/RefPooling2dWorkload.hpp
+++ b/src/backends/reference/workloads/RefPooling2dWorkload.hpp
@@ -5,7 +5,7 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
#include "Decoders.hpp"
@@ -13,10 +13,10 @@
namespace armnn
{
-class RefPooling2dWorkload : public BaseWorkload<Pooling2dQueueDescriptor>
+class RefPooling2dWorkload : public RefBaseWorkload<Pooling2dQueueDescriptor>
{
public:
- using BaseWorkload<Pooling2dQueueDescriptor>::BaseWorkload;
+ using RefBaseWorkload<Pooling2dQueueDescriptor>::RefBaseWorkload;
void Execute() const override;
void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override;
diff --git a/src/backends/reference/workloads/RefPooling3dWorkload.hpp b/src/backends/reference/workloads/RefPooling3dWorkload.hpp
index 911c438627..92bc4766cf 100644
--- a/src/backends/reference/workloads/RefPooling3dWorkload.hpp
+++ b/src/backends/reference/workloads/RefPooling3dWorkload.hpp
@@ -5,7 +5,7 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
#include "Decoders.hpp"
@@ -13,10 +13,10 @@
namespace armnn
{
-class RefPooling3dWorkload : public BaseWorkload<Pooling3dQueueDescriptor>
+class RefPooling3dWorkload : public RefBaseWorkload<Pooling3dQueueDescriptor>
{
public:
- using BaseWorkload<Pooling3dQueueDescriptor>::BaseWorkload;
+ using RefBaseWorkload<Pooling3dQueueDescriptor>::RefBaseWorkload;
void Execute() const override;
void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override;
diff --git a/src/backends/reference/workloads/RefPreluWorkload.cpp b/src/backends/reference/workloads/RefPreluWorkload.cpp
index c1d8de2d01..94eeea1884 100644
--- a/src/backends/reference/workloads/RefPreluWorkload.cpp
+++ b/src/backends/reference/workloads/RefPreluWorkload.cpp
@@ -15,7 +15,7 @@ namespace armnn
RefPreluWorkload::RefPreluWorkload(const PreluQueueDescriptor& descriptor,
const WorkloadInfo& info)
- : BaseWorkload(descriptor, info)
+ : RefBaseWorkload(descriptor, info)
{}
void RefPreluWorkload::Execute() const
diff --git a/src/backends/reference/workloads/RefPreluWorkload.hpp b/src/backends/reference/workloads/RefPreluWorkload.hpp
index b5c97dfa90..51ba2c15a7 100644
--- a/src/backends/reference/workloads/RefPreluWorkload.hpp
+++ b/src/backends/reference/workloads/RefPreluWorkload.hpp
@@ -5,13 +5,13 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
namespace armnn
{
-class RefPreluWorkload : public BaseWorkload<PreluQueueDescriptor>
+class RefPreluWorkload : public RefBaseWorkload<PreluQueueDescriptor>
{
public:
explicit RefPreluWorkload(const PreluQueueDescriptor& descriptor,
diff --git a/src/backends/reference/workloads/RefQLstmWorkload.cpp b/src/backends/reference/workloads/RefQLstmWorkload.cpp
index dc29d0b92d..74f5f1ef4c 100644
--- a/src/backends/reference/workloads/RefQLstmWorkload.cpp
+++ b/src/backends/reference/workloads/RefQLstmWorkload.cpp
@@ -14,7 +14,7 @@ namespace armnn
{
RefQLstmWorkload::RefQLstmWorkload(const QLstmQueueDescriptor &descriptor, const WorkloadInfo &info)
- : BaseWorkload<QLstmQueueDescriptor>(descriptor, info)
+ : RefBaseWorkload<QLstmQueueDescriptor>(descriptor, info)
, m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights))
, m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights))
, m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights))
diff --git a/src/backends/reference/workloads/RefQLstmWorkload.hpp b/src/backends/reference/workloads/RefQLstmWorkload.hpp
index 093cfd16af..0e64a38ac9 100644
--- a/src/backends/reference/workloads/RefQLstmWorkload.hpp
+++ b/src/backends/reference/workloads/RefQLstmWorkload.hpp
@@ -7,13 +7,13 @@
#include <armnn/TypesUtils.hpp>
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
namespace armnn
{
-class RefQLstmWorkload : public BaseWorkload<QLstmQueueDescriptor>
+class RefQLstmWorkload : public RefBaseWorkload<QLstmQueueDescriptor>
{
public:
explicit RefQLstmWorkload(const QLstmQueueDescriptor& descriptor, const WorkloadInfo& info);
diff --git a/src/backends/reference/workloads/RefQuantizeWorkload.cpp b/src/backends/reference/workloads/RefQuantizeWorkload.cpp
index 35791e65fb..10ef0e5e15 100644
--- a/src/backends/reference/workloads/RefQuantizeWorkload.cpp
+++ b/src/backends/reference/workloads/RefQuantizeWorkload.cpp
@@ -29,7 +29,7 @@ void QuantizeImpl(Decoder<float>& in, Encoder<float>& out, size_t numValues)
} //namespace
RefQuantizeWorkload::RefQuantizeWorkload(const QuantizeQueueDescriptor& descriptor, const WorkloadInfo &info)
- : BaseWorkload(descriptor, info)
+ : RefBaseWorkload(descriptor, info)
, m_NumElements(info.m_InputTensorInfos[0].GetNumElements())
{
}
diff --git a/src/backends/reference/workloads/RefQuantizeWorkload.hpp b/src/backends/reference/workloads/RefQuantizeWorkload.hpp
index a32efa7dd7..e38241067d 100644
--- a/src/backends/reference/workloads/RefQuantizeWorkload.hpp
+++ b/src/backends/reference/workloads/RefQuantizeWorkload.hpp
@@ -5,14 +5,14 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
#include "Decoders.hpp"
#include "Encoders.hpp"
namespace armnn {
-class RefQuantizeWorkload : public BaseWorkload<QuantizeQueueDescriptor>
+class RefQuantizeWorkload : public RefBaseWorkload<QuantizeQueueDescriptor>
{
public:
RefQuantizeWorkload(const QuantizeQueueDescriptor& descriptor, const WorkloadInfo &info);
diff --git a/src/backends/reference/workloads/RefRankWorkload.hpp b/src/backends/reference/workloads/RefRankWorkload.hpp
index e1f30c5ba5..000828f9e4 100644
--- a/src/backends/reference/workloads/RefRankWorkload.hpp
+++ b/src/backends/reference/workloads/RefRankWorkload.hpp
@@ -5,7 +5,7 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
#include "RefWorkloadUtils.hpp"
@@ -13,10 +13,10 @@
namespace armnn
{
-struct RefRankWorkload : public BaseWorkload<RankQueueDescriptor>
+struct RefRankWorkload : public RefBaseWorkload<RankQueueDescriptor>
{
public:
- using BaseWorkload<RankQueueDescriptor>::BaseWorkload;
+ using RefBaseWorkload<RankQueueDescriptor>::RefBaseWorkload;
virtual void Execute() const override
{
Execute(m_Data.m_Inputs, m_Data.m_Outputs);
diff --git a/src/backends/reference/workloads/RefReduceWorkload.cpp b/src/backends/reference/workloads/RefReduceWorkload.cpp
index 821e828b6e..62881daaf7 100644
--- a/src/backends/reference/workloads/RefReduceWorkload.cpp
+++ b/src/backends/reference/workloads/RefReduceWorkload.cpp
@@ -16,7 +16,7 @@ namespace armnn
RefReduceWorkload::RefReduceWorkload(
const ReduceQueueDescriptor& descriptor,
const WorkloadInfo& info)
- : BaseWorkload<ReduceQueueDescriptor>(descriptor, info) {}
+ : RefBaseWorkload<ReduceQueueDescriptor>(descriptor, info) {}
void RefReduceWorkload::Execute() const
{
diff --git a/src/backends/reference/workloads/RefReduceWorkload.hpp b/src/backends/reference/workloads/RefReduceWorkload.hpp
index d2280cc660..d759bc2ef1 100644
--- a/src/backends/reference/workloads/RefReduceWorkload.hpp
+++ b/src/backends/reference/workloads/RefReduceWorkload.hpp
@@ -5,13 +5,13 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
namespace armnn
{
-class RefReduceWorkload : public BaseWorkload<ReduceQueueDescriptor>
+class RefReduceWorkload : public RefBaseWorkload<ReduceQueueDescriptor>
{
public:
explicit RefReduceWorkload(const ReduceQueueDescriptor& descriptor,
diff --git a/src/backends/reference/workloads/RefReshapeWorkload.hpp b/src/backends/reference/workloads/RefReshapeWorkload.hpp
index 26a86c1d11..7596685336 100644
--- a/src/backends/reference/workloads/RefReshapeWorkload.hpp
+++ b/src/backends/reference/workloads/RefReshapeWorkload.hpp
@@ -5,16 +5,16 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
namespace armnn
{
-class RefReshapeWorkload : public BaseWorkload<ReshapeQueueDescriptor>
+class RefReshapeWorkload : public RefBaseWorkload<ReshapeQueueDescriptor>
{
public:
- using BaseWorkload<ReshapeQueueDescriptor>::BaseWorkload;
+ using RefBaseWorkload<ReshapeQueueDescriptor>::RefBaseWorkload;
void Execute() const override;
void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override;
private:
diff --git a/src/backends/reference/workloads/RefResizeWorkload.hpp b/src/backends/reference/workloads/RefResizeWorkload.hpp
index 82949ed639..f7747193ec 100644
--- a/src/backends/reference/workloads/RefResizeWorkload.hpp
+++ b/src/backends/reference/workloads/RefResizeWorkload.hpp
@@ -5,16 +5,16 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
namespace armnn
{
-class RefResizeWorkload : public BaseWorkload<ResizeQueueDescriptor>
+class RefResizeWorkload : public RefBaseWorkload<ResizeQueueDescriptor>
{
public:
- using BaseWorkload<ResizeQueueDescriptor>::BaseWorkload;
+ using RefBaseWorkload<ResizeQueueDescriptor>::RefBaseWorkload;
void Execute() const override;
void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override;
private:
diff --git a/src/backends/reference/workloads/RefShapeWorkload.hpp b/src/backends/reference/workloads/RefShapeWorkload.hpp
index 209cccda68..b7ed761e0c 100644
--- a/src/backends/reference/workloads/RefShapeWorkload.hpp
+++ b/src/backends/reference/workloads/RefShapeWorkload.hpp
@@ -5,7 +5,7 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
#include "RefWorkloadUtils.hpp"
@@ -13,10 +13,10 @@
namespace armnn
{
-struct RefShapeWorkload : public BaseWorkload<ShapeQueueDescriptor>
+struct RefShapeWorkload : public RefBaseWorkload<ShapeQueueDescriptor>
{
public:
- using BaseWorkload<ShapeQueueDescriptor>::BaseWorkload;
+ using RefBaseWorkload<ShapeQueueDescriptor>::RefBaseWorkload;
virtual void Execute() const override
{
Execute(m_Data.m_Inputs, m_Data.m_Outputs);
diff --git a/src/backends/reference/workloads/RefSliceWorkload.hpp b/src/backends/reference/workloads/RefSliceWorkload.hpp
index 69dae5a1aa..b9dca86c4e 100644
--- a/src/backends/reference/workloads/RefSliceWorkload.hpp
+++ b/src/backends/reference/workloads/RefSliceWorkload.hpp
@@ -5,16 +5,16 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
namespace armnn
{
-class RefSliceWorkload : public BaseWorkload<SliceQueueDescriptor>
+class RefSliceWorkload : public RefBaseWorkload<SliceQueueDescriptor>
{
public:
- using BaseWorkload<SliceQueueDescriptor>::BaseWorkload;
+ using RefBaseWorkload<SliceQueueDescriptor>::RefBaseWorkload;
void Execute() const override;
void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override;
diff --git a/src/backends/reference/workloads/RefSoftmaxWorkload.hpp b/src/backends/reference/workloads/RefSoftmaxWorkload.hpp
index 42dbb53373..cac102a2bb 100644
--- a/src/backends/reference/workloads/RefSoftmaxWorkload.hpp
+++ b/src/backends/reference/workloads/RefSoftmaxWorkload.hpp
@@ -5,16 +5,16 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
namespace armnn
{
-class RefSoftmaxWorkload : public BaseWorkload<SoftmaxQueueDescriptor>
+class RefSoftmaxWorkload : public RefBaseWorkload<SoftmaxQueueDescriptor>
{
public:
- using BaseWorkload<SoftmaxQueueDescriptor>::BaseWorkload;
+ using RefBaseWorkload<SoftmaxQueueDescriptor>::RefBaseWorkload;
void Execute() const override;
void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override;
private:
diff --git a/src/backends/reference/workloads/RefSpaceToBatchNdWorkload.hpp b/src/backends/reference/workloads/RefSpaceToBatchNdWorkload.hpp
index ec764c75bb..eb2d93fb86 100644
--- a/src/backends/reference/workloads/RefSpaceToBatchNdWorkload.hpp
+++ b/src/backends/reference/workloads/RefSpaceToBatchNdWorkload.hpp
@@ -4,17 +4,17 @@
//
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/TypesUtils.hpp>
namespace armnn
{
-class RefSpaceToBatchNdWorkload : public BaseWorkload<SpaceToBatchNdQueueDescriptor>
+class RefSpaceToBatchNdWorkload : public RefBaseWorkload<SpaceToBatchNdQueueDescriptor>
{
public:
- using BaseWorkload<SpaceToBatchNdQueueDescriptor>::BaseWorkload;
+ using RefBaseWorkload<SpaceToBatchNdQueueDescriptor>::RefBaseWorkload;
void Execute() const override;
void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override;
private:
diff --git a/src/backends/reference/workloads/RefSpaceToDepthWorkload.hpp b/src/backends/reference/workloads/RefSpaceToDepthWorkload.hpp
index bc71fde20d..17f8d2f61e 100644
--- a/src/backends/reference/workloads/RefSpaceToDepthWorkload.hpp
+++ b/src/backends/reference/workloads/RefSpaceToDepthWorkload.hpp
@@ -4,17 +4,17 @@
//
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/TypesUtils.hpp>
namespace armnn
{
-class RefSpaceToDepthWorkload : public BaseWorkload<SpaceToDepthQueueDescriptor>
+class RefSpaceToDepthWorkload : public RefBaseWorkload<SpaceToDepthQueueDescriptor>
{
public:
- using BaseWorkload<SpaceToDepthQueueDescriptor>::BaseWorkload;
+ using RefBaseWorkload<SpaceToDepthQueueDescriptor>::RefBaseWorkload;
void Execute() const override;
void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override;
private:
diff --git a/src/backends/reference/workloads/RefSplitterWorkload.hpp b/src/backends/reference/workloads/RefSplitterWorkload.hpp
index 28dc83db36..0b72bb9fdc 100644
--- a/src/backends/reference/workloads/RefSplitterWorkload.hpp
+++ b/src/backends/reference/workloads/RefSplitterWorkload.hpp
@@ -5,7 +5,7 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
#include "Decoders.hpp"
#include "Encoders.hpp"
@@ -13,10 +13,10 @@
namespace armnn
{
-class RefSplitterWorkload : public BaseWorkload<SplitterQueueDescriptor>
+class RefSplitterWorkload : public RefBaseWorkload<SplitterQueueDescriptor>
{
public:
- using BaseWorkload<SplitterQueueDescriptor>::BaseWorkload;
+ using RefBaseWorkload<SplitterQueueDescriptor>::RefBaseWorkload;
void Execute() const override;
void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override;
private:
diff --git a/src/backends/reference/workloads/RefStackWorkload.cpp b/src/backends/reference/workloads/RefStackWorkload.cpp
index 3f7fd7bda2..f57e6e0f1e 100644
--- a/src/backends/reference/workloads/RefStackWorkload.cpp
+++ b/src/backends/reference/workloads/RefStackWorkload.cpp
@@ -15,7 +15,7 @@ namespace armnn
RefStackWorkload::RefStackWorkload(const StackQueueDescriptor& descriptor,
const WorkloadInfo& info)
- : BaseWorkload(descriptor, info)
+ : RefBaseWorkload(descriptor, info)
{}
void RefStackWorkload::Execute() const
diff --git a/src/backends/reference/workloads/RefStackWorkload.hpp b/src/backends/reference/workloads/RefStackWorkload.hpp
index fbca11b2fa..19f4a7be67 100644
--- a/src/backends/reference/workloads/RefStackWorkload.hpp
+++ b/src/backends/reference/workloads/RefStackWorkload.hpp
@@ -5,13 +5,13 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
namespace armnn
{
-class RefStackWorkload : public BaseWorkload<StackQueueDescriptor>
+class RefStackWorkload : public RefBaseWorkload<StackQueueDescriptor>
{
public:
explicit RefStackWorkload(const StackQueueDescriptor& descriptor,
diff --git a/src/backends/reference/workloads/RefStridedSliceWorkload.cpp b/src/backends/reference/workloads/RefStridedSliceWorkload.cpp
index 336a687d5c..41fe4c3a1c 100644
--- a/src/backends/reference/workloads/RefStridedSliceWorkload.cpp
+++ b/src/backends/reference/workloads/RefStridedSliceWorkload.cpp
@@ -12,7 +12,7 @@ namespace armnn
RefStridedSliceWorkload::RefStridedSliceWorkload(const StridedSliceQueueDescriptor& descriptor,
const WorkloadInfo& info)
- : BaseWorkload(descriptor, info)
+ : RefBaseWorkload(descriptor, info)
{}
void RefStridedSliceWorkload::Execute() const
diff --git a/src/backends/reference/workloads/RefStridedSliceWorkload.hpp b/src/backends/reference/workloads/RefStridedSliceWorkload.hpp
index d2ffca7414..ea443cf80d 100644
--- a/src/backends/reference/workloads/RefStridedSliceWorkload.hpp
+++ b/src/backends/reference/workloads/RefStridedSliceWorkload.hpp
@@ -5,12 +5,12 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
namespace armnn
{
-class RefStridedSliceWorkload : public BaseWorkload<StridedSliceQueueDescriptor>
+class RefStridedSliceWorkload : public RefBaseWorkload<StridedSliceQueueDescriptor>
{
public:
RefStridedSliceWorkload(const StridedSliceQueueDescriptor& descriptor, const WorkloadInfo& info);
diff --git a/src/backends/reference/workloads/RefTransposeConvolution2dWorkload.cpp b/src/backends/reference/workloads/RefTransposeConvolution2dWorkload.cpp
index 8665648fe6..64a2d4c7b2 100644
--- a/src/backends/reference/workloads/RefTransposeConvolution2dWorkload.cpp
+++ b/src/backends/reference/workloads/RefTransposeConvolution2dWorkload.cpp
@@ -15,7 +15,7 @@ namespace armnn
RefTransposeConvolution2dWorkload::RefTransposeConvolution2dWorkload(
const TransposeConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) :
- BaseWorkload<TransposeConvolution2dQueueDescriptor>(descriptor, info)
+ RefBaseWorkload<TransposeConvolution2dQueueDescriptor>(descriptor, info)
{
// set up weights decoder
m_Weights = std::make_unique<ScopedTensorHandle>(*(descriptor.m_Weight));
diff --git a/src/backends/reference/workloads/RefTransposeConvolution2dWorkload.hpp b/src/backends/reference/workloads/RefTransposeConvolution2dWorkload.hpp
index aa2546f420..6bcee9a838 100644
--- a/src/backends/reference/workloads/RefTransposeConvolution2dWorkload.hpp
+++ b/src/backends/reference/workloads/RefTransposeConvolution2dWorkload.hpp
@@ -9,12 +9,12 @@
#include "Encoders.hpp"
#include <armnn/backends/TensorHandle.hpp>
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
namespace armnn
{
-class RefTransposeConvolution2dWorkload : public BaseWorkload<TransposeConvolution2dQueueDescriptor>
+class RefTransposeConvolution2dWorkload : public RefBaseWorkload<TransposeConvolution2dQueueDescriptor>
{
public:
RefTransposeConvolution2dWorkload(const TransposeConvolution2dQueueDescriptor& descriptor,
diff --git a/src/backends/reference/workloads/RefTransposeWorkload.hpp b/src/backends/reference/workloads/RefTransposeWorkload.hpp
index bf59de7813..b8c3649745 100644
--- a/src/backends/reference/workloads/RefTransposeWorkload.hpp
+++ b/src/backends/reference/workloads/RefTransposeWorkload.hpp
@@ -5,7 +5,7 @@
#pragma once
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/TypesUtils.hpp>
diff --git a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp
index 311fa18f91..d447a46b23 100644
--- a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp
+++ b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp
@@ -19,7 +19,7 @@ namespace armnn
RefUnidirectionalSequenceLstmWorkload::RefUnidirectionalSequenceLstmWorkload(
const UnidirectionalSequenceLstmQueueDescriptor& descriptor,
const WorkloadInfo& info)
- : BaseWorkload<UnidirectionalSequenceLstmQueueDescriptor>(descriptor, info)
+ : RefBaseWorkload<UnidirectionalSequenceLstmQueueDescriptor>(descriptor, info)
, m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights))
, m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights))
, m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights))
diff --git a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.hpp b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.hpp
index d0c000f20d..7a91cee642 100644
--- a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.hpp
+++ b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.hpp
@@ -7,7 +7,7 @@
#include <armnn/TypesUtils.hpp>
-#include <armnn/backends/Workload.hpp>
+#include "RefBaseWorkload.hpp"
#include <armnn/backends/WorkloadData.hpp>
#include "Encoders.hpp"
@@ -16,7 +16,7 @@
namespace armnn
{
-class RefUnidirectionalSequenceLstmWorkload : public BaseWorkload<UnidirectionalSequenceLstmQueueDescriptor>
+class RefUnidirectionalSequenceLstmWorkload : public RefBaseWorkload<UnidirectionalSequenceLstmQueueDescriptor>
{
public:
explicit RefUnidirectionalSequenceLstmWorkload(const UnidirectionalSequenceLstmQueueDescriptor& descriptor,