From 840c45d19bff23d64f78a7e466886fb970b4fcc9 Mon Sep 17 00:00:00 2001 From: Finn Williams Date: Tue, 15 Feb 2022 20:47:34 +0000 Subject: Refactor Forced Import * Find and replace all workloads associated with imported IO * Only attempt tensorhandle replacement if supported by all workloads * Add new RefBaseWorkload to enable forced input for ref backend * Store imported tensorhandles in preImportedTensorhandles instead of outputHandles * Create pre-imported tensorhandles at network load-time * Front load import workload validation to load network time * Only call ReplaceTensorHandle when needed Change-Id: I3816a71b7f57ae90388bb16462a75d4ef3544fa7 Signed-off-by: Finn Williams --- include/armnn/backends/IWorkload.hpp | 4 + include/armnn/backends/Workload.hpp | 11 +- src/armnn/LoadedNetwork.cpp | 447 +++++++++++---------- src/armnn/LoadedNetwork.hpp | 17 +- src/backends/backendsCommon/WorkloadData.cpp | 1 - .../cl/workloads/ClConvolution2dWorkload.hpp | 2 + src/backends/reference/workloads/CMakeLists.txt | 1 + .../reference/workloads/RefActivationWorkload.hpp | 6 +- .../reference/workloads/RefArgMinMaxWorkload.cpp | 2 +- .../reference/workloads/RefArgMinMaxWorkload.hpp | 4 +- .../reference/workloads/RefBaseWorkload.hpp | 36 ++ .../workloads/RefBatchNormalizationWorkload.cpp | 2 +- .../workloads/RefBatchNormalizationWorkload.hpp | 4 +- .../workloads/RefBatchToSpaceNdWorkload.hpp | 6 +- .../reference/workloads/RefCastWorkload.hpp | 6 +- .../workloads/RefChannelShuffleWorkload.hpp | 6 +- .../reference/workloads/RefComparisonWorkload.cpp | 2 +- .../reference/workloads/RefComparisonWorkload.hpp | 6 +- .../reference/workloads/RefConcatWorkload.hpp | 6 +- .../reference/workloads/RefConstantWorkload.cpp | 2 +- .../reference/workloads/RefConstantWorkload.hpp | 4 +- .../workloads/RefConvertBf16ToFp32Workload.hpp | 2 +- .../workloads/RefConvertFp16ToFp32Workload.hpp | 2 +- .../workloads/RefConvertFp32ToBf16Workload.hpp | 2 +- .../workloads/RefConvertFp32ToFp16Workload.hpp | 2 +- .../workloads/RefConvolution2dWorkload.cpp | 2 +- .../workloads/RefConvolution2dWorkload.hpp | 4 +- .../workloads/RefConvolution3dWorkload.cpp | 2 +- .../workloads/RefConvolution3dWorkload.hpp | 4 +- .../reference/workloads/RefDebugWorkload.hpp | 2 +- .../workloads/RefDepthToSpaceWorkload.hpp | 6 +- .../RefDepthwiseConvolution2dWorkload.cpp | 2 +- .../RefDepthwiseConvolution2dWorkload.hpp | 4 +- .../reference/workloads/RefDequantizeWorkload.hpp | 8 +- .../workloads/RefDetectionPostProcessWorkload.cpp | 2 +- .../workloads/RefDetectionPostProcessWorkload.hpp | 4 +- .../workloads/RefElementwiseUnaryWorkload.cpp | 2 +- .../workloads/RefElementwiseUnaryWorkload.hpp | 6 +- .../reference/workloads/RefElementwiseWorkload.cpp | 2 +- .../reference/workloads/RefElementwiseWorkload.hpp | 6 +- .../RefFakeQuantizationFloat32Workload.hpp | 2 +- .../reference/workloads/RefFillWorkload.hpp | 6 +- .../reference/workloads/RefFloorWorkload.hpp | 6 +- .../workloads/RefFullyConnectedWorkload.cpp | 2 +- .../workloads/RefFullyConnectedWorkload.hpp | 4 +- .../reference/workloads/RefGatherWorkload.hpp | 6 +- .../workloads/RefInstanceNormalizationWorkload.cpp | 2 +- .../workloads/RefInstanceNormalizationWorkload.hpp | 4 +- .../workloads/RefL2NormalizationWorkload.cpp | 2 +- .../workloads/RefL2NormalizationWorkload.hpp | 4 +- .../reference/workloads/RefLogSoftmaxWorkload.hpp | 6 +- .../workloads/RefLogicalBinaryWorkload.cpp | 2 +- .../workloads/RefLogicalBinaryWorkload.hpp | 6 +- .../workloads/RefLogicalUnaryWorkload.cpp | 2 +- .../workloads/RefLogicalUnaryWorkload.hpp | 6 +- .../reference/workloads/RefLstmWorkload.cpp | 2 +- .../reference/workloads/RefLstmWorkload.hpp | 4 +- .../reference/workloads/RefMeanWorkload.cpp | 2 +- .../reference/workloads/RefMeanWorkload.hpp | 4 +- .../workloads/RefNormalizationWorkload.cpp | 2 +- .../workloads/RefNormalizationWorkload.hpp | 4 +- .../reference/workloads/RefPadWorkload.hpp | 6 +- .../reference/workloads/RefPermuteWorkload.hpp | 2 +- .../reference/workloads/RefPooling2dWorkload.hpp | 6 +- .../reference/workloads/RefPooling3dWorkload.hpp | 6 +- .../reference/workloads/RefPreluWorkload.cpp | 2 +- .../reference/workloads/RefPreluWorkload.hpp | 4 +- .../reference/workloads/RefQLstmWorkload.cpp | 2 +- .../reference/workloads/RefQLstmWorkload.hpp | 4 +- .../reference/workloads/RefQuantizeWorkload.cpp | 2 +- .../reference/workloads/RefQuantizeWorkload.hpp | 4 +- .../reference/workloads/RefRankWorkload.hpp | 6 +- .../reference/workloads/RefReduceWorkload.cpp | 2 +- .../reference/workloads/RefReduceWorkload.hpp | 4 +- .../reference/workloads/RefReshapeWorkload.hpp | 6 +- .../reference/workloads/RefResizeWorkload.hpp | 6 +- .../reference/workloads/RefShapeWorkload.hpp | 6 +- .../reference/workloads/RefSliceWorkload.hpp | 6 +- .../reference/workloads/RefSoftmaxWorkload.hpp | 6 +- .../workloads/RefSpaceToBatchNdWorkload.hpp | 6 +- .../workloads/RefSpaceToDepthWorkload.hpp | 6 +- .../reference/workloads/RefSplitterWorkload.hpp | 6 +- .../reference/workloads/RefStackWorkload.cpp | 2 +- .../reference/workloads/RefStackWorkload.hpp | 4 +- .../workloads/RefStridedSliceWorkload.cpp | 2 +- .../workloads/RefStridedSliceWorkload.hpp | 4 +- .../RefTransposeConvolution2dWorkload.cpp | 2 +- .../RefTransposeConvolution2dWorkload.hpp | 4 +- .../reference/workloads/RefTransposeWorkload.hpp | 2 +- .../RefUnidirectionalSequenceLstmWorkload.cpp | 2 +- .../RefUnidirectionalSequenceLstmWorkload.hpp | 4 +- 91 files changed, 466 insertions(+), 375 deletions(-) create mode 100644 src/backends/reference/workloads/RefBaseWorkload.hpp diff --git a/include/armnn/backends/IWorkload.hpp b/include/armnn/backends/IWorkload.hpp index d63e0acc72..ce3914bc5a 100644 --- a/include/armnn/backends/IWorkload.hpp +++ b/include/armnn/backends/IWorkload.hpp @@ -31,6 +31,10 @@ public: virtual profiling::ProfilingGuid GetGuid() const = 0; + // SupportsTensorHandleReplacement signals that a given workload is capable of + // replacing any of its I/O tensors via ReplaceInput/OutputTensorHandle + virtual bool SupportsTensorHandleReplacement() const = 0; + // Replace input tensor handle with the given TensorHandle virtual void ReplaceInputTensorHandle(ITensorHandle* /*input*/, unsigned int /*slot*/) = 0; diff --git a/include/armnn/backends/Workload.hpp b/include/armnn/backends/Workload.hpp index 07e1abb392..21109480dc 100644 --- a/include/armnn/backends/Workload.hpp +++ b/include/armnn/backends/Workload.hpp @@ -54,16 +54,23 @@ public: profiling::ProfilingGuid GetGuid() const final { return m_Guid; } + virtual bool SupportsTensorHandleReplacement() const override + { + return false; + } + // Replace input tensor handle with the given TensorHandle void ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override { - m_Data.m_Inputs[slot] = tensorHandle; + armnn::IgnoreUnused(tensorHandle, slot); + throw armnn::UnimplementedException("ReplaceInputTensorHandle not implemented for this workload"); } // Replace output tensor handle with the given TensorHandle void ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override { - m_Data.m_Outputs[slot] = tensorHandle; + armnn::IgnoreUnused(tensorHandle, slot); + throw armnn::UnimplementedException("ReplaceOutputTensorHandle not implemented for this workload"); } protected: diff --git a/src/armnn/LoadedNetwork.cpp b/src/armnn/LoadedNetwork.cpp index fd7279a294..bcceaf4a99 100644 --- a/src/armnn/LoadedNetwork.cpp +++ b/src/armnn/LoadedNetwork.cpp @@ -139,6 +139,13 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr net, bool useExternalMemoryManager = false; bool useInternalMemoryManager = false; Graph& order = m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().TopologicalSort(); + + if (!networkProperties.m_AsyncEnabled) + { + m_IsInputImported = std::vector(order.GetNumInputs(), false); + m_IsOutputImported = std::vector(order.GetNumOutputs(), false); + } + for (auto&& layer : order) { auto const& backendId = layer->GetBackendId(); @@ -312,44 +319,6 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr net, } else { - if (layer->GetNumInputSlots() >= 1) - { - unsigned int inputSlotIndex = 0; - for (auto& inputSlot : layer->GetInputSlots()) - { - if (inputSlot.GetConnectedOutputSlot()->GetOwningLayer().GetType() == LayerType::Input) - { - auto inputLayer = - PolymorphicDowncast( - &inputSlot.GetConnectedOutputSlot()->GetOwningLayer()); - m_InputWorkloadSlotPairs[inputLayer->GetBindingId()] = - std::make_pair(m_WorkloadQueue.size(), inputSlotIndex); - } - ++inputSlotIndex; - } - } - - if (layer->GetNumOutputSlots() >= 1) - { - unsigned int outputSlotIndex = 0; - for (auto& outputSlot : layer->GetOutputSlots()) - { - for (unsigned int i = 0; i < outputSlot.GetNumConnections(); i++) - { - // If any of the connections on this outputSlot are connected to an Output then - // Add its index within layer->GetOutputSlots() to m_OutputWorkloadSlotPairs - if (outputSlot.GetConnection(i)->GetOwningLayer().GetType() == LayerType::Output) - { - auto outputLayer = PolymorphicDowncast( - &outputSlot.GetConnection(i)->GetOwningLayer()); - m_OutputWorkloadSlotPairs[outputLayer->GetBindingId()] = - std::make_pair(m_WorkloadQueue.size(), outputSlotIndex); - continue; - } - } - ++outputSlotIndex; - } - } m_WorkloadQueue.push_back(std::move(workload)); } @@ -361,6 +330,100 @@ LoadedNetwork::LoadedNetwork(std::unique_ptr net, } } + // Gather information about workloads for inputs & outputs + if (!networkProperties.m_AsyncEnabled && m_WorkloadQueue.size() != 0) + { + const int noOfInputs = armnn::numeric_cast(order.GetNumInputs()); + + // Get indices of all workloads connected to each input and + // check if they support tensor handle replacement + for (const BindableLayer* layer: order.GetInputLayers()) + { + const auto bindingId = layer->GetBindingId(); + + bool supportsReplacement = true; + + for (const auto inputSlot: layer->GetOutputSlot(0).GetConnections()) + { + auto workloadIndex = std::distance(order.begin(), order.GetPosInGraph(inputSlot->GetOwningLayer())); + workloadIndex -= noOfInputs; + + m_InputWorkloadSlotPairs[bindingId].emplace_back(WorkloadIndices{ + armnn::numeric_cast(workloadIndex), inputSlot->GetSlotIndex()}); + + auto workload = m_WorkloadQueue[m_InputWorkloadSlotPairs[bindingId].back().m_WorkloadIndex].get(); + supportsReplacement &= workload->SupportsTensorHandleReplacement(); + } + + ITensorHandleFactory::FactoryId factoryId = layer->GetOutputSlot(0).GetTensorHandleFactoryId(); + // Get matching import factory Id + ITensorHandleFactory::FactoryId importFactoryId = + m_TensorHandleFactoryRegistry.GetMatchingImportFactoryId(factoryId); + + ITensorHandleFactory *importFactory = m_TensorHandleFactoryRegistry.GetFactory(importFactoryId); + + if (supportsReplacement && importFactory) + { + m_PreImportedInputHandles.emplace_back( + bindingId, importFactory->CreateTensorHandle(layer->GetOutputSlot(0).GetTensorInfo(), false)); + } + else + { + m_PreImportedInputHandles.emplace_back(bindingId, nullptr); + } + } + + // Get indices of all workloads connected to each output and + // check if they support tensor handle replacement + for (const BindableLayer* layer: order.GetOutputLayers()) + { + const auto bindingId = layer->GetBindingId(); + + const auto outputSlot = layer->GetInputSlot(0).GetConnectedOutputSlot(); + auto& indices = m_OutputWorkloadSlotPairs[bindingId]; + + auto workloadIndex = std::distance(order.begin(), order.GetPosInGraph(outputSlot->GetOwningLayer())); + workloadIndex -= noOfInputs; + + indices.m_OutputSlotIndices = WorkloadIndices{numeric_cast(workloadIndex), + outputSlot->CalculateIndexOnOwner()}; + + bool supportsReplacement = true; + auto outputWorkload = m_WorkloadQueue[indices.m_OutputSlotIndices.m_WorkloadIndex].get(); + supportsReplacement &= outputWorkload->SupportsTensorHandleReplacement(); + + for (auto &inputSlot: outputSlot->GetConnections()) + { + if(inputSlot->GetOwningLayer().GetType() != LayerType::Output) + { + auto inWorkloadIndex = std::distance(order.begin(), + order.GetPosInGraph(inputSlot->GetOwningLayer())); + inWorkloadIndex -= noOfInputs; + indices.m_InputSlotIndices.emplace_back(WorkloadIndices{numeric_cast(inWorkloadIndex), + inputSlot->GetSlotIndex()}); + auto inputWorkload = m_WorkloadQueue[indices.m_InputSlotIndices.back().m_WorkloadIndex].get(); + supportsReplacement &= inputWorkload->SupportsTensorHandleReplacement(); + } + } + + ITensorHandleFactory::FactoryId factoryId = outputSlot->GetTensorHandleFactoryId(); + // Get matching import factory Id + ITensorHandleFactory::FactoryId importFactoryId = + m_TensorHandleFactoryRegistry.GetMatchingImportFactoryId(factoryId); + ITensorHandleFactory *importFactory = m_TensorHandleFactoryRegistry.GetFactory(importFactoryId); + + if (supportsReplacement && importFactory) + { + m_PreImportedOutputHandles.emplace_back( + bindingId, importFactory->CreateTensorHandle(outputSlot->GetTensorInfo(), false)); + } + else + { + m_PreImportedOutputHandles.emplace_back(bindingId, nullptr); + } + } + } + for (auto&& workloadFactory : m_WorkloadFactories) { workloadFactory.second->AfterWorkloadsCreated(); @@ -699,77 +762,133 @@ Status LoadedNetwork::EnqueueWorkload(const InputTensors& inputTensors, m_InputQueue.clear(); m_InputQueue.reserve(graph.GetNumInputs()); + if (preImportedInputIds.size() > graph.GetNumInputs()) + { + throw InvalidArgumentException("Invalid number of preImportedInputIds"); + } + + unsigned int inputIndex = 0; + unsigned int importedInputIdIndex = 0; + std::sort(preImportedInputIds.begin(), preImportedInputIds.end()); for (const BindableLayer* inputLayer : graph.GetInputLayers()) { - if (preImportedInputIds.size() > graph.GetNumInputs()) + if (importedInputIdIndex < preImportedInputIds.size() && + inputIndex == preImportedInputIds[importedInputIdIndex]) { - throw InvalidArgumentException("Invalid number of preImportedInputIds"); + // Only replace tensorhandles if they have not already been replaced + if (!m_IsInputImported[inputIndex]) + { + auto outputTensorHandle = m_PreImportedInputHandles[inputIndex].m_TensorHandle.get(); + + for (const auto& workloadInfo: m_InputWorkloadSlotPairs[inputLayer->GetBindingId()]) + { + auto workload = m_WorkloadQueue[workloadInfo.m_WorkloadIndex].get(); + workload->ReplaceInputTensorHandle(outputTensorHandle, workloadInfo.m_SlotIndex); + } + m_IsInputImported[inputIndex] = true; + } + importedInputIdIndex++; } - auto layerBindingId = inputLayer->GetBindingId(); - auto it = std::find_if(preImportedInputIds.begin(), preImportedInputIds.end(), - [=](auto preImportedInputId) + else { - return m_PreImportedInputHandles[preImportedInputId].m_LayerBindingId == layerBindingId; - }); + if (m_IsInputImported[inputIndex]) + { + OutputHandler& handler = const_cast(inputLayer->GetOutputHandler(0)); + + for (const auto& workloadInfo: m_InputWorkloadSlotPairs[inputLayer->GetBindingId()]) + { + auto workload = m_WorkloadQueue[workloadInfo.m_WorkloadIndex].get(); + workload->ReplaceInputTensorHandle(handler.GetData(), workloadInfo.m_SlotIndex); + } + + m_IsInputImported[inputIndex] = false; + } - if (it == preImportedInputIds.end()) - { // InputTensorHandle is not imported yet, process to enqueue input const TensorPin& pin = workloadData.GetInputTensorPin(inputLayer->GetBindingId()); EnqueueInput(*inputLayer, pin.GetTensorHandle(), pin.GetTensorInfo()); } + inputIndex++; } } - // For each output to the network, call EnqueueOutput with the data passed by the user. { ARMNN_SCOPED_PROFILING_EVENT(Compute::Undefined, "PrepareOutputs"); m_OutputQueue.clear(); m_OutputQueue.reserve(graph.GetNumOutputs()); + if (preImportedOutputIds.size() > graph.GetNumOutputs()) + { + throw InvalidArgumentException("Invalid number of preImportedOutputIds"); + } + + unsigned int outputIndex = 0; + unsigned int importedOutputIdIndex = 0; + std::sort(preImportedOutputIds.begin(), preImportedOutputIds.end()); for (const BindableLayer* outputLayer : graph.GetOutputLayers()) { - if (preImportedOutputIds.size() > graph.GetNumOutputs()) - { - throw InvalidArgumentException("Invalid number of preImportedOutputIds"); - } - auto layerBindingId = outputLayer->GetBindingId(); - auto it = std::find_if(preImportedOutputIds.begin(), preImportedOutputIds.end(), - [=](auto preImportedOutputId) + if (importedOutputIdIndex < preImportedOutputIds.size() && + outputIndex == preImportedOutputIds[importedOutputIdIndex]) { - return m_PreImportedOutputHandles[preImportedOutputId].m_LayerBindingId == layerBindingId; - }); + // Only replace tensorhandles if they have not already been replaced + ITensorHandle* inputTensorHandle = m_PreImportedOutputHandles[outputIndex].m_TensorHandle.get(); - const TensorPin& pin = workloadData.GetOutputTensorPin(outputLayer->GetBindingId()); + if (!m_IsOutputImported[outputIndex]) + { + const auto bindingId = outputLayer->GetBindingId(); + const auto& indices = m_OutputWorkloadSlotPairs[bindingId]; - if (it == preImportedOutputIds.end()) - { - // OutputTensorHandle is not imported yet, process to enqueue Output - EnqueueOutput(*outputLayer, pin.GetTensorHandle(), pin.GetTensorInfo()); - } - else - { - // Insert synchronization workload for the imported output - OutputQueueDescriptor outputQueueDescriptor; - WorkloadInfo info; + auto outputWorkload = m_WorkloadQueue[indices.m_OutputSlotIndices.m_WorkloadIndex].get(); - outputQueueDescriptor.m_Outputs.push_back(pin.GetTensorHandle()); - info.m_OutputTensorInfos.push_back(pin.GetTensorInfo()); + outputWorkload->ReplaceOutputTensorHandle(inputTensorHandle, + indices.m_OutputSlotIndices.m_SlotIndex); - // Gets the output handler from the previous node. - const OutputHandler& outputHandler = - outputLayer->GetInputSlots()[0].GetConnectedOutputSlot()->GetOutputHandler(); + for (const auto& workloadInfo: indices.m_InputSlotIndices) + { + auto inputWorkload = m_WorkloadQueue[workloadInfo.m_WorkloadIndex].get(); + inputWorkload->ReplaceInputTensorHandle(inputTensorHandle, workloadInfo.m_SlotIndex); + } + m_IsOutputImported[outputIndex] = true; + } - const TensorInfo& inputTensorInfo = outputHandler.GetTensorInfo(); - ITensorHandle* inputTensorHandle = outputHandler.GetData(); ARMNN_ASSERT_MSG(inputTensorHandle != nullptr, "Data should have been allocated."); MemSyncQueueDescriptor syncDesc; syncDesc.m_Inputs.push_back(inputTensorHandle); - info.m_InputTensorInfos.push_back(inputTensorInfo); + WorkloadInfo info; + info.m_InputTensorInfos.push_back( + outputLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo()); auto syncWorkload = std::make_unique(syncDesc, info); ARMNN_ASSERT_MSG(syncWorkload, "No sync workload created"); m_OutputQueue.push_back(move(syncWorkload)); + importedOutputIdIndex++; } + else + { + if (m_IsOutputImported[outputIndex]) + { + const auto bindingId = outputLayer->GetBindingId(); + const auto& indices = m_OutputWorkloadSlotPairs[bindingId]; + + auto outputWorkload = m_WorkloadQueue[indices.m_OutputSlotIndices.m_WorkloadIndex].get(); + const OutputHandler& outputHandler = + outputLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetOutputHandler(); + + outputWorkload->ReplaceOutputTensorHandle( + outputHandler.GetData(), indices.m_OutputSlotIndices.m_SlotIndex); + + for (const auto& workloadInfo: indices.m_InputSlotIndices) + { + auto inputWorkload = m_WorkloadQueue[workloadInfo.m_WorkloadIndex].get(); + inputWorkload->ReplaceInputTensorHandle(outputHandler.GetData(), workloadInfo.m_SlotIndex); + } + m_IsOutputImported[outputIndex] = false; + } + + const TensorPin& pin = workloadData.GetOutputTensorPin(outputLayer->GetBindingId()); + // OutputTensorHandle is not imported yet, process to enqueue Output + EnqueueOutput(*outputLayer, pin.GetTensorHandle(), pin.GetTensorInfo()); + } + outputIndex++; } } @@ -806,6 +925,7 @@ Status LoadedNetwork::EnqueueWorkload(const InputTensors& inputTensors, timelineUtils->RecordEvent(inferenceGuid, LabelsAndEventClasses::ARMNN_PROFILING_EOL_EVENT_CLASS); timelineUtils->Commit(); } + return executionSucceeded ? Status::Success : Status::Failure; } @@ -1186,14 +1306,13 @@ const armnn::Tensor GetOutputTensor(const LayerBindingId layerId, const OutputTe std::vector LoadedNetwork::ImportInputs(const InputTensors& inputTensors, MemorySource forceImportMemorySource) { - if (!m_NetworkProperties.m_ImportEnabled) + if (!m_NetworkProperties.m_AsyncEnabled) { // Cannot import if import is not enabled and forceImportMemorySource is undefined if (forceImportMemorySource == MemorySource::Undefined) { throw MemoryImportException("ImportInputs: Memory Import failed, NetworkProperties.m_ImportEnabled"); } - // If forceImportMemorySource is defined, try import if memory is aligned if (inputTensors.size() != m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().GetNumInputs()) { throw MemoryImportException("ImportInputs: Force Import failed, incorrect number of tensors"); @@ -1201,85 +1320,42 @@ std::vector LoadedNetwork::ImportInputs(const InputTensors& inp std::vector importedInputs; Graph& graph = m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().TopologicalSort(); - for (auto inputTensor : inputTensors) + unsigned int inputIndex = 0; + for (const BindableLayer* inputLayer : graph.GetInputLayers()) { - auto layerBindingId = inputTensor.first; - auto it = std::find_if(graph.GetInputLayers().begin(), graph.GetInputLayers().end(), [=](auto* layer) - { - return layer->GetBindingId() == layerBindingId; - }); + auto outputTensorHandle = m_PreImportedInputHandles[inputIndex].m_TensorHandle.get(); - if (it == graph.GetInputLayers().end()) + if (!outputTensorHandle) { - throw MemoryImportException(fmt::format( - "ImportInputs: Memory Import failed, unknown LayerBindingId: {}", layerBindingId)); + inputIndex++; + continue; } - const Layer* layer = *it; - if (layer->GetType() != LayerType::Input) + auto layerBindingId = inputLayer->GetBindingId(); + auto it = std::find_if(inputTensors.begin(), inputTensors.end(), [=](const auto& inputTensor) { - throw InvalidArgumentException("ImportInputs: given layer not an InputLayer"); - } - const OutputSlot& outputSlot = layer->GetOutputSlots()[0]; - ITensorHandleFactory::FactoryId factoryId = outputSlot.GetTensorHandleFactoryId(); - // Get matching import factory Id - ITensorHandleFactory::FactoryId importFactoryId = - m_TensorHandleFactoryRegistry.GetMatchingImportFactoryId(factoryId); - ITensorHandleFactory* importFactory = - m_TensorHandleFactoryRegistry.GetFactory(importFactoryId, forceImportMemorySource); - if (!importFactory) + return inputTensor.first == layerBindingId; + }); + + if (it == inputTensors.end()) { - throw MemoryImportException("ImportInputs: Force Import failed, cannot find matching Import Factory"); + inputIndex++; + continue; } - OutputHandler& handler = const_cast(layer->GetOutputHandler(0)); - handler.SetAllocatedData(); - handler.CreateTensorHandles(*importFactory, false); - ITensorHandle* outputTensorHandle = handler.GetData(); + const auto& inputTensor = *it; std::unique_ptr passThroughTensorHandle = std::make_unique(inputTensor.second.GetInfo(), inputTensor.second.GetMemoryArea()); - // Check if the input memory can be imported - if (outputTensorHandle->CanBeImported(passThroughTensorHandle->Map(), forceImportMemorySource)) - { - passThroughTensorHandle->Unmap(); - if (outputTensorHandle->Import(passThroughTensorHandle->Map(), forceImportMemorySource)) - { - passThroughTensorHandle->Unmap(); - try - { - m_WorkloadQueue[m_InputWorkloadSlotPairs[layerBindingId].first].get()->ReplaceInputTensorHandle( - outputTensorHandle, m_InputWorkloadSlotPairs[layerBindingId].second); - importedInputs.push_back(m_CurImportedInputId++); - // For force import, we want OutputHandler to own the TensorHandle, - // so we do not move the TensorHandle to m_PreImportedInputHandles as in AsyncEnabled networks - ImportedTensorHandlePin importedTensorHandlePin{layerBindingId, nullptr}; - m_PreImportedInputHandles.push_back(std::move(importedTensorHandlePin)); - } - catch(armnn::UnimplementedException& e) - { - IgnoreUnused(e); - // Method not implement, cannot use import tensor and have to use allocated data instead - handler.UseAllocatedData(); - } - } - } - else + + if (outputTensorHandle->CanBeImported(passThroughTensorHandle->Map(), forceImportMemorySource) + && (outputTensorHandle->Import(passThroughTensorHandle->Map(), forceImportMemorySource))) { - // Cannot import, use allocated data - handler.UseAllocatedData(); - // Ensure that the workload get correct tensor - try - { - m_WorkloadQueue[m_InputWorkloadSlotPairs[layerBindingId].first].get()->ReplaceInputTensorHandle( - handler.GetData(), m_InputWorkloadSlotPairs[layerBindingId].second); - } - catch(armnn::UnimplementedException& e) - { - IgnoreUnused(e); - } + importedInputs.push_back(inputIndex); } + passThroughTensorHandle->Unmap(); + inputIndex++; } return importedInputs; @@ -1363,7 +1439,7 @@ std::vector LoadedNetwork::ImportInputs(const InputTensors& inp std::vector LoadedNetwork::ImportOutputs(const OutputTensors& outputTensors, MemorySource forceImportMemorySource) { - if (!m_NetworkProperties.m_ExportEnabled) + if (!m_NetworkProperties.m_AsyncEnabled) { // Cannot import if import is not enabled and forceImportMemorySource is undefined if (forceImportMemorySource == MemorySource::Undefined) @@ -1377,85 +1453,38 @@ std::vector LoadedNetwork::ImportOutputs(const OutputTensors& } std::vector importedOutputs; Graph& graph = m_OptimizedNetwork->pOptimizedNetworkImpl->GetGraph().TopologicalSort(); - for (auto outputTensor : outputTensors) + + unsigned int outputIndex = 0; + for (const BindableLayer* const outputLayer : graph.GetOutputLayers()) { - auto layerBindingId = outputTensor.first; - auto it = std::find_if(graph.GetOutputLayers().begin(), graph.GetOutputLayers().end(), [=](auto* layer) - { - return layer->GetBindingId() == layerBindingId; - }); + auto inputTensorHandle = m_PreImportedOutputHandles[outputIndex].m_TensorHandle.get(); - if (it == graph.GetOutputLayers().end()) + if (!inputTensorHandle) { - throw MemoryImportException(fmt::format("ImportOutputs: Memory Import failed, " - "unknown LayerBindingId: {}", - layerBindingId)); + outputIndex++; + continue; } - const Layer* layer = *it; - if (layer->GetType() != LayerType::Output) + auto layerBindingId = outputLayer->GetBindingId(); + auto it = std::find_if(outputTensors.begin(), outputTensors.end(), [=] (const auto& outputTensor) { - throw InvalidArgumentException("ImportOutputs: given layer not an OutputLayer"); - } + return outputTensor.first == layerBindingId; + }); - const OutputSlot* outputSlot = layer->GetInputSlots()[0].GetConnectedOutputSlot(); - ITensorHandleFactory::FactoryId factoryId = outputSlot->GetTensorHandleFactoryId(); - ITensorHandleFactory::FactoryId importFactoryId = - m_TensorHandleFactoryRegistry.GetMatchingImportFactoryId(factoryId); - ITensorHandleFactory* importFactory = - m_TensorHandleFactoryRegistry.GetFactory(importFactoryId, forceImportMemorySource); - if (!importFactory) + if (it == outputTensors.end()) { - throw MemoryImportException("ImportOutputs: Force Import failed, cannot find matching Import Factory"); + outputIndex++; + continue; } - OutputHandler& outputHandler = - const_cast(layer->GetInputSlots()[0].GetConnectedOutputSlot()->GetOutputHandler()); - outputHandler.SetAllocatedData(); - ITensorHandle* inputTensorHandle = outputHandler.GetData(); - outputHandler.CreateTensorHandles(*importFactory, false); - inputTensorHandle = outputHandler.GetData(); - + const auto outputTensor = *it; // Check if the output memory can be imported - if (inputTensorHandle->CanBeImported(outputTensor.second.GetMemoryArea(), forceImportMemorySource)) - { - if (inputTensorHandle->Import(outputTensor.second.GetMemoryArea(), forceImportMemorySource)) - { - try - { - m_WorkloadQueue[m_OutputWorkloadSlotPairs[layerBindingId].first].get()-> - ReplaceOutputTensorHandle(inputTensorHandle, - m_OutputWorkloadSlotPairs[layerBindingId].second); - importedOutputs.push_back(m_CurImportedOutputId++); - // For force import, we want OutputHandler to own the TensorHandle, - // so we do not move the TensorHandle to m_PreImportedOutputHandles as in AsyncEnabled networks - ImportedTensorHandlePin importedTensorHandlePin{layerBindingId, nullptr}; - m_PreImportedOutputHandles.push_back(std::move(importedTensorHandlePin)); - } - catch(armnn::UnimplementedException& e) - { - IgnoreUnused(e); - // Method not implement, cannot use import tensor and have to use allocated data instead - outputHandler.UseAllocatedData(); - } - } - } - else + if (inputTensorHandle->CanBeImported(outputTensor.second.GetMemoryArea(), forceImportMemorySource) + && inputTensorHandle->Import(outputTensor.second.GetMemoryArea(), forceImportMemorySource)) { - // Cannot import, use allocated memory - outputHandler.UseAllocatedData(); - // Ensure that the workload get correct tensor - try - { - m_WorkloadQueue[m_OutputWorkloadSlotPairs[layerBindingId].first].get()-> - ReplaceOutputTensorHandle(outputHandler.GetData(), - m_OutputWorkloadSlotPairs[layerBindingId].second); - } - catch(armnn::UnimplementedException& e) - { - IgnoreUnused(e); - } + importedOutputs.push_back(outputIndex); } + outputIndex++; } return importedOutputs; } diff --git a/src/armnn/LoadedNetwork.hpp b/src/armnn/LoadedNetwork.hpp index f637dec8eb..dc2f4dc10f 100644 --- a/src/armnn/LoadedNetwork.hpp +++ b/src/armnn/LoadedNetwork.hpp @@ -204,8 +204,21 @@ private: // A set of vectors to record the workload queue indexes and their corresponding Input/Output Slot indexes // which are connected to Inputs and Outputs for the network. - std::unordered_map> m_InputWorkloadSlotPairs; - std::unordered_map> m_OutputWorkloadSlotPairs; + struct WorkloadIndices + { + unsigned int m_WorkloadIndex; + unsigned int m_SlotIndex; + }; + + struct OutputWorkloadIndices + { + WorkloadIndices m_OutputSlotIndices; + std::vector m_InputSlotIndices; + }; + std::unordered_map> m_InputWorkloadSlotPairs; + std::unordered_map m_OutputWorkloadSlotPairs; + std::vector m_IsInputImported; + std::vector m_IsOutputImported; }; diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 385affa5fa..fc48ffce28 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -583,7 +583,6 @@ void MemImportQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const void MemSyncQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { ValidateNumInputs(workloadInfo, "MemSyncQueueDescriptor", 1); - ValidateNumOutputs(workloadInfo, "MemSyncQueueDescriptor" , 1); if (m_Inputs.size() != 1) { diff --git a/src/backends/cl/workloads/ClConvolution2dWorkload.hpp b/src/backends/cl/workloads/ClConvolution2dWorkload.hpp index 891d5096cd..e4177e4327 100644 --- a/src/backends/cl/workloads/ClConvolution2dWorkload.hpp +++ b/src/backends/cl/workloads/ClConvolution2dWorkload.hpp @@ -40,6 +40,8 @@ public: arm_compute::ConvolutionMethod GetConvolutionMethod() const; + bool SupportsTensorHandleReplacement() const override { return true;}; + protected: void Reconfigure() override; diff --git a/src/backends/reference/workloads/CMakeLists.txt b/src/backends/reference/workloads/CMakeLists.txt index 60d8255454..46c2706742 100644 --- a/src/backends/reference/workloads/CMakeLists.txt +++ b/src/backends/reference/workloads/CMakeLists.txt @@ -68,6 +68,7 @@ list(APPEND armnnRefBackendWorkloads_sources RefActivationWorkload.hpp RefArgMinMaxWorkload.cpp RefArgMinMaxWorkload.hpp + RefBaseWorkload.hpp RefBatchNormalizationWorkload.cpp RefBatchNormalizationWorkload.hpp RefBatchToSpaceNdWorkload.cpp diff --git a/src/backends/reference/workloads/RefActivationWorkload.hpp b/src/backends/reference/workloads/RefActivationWorkload.hpp index 9814ac172b..8dc2d52d9b 100644 --- a/src/backends/reference/workloads/RefActivationWorkload.hpp +++ b/src/backends/reference/workloads/RefActivationWorkload.hpp @@ -5,16 +5,16 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include namespace armnn { -class RefActivationWorkload : public BaseWorkload +class RefActivationWorkload : public RefBaseWorkload { public: - using BaseWorkload::BaseWorkload; + using RefBaseWorkload::RefBaseWorkload; void Execute() const override; void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override; diff --git a/src/backends/reference/workloads/RefArgMinMaxWorkload.cpp b/src/backends/reference/workloads/RefArgMinMaxWorkload.cpp index 2d635bf6c2..d724273287 100644 --- a/src/backends/reference/workloads/RefArgMinMaxWorkload.cpp +++ b/src/backends/reference/workloads/RefArgMinMaxWorkload.cpp @@ -16,7 +16,7 @@ namespace armnn RefArgMinMaxWorkload::RefArgMinMaxWorkload( const ArgMinMaxQueueDescriptor& descriptor, const WorkloadInfo& info) - : BaseWorkload(descriptor, info) {} + : RefBaseWorkload(descriptor, info) {} void RefArgMinMaxWorkload::Execute() const diff --git a/src/backends/reference/workloads/RefArgMinMaxWorkload.hpp b/src/backends/reference/workloads/RefArgMinMaxWorkload.hpp index f3c264469b..97c4b45d60 100644 --- a/src/backends/reference/workloads/RefArgMinMaxWorkload.hpp +++ b/src/backends/reference/workloads/RefArgMinMaxWorkload.hpp @@ -5,12 +5,12 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include namespace armnn { -class RefArgMinMaxWorkload : public BaseWorkload +class RefArgMinMaxWorkload : public RefBaseWorkload { public: explicit RefArgMinMaxWorkload(const ArgMinMaxQueueDescriptor& descriptor, diff --git a/src/backends/reference/workloads/RefBaseWorkload.hpp b/src/backends/reference/workloads/RefBaseWorkload.hpp new file mode 100644 index 0000000000..824b4ccc67 --- /dev/null +++ b/src/backends/reference/workloads/RefBaseWorkload.hpp @@ -0,0 +1,36 @@ +// +// Copyright © 2022 Arm Ltd and Contributors. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#pragma once + +#include + +namespace armnn +{ + template + class RefBaseWorkload : public BaseWorkload + { + public: + RefBaseWorkload(const QueueDescriptor& descriptor, const WorkloadInfo& info) + : BaseWorkload(descriptor, info) + {} + + virtual bool SupportsTensorHandleReplacement() const override + { + return true; + } + // Replace input tensor handle with the given TensorHandle + void ReplaceInputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override + { + this->m_Data.m_Inputs[slot] = tensorHandle; + } + + // Replace output tensor handle with the given TensorHandle + void ReplaceOutputTensorHandle(ITensorHandle* tensorHandle, unsigned int slot) override + { + this->m_Data.m_Outputs[slot] = tensorHandle; + } + }; +} //namespace armnn \ No newline at end of file diff --git a/src/backends/reference/workloads/RefBatchNormalizationWorkload.cpp b/src/backends/reference/workloads/RefBatchNormalizationWorkload.cpp index 282374d89b..a6bd986f1d 100644 --- a/src/backends/reference/workloads/RefBatchNormalizationWorkload.cpp +++ b/src/backends/reference/workloads/RefBatchNormalizationWorkload.cpp @@ -15,7 +15,7 @@ namespace armnn RefBatchNormalizationWorkload::RefBatchNormalizationWorkload(const BatchNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) - : BaseWorkload(descriptor, info) + : RefBaseWorkload(descriptor, info) , m_Mean (std::make_unique(*(descriptor.m_Mean))) , m_Variance(std::make_unique(*(descriptor.m_Variance))) , m_Beta (std::make_unique(*(descriptor.m_Beta))) diff --git a/src/backends/reference/workloads/RefBatchNormalizationWorkload.hpp b/src/backends/reference/workloads/RefBatchNormalizationWorkload.hpp index 305c0ce573..60dd2a927c 100644 --- a/src/backends/reference/workloads/RefBatchNormalizationWorkload.hpp +++ b/src/backends/reference/workloads/RefBatchNormalizationWorkload.hpp @@ -5,13 +5,13 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include namespace armnn { -class RefBatchNormalizationWorkload : public BaseWorkload +class RefBatchNormalizationWorkload : public RefBaseWorkload { public: explicit RefBatchNormalizationWorkload(const BatchNormalizationQueueDescriptor& descriptor, diff --git a/src/backends/reference/workloads/RefBatchToSpaceNdWorkload.hpp b/src/backends/reference/workloads/RefBatchToSpaceNdWorkload.hpp index 7d18c12476..d7ee6fc81c 100644 --- a/src/backends/reference/workloads/RefBatchToSpaceNdWorkload.hpp +++ b/src/backends/reference/workloads/RefBatchToSpaceNdWorkload.hpp @@ -5,16 +5,16 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include namespace armnn { -class RefBatchToSpaceNdWorkload : public BaseWorkload +class RefBatchToSpaceNdWorkload : public RefBaseWorkload { public: - using BaseWorkload::BaseWorkload; + using RefBaseWorkload::RefBaseWorkload; void Execute() const override; void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override; diff --git a/src/backends/reference/workloads/RefCastWorkload.hpp b/src/backends/reference/workloads/RefCastWorkload.hpp index ccafaafac9..6f7e56a6b6 100644 --- a/src/backends/reference/workloads/RefCastWorkload.hpp +++ b/src/backends/reference/workloads/RefCastWorkload.hpp @@ -5,7 +5,7 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include #include "RefWorkloadUtils.hpp" @@ -13,10 +13,10 @@ namespace armnn { -class RefCastWorkload : public BaseWorkload +class RefCastWorkload : public RefBaseWorkload { public: - using BaseWorkload::BaseWorkload; + using RefBaseWorkload::RefBaseWorkload; void Execute() const override; void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override; private: diff --git a/src/backends/reference/workloads/RefChannelShuffleWorkload.hpp b/src/backends/reference/workloads/RefChannelShuffleWorkload.hpp index 0c8037823a..b459b87592 100644 --- a/src/backends/reference/workloads/RefChannelShuffleWorkload.hpp +++ b/src/backends/reference/workloads/RefChannelShuffleWorkload.hpp @@ -5,16 +5,16 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include namespace armnn { -class RefChannelShuffleWorkload : public BaseWorkload +class RefChannelShuffleWorkload : public RefBaseWorkload { public: - using BaseWorkload::BaseWorkload; + using RefBaseWorkload::RefBaseWorkload; void Execute() const override; void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override; diff --git a/src/backends/reference/workloads/RefComparisonWorkload.cpp b/src/backends/reference/workloads/RefComparisonWorkload.cpp index 03df7a4c4a..433e3e8ad8 100644 --- a/src/backends/reference/workloads/RefComparisonWorkload.cpp +++ b/src/backends/reference/workloads/RefComparisonWorkload.cpp @@ -21,7 +21,7 @@ namespace armnn RefComparisonWorkload::RefComparisonWorkload(const ComparisonQueueDescriptor& desc, const WorkloadInfo& info) - : BaseWorkload(desc, info) + : RefBaseWorkload(desc, info) {} void RefComparisonWorkload::PostAllocationConfigure() diff --git a/src/backends/reference/workloads/RefComparisonWorkload.hpp b/src/backends/reference/workloads/RefComparisonWorkload.hpp index f2780c7ae5..93cfd1f2b1 100644 --- a/src/backends/reference/workloads/RefComparisonWorkload.hpp +++ b/src/backends/reference/workloads/RefComparisonWorkload.hpp @@ -7,16 +7,16 @@ #include "BaseIterator.hpp" -#include +#include "RefBaseWorkload.hpp" #include namespace armnn { -class RefComparisonWorkload : public BaseWorkload +class RefComparisonWorkload : public RefBaseWorkload { public: - using BaseWorkload::m_Data; + using RefBaseWorkload::m_Data; RefComparisonWorkload(const ComparisonQueueDescriptor& descriptor, const WorkloadInfo& info); void PostAllocationConfigure() override; diff --git a/src/backends/reference/workloads/RefConcatWorkload.hpp b/src/backends/reference/workloads/RefConcatWorkload.hpp index cb1ecf06a7..11d6d016ed 100644 --- a/src/backends/reference/workloads/RefConcatWorkload.hpp +++ b/src/backends/reference/workloads/RefConcatWorkload.hpp @@ -5,16 +5,16 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include namespace armnn { -class RefConcatWorkload : public BaseWorkload +class RefConcatWorkload : public RefBaseWorkload { public: - using BaseWorkload::BaseWorkload; + using RefBaseWorkload::RefBaseWorkload; void Execute() const override; void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override; private: diff --git a/src/backends/reference/workloads/RefConstantWorkload.cpp b/src/backends/reference/workloads/RefConstantWorkload.cpp index 6290237d69..571dbb219a 100644 --- a/src/backends/reference/workloads/RefConstantWorkload.cpp +++ b/src/backends/reference/workloads/RefConstantWorkload.cpp @@ -18,7 +18,7 @@ namespace armnn RefConstantWorkload::RefConstantWorkload( const ConstantQueueDescriptor& descriptor, const WorkloadInfo& info) - : BaseWorkload(descriptor, info) {} + : RefBaseWorkload(descriptor, info) {} void RefConstantWorkload::Execute() const { diff --git a/src/backends/reference/workloads/RefConstantWorkload.hpp b/src/backends/reference/workloads/RefConstantWorkload.hpp index c158983d7a..181d79d320 100644 --- a/src/backends/reference/workloads/RefConstantWorkload.hpp +++ b/src/backends/reference/workloads/RefConstantWorkload.hpp @@ -5,7 +5,7 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include #include @@ -14,7 +14,7 @@ namespace armnn { // Base class template providing an implementation of the Constant layer common to all data types. -class RefConstantWorkload : public BaseWorkload +class RefConstantWorkload : public RefBaseWorkload { public: RefConstantWorkload(const ConstantQueueDescriptor& descriptor, const WorkloadInfo& info); diff --git a/src/backends/reference/workloads/RefConvertBf16ToFp32Workload.hpp b/src/backends/reference/workloads/RefConvertBf16ToFp32Workload.hpp index b3af111fa3..8b5c6d56c2 100644 --- a/src/backends/reference/workloads/RefConvertBf16ToFp32Workload.hpp +++ b/src/backends/reference/workloads/RefConvertBf16ToFp32Workload.hpp @@ -5,7 +5,7 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include namespace armnn diff --git a/src/backends/reference/workloads/RefConvertFp16ToFp32Workload.hpp b/src/backends/reference/workloads/RefConvertFp16ToFp32Workload.hpp index acb1995b9f..feb442ef5a 100644 --- a/src/backends/reference/workloads/RefConvertFp16ToFp32Workload.hpp +++ b/src/backends/reference/workloads/RefConvertFp16ToFp32Workload.hpp @@ -5,7 +5,7 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include namespace armnn diff --git a/src/backends/reference/workloads/RefConvertFp32ToBf16Workload.hpp b/src/backends/reference/workloads/RefConvertFp32ToBf16Workload.hpp index 97a138f49c..cd3cfa4cf3 100644 --- a/src/backends/reference/workloads/RefConvertFp32ToBf16Workload.hpp +++ b/src/backends/reference/workloads/RefConvertFp32ToBf16Workload.hpp @@ -5,7 +5,7 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include namespace armnn diff --git a/src/backends/reference/workloads/RefConvertFp32ToFp16Workload.hpp b/src/backends/reference/workloads/RefConvertFp32ToFp16Workload.hpp index 8cc822e7d8..fe137ed62f 100644 --- a/src/backends/reference/workloads/RefConvertFp32ToFp16Workload.hpp +++ b/src/backends/reference/workloads/RefConvertFp32ToFp16Workload.hpp @@ -5,7 +5,7 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include namespace armnn diff --git a/src/backends/reference/workloads/RefConvolution2dWorkload.cpp b/src/backends/reference/workloads/RefConvolution2dWorkload.cpp index 20c5c08b17..d57040eaec 100644 --- a/src/backends/reference/workloads/RefConvolution2dWorkload.cpp +++ b/src/backends/reference/workloads/RefConvolution2dWorkload.cpp @@ -14,7 +14,7 @@ namespace armnn { RefConvolution2dWorkload::RefConvolution2dWorkload( const Convolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) - : BaseWorkload(descriptor, info) + : RefBaseWorkload(descriptor, info) { WorkloadInfo detailsInfo; detailsInfo.m_InputTensorInfos = info.m_InputTensorInfos; diff --git a/src/backends/reference/workloads/RefConvolution2dWorkload.hpp b/src/backends/reference/workloads/RefConvolution2dWorkload.hpp index 880547dc33..3335782f78 100644 --- a/src/backends/reference/workloads/RefConvolution2dWorkload.hpp +++ b/src/backends/reference/workloads/RefConvolution2dWorkload.hpp @@ -5,7 +5,7 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include #include "Decoders.hpp" #include "Encoders.hpp" @@ -13,7 +13,7 @@ namespace armnn { -class RefConvolution2dWorkload : public BaseWorkload +class RefConvolution2dWorkload : public RefBaseWorkload { public: explicit RefConvolution2dWorkload(const Convolution2dQueueDescriptor& descriptor, diff --git a/src/backends/reference/workloads/RefConvolution3dWorkload.cpp b/src/backends/reference/workloads/RefConvolution3dWorkload.cpp index afab88f0a8..5f542807ed 100644 --- a/src/backends/reference/workloads/RefConvolution3dWorkload.cpp +++ b/src/backends/reference/workloads/RefConvolution3dWorkload.cpp @@ -14,7 +14,7 @@ namespace armnn { RefConvolution3dWorkload::RefConvolution3dWorkload( const Convolution3dQueueDescriptor& descriptor, const WorkloadInfo& info) - : BaseWorkload(descriptor, info) + : RefBaseWorkload(descriptor, info) { WorkloadInfo detailsInfo; detailsInfo.m_InputTensorInfos = info.m_InputTensorInfos; diff --git a/src/backends/reference/workloads/RefConvolution3dWorkload.hpp b/src/backends/reference/workloads/RefConvolution3dWorkload.hpp index 53ce309eb8..6c74675eec 100644 --- a/src/backends/reference/workloads/RefConvolution3dWorkload.hpp +++ b/src/backends/reference/workloads/RefConvolution3dWorkload.hpp @@ -5,7 +5,7 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include #include "Decoders.hpp" #include "Encoders.hpp" @@ -13,7 +13,7 @@ namespace armnn { -class RefConvolution3dWorkload : public BaseWorkload +class RefConvolution3dWorkload : public RefBaseWorkload { public: explicit RefConvolution3dWorkload(const Convolution3dQueueDescriptor& descriptor, diff --git a/src/backends/reference/workloads/RefDebugWorkload.hpp b/src/backends/reference/workloads/RefDebugWorkload.hpp index 66af9a0b0f..a1579599f4 100644 --- a/src/backends/reference/workloads/RefDebugWorkload.hpp +++ b/src/backends/reference/workloads/RefDebugWorkload.hpp @@ -7,7 +7,7 @@ #include -#include +#include "RefBaseWorkload.hpp" namespace armnn { diff --git a/src/backends/reference/workloads/RefDepthToSpaceWorkload.hpp b/src/backends/reference/workloads/RefDepthToSpaceWorkload.hpp index 854a564062..bd179d3b9c 100644 --- a/src/backends/reference/workloads/RefDepthToSpaceWorkload.hpp +++ b/src/backends/reference/workloads/RefDepthToSpaceWorkload.hpp @@ -5,15 +5,15 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" namespace armnn { -class RefDepthToSpaceWorkload : public BaseWorkload +class RefDepthToSpaceWorkload : public RefBaseWorkload { public: - using BaseWorkload::BaseWorkload; + using RefBaseWorkload::RefBaseWorkload; void Execute() const override; void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override; private: diff --git a/src/backends/reference/workloads/RefDepthwiseConvolution2dWorkload.cpp b/src/backends/reference/workloads/RefDepthwiseConvolution2dWorkload.cpp index b447d1a441..ad5edde7e6 100644 --- a/src/backends/reference/workloads/RefDepthwiseConvolution2dWorkload.cpp +++ b/src/backends/reference/workloads/RefDepthwiseConvolution2dWorkload.cpp @@ -17,7 +17,7 @@ namespace armnn RefDepthwiseConvolution2dWorkload::RefDepthwiseConvolution2dWorkload( const DepthwiseConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) - : BaseWorkload(descriptor, info) + : RefBaseWorkload(descriptor, info) { m_Weight = std::make_unique(*(descriptor.m_Weight)); const TensorInfo& rFilterInfo = m_Weight->GetTensorInfo(); diff --git a/src/backends/reference/workloads/RefDepthwiseConvolution2dWorkload.hpp b/src/backends/reference/workloads/RefDepthwiseConvolution2dWorkload.hpp index ae93d03656..5d4b483fa7 100644 --- a/src/backends/reference/workloads/RefDepthwiseConvolution2dWorkload.hpp +++ b/src/backends/reference/workloads/RefDepthwiseConvolution2dWorkload.hpp @@ -2,7 +2,7 @@ // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // -#include +#include "RefBaseWorkload.hpp" #include #include "Decoders.hpp" #include "Encoders.hpp" @@ -12,7 +12,7 @@ namespace armnn { -class RefDepthwiseConvolution2dWorkload : public BaseWorkload { +class RefDepthwiseConvolution2dWorkload : public RefBaseWorkload { public: explicit RefDepthwiseConvolution2dWorkload(const DepthwiseConvolution2dQueueDescriptor &descriptor, const WorkloadInfo &info); diff --git a/src/backends/reference/workloads/RefDequantizeWorkload.hpp b/src/backends/reference/workloads/RefDequantizeWorkload.hpp index 285c6496bb..8fa8951677 100644 --- a/src/backends/reference/workloads/RefDequantizeWorkload.hpp +++ b/src/backends/reference/workloads/RefDequantizeWorkload.hpp @@ -5,16 +5,16 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" namespace armnn { -class RefDequantizeWorkload : public BaseWorkload +class RefDequantizeWorkload : public RefBaseWorkload { public: - using BaseWorkload::m_Data; - using BaseWorkload::BaseWorkload; + using RefBaseWorkload::m_Data; + using RefBaseWorkload::RefBaseWorkload; void Execute() const override; void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override; diff --git a/src/backends/reference/workloads/RefDetectionPostProcessWorkload.cpp b/src/backends/reference/workloads/RefDetectionPostProcessWorkload.cpp index 4bc9eb1704..5f01db3280 100644 --- a/src/backends/reference/workloads/RefDetectionPostProcessWorkload.cpp +++ b/src/backends/reference/workloads/RefDetectionPostProcessWorkload.cpp @@ -15,7 +15,7 @@ namespace armnn RefDetectionPostProcessWorkload::RefDetectionPostProcessWorkload( const DetectionPostProcessQueueDescriptor& descriptor, const WorkloadInfo& info) - : BaseWorkload(descriptor, info), + : RefBaseWorkload(descriptor, info), m_Anchors(std::make_unique(*(descriptor.m_Anchors))) {} void RefDetectionPostProcessWorkload::Execute() const diff --git a/src/backends/reference/workloads/RefDetectionPostProcessWorkload.hpp b/src/backends/reference/workloads/RefDetectionPostProcessWorkload.hpp index 4c3ad42b0f..53b2971063 100644 --- a/src/backends/reference/workloads/RefDetectionPostProcessWorkload.hpp +++ b/src/backends/reference/workloads/RefDetectionPostProcessWorkload.hpp @@ -5,13 +5,13 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include namespace armnn { -class RefDetectionPostProcessWorkload : public BaseWorkload +class RefDetectionPostProcessWorkload : public RefBaseWorkload { public: explicit RefDetectionPostProcessWorkload(const DetectionPostProcessQueueDescriptor& descriptor, diff --git a/src/backends/reference/workloads/RefElementwiseUnaryWorkload.cpp b/src/backends/reference/workloads/RefElementwiseUnaryWorkload.cpp index be153636f9..3ea51b9f69 100644 --- a/src/backends/reference/workloads/RefElementwiseUnaryWorkload.cpp +++ b/src/backends/reference/workloads/RefElementwiseUnaryWorkload.cpp @@ -27,7 +27,7 @@ namespace armnn RefElementwiseUnaryWorkload::RefElementwiseUnaryWorkload(const ElementwiseUnaryQueueDescriptor& desc, const WorkloadInfo& info) - : BaseWorkload(desc, info) + : RefBaseWorkload(desc, info) {} void RefElementwiseUnaryWorkload::Execute() const diff --git a/src/backends/reference/workloads/RefElementwiseUnaryWorkload.hpp b/src/backends/reference/workloads/RefElementwiseUnaryWorkload.hpp index e055fd012c..91229b3c58 100644 --- a/src/backends/reference/workloads/RefElementwiseUnaryWorkload.hpp +++ b/src/backends/reference/workloads/RefElementwiseUnaryWorkload.hpp @@ -7,16 +7,16 @@ #include "BaseIterator.hpp" -#include +#include "RefBaseWorkload.hpp" #include namespace armnn { -class RefElementwiseUnaryWorkload : public BaseWorkload +class RefElementwiseUnaryWorkload : public RefBaseWorkload { public: - using BaseWorkload::m_Data; + using RefBaseWorkload::m_Data; RefElementwiseUnaryWorkload(const ElementwiseUnaryQueueDescriptor& descriptor, const WorkloadInfo& info); void Execute() const override; diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.cpp b/src/backends/reference/workloads/RefElementwiseWorkload.cpp index dd7d325ca5..d14ce075b0 100644 --- a/src/backends/reference/workloads/RefElementwiseWorkload.cpp +++ b/src/backends/reference/workloads/RefElementwiseWorkload.cpp @@ -21,7 +21,7 @@ template ::RefElementwiseWorkload( const ParentDescriptor& desc, const WorkloadInfo& info) - : BaseWorkload(desc, info) + : RefBaseWorkload(desc, info) { } diff --git a/src/backends/reference/workloads/RefElementwiseWorkload.hpp b/src/backends/reference/workloads/RefElementwiseWorkload.hpp index 4b108e4363..065a7833d7 100644 --- a/src/backends/reference/workloads/RefElementwiseWorkload.hpp +++ b/src/backends/reference/workloads/RefElementwiseWorkload.hpp @@ -6,7 +6,7 @@ #pragma once #include -#include +#include "RefBaseWorkload.hpp" #include #include "BaseIterator.hpp" #include "ElementwiseFunction.hpp" @@ -18,12 +18,12 @@ namespace armnn { template -class RefElementwiseWorkload : public BaseWorkload +class RefElementwiseWorkload : public RefBaseWorkload { public: using InType = typename ElementwiseBinaryFunction::InType; using OutType = typename ElementwiseBinaryFunction::OutType; - using BaseWorkload::m_Data; + using RefBaseWorkload::m_Data; RefElementwiseWorkload(const ParentDescriptor& descriptor, const WorkloadInfo& info); void Execute() const override; diff --git a/src/backends/reference/workloads/RefFakeQuantizationFloat32Workload.hpp b/src/backends/reference/workloads/RefFakeQuantizationFloat32Workload.hpp index 53b3375a50..85dc6af326 100644 --- a/src/backends/reference/workloads/RefFakeQuantizationFloat32Workload.hpp +++ b/src/backends/reference/workloads/RefFakeQuantizationFloat32Workload.hpp @@ -5,7 +5,7 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include namespace armnn diff --git a/src/backends/reference/workloads/RefFillWorkload.hpp b/src/backends/reference/workloads/RefFillWorkload.hpp index 56d44b85f7..d1e00581cd 100644 --- a/src/backends/reference/workloads/RefFillWorkload.hpp +++ b/src/backends/reference/workloads/RefFillWorkload.hpp @@ -5,16 +5,16 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include namespace armnn { -class RefFillWorkload : public BaseWorkload +class RefFillWorkload : public RefBaseWorkload { public: - using BaseWorkload::BaseWorkload; + using RefBaseWorkload::RefBaseWorkload; void Execute() const override; void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override; private: diff --git a/src/backends/reference/workloads/RefFloorWorkload.hpp b/src/backends/reference/workloads/RefFloorWorkload.hpp index 1a532f7a49..6237ff0c61 100644 --- a/src/backends/reference/workloads/RefFloorWorkload.hpp +++ b/src/backends/reference/workloads/RefFloorWorkload.hpp @@ -5,16 +5,16 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include namespace armnn { -class RefFloorWorkload : public BaseWorkload +class RefFloorWorkload : public RefBaseWorkload { public: - using BaseWorkload::BaseWorkload; + using RefBaseWorkload::RefBaseWorkload; void Execute() const override; void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override; private: diff --git a/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp b/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp index 5a7951ec48..c6ea147043 100644 --- a/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp +++ b/src/backends/reference/workloads/RefFullyConnectedWorkload.cpp @@ -14,7 +14,7 @@ namespace armnn { RefFullyConnectedWorkload::RefFullyConnectedWorkload( const FullyConnectedQueueDescriptor& descriptor, const WorkloadInfo& info) - : BaseWorkload(descriptor, info) + : RefBaseWorkload(descriptor, info) { } diff --git a/src/backends/reference/workloads/RefFullyConnectedWorkload.hpp b/src/backends/reference/workloads/RefFullyConnectedWorkload.hpp index 3ee4a4a83c..432a8879a0 100644 --- a/src/backends/reference/workloads/RefFullyConnectedWorkload.hpp +++ b/src/backends/reference/workloads/RefFullyConnectedWorkload.hpp @@ -5,7 +5,7 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include #include "BaseIterator.hpp" #include "Decoders.hpp" @@ -15,7 +15,7 @@ namespace armnn { -class RefFullyConnectedWorkload : public BaseWorkload +class RefFullyConnectedWorkload : public RefBaseWorkload { public: explicit RefFullyConnectedWorkload(const FullyConnectedQueueDescriptor& descriptor, diff --git a/src/backends/reference/workloads/RefGatherWorkload.hpp b/src/backends/reference/workloads/RefGatherWorkload.hpp index a2698e3a25..ec880a5109 100644 --- a/src/backends/reference/workloads/RefGatherWorkload.hpp +++ b/src/backends/reference/workloads/RefGatherWorkload.hpp @@ -5,7 +5,7 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include #include @@ -16,10 +16,10 @@ namespace armnn { -class RefGatherWorkload : public BaseWorkload +class RefGatherWorkload : public RefBaseWorkload { public: - using BaseWorkload::BaseWorkload; + using RefBaseWorkload::RefBaseWorkload; void Execute() const override; void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override; private: diff --git a/src/backends/reference/workloads/RefInstanceNormalizationWorkload.cpp b/src/backends/reference/workloads/RefInstanceNormalizationWorkload.cpp index e642dc9b9a..c103a6b9d3 100644 --- a/src/backends/reference/workloads/RefInstanceNormalizationWorkload.cpp +++ b/src/backends/reference/workloads/RefInstanceNormalizationWorkload.cpp @@ -16,7 +16,7 @@ namespace armnn RefInstanceNormalizationWorkload::RefInstanceNormalizationWorkload( const InstanceNormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) - : BaseWorkload(descriptor, info) {} + : RefBaseWorkload(descriptor, info) {} void RefInstanceNormalizationWorkload::Execute() const { diff --git a/src/backends/reference/workloads/RefInstanceNormalizationWorkload.hpp b/src/backends/reference/workloads/RefInstanceNormalizationWorkload.hpp index 3283c444d2..a4b2dd39cb 100644 --- a/src/backends/reference/workloads/RefInstanceNormalizationWorkload.hpp +++ b/src/backends/reference/workloads/RefInstanceNormalizationWorkload.hpp @@ -5,13 +5,13 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include namespace armnn { -class RefInstanceNormalizationWorkload : public BaseWorkload +class RefInstanceNormalizationWorkload : public RefBaseWorkload { public: explicit RefInstanceNormalizationWorkload(const InstanceNormalizationQueueDescriptor& descriptor, diff --git a/src/backends/reference/workloads/RefL2NormalizationWorkload.cpp b/src/backends/reference/workloads/RefL2NormalizationWorkload.cpp index ca31503620..f6fcff3cc5 100644 --- a/src/backends/reference/workloads/RefL2NormalizationWorkload.cpp +++ b/src/backends/reference/workloads/RefL2NormalizationWorkload.cpp @@ -22,7 +22,7 @@ namespace armnn RefL2NormalizationWorkload::RefL2NormalizationWorkload( const L2NormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) - : BaseWorkload(descriptor, info) {} + : RefBaseWorkload(descriptor, info) {} void RefL2NormalizationWorkload::Execute() const { diff --git a/src/backends/reference/workloads/RefL2NormalizationWorkload.hpp b/src/backends/reference/workloads/RefL2NormalizationWorkload.hpp index dd129c663e..c64e2ea0fd 100644 --- a/src/backends/reference/workloads/RefL2NormalizationWorkload.hpp +++ b/src/backends/reference/workloads/RefL2NormalizationWorkload.hpp @@ -5,13 +5,13 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include namespace armnn { -class RefL2NormalizationWorkload : public BaseWorkload +class RefL2NormalizationWorkload : public RefBaseWorkload { public: explicit RefL2NormalizationWorkload(const L2NormalizationQueueDescriptor& descriptor, diff --git a/src/backends/reference/workloads/RefLogSoftmaxWorkload.hpp b/src/backends/reference/workloads/RefLogSoftmaxWorkload.hpp index 9f87def1bd..91ad5f6c36 100644 --- a/src/backends/reference/workloads/RefLogSoftmaxWorkload.hpp +++ b/src/backends/reference/workloads/RefLogSoftmaxWorkload.hpp @@ -5,16 +5,16 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include namespace armnn { -class RefLogSoftmaxWorkload : public BaseWorkload +class RefLogSoftmaxWorkload : public RefBaseWorkload { public: - using BaseWorkload::BaseWorkload; + using RefBaseWorkload::RefBaseWorkload; void Execute() const override; void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override; private: diff --git a/src/backends/reference/workloads/RefLogicalBinaryWorkload.cpp b/src/backends/reference/workloads/RefLogicalBinaryWorkload.cpp index f187e0ca31..f0cb846acf 100644 --- a/src/backends/reference/workloads/RefLogicalBinaryWorkload.cpp +++ b/src/backends/reference/workloads/RefLogicalBinaryWorkload.cpp @@ -19,7 +19,7 @@ namespace armnn RefLogicalBinaryWorkload::RefLogicalBinaryWorkload(const LogicalBinaryQueueDescriptor& desc, const WorkloadInfo& info) - : BaseWorkload(desc, info) + : RefBaseWorkload(desc, info) {} void RefLogicalBinaryWorkload::Execute() const diff --git a/src/backends/reference/workloads/RefLogicalBinaryWorkload.hpp b/src/backends/reference/workloads/RefLogicalBinaryWorkload.hpp index 053de7daf9..797d937d80 100644 --- a/src/backends/reference/workloads/RefLogicalBinaryWorkload.hpp +++ b/src/backends/reference/workloads/RefLogicalBinaryWorkload.hpp @@ -7,16 +7,16 @@ #include "BaseIterator.hpp" -#include +#include "RefBaseWorkload.hpp" #include namespace armnn { -class RefLogicalBinaryWorkload : public BaseWorkload +class RefLogicalBinaryWorkload : public RefBaseWorkload { public: - using BaseWorkload::m_Data; + using RefBaseWorkload::m_Data; RefLogicalBinaryWorkload(const LogicalBinaryQueueDescriptor& descriptor, const WorkloadInfo& info); void Execute() const override; diff --git a/src/backends/reference/workloads/RefLogicalUnaryWorkload.cpp b/src/backends/reference/workloads/RefLogicalUnaryWorkload.cpp index bef2bdc668..ec0aa0e454 100644 --- a/src/backends/reference/workloads/RefLogicalUnaryWorkload.cpp +++ b/src/backends/reference/workloads/RefLogicalUnaryWorkload.cpp @@ -19,7 +19,7 @@ namespace armnn RefLogicalUnaryWorkload::RefLogicalUnaryWorkload(const ElementwiseUnaryQueueDescriptor& desc, const WorkloadInfo& info) - : BaseWorkload(desc, info) + : RefBaseWorkload(desc, info) {} void RefLogicalUnaryWorkload::Execute() const diff --git a/src/backends/reference/workloads/RefLogicalUnaryWorkload.hpp b/src/backends/reference/workloads/RefLogicalUnaryWorkload.hpp index 008d24fef8..ebd5826cc5 100644 --- a/src/backends/reference/workloads/RefLogicalUnaryWorkload.hpp +++ b/src/backends/reference/workloads/RefLogicalUnaryWorkload.hpp @@ -7,16 +7,16 @@ #include "BaseIterator.hpp" -#include +#include "RefBaseWorkload.hpp" #include namespace armnn { -class RefLogicalUnaryWorkload : public BaseWorkload +class RefLogicalUnaryWorkload : public RefBaseWorkload { public: - using BaseWorkload::m_Data; + using RefBaseWorkload::m_Data; RefLogicalUnaryWorkload(const ElementwiseUnaryQueueDescriptor& descriptor, const WorkloadInfo& info); void Execute() const override; diff --git a/src/backends/reference/workloads/RefLstmWorkload.cpp b/src/backends/reference/workloads/RefLstmWorkload.cpp index 1ff6f50ed5..8609811253 100644 --- a/src/backends/reference/workloads/RefLstmWorkload.cpp +++ b/src/backends/reference/workloads/RefLstmWorkload.cpp @@ -15,7 +15,7 @@ namespace armnn { RefLstmWorkload::RefLstmWorkload(const LstmQueueDescriptor &descriptor, const WorkloadInfo &info) - : BaseWorkload(descriptor, info) + : RefBaseWorkload(descriptor, info) , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights)) , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights)) , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights)) diff --git a/src/backends/reference/workloads/RefLstmWorkload.hpp b/src/backends/reference/workloads/RefLstmWorkload.hpp index 72f6360281..57526c9ba2 100644 --- a/src/backends/reference/workloads/RefLstmWorkload.hpp +++ b/src/backends/reference/workloads/RefLstmWorkload.hpp @@ -7,13 +7,13 @@ #include -#include +#include "RefBaseWorkload.hpp" #include namespace armnn { -class RefLstmWorkload : public BaseWorkload +class RefLstmWorkload : public RefBaseWorkload { public: explicit RefLstmWorkload(const LstmQueueDescriptor& descriptor, const WorkloadInfo& info); diff --git a/src/backends/reference/workloads/RefMeanWorkload.cpp b/src/backends/reference/workloads/RefMeanWorkload.cpp index 7941ce2c36..23abaf8ff4 100644 --- a/src/backends/reference/workloads/RefMeanWorkload.cpp +++ b/src/backends/reference/workloads/RefMeanWorkload.cpp @@ -16,7 +16,7 @@ namespace armnn { RefMeanWorkload::RefMeanWorkload(const MeanQueueDescriptor& descriptor, const WorkloadInfo& info) - :BaseWorkload(descriptor, info) {} + :RefBaseWorkload(descriptor, info) {} void RefMeanWorkload::Execute() const { diff --git a/src/backends/reference/workloads/RefMeanWorkload.hpp b/src/backends/reference/workloads/RefMeanWorkload.hpp index 2825d669c4..c4c6a1261c 100644 --- a/src/backends/reference/workloads/RefMeanWorkload.hpp +++ b/src/backends/reference/workloads/RefMeanWorkload.hpp @@ -5,7 +5,7 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include #include "Decoders.hpp" @@ -14,7 +14,7 @@ namespace armnn { -class RefMeanWorkload : public BaseWorkload +class RefMeanWorkload : public RefBaseWorkload { public: explicit RefMeanWorkload (const MeanQueueDescriptor& descriptor, const WorkloadInfo& info); diff --git a/src/backends/reference/workloads/RefNormalizationWorkload.cpp b/src/backends/reference/workloads/RefNormalizationWorkload.cpp index 36828acfb3..613868de57 100644 --- a/src/backends/reference/workloads/RefNormalizationWorkload.cpp +++ b/src/backends/reference/workloads/RefNormalizationWorkload.cpp @@ -158,7 +158,7 @@ namespace armnn RefNormalizationWorkload::RefNormalizationWorkload(const NormalizationQueueDescriptor& descriptor, const WorkloadInfo& info) - : BaseWorkload(descriptor, info) + : RefBaseWorkload(descriptor, info) {} void RefNormalizationWorkload::Execute() const diff --git a/src/backends/reference/workloads/RefNormalizationWorkload.hpp b/src/backends/reference/workloads/RefNormalizationWorkload.hpp index b152072496..5218e1e43a 100644 --- a/src/backends/reference/workloads/RefNormalizationWorkload.hpp +++ b/src/backends/reference/workloads/RefNormalizationWorkload.hpp @@ -5,13 +5,13 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include namespace armnn { -class RefNormalizationWorkload : public BaseWorkload +class RefNormalizationWorkload : public RefBaseWorkload { public: explicit RefNormalizationWorkload(const NormalizationQueueDescriptor& descriptor, diff --git a/src/backends/reference/workloads/RefPadWorkload.hpp b/src/backends/reference/workloads/RefPadWorkload.hpp index 18c406a4de..c5871059cc 100644 --- a/src/backends/reference/workloads/RefPadWorkload.hpp +++ b/src/backends/reference/workloads/RefPadWorkload.hpp @@ -5,16 +5,16 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include namespace armnn { -class RefPadWorkload : public BaseWorkload +class RefPadWorkload : public RefBaseWorkload { public: - using BaseWorkload::BaseWorkload; + using RefBaseWorkload::RefBaseWorkload; void Execute() const override; void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override; private: diff --git a/src/backends/reference/workloads/RefPermuteWorkload.hpp b/src/backends/reference/workloads/RefPermuteWorkload.hpp index 9424441c37..d1e44520a1 100644 --- a/src/backends/reference/workloads/RefPermuteWorkload.hpp +++ b/src/backends/reference/workloads/RefPermuteWorkload.hpp @@ -5,7 +5,7 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include diff --git a/src/backends/reference/workloads/RefPooling2dWorkload.hpp b/src/backends/reference/workloads/RefPooling2dWorkload.hpp index 125fea8d4e..a073e3921b 100644 --- a/src/backends/reference/workloads/RefPooling2dWorkload.hpp +++ b/src/backends/reference/workloads/RefPooling2dWorkload.hpp @@ -5,7 +5,7 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include #include "Decoders.hpp" @@ -13,10 +13,10 @@ namespace armnn { -class RefPooling2dWorkload : public BaseWorkload +class RefPooling2dWorkload : public RefBaseWorkload { public: - using BaseWorkload::BaseWorkload; + using RefBaseWorkload::RefBaseWorkload; void Execute() const override; void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override; diff --git a/src/backends/reference/workloads/RefPooling3dWorkload.hpp b/src/backends/reference/workloads/RefPooling3dWorkload.hpp index 911c438627..92bc4766cf 100644 --- a/src/backends/reference/workloads/RefPooling3dWorkload.hpp +++ b/src/backends/reference/workloads/RefPooling3dWorkload.hpp @@ -5,7 +5,7 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include #include "Decoders.hpp" @@ -13,10 +13,10 @@ namespace armnn { -class RefPooling3dWorkload : public BaseWorkload +class RefPooling3dWorkload : public RefBaseWorkload { public: - using BaseWorkload::BaseWorkload; + using RefBaseWorkload::RefBaseWorkload; void Execute() const override; void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override; diff --git a/src/backends/reference/workloads/RefPreluWorkload.cpp b/src/backends/reference/workloads/RefPreluWorkload.cpp index c1d8de2d01..94eeea1884 100644 --- a/src/backends/reference/workloads/RefPreluWorkload.cpp +++ b/src/backends/reference/workloads/RefPreluWorkload.cpp @@ -15,7 +15,7 @@ namespace armnn RefPreluWorkload::RefPreluWorkload(const PreluQueueDescriptor& descriptor, const WorkloadInfo& info) - : BaseWorkload(descriptor, info) + : RefBaseWorkload(descriptor, info) {} void RefPreluWorkload::Execute() const diff --git a/src/backends/reference/workloads/RefPreluWorkload.hpp b/src/backends/reference/workloads/RefPreluWorkload.hpp index b5c97dfa90..51ba2c15a7 100644 --- a/src/backends/reference/workloads/RefPreluWorkload.hpp +++ b/src/backends/reference/workloads/RefPreluWorkload.hpp @@ -5,13 +5,13 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include namespace armnn { -class RefPreluWorkload : public BaseWorkload +class RefPreluWorkload : public RefBaseWorkload { public: explicit RefPreluWorkload(const PreluQueueDescriptor& descriptor, diff --git a/src/backends/reference/workloads/RefQLstmWorkload.cpp b/src/backends/reference/workloads/RefQLstmWorkload.cpp index dc29d0b92d..74f5f1ef4c 100644 --- a/src/backends/reference/workloads/RefQLstmWorkload.cpp +++ b/src/backends/reference/workloads/RefQLstmWorkload.cpp @@ -14,7 +14,7 @@ namespace armnn { RefQLstmWorkload::RefQLstmWorkload(const QLstmQueueDescriptor &descriptor, const WorkloadInfo &info) - : BaseWorkload(descriptor, info) + : RefBaseWorkload(descriptor, info) , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights)) , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights)) , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights)) diff --git a/src/backends/reference/workloads/RefQLstmWorkload.hpp b/src/backends/reference/workloads/RefQLstmWorkload.hpp index 093cfd16af..0e64a38ac9 100644 --- a/src/backends/reference/workloads/RefQLstmWorkload.hpp +++ b/src/backends/reference/workloads/RefQLstmWorkload.hpp @@ -7,13 +7,13 @@ #include -#include +#include "RefBaseWorkload.hpp" #include namespace armnn { -class RefQLstmWorkload : public BaseWorkload +class RefQLstmWorkload : public RefBaseWorkload { public: explicit RefQLstmWorkload(const QLstmQueueDescriptor& descriptor, const WorkloadInfo& info); diff --git a/src/backends/reference/workloads/RefQuantizeWorkload.cpp b/src/backends/reference/workloads/RefQuantizeWorkload.cpp index 35791e65fb..10ef0e5e15 100644 --- a/src/backends/reference/workloads/RefQuantizeWorkload.cpp +++ b/src/backends/reference/workloads/RefQuantizeWorkload.cpp @@ -29,7 +29,7 @@ void QuantizeImpl(Decoder& in, Encoder& out, size_t numValues) } //namespace RefQuantizeWorkload::RefQuantizeWorkload(const QuantizeQueueDescriptor& descriptor, const WorkloadInfo &info) - : BaseWorkload(descriptor, info) + : RefBaseWorkload(descriptor, info) , m_NumElements(info.m_InputTensorInfos[0].GetNumElements()) { } diff --git a/src/backends/reference/workloads/RefQuantizeWorkload.hpp b/src/backends/reference/workloads/RefQuantizeWorkload.hpp index a32efa7dd7..e38241067d 100644 --- a/src/backends/reference/workloads/RefQuantizeWorkload.hpp +++ b/src/backends/reference/workloads/RefQuantizeWorkload.hpp @@ -5,14 +5,14 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include #include "Decoders.hpp" #include "Encoders.hpp" namespace armnn { -class RefQuantizeWorkload : public BaseWorkload +class RefQuantizeWorkload : public RefBaseWorkload { public: RefQuantizeWorkload(const QuantizeQueueDescriptor& descriptor, const WorkloadInfo &info); diff --git a/src/backends/reference/workloads/RefRankWorkload.hpp b/src/backends/reference/workloads/RefRankWorkload.hpp index e1f30c5ba5..000828f9e4 100644 --- a/src/backends/reference/workloads/RefRankWorkload.hpp +++ b/src/backends/reference/workloads/RefRankWorkload.hpp @@ -5,7 +5,7 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include #include "RefWorkloadUtils.hpp" @@ -13,10 +13,10 @@ namespace armnn { -struct RefRankWorkload : public BaseWorkload +struct RefRankWorkload : public RefBaseWorkload { public: - using BaseWorkload::BaseWorkload; + using RefBaseWorkload::RefBaseWorkload; virtual void Execute() const override { Execute(m_Data.m_Inputs, m_Data.m_Outputs); diff --git a/src/backends/reference/workloads/RefReduceWorkload.cpp b/src/backends/reference/workloads/RefReduceWorkload.cpp index 821e828b6e..62881daaf7 100644 --- a/src/backends/reference/workloads/RefReduceWorkload.cpp +++ b/src/backends/reference/workloads/RefReduceWorkload.cpp @@ -16,7 +16,7 @@ namespace armnn RefReduceWorkload::RefReduceWorkload( const ReduceQueueDescriptor& descriptor, const WorkloadInfo& info) - : BaseWorkload(descriptor, info) {} + : RefBaseWorkload(descriptor, info) {} void RefReduceWorkload::Execute() const { diff --git a/src/backends/reference/workloads/RefReduceWorkload.hpp b/src/backends/reference/workloads/RefReduceWorkload.hpp index d2280cc660..d759bc2ef1 100644 --- a/src/backends/reference/workloads/RefReduceWorkload.hpp +++ b/src/backends/reference/workloads/RefReduceWorkload.hpp @@ -5,13 +5,13 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include namespace armnn { -class RefReduceWorkload : public BaseWorkload +class RefReduceWorkload : public RefBaseWorkload { public: explicit RefReduceWorkload(const ReduceQueueDescriptor& descriptor, diff --git a/src/backends/reference/workloads/RefReshapeWorkload.hpp b/src/backends/reference/workloads/RefReshapeWorkload.hpp index 26a86c1d11..7596685336 100644 --- a/src/backends/reference/workloads/RefReshapeWorkload.hpp +++ b/src/backends/reference/workloads/RefReshapeWorkload.hpp @@ -5,16 +5,16 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include namespace armnn { -class RefReshapeWorkload : public BaseWorkload +class RefReshapeWorkload : public RefBaseWorkload { public: - using BaseWorkload::BaseWorkload; + using RefBaseWorkload::RefBaseWorkload; void Execute() const override; void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override; private: diff --git a/src/backends/reference/workloads/RefResizeWorkload.hpp b/src/backends/reference/workloads/RefResizeWorkload.hpp index 82949ed639..f7747193ec 100644 --- a/src/backends/reference/workloads/RefResizeWorkload.hpp +++ b/src/backends/reference/workloads/RefResizeWorkload.hpp @@ -5,16 +5,16 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include namespace armnn { -class RefResizeWorkload : public BaseWorkload +class RefResizeWorkload : public RefBaseWorkload { public: - using BaseWorkload::BaseWorkload; + using RefBaseWorkload::RefBaseWorkload; void Execute() const override; void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override; private: diff --git a/src/backends/reference/workloads/RefShapeWorkload.hpp b/src/backends/reference/workloads/RefShapeWorkload.hpp index 209cccda68..b7ed761e0c 100644 --- a/src/backends/reference/workloads/RefShapeWorkload.hpp +++ b/src/backends/reference/workloads/RefShapeWorkload.hpp @@ -5,7 +5,7 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include #include "RefWorkloadUtils.hpp" @@ -13,10 +13,10 @@ namespace armnn { -struct RefShapeWorkload : public BaseWorkload +struct RefShapeWorkload : public RefBaseWorkload { public: - using BaseWorkload::BaseWorkload; + using RefBaseWorkload::RefBaseWorkload; virtual void Execute() const override { Execute(m_Data.m_Inputs, m_Data.m_Outputs); diff --git a/src/backends/reference/workloads/RefSliceWorkload.hpp b/src/backends/reference/workloads/RefSliceWorkload.hpp index 69dae5a1aa..b9dca86c4e 100644 --- a/src/backends/reference/workloads/RefSliceWorkload.hpp +++ b/src/backends/reference/workloads/RefSliceWorkload.hpp @@ -5,16 +5,16 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include namespace armnn { -class RefSliceWorkload : public BaseWorkload +class RefSliceWorkload : public RefBaseWorkload { public: - using BaseWorkload::BaseWorkload; + using RefBaseWorkload::RefBaseWorkload; void Execute() const override; void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override; diff --git a/src/backends/reference/workloads/RefSoftmaxWorkload.hpp b/src/backends/reference/workloads/RefSoftmaxWorkload.hpp index 42dbb53373..cac102a2bb 100644 --- a/src/backends/reference/workloads/RefSoftmaxWorkload.hpp +++ b/src/backends/reference/workloads/RefSoftmaxWorkload.hpp @@ -5,16 +5,16 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include namespace armnn { -class RefSoftmaxWorkload : public BaseWorkload +class RefSoftmaxWorkload : public RefBaseWorkload { public: - using BaseWorkload::BaseWorkload; + using RefBaseWorkload::RefBaseWorkload; void Execute() const override; void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override; private: diff --git a/src/backends/reference/workloads/RefSpaceToBatchNdWorkload.hpp b/src/backends/reference/workloads/RefSpaceToBatchNdWorkload.hpp index ec764c75bb..eb2d93fb86 100644 --- a/src/backends/reference/workloads/RefSpaceToBatchNdWorkload.hpp +++ b/src/backends/reference/workloads/RefSpaceToBatchNdWorkload.hpp @@ -4,17 +4,17 @@ // #pragma once -#include +#include "RefBaseWorkload.hpp" #include namespace armnn { -class RefSpaceToBatchNdWorkload : public BaseWorkload +class RefSpaceToBatchNdWorkload : public RefBaseWorkload { public: - using BaseWorkload::BaseWorkload; + using RefBaseWorkload::RefBaseWorkload; void Execute() const override; void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override; private: diff --git a/src/backends/reference/workloads/RefSpaceToDepthWorkload.hpp b/src/backends/reference/workloads/RefSpaceToDepthWorkload.hpp index bc71fde20d..17f8d2f61e 100644 --- a/src/backends/reference/workloads/RefSpaceToDepthWorkload.hpp +++ b/src/backends/reference/workloads/RefSpaceToDepthWorkload.hpp @@ -4,17 +4,17 @@ // #pragma once -#include +#include "RefBaseWorkload.hpp" #include namespace armnn { -class RefSpaceToDepthWorkload : public BaseWorkload +class RefSpaceToDepthWorkload : public RefBaseWorkload { public: - using BaseWorkload::BaseWorkload; + using RefBaseWorkload::RefBaseWorkload; void Execute() const override; void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override; private: diff --git a/src/backends/reference/workloads/RefSplitterWorkload.hpp b/src/backends/reference/workloads/RefSplitterWorkload.hpp index 28dc83db36..0b72bb9fdc 100644 --- a/src/backends/reference/workloads/RefSplitterWorkload.hpp +++ b/src/backends/reference/workloads/RefSplitterWorkload.hpp @@ -5,7 +5,7 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include #include "Decoders.hpp" #include "Encoders.hpp" @@ -13,10 +13,10 @@ namespace armnn { -class RefSplitterWorkload : public BaseWorkload +class RefSplitterWorkload : public RefBaseWorkload { public: - using BaseWorkload::BaseWorkload; + using RefBaseWorkload::RefBaseWorkload; void Execute() const override; void ExecuteAsync(WorkingMemDescriptor& workingMemDescriptor) override; private: diff --git a/src/backends/reference/workloads/RefStackWorkload.cpp b/src/backends/reference/workloads/RefStackWorkload.cpp index 3f7fd7bda2..f57e6e0f1e 100644 --- a/src/backends/reference/workloads/RefStackWorkload.cpp +++ b/src/backends/reference/workloads/RefStackWorkload.cpp @@ -15,7 +15,7 @@ namespace armnn RefStackWorkload::RefStackWorkload(const StackQueueDescriptor& descriptor, const WorkloadInfo& info) - : BaseWorkload(descriptor, info) + : RefBaseWorkload(descriptor, info) {} void RefStackWorkload::Execute() const diff --git a/src/backends/reference/workloads/RefStackWorkload.hpp b/src/backends/reference/workloads/RefStackWorkload.hpp index fbca11b2fa..19f4a7be67 100644 --- a/src/backends/reference/workloads/RefStackWorkload.hpp +++ b/src/backends/reference/workloads/RefStackWorkload.hpp @@ -5,13 +5,13 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include namespace armnn { -class RefStackWorkload : public BaseWorkload +class RefStackWorkload : public RefBaseWorkload { public: explicit RefStackWorkload(const StackQueueDescriptor& descriptor, diff --git a/src/backends/reference/workloads/RefStridedSliceWorkload.cpp b/src/backends/reference/workloads/RefStridedSliceWorkload.cpp index 336a687d5c..41fe4c3a1c 100644 --- a/src/backends/reference/workloads/RefStridedSliceWorkload.cpp +++ b/src/backends/reference/workloads/RefStridedSliceWorkload.cpp @@ -12,7 +12,7 @@ namespace armnn RefStridedSliceWorkload::RefStridedSliceWorkload(const StridedSliceQueueDescriptor& descriptor, const WorkloadInfo& info) - : BaseWorkload(descriptor, info) + : RefBaseWorkload(descriptor, info) {} void RefStridedSliceWorkload::Execute() const diff --git a/src/backends/reference/workloads/RefStridedSliceWorkload.hpp b/src/backends/reference/workloads/RefStridedSliceWorkload.hpp index d2ffca7414..ea443cf80d 100644 --- a/src/backends/reference/workloads/RefStridedSliceWorkload.hpp +++ b/src/backends/reference/workloads/RefStridedSliceWorkload.hpp @@ -5,12 +5,12 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" namespace armnn { -class RefStridedSliceWorkload : public BaseWorkload +class RefStridedSliceWorkload : public RefBaseWorkload { public: RefStridedSliceWorkload(const StridedSliceQueueDescriptor& descriptor, const WorkloadInfo& info); diff --git a/src/backends/reference/workloads/RefTransposeConvolution2dWorkload.cpp b/src/backends/reference/workloads/RefTransposeConvolution2dWorkload.cpp index 8665648fe6..64a2d4c7b2 100644 --- a/src/backends/reference/workloads/RefTransposeConvolution2dWorkload.cpp +++ b/src/backends/reference/workloads/RefTransposeConvolution2dWorkload.cpp @@ -15,7 +15,7 @@ namespace armnn RefTransposeConvolution2dWorkload::RefTransposeConvolution2dWorkload( const TransposeConvolution2dQueueDescriptor& descriptor, const WorkloadInfo& info) : - BaseWorkload(descriptor, info) + RefBaseWorkload(descriptor, info) { // set up weights decoder m_Weights = std::make_unique(*(descriptor.m_Weight)); diff --git a/src/backends/reference/workloads/RefTransposeConvolution2dWorkload.hpp b/src/backends/reference/workloads/RefTransposeConvolution2dWorkload.hpp index aa2546f420..6bcee9a838 100644 --- a/src/backends/reference/workloads/RefTransposeConvolution2dWorkload.hpp +++ b/src/backends/reference/workloads/RefTransposeConvolution2dWorkload.hpp @@ -9,12 +9,12 @@ #include "Encoders.hpp" #include -#include +#include "RefBaseWorkload.hpp" namespace armnn { -class RefTransposeConvolution2dWorkload : public BaseWorkload +class RefTransposeConvolution2dWorkload : public RefBaseWorkload { public: RefTransposeConvolution2dWorkload(const TransposeConvolution2dQueueDescriptor& descriptor, diff --git a/src/backends/reference/workloads/RefTransposeWorkload.hpp b/src/backends/reference/workloads/RefTransposeWorkload.hpp index bf59de7813..b8c3649745 100644 --- a/src/backends/reference/workloads/RefTransposeWorkload.hpp +++ b/src/backends/reference/workloads/RefTransposeWorkload.hpp @@ -5,7 +5,7 @@ #pragma once -#include +#include "RefBaseWorkload.hpp" #include diff --git a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp index 311fa18f91..d447a46b23 100644 --- a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp +++ b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.cpp @@ -19,7 +19,7 @@ namespace armnn RefUnidirectionalSequenceLstmWorkload::RefUnidirectionalSequenceLstmWorkload( const UnidirectionalSequenceLstmQueueDescriptor& descriptor, const WorkloadInfo& info) - : BaseWorkload(descriptor, info) + : RefBaseWorkload(descriptor, info) , m_InputToInputWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToInputWeights)) , m_InputToForgetWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToForgetWeights)) , m_InputToCellWeightsTensor (AssignScopedTensorHandle(descriptor.m_InputToCellWeights)) diff --git a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.hpp b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.hpp index d0c000f20d..7a91cee642 100644 --- a/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.hpp +++ b/src/backends/reference/workloads/RefUnidirectionalSequenceLstmWorkload.hpp @@ -7,7 +7,7 @@ #include -#include +#include "RefBaseWorkload.hpp" #include #include "Encoders.hpp" @@ -16,7 +16,7 @@ namespace armnn { -class RefUnidirectionalSequenceLstmWorkload : public BaseWorkload +class RefUnidirectionalSequenceLstmWorkload : public RefBaseWorkload { public: explicit RefUnidirectionalSequenceLstmWorkload(const UnidirectionalSequenceLstmQueueDescriptor& descriptor, -- cgit v1.2.1