diff options
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index b8d4f0dfff..cfb38b4820 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -2779,4 +2779,28 @@ void DepthToSpaceQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) con } } +void ComparisonQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const +{ + const std::string descriptorName{"ComparisonQueueDescriptor"}; + + ValidateNumInputs(workloadInfo, descriptorName, 2); + ValidateNumOutputs(workloadInfo, descriptorName, 1); + + const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0]; + const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1]; + const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0]; + + ValidateBroadcastTensorShapesMatch(inputTensorInfo0, + inputTensorInfo1, + outputTensorInfo, + descriptorName, + "input_0", + "input_1"); + + if (outputTensorInfo.GetDataType() != DataType::Boolean) + { + throw InvalidArgumentException(descriptorName + ": Output tensor type must be Boolean."); + } +} + } // namespace armnn |