aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/SplitterLayer.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/SplitterLayer.hpp')
-rw-r--r--src/armnn/layers/SplitterLayer.hpp10
1 files changed, 8 insertions, 2 deletions
diff --git a/src/armnn/layers/SplitterLayer.hpp b/src/armnn/layers/SplitterLayer.hpp
index 19b05562e8..9c684d479f 100644
--- a/src/armnn/layers/SplitterLayer.hpp
+++ b/src/armnn/layers/SplitterLayer.hpp
@@ -22,9 +22,11 @@ public:
/// Set the outputs to be appropriate sub tensors of the input if sub tensors are supported
/// otherwise creates tensor handlers.
- /// @param [in] graph The graph where this layer can be found.
+ /// @param [in] registry Contains all the registered tensor handle factories available for use.
/// @param [in] factory The workload factory which will create the workload.
- virtual void CreateTensorHandles(Graph& graph, const IWorkloadFactory& factory) override;
+ //virtual void CreateTensorHandles(Graph& graph, const IWorkloadFactory& factory) override;
+ virtual void CreateTensorHandles(const TensorHandleFactoryRegistry& registry,
+ const IWorkloadFactory& factory) override;
/// Creates a dynamically-allocated copy of this layer.
/// @param [in] graph The graph into which this layer is being cloned.
@@ -50,6 +52,10 @@ protected:
/// Default destructor
~SplitterLayer() = default;
+
+private:
+ template <typename FactoryType>
+ void CreateTensors(const FactoryType& factory);
};
} // namespace