diff options
Diffstat (limited to 'src/armnn/layers/ConstantLayer.cpp')
-rw-r--r-- | src/armnn/layers/ConstantLayer.cpp | 7 |
1 files changed, 3 insertions, 4 deletions
diff --git a/src/armnn/layers/ConstantLayer.cpp b/src/armnn/layers/ConstantLayer.cpp index 31e1549e0e..136616c204 100644 --- a/src/armnn/layers/ConstantLayer.cpp +++ b/src/armnn/layers/ConstantLayer.cpp @@ -18,12 +18,11 @@ ConstantLayer::ConstantLayer(const char* name) { } -std::unique_ptr<IWorkload> ConstantLayer::CreateWorkload(const Graph& graph, - const IWorkloadFactory& factory) const +std::unique_ptr<IWorkload> ConstantLayer::CreateWorkload(const IWorkloadFactory& factory) const { ConstantQueueDescriptor descriptor; descriptor.m_LayerOutput = m_LayerOutput.get(); - return factory.CreateConstant(descriptor, PrepInfoAndDesc(descriptor, graph)); + return factory.CreateConstant(descriptor, PrepInfoAndDesc(descriptor)); } ConstantLayer* ConstantLayer::Clone(Graph& graph) const @@ -38,7 +37,7 @@ ConstantLayer* ConstantLayer::Clone(Graph& graph) const std::vector<TensorShape> ConstantLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const { - return std::vector<TensorShape>({ m_LayerOutput->GetTensorInfo().GetShape() }); + return std::vector<TensorShape>({ inputShapes[0] }); } void ConstantLayer::ValidateTensorShapesFromInputs() |