aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/Layer.hpp
diff options
context:
space:
mode:
authorFinn Williams <Finn.Williams@arm.com>2020-07-03 10:12:03 +0100
committerJim Flynn <jim.flynn@arm.com>2020-07-26 15:42:26 +0000
commitf24effa4995ea4c3dd91e33d4a2787e02decf8b4 (patch)
tree56e0f22cab0fd8544693b9240bd8d74426eaa454 /src/armnn/Layer.hpp
parent8398edcfb933b638ddf4b88d84d6e188c49b1e0d (diff)
downloadarmnn-f24effa4995ea4c3dd91e33d4a2787e02decf8b4.tar.gz
IVGCVSW-5155 Update Arm NN API to allow for call to shape inference
Signed-off-by: Finn Williams <Finn.Williams@arm.com> Change-Id: I0a2babe5b5b09eb81c9900dc3a05071034a0440b
Diffstat (limited to 'src/armnn/Layer.hpp')
-rw-r--r--src/armnn/Layer.hpp14
1 files changed, 11 insertions, 3 deletions
diff --git a/src/armnn/Layer.hpp b/src/armnn/Layer.hpp
index f1954b9d07..d4a24e4925 100644
--- a/src/armnn/Layer.hpp
+++ b/src/armnn/Layer.hpp
@@ -1,5 +1,5 @@
//
-// Copyright © 2017 Arm Ltd. All rights reserved.
+// Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
@@ -228,6 +228,8 @@ public:
return const_cast<OutputHandler&>(const_cast<const Layer*>(this)->GetOutputHandler(i));
}
+ ShapeInferenceMethod GetShapeInferenceMethod() const { return m_ShapeInferenceMethod; };
+
const std::vector<InputSlot>& GetInputSlots() const { return m_InputSlots; }
const std::vector<OutputSlot>& GetOutputSlots() const { return m_OutputSlots; }
@@ -277,8 +279,7 @@ public:
void VerifyLayerConnections(unsigned int expectedConnections, const CheckLocation& location) const;
- virtual void ValidateTensorShapesFromInputs(
- ShapeInferenceMethod shapeInferenceMethod = ShapeInferenceMethod::ValidateOnly) = 0;
+ virtual void ValidateTensorShapesFromInputs() = 0;
std::vector<TensorShape> InferOutputShapes(const std::vector<TensorShape>& inputShapes) const override;
@@ -328,6 +329,11 @@ public:
}
Optional<BackendId> GetBackendHint() const { return m_BackendHint; }
+ void SetShapeInferenceMethod(ShapeInferenceMethod shapeInferenceMethod)
+ {
+ m_ShapeInferenceMethod = shapeInferenceMethod;
+ }
+
protected:
// Graph needs access to the virtual destructor.
friend class Graph;
@@ -378,6 +384,7 @@ private:
protected:
std::vector<OutputHandler> m_OutputHandlers;
+ ShapeInferenceMethod m_ShapeInferenceMethod;
private:
const std::string m_LayerName;
@@ -396,6 +403,7 @@ private:
LayerGuid m_Guid;
std::list<std::string> m_RelatedLayerNames;
+
};
// A layer user-provided data can be bound to (e.g. inputs, outputs).