diff options
Diffstat (limited to 'src/backends/WorkloadData.cpp')
-rw-r--r-- | src/backends/WorkloadData.cpp | 40 |
1 files changed, 40 insertions, 0 deletions
diff --git a/src/backends/WorkloadData.cpp b/src/backends/WorkloadData.cpp index ef31fbd1fb..495d4ecde9 100644 --- a/src/backends/WorkloadData.cpp +++ b/src/backends/WorkloadData.cpp @@ -741,6 +741,46 @@ void ReshapeQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const } } +void SpaceToBatchNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const +{ + ValidateSingleInput(workloadInfo, "SpaceToBatchNdQueueDescriptor"); + ValidateSingleOutput(workloadInfo, "SpaceToBatchNdQueueDescriptor"); + + ValidateTensorNumDimensions(workloadInfo.m_InputTensorInfos[0], "SpaceToBatchNdQueueDescriptor", 4, "input"); + ValidateTensorNumDimensions(workloadInfo.m_OutputTensorInfos[0], "SpaceToBatchNdQueueDescriptor", 4, "output"); + + if (workloadInfo.m_InputTensorInfos[0].GetNumElements() != workloadInfo.m_OutputTensorInfos[0].GetNumElements()) + { + throw InvalidArgumentException("SpaceToBatchNdQueueDescriptor: Input tensor has " + + to_string(workloadInfo.m_InputTensorInfos[0].GetNumElements()) + " but output tensor has " + + to_string(workloadInfo.m_OutputTensorInfos[0].GetNumElements()) + " elements."); + } + + if (m_Parameters.m_BlockShape.size() != 2) + { + throw InvalidArgumentException("Block Shape must contains 2 spatial dimensions"); + } + + if (m_Parameters.m_BlockShape.size() != m_Parameters.m_PadList.size()) + { + throw InvalidArgumentException("Pad List must contains the same number of dimensions as Block Shape."); + } + + const TensorShape inputShape = workloadInfo.m_InputTensorInfos[0].GetShape(); + + std::pair<unsigned int, unsigned int> heightPad = m_Parameters.m_PadList[0]; + std::pair<unsigned int, unsigned int> widthPad = m_Parameters.m_PadList[1]; + + if ((inputShape[m_Parameters.m_DataLayout.GetHeightIndex()] + heightPad.first + heightPad.second) + % m_Parameters.m_BlockShape[0] != 0 || + (inputShape[m_Parameters.m_DataLayout.GetWidthIndex()] + widthPad.first + widthPad.second) + % m_Parameters.m_BlockShape[1] != 0) + { + throw InvalidArgumentException( + "Input shape after padding must be divisible by Block Shape in all spatial dimensions"); + } +} + void FloorQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const { ValidateSingleInput(workloadInfo, "FloorQueueDescriptor"); |