aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadData.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp32
1 files changed, 19 insertions, 13 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index 62cbd05c13..7c02947b32 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -749,21 +749,14 @@ void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "SpaceToBatchNdQueueDescriptor", 4, "input");
ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "SpaceToBatchNdQueueDescriptor", 4, "output");
- if (workloadInfo.m_InputTensorInfos[0].GetNumElements() != workloadInfo.m_OutputTensorInfos[0].GetNumElements())
- {
- throw InvalidArgumentException("SpaceToBatchNdQueueDescriptor: Input tensor has " +
- to_string(workloadInfo.m_InputTensorInfos[0].GetNumElements()) + " but output tensor has " +
- to_string(workloadInfo.m_OutputTensorInfos[0].GetNumElements()) + " elements.");
- }
-
if (m_Parameters.m_BlockShape.size() != 2)
{
- throw InvalidArgumentException("Block Shape must contains 2 spatial dimensions");
+ throw InvalidArgumentException("Block Shape must contain 2 spatial dimensions");
}
if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size())
{
- throw InvalidArgumentException("Pad List must contains the same number of dimensions as Block Shape.");
+ throw InvalidArgumentException("Pad List must contain the same number of dimensions as Block Shape.");
}
const TensorShape inputShape = workloadInfo.m_InputTensorInfos[0].GetShape();
@@ -771,10 +764,23 @@ void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0];
std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1];
- if ((inputShape[m_Parameters.m_DataLayout.GetHeightIndex()] + heightPad.first + heightPad.second)
- % m_Parameters.m_BlockShape[0] != 0 ||
- (inputShape[m_Parameters.m_DataLayout.GetWidthIndex()] + widthPad.first + widthPad.second)
- % m_Parameters.m_BlockShape[1] != 0)
+ unsigned int inputHeight = inputShape[m_Parameters.m_DataLayout.GetHeightIndex()]
+ + heightPad.first + heightPad.second;
+
+ unsigned int inputWidth = inputShape[m_Parameters.m_DataLayout.GetWidthIndex()]
+ + widthPad.first + widthPad.second;
+
+ unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth
+ * inputShape[m_Parameters.m_DataLayout.GetChannelsIndex()];
+
+ if (workloadInfo.m_OutputTensorInfos[0].GetNumElements() != numInputElements)
+ {
+ throw InvalidArgumentException("SpaceToBatchNdQueueDescriptor: Input tensor has " +
+ to_string(numInputElements) + " after padding but output tensor has " +
+ to_string(workloadInfo.m_OutputTensorInfos[0].GetNumElements()) + " elements.");
+ }
+
+ if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0)
{
throw InvalidArgumentException(
"Input shape after padding must be divisible by Block Shape in all spatial dimensions");