From f24effa4995ea4c3dd91e33d4a2787e02decf8b4 Mon Sep 17 00:00:00 2001 From: Finn Williams Date: Fri, 3 Jul 2020 10:12:03 +0100 Subject: IVGCVSW-5155 Update Arm NN API to allow for call to shape inference Signed-off-by: Finn Williams Change-Id: I0a2babe5b5b09eb81c9900dc3a05071034a0440b --- src/armnn/Layer.hpp | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) (limited to 'src/armnn/Layer.hpp') diff --git a/src/armnn/Layer.hpp b/src/armnn/Layer.hpp index f1954b9d07..d4a24e4925 100644 --- a/src/armnn/Layer.hpp +++ b/src/armnn/Layer.hpp @@ -1,5 +1,5 @@ // -// Copyright © 2017 Arm Ltd. All rights reserved. +// Copyright © 2017 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once @@ -228,6 +228,8 @@ public: return const_cast(const_cast(this)->GetOutputHandler(i)); } + ShapeInferenceMethod GetShapeInferenceMethod() const { return m_ShapeInferenceMethod; }; + const std::vector& GetInputSlots() const { return m_InputSlots; } const std::vector& GetOutputSlots() const { return m_OutputSlots; } @@ -277,8 +279,7 @@ public: void VerifyLayerConnections(unsigned int expectedConnections, const CheckLocation& location) const; - virtual void ValidateTensorShapesFromInputs( - ShapeInferenceMethod shapeInferenceMethod = ShapeInferenceMethod::ValidateOnly) = 0; + virtual void ValidateTensorShapesFromInputs() = 0; std::vector InferOutputShapes(const std::vector& inputShapes) const override; @@ -328,6 +329,11 @@ public: } Optional GetBackendHint() const { return m_BackendHint; } + void SetShapeInferenceMethod(ShapeInferenceMethod shapeInferenceMethod) + { + m_ShapeInferenceMethod = shapeInferenceMethod; + } + protected: // Graph needs access to the virtual destructor. friend class Graph; @@ -378,6 +384,7 @@ private: protected: std::vector m_OutputHandlers; + ShapeInferenceMethod m_ShapeInferenceMethod; private: const std::string m_LayerName; @@ -396,6 +403,7 @@ private: LayerGuid m_Guid; std::list m_RelatedLayerNames; + }; // A layer user-provided data can be bound to (e.g. inputs, outputs). -- cgit v1.2.1