aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/armnn/layers/ReduceLayer.cpp12
1 files changed, 11 insertions, 1 deletions
diff --git a/src/armnn/layers/ReduceLayer.cpp b/src/armnn/layers/ReduceLayer.cpp
index b68cd2eabc..31a2dfa479 100644
--- a/src/armnn/layers/ReduceLayer.cpp
+++ b/src/armnn/layers/ReduceLayer.cpp
@@ -22,12 +22,22 @@ ReduceLayer::ReduceLayer(const ReduceDescriptor& param, const char* name)
std::unique_ptr<IWorkload> ReduceLayer::CreateWorkload(const IWorkloadFactory& factory) const
{
ReduceQueueDescriptor descriptor;
+ descriptor.m_Parameters.m_vAxis = m_Param.m_vAxis;
+ descriptor.m_Parameters.m_KeepDims = m_Param.m_KeepDims;
+ descriptor.m_Parameters.m_ReduceOperation = m_Param.m_ReduceOperation;
+ SetAdditionalInfo(descriptor);
+
return factory.CreateReduce(descriptor, PrepInfoAndDesc(descriptor));
}
ReduceLayer* ReduceLayer::Clone(Graph& graph) const
{
- return CloneBase<ReduceLayer>(graph, m_Param, GetName());
+ auto layer = CloneBase<ReduceLayer>(graph, m_Param, GetName());
+ layer->m_Param.m_vAxis = m_Param.m_vAxis;
+ layer->m_Param.m_KeepDims = m_Param.m_KeepDims;
+ layer->m_Param.m_ReduceOperation = m_Param.m_ReduceOperation;
+
+ return std::move(layer);
}
void ReduceLayer::ValidateTensorShapesFromInputs()