diff options
author | telsoa01 <telmo.soares@arm.com> | 2018-08-31 09:22:23 +0100 |
---|---|---|
committer | telsoa01 <telmo.soares@arm.com> | 2018-08-31 09:22:23 +0100 |
commit | c577f2c6a3b4ddb6ba87a882723c53a248afbeba (patch) | |
tree | bd7d4c148df27f8be6649d313efb24f536b7cf34 /src/armnn/layers/ConstantLayer.cpp | |
parent | 4c7098bfeab1ffe1cdc77f6c15548d3e73274746 (diff) | |
download | armnn-c577f2c6a3b4ddb6ba87a882723c53a248afbeba.tar.gz |
Release 18.08
Diffstat (limited to 'src/armnn/layers/ConstantLayer.cpp')
-rw-r--r-- | src/armnn/layers/ConstantLayer.cpp | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/src/armnn/layers/ConstantLayer.cpp b/src/armnn/layers/ConstantLayer.cpp index 937d38a31d..2abc595605 100644 --- a/src/armnn/layers/ConstantLayer.cpp +++ b/src/armnn/layers/ConstantLayer.cpp @@ -13,9 +13,8 @@ namespace armnn { -ConstantLayer::ConstantLayer(const std::shared_ptr<ScopedCpuTensorHandle>& input, const char* name) +ConstantLayer::ConstantLayer(const char* name) : Layer(0, 1, LayerType::Constant, name) - , m_LayerOutput(input) { } @@ -29,13 +28,22 @@ std::unique_ptr<IWorkload> ConstantLayer::CreateWorkload(const Graph& graph, ConstantLayer* ConstantLayer::Clone(Graph& graph) const { - // Cloned layers share the same layer output object - return CloneBase<ConstantLayer>(graph, m_LayerOutput, GetName()); + // Cloned layers share the same layer output object. + auto layer = CloneBase<ConstantLayer>(graph, GetName()); + + layer->m_LayerOutput = m_LayerOutput ? std::make_unique<ScopedCpuTensorHandle>(*m_LayerOutput) : nullptr; + + return std::move(layer); +} + +std::vector<TensorShape> ConstantLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const +{ + return std::vector<TensorShape>({ m_LayerOutput->GetTensorInfo().GetShape() }); } void ConstantLayer::ValidateTensorShapesFromInputs() { - // get the output shape from the value of the constant layer + // Get the output shape from the value of the constant layer. TensorShape const& outShape = m_LayerOutput->GetTensorInfo().GetShape(); ConditionalThrowIfNotEqual<LayerValidationException>( "ConstantLayer: TensorShape set on OutputSlot[0] does not match the inferred shape.", |