diff options
-rw-r--r-- | src/armnn/layers/ReduceLayer.cpp | 12 |
1 files changed, 11 insertions, 1 deletions
diff --git a/src/armnn/layers/ReduceLayer.cpp b/src/armnn/layers/ReduceLayer.cpp index b68cd2eabc..31a2dfa479 100644 --- a/src/armnn/layers/ReduceLayer.cpp +++ b/src/armnn/layers/ReduceLayer.cpp @@ -22,12 +22,22 @@ ReduceLayer::ReduceLayer(const ReduceDescriptor& param, const char* name) std::unique_ptr<IWorkload> ReduceLayer::CreateWorkload(const IWorkloadFactory& factory) const { ReduceQueueDescriptor descriptor; + descriptor.m_Parameters.m_vAxis = m_Param.m_vAxis; + descriptor.m_Parameters.m_KeepDims = m_Param.m_KeepDims; + descriptor.m_Parameters.m_ReduceOperation = m_Param.m_ReduceOperation; + SetAdditionalInfo(descriptor); + return factory.CreateReduce(descriptor, PrepInfoAndDesc(descriptor)); } ReduceLayer* ReduceLayer::Clone(Graph& graph) const { - return CloneBase<ReduceLayer>(graph, m_Param, GetName()); + auto layer = CloneBase<ReduceLayer>(graph, m_Param, GetName()); + layer->m_Param.m_vAxis = m_Param.m_vAxis; + layer->m_Param.m_KeepDims = m_Param.m_KeepDims; + layer->m_Param.m_ReduceOperation = m_Param.m_ReduceOperation; + + return std::move(layer); } void ReduceLayer::ValidateTensorShapesFromInputs() |