From 32b9046ea74d2387a08819cf5e67c183e03f6d3f Mon Sep 17 00:00:00 2001 From: narpra01 Date: Thu, 13 Sep 2018 11:07:48 +0100 Subject: IVGCVSW-1813 - Add MeanLayer * add MeanLayer functionalities * modify MeanQueueDescriptor to use parameter * add IsMeanSupported placeholder for all backends Change-Id: Ic69a34a61df667849977aad9b38f9a01eef565b5 --- src/armnn/backends/WorkloadData.cpp | 21 +++------------------ 1 file changed, 3 insertions(+), 18 deletions(-) (limited to 'src/armnn/backends/WorkloadData.cpp') diff --git a/src/armnn/backends/WorkloadData.cpp b/src/armnn/backends/WorkloadData.cpp index 3ed77dacdb..25144a4753 100644 --- a/src/armnn/backends/WorkloadData.cpp +++ b/src/armnn/backends/WorkloadData.cpp @@ -129,18 +129,6 @@ void ValidateTensorNumDimensions(const TensorInfo& tensor, } } -void ValidateTensorMaxNumElements(const TensorInfo& tensor, - std::string const& descName, - unsigned int maxNumElements, - std::string const& tensorName) -{ - if (tensor.GetNumElements() > maxNumElements) - { - throw InvalidArgumentException(descName + ": Expected maximum of " + to_string(maxNumElements) + " but got " + - to_string(tensor.GetNumElements()) + " elements for " + tensorName + " tensor."); - } -} - //--------------------------------------------------------------- void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType, const std::string& descName, std::string const& tensorName) @@ -844,20 +832,17 @@ void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const const TensorInfo& input = workloadInfo.m_InputTensorInfos[0]; const TensorInfo& output = workloadInfo.m_OutputTensorInfos[0]; - if (m_Keepdims) + if (m_Parameters.m_KeepDims) { ValidateTensorNumDimensions(output, "MeanQueueDescriptor", input.GetNumDimensions(), "output"); } - else if (m_Axis == nullptr) + else if (m_Parameters.m_Axis.empty()) { ValidateTensorNumDimensions(output, "MeanQueueDescriptor", 1, "output"); } else { - const TensorInfo& axis = m_Axis->GetTensorInfo(); - ValidateTensorNumDimensions(axis, "MeanQueueDescriptor", 1, "axis"); - ValidateTensorMaxNumElements(axis, "MeanQueueDescriptor", input.GetNumDimensions(), "axis"); - unsigned int outputDim = input.GetNumDimensions() - axis.GetNumElements(); + auto outputDim = input.GetNumDimensions() - boost::numeric_cast(m_Parameters.m_Axis.size()); ValidateTensorNumDimensions(output, "MeanQueueDescriptor", outputDim > 0 ? outputDim : 1, -- cgit v1.2.1