diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 32 |
1 files changed, 19 insertions, 13 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 62cbd05c13..7c02947b32 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -749,21 +749,14 @@ void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "SpaceToBatchNdQueueDescriptor", 4, "input"); ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "SpaceToBatchNdQueueDescriptor", 4, "output"); - if (workloadInfo.m_InputTensorInfos[0].GetNumElements() != workloadInfo.m_OutputTensorInfos[0].GetNumElements()) - { - throw InvalidArgumentException("SpaceToBatchNdQueueDescriptor: Input tensor has " + - to_string(workloadInfo.m_InputTensorInfos[0].GetNumElements()) + " but output tensor has " + - to_string(workloadInfo.m_OutputTensorInfos[0].GetNumElements()) + " elements."); - } - if (m_Parameters.m_BlockShape.size() != 2) { - throw InvalidArgumentException("Block Shape must contains 2 spatial dimensions"); + throw InvalidArgumentException("Block Shape must contain 2 spatial dimensions"); } if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size()) { - throw InvalidArgumentException("Pad List must contains the same number of dimensions as Block Shape."); + throw InvalidArgumentException("Pad List must contain the same number of dimensions as Block Shape."); } const TensorShape inputShape = workloadInfo.m_InputTensorInfos[0].GetShape(); @@ -771,10 +764,23 @@ void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0]; std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1]; - if ((inputShape[m_Parameters.m_DataLayout.GetHeightIndex()] + heightPad.first + heightPad.second) - % m_Parameters.m_BlockShape[0] != 0 || - (inputShape[m_Parameters.m_DataLayout.GetWidthIndex()] + widthPad.first + widthPad.second) - % m_Parameters.m_BlockShape[1] != 0) + unsigned int inputHeight = inputShape[m_Parameters.m_DataLayout.GetHeightIndex()] + + heightPad.first + heightPad.second; + + unsigned int inputWidth = inputShape[m_Parameters.m_DataLayout.GetWidthIndex()] + + widthPad.first + widthPad.second; + + unsigned int numInputElements = inputShape[0] * inputHeight * inputWidth + * inputShape[m_Parameters.m_DataLayout.GetChannelsIndex()]; + + if (workloadInfo.m_OutputTensorInfos[0].GetNumElements() != numInputElements) + { + throw InvalidArgumentException("SpaceToBatchNdQueueDescriptor: Input tensor has " + + to_string(numInputElements) + " after padding but output tensor has " + + to_string(workloadInfo.m_OutputTensorInfos[0].GetNumElements()) + " elements."); + } + + if (inputHeight % m_Parameters.m_BlockShape[0] != 0 || inputWidth % m_Parameters.m_BlockShape[1] != 0) { throw InvalidArgumentException( "Input shape after padding must be divisible by Block Shape in all spatial dimensions"); |