diff options
Diffstat (limited to 'src/armnn/Layer.hpp')
-rw-r--r-- | src/armnn/Layer.hpp | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/src/armnn/Layer.hpp b/src/armnn/Layer.hpp index 3f00a20e65..e0a1ad66f2 100644 --- a/src/armnn/Layer.hpp +++ b/src/armnn/Layer.hpp @@ -265,7 +265,7 @@ public: // Virtuals - virtual std::unique_ptr<IWorkload> CreateWorkload(const Graph& graph, const IWorkloadFactory& factory) const = 0; + virtual std::unique_ptr<IWorkload> CreateWorkload(const IWorkloadFactory& factory) const = 0; virtual void CreateTensorHandles(const TensorHandleFactoryRegistry& registry, const IWorkloadFactory& factory, @@ -326,26 +326,26 @@ protected: virtual ~Layer() = default; template <typename QueueDescriptor> - void CollectQueueDescriptorInputs(QueueDescriptor& descriptor, WorkloadInfo& info, const Graph& graph) const + void CollectQueueDescriptorInputs(QueueDescriptor& descriptor, WorkloadInfo& info) const { WorkloadDataCollector dataCollector(descriptor.m_Inputs, info.m_InputTensorInfos); - CollectWorkloadInputs(dataCollector, graph); + CollectWorkloadInputs(dataCollector); } template <typename QueueDescriptor> - void CollectQueueDescriptorOutputs(QueueDescriptor& descriptor, WorkloadInfo& info, const Graph& graph) const + void CollectQueueDescriptorOutputs(QueueDescriptor& descriptor, WorkloadInfo& info) const { WorkloadDataCollector dataCollector(descriptor.m_Outputs, info.m_OutputTensorInfos); - CollectWorkloadOutputs(dataCollector, graph); + CollectWorkloadOutputs(dataCollector); } /// Helper function to reduce duplication in *Layer::CreateWorkload. template <typename QueueDescriptor> - WorkloadInfo PrepInfoAndDesc(QueueDescriptor& descriptor, const Graph& graph) const + WorkloadInfo PrepInfoAndDesc(QueueDescriptor& descriptor) const { WorkloadInfo info; - CollectQueueDescriptorInputs(descriptor, info, graph); - CollectQueueDescriptorOutputs(descriptor, info, graph); + CollectQueueDescriptorInputs(descriptor, info); + CollectQueueDescriptorOutputs(descriptor, info); return info; } @@ -357,8 +357,8 @@ protected: virtual ConstantTensors GetConstantTensorsByRef() {return ConstantTensors(); }; private: - void CollectWorkloadInputs(WorkloadDataCollector& dataCollector, const Graph& graph) const; - void CollectWorkloadOutputs(WorkloadDataCollector& dataCollector, const Graph& graph) const; + void CollectWorkloadInputs(WorkloadDataCollector& dataCollector) const; + void CollectWorkloadOutputs(WorkloadDataCollector& dataCollector) const; protected: std::vector<OutputHandler> m_OutputHandlers; |