aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/layers/GatherLayer.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/layers/GatherLayer.hpp')
-rw-r--r--src/armnn/layers/GatherLayer.hpp11
1 files changed, 6 insertions, 5 deletions
diff --git a/src/armnn/layers/GatherLayer.hpp b/src/armnn/layers/GatherLayer.hpp
index 598ca44dc4..d8737adbee 100644
--- a/src/armnn/layers/GatherLayer.hpp
+++ b/src/armnn/layers/GatherLayer.hpp
@@ -1,17 +1,17 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
-#include "Layer.hpp"
+#include "LayerWithParameters.hpp"
namespace armnn
{
/// This layer represents a Gather operator.
-class GatherLayer : public Layer
+class GatherLayer : public LayerWithParameters<GatherDescriptor>
{
public:
/// Makes a workload for the Gather type.
@@ -24,7 +24,7 @@ public:
/// @param [in] graph The graph into which this layer is being cloned.
GatherLayer* Clone(Graph& graph) const override;
- /// Check if the input tensor shape(s)
+ /// Check if the input tensor shape(s).
/// will lead to a valid configuration of @ref GatherLayer.
/// @param [in] shapeInferenceMethod Indicates if output shape shall be overwritten or just validate.
void ValidateTensorShapesFromInputs(
@@ -34,8 +34,9 @@ public:
protected:
/// Constructor to create a GatherLayer.
+ /// @param [in] param GatherDescriptor to configure the stack operation.
/// @param [in] name Optional name for the layer.
- GatherLayer(const char* name);
+ GatherLayer(const GatherDescriptor& param, const char* name);
/// Default destructor
~GatherLayer() = default;