diff options
Diffstat (limited to 'src/armnn/Layer.hpp')
-rw-r--r-- | src/armnn/Layer.hpp | 14 |
1 files changed, 11 insertions, 3 deletions
diff --git a/src/armnn/Layer.hpp b/src/armnn/Layer.hpp index f1954b9d07..d4a24e4925 100644 --- a/src/armnn/Layer.hpp +++ b/src/armnn/Layer.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once @@ -228,6 +228,8 @@ public: return const_cast<OutputHandler&>(const_cast<const Layer*>(this)->GetOutputHandler(i)); } + ShapeInferenceMethod GetShapeInferenceMethod() const { return m_ShapeInferenceMethod; }; + const std::vector<InputSlot>& GetInputSlots() const { return m_InputSlots; } const std::vector<OutputSlot>& GetOutputSlots() const { return m_OutputSlots; } @@ -277,8 +279,7 @@ public: void VerifyLayerConnections(unsigned int expectedConnections, const CheckLocation& location) const; - virtual void ValidateTensorShapesFromInputs( - ShapeInferenceMethod shapeInferenceMethod = ShapeInferenceMethod::ValidateOnly) = 0; + virtual void ValidateTensorShapesFromInputs() = 0; std::vector<TensorShape> InferOutputShapes(const std::vector<TensorShape>& inputShapes) const override; @@ -328,6 +329,11 @@ public: } Optional<BackendId> GetBackendHint() const { return m_BackendHint; } + void SetShapeInferenceMethod(ShapeInferenceMethod shapeInferenceMethod) + { + m_ShapeInferenceMethod = shapeInferenceMethod; + } + protected: // Graph needs access to the virtual destructor. friend class Graph; @@ -378,6 +384,7 @@ private: protected: std::vector<OutputHandler> m_OutputHandlers; + ShapeInferenceMethod m_ShapeInferenceMethod; private: const std::string m_LayerName; @@ -396,6 +403,7 @@ private: LayerGuid m_Guid; std::list<std::string> m_RelatedLayerNames; + }; // A layer user-provided data can be bound to (e.g. inputs, outputs). |