aboutsummaryrefslogtreecommitdiff
path: root/src/armnn/backends/WorkloadData.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnn/backends/WorkloadData.cpp')
-rw-r--r--src/armnn/backends/WorkloadData.cpp21
1 files changed, 3 insertions, 18 deletions
diff --git a/src/armnn/backends/WorkloadData.cpp b/src/armnn/backends/WorkloadData.cpp
index 3ed77dacdb..25144a4753 100644
--- a/src/armnn/backends/WorkloadData.cpp
+++ b/src/armnn/backends/WorkloadData.cpp
@@ -129,18 +129,6 @@ void ValidateTensorNumDimensions(const TensorInfo& tensor,
}
}
-void ValidateTensorMaxNumElements(const TensorInfo& tensor,
- std::string const& descName,
- unsigned int maxNumElements,
- std::string const& tensorName)
-{
- if (tensor.GetNumElements() > maxNumElements)
- {
- throw InvalidArgumentException(descName + ": Expected maximum of " + to_string(maxNumElements) + " but got " +
- to_string(tensor.GetNumElements()) + " elements for " + tensorName + " tensor.");
- }
-}
-
//---------------------------------------------------------------
void ValidateTensorDataType(const TensorInfo& tensor, DataType dataType,
const std::string& descName, std::string const& tensorName)
@@ -844,20 +832,17 @@ void MeanQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
const TensorInfo& input = workloadInfo.m_InputTensorInfos[0];
const TensorInfo& output = workloadInfo.m_OutputTensorInfos[0];
- if (m_Keepdims)
+ if (m_Parameters.m_KeepDims)
{
ValidateTensorNumDimensions(output, "MeanQueueDescriptor", input.GetNumDimensions(), "output");
}
- else if (m_Axis == nullptr)
+ else if (m_Parameters.m_Axis.empty())
{
ValidateTensorNumDimensions(output, "MeanQueueDescriptor", 1, "output");
}
else
{
- const TensorInfo& axis = m_Axis->GetTensorInfo();
- ValidateTensorNumDimensions(axis, "MeanQueueDescriptor", 1, "axis");
- ValidateTensorMaxNumElements(axis, "MeanQueueDescriptor", input.GetNumDimensions(), "axis");
- unsigned int outputDim = input.GetNumDimensions() - axis.GetNumElements();
+ auto outputDim = input.GetNumDimensions() - boost::numeric_cast<unsigned int>(m_Parameters.m_Axis.size());
ValidateTensorNumDimensions(output,
"MeanQueueDescriptor",
outputDim > 0 ? outputDim : 1,