aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/ConcatLayer.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/ConcatLayer.hpp')
-rw-r--r--src/armnn/layers/ConcatLayer.hpp11
1 files changed, 9 insertions, 2 deletions
diff --git a/src/armnn/layers/ConcatLayer.hpp b/src/armnn/layers/ConcatLayer.hpp
index 4268291916..eb7d93ce14 100644
--- a/src/armnn/layers/ConcatLayer.hpp
+++ b/src/armnn/layers/ConcatLayer.hpp
@@ -22,9 +22,11 @@ public:
/// Set the outputs to be appropriate sub tensors of the input if sub tensors are supported
/// otherwise creates tensor handlers.
- /// @param [in] graph The graph where this layer can be found.
+ /// @param [in] registry Contains all the registered tensor handle factories available for use.
/// @param [in] factory The workload factory which will create the workload.
- virtual void CreateTensorHandles(Graph& graph, const IWorkloadFactory& factory) override;
+// virtual void CreateTensorHandles(Graph& graph, const IWorkloadFactory& factory) override;
+ virtual void CreateTensorHandles(const TensorHandleFactoryRegistry& registry,
+ const IWorkloadFactory& factory) override;
/// Creates a dynamically-allocated copy of this layer.
/// @param [in] graph The graph into which this layer is being cloned.
@@ -50,6 +52,11 @@ protected:
/// Default destructor
~ConcatLayer() = default;
+
+private:
+ template <typename FactoryType>
+ void CreateTensors(const FactoryType& factory);
+
};
} // namespace