aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/armnn/layers/SpaceToDepthLayer.cpp4
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp26
2 files changed, 15 insertions, 15 deletions
diff --git a/src/armnn/layers/SpaceToDepthLayer.cpp b/src/armnn/layers/SpaceToDepthLayer.cpp
index b24490f82f..8a9f1c296c 100644
--- a/src/armnn/layers/SpaceToDepthLayer.cpp
+++ b/src/armnn/layers/SpaceToDepthLayer.cpp
@@ -47,8 +47,6 @@ std::vector<TensorShape> SpaceToDepthLayer::InferOutputShapes(const std::vector<
TensorShape inputShape = inputShapes[0];
TensorShape outputShape(inputShape);
- outputShape[0] = inputShape[0];
-
DataLayoutIndexed dimensionIndices{m_Param.m_DataLayout};
unsigned int hIndex = dimensionIndices.GetHeightIndex();
unsigned int wIndex = dimensionIndices.GetWidthIndex();
@@ -82,4 +80,4 @@ void SpaceToDepthLayer::Accept(ILayerVisitor& visitor) const
visitor.VisitSpaceToDepthLayer(this, GetParameters(), GetName());
}
-} // namespace armnn \ No newline at end of file
+} // namespace armnn
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 52d14097af..3fbdec7bf9 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -1408,29 +1408,31 @@ void SpaceToDepthQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) con
ValidateDataTypes(inputTensorInfo, supportedTypes, descriptorName);
ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName);
+ ValidateTensorNumElementsMatch(inputTensorInfo, outputTensorInfo, descriptorName, "input", "output");
+
+ if (m_Parameters.m_BlockSize == 0)
+ {
+ throw InvalidArgumentException(descriptorName + ": Block size cannot be 0.");
+ }
+
DataLayoutIndexed dimensionIndices(m_Parameters.m_DataLayout);
const unsigned int wIndex = dimensionIndices.GetWidthIndex();
const unsigned int hIndex = dimensionIndices.GetHeightIndex();
const unsigned int cIndex = dimensionIndices.GetChannelsIndex();
const TensorShape& inputShape = inputTensorInfo.GetShape();
-
- const unsigned int numInputElements =
- inputShape[0] * inputShape[wIndex] * inputShape[hIndex] * inputShape[cIndex];
- const unsigned int numOutputElements = outputTensorInfo.GetNumElements();
-
- if (numOutputElements != numInputElements)
- {
- throw InvalidArgumentException(descriptorName + ": Input tensor has " +
- std::to_string(numInputElements) + " but output tensor has " +
- std::to_string(numOutputElements) + " elements.");
- }
-
if (inputShape[hIndex] % m_Parameters.m_BlockSize != 0 || inputShape[wIndex] % m_Parameters.m_BlockSize != 0)
{
throw InvalidArgumentException(descriptorName + ": Input shape must be divisible "
"by block size in all spatial dimensions");
}
+
+ const TensorShape& outputShape = outputTensorInfo.GetShape();
+ if (outputShape[cIndex] % (m_Parameters.m_BlockSize * m_Parameters.m_BlockSize) != 0)
+ {
+ throw InvalidArgumentException(descriptorName + ": The depth of the output tensor"
+ "must be divisible by the square of block size." );
+ }
}
void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const