diff options
Diffstat (limited to 'src/backends/backendsCommon')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 103 | ||||
-rw-r--r-- | src/backends/backendsCommon/test/WorkloadDataValidation.cpp | 14 |
2 files changed, 109 insertions, 8 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index b850a65acf..1360ac5d0c 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -491,13 +491,29 @@ void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const ValidateNumInputs(workloadInfo, "AdditionQueueDescriptor", 2); ValidateNumOutputs(workloadInfo, "AdditionQueueDescriptor", 1); + std::vector<DataType> supportedTypes = { + DataType::Float32, + DataType::QuantisedAsymm8 + }; + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], + supportedTypes, + "AdditionQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[1], + supportedTypes, + "AdditionQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], + supportedTypes, + "AdditionQueueDescriptor"); + ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], workloadInfo.m_OutputTensorInfos[0], "AdditionQueueDescriptor", "first input", "second input"); - } //--------------------------------------------------------------- @@ -506,6 +522,23 @@ void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c ValidateNumInputs(workloadInfo, "MultiplicationQueueDescriptor", 2); ValidateNumOutputs(workloadInfo, "MultiplicationQueueDescriptor", 1); + std::vector<DataType> supportedTypes = { + DataType::Float32, + DataType::QuantisedAsymm8 + }; + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], + supportedTypes, + "MultiplicationQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[1], + supportedTypes, + "MultiplicationQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], + supportedTypes, + "MultiplicationQueueDescriptor"); + ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], workloadInfo.m_OutputTensorInfos[0], @@ -857,6 +890,23 @@ void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const ValidateNumInputs(workloadInfo, "DivisionQueueDescriptor", 2); ValidateNumOutputs(workloadInfo, "DivisionQueueDescriptor", 1); + std::vector<DataType> supportedTypes = { + DataType::Float32, + DataType::QuantisedAsymm8 + }; + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], + supportedTypes, + "DivisionQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[1], + supportedTypes, + "DivisionQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], + supportedTypes, + "DivisionQueueDescriptor"); + ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], workloadInfo.m_OutputTensorInfos[0], @@ -870,6 +920,23 @@ void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons ValidateNumInputs(workloadInfo, "SubtractionQueueDescriptor", 2); ValidateNumOutputs(workloadInfo, "SubtractionQueueDescriptor", 1); + std::vector<DataType> supportedTypes = { + DataType::Float32, + DataType::QuantisedAsymm8 + }; + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], + supportedTypes, + "SubtractionQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[1], + supportedTypes, + "SubtractionQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], + supportedTypes, + "SubtractionQueueDescriptor"); + ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], workloadInfo.m_OutputTensorInfos[0], @@ -883,6 +950,23 @@ void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const ValidateNumInputs(workloadInfo, "MaximumQueueDescriptor", 2); ValidateNumOutputs(workloadInfo, "MaximumQueueDescriptor", 1); + std::vector<DataType> supportedTypes = { + DataType::Float32, + DataType::QuantisedAsymm8 + }; + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], + supportedTypes, + "MaximumQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[1], + supportedTypes, + "MaximumQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], + supportedTypes, + "MaximumQueueDescriptor"); + ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], workloadInfo.m_OutputTensorInfos[0], @@ -1008,6 +1092,23 @@ void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const ValidateNumInputs(workloadInfo, "MinimumQueueDescriptor", 2); ValidateNumOutputs(workloadInfo, "MinimumQueueDescriptor", 1); + std::vector<DataType> supportedTypes = { + DataType::Float32, + DataType::QuantisedAsymm8 + }; + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], + supportedTypes, + "MinimumQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[1], + supportedTypes, + "MinimumQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], + supportedTypes, + "MinimumQueueDescriptor"); + ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], workloadInfo.m_OutputTensorInfos[0], diff --git a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp index 3664d56c28..d37cc74c66 100644 --- a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp +++ b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp @@ -312,17 +312,17 @@ BOOST_AUTO_TEST_CASE(AdditionQueueDescriptor_Validate_InputNumbers) AddOutputToWorkload(invalidData, invalidInfo, outputTensorInfo, nullptr); // Too few inputs. - BOOST_CHECK_THROW(RefAdditionFloat32Workload(invalidData, invalidInfo), armnn::InvalidArgumentException); + BOOST_CHECK_THROW(RefAdditionWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException); AddInputToWorkload(invalidData, invalidInfo, input2TensorInfo, nullptr); // Correct. - BOOST_CHECK_NO_THROW(RefAdditionFloat32Workload(invalidData, invalidInfo)); + BOOST_CHECK_NO_THROW(RefAdditionWorkload(invalidData, invalidInfo)); AddInputToWorkload(invalidData, invalidInfo, input3TensorInfo, nullptr); // Too many inputs. - BOOST_CHECK_THROW(RefAdditionFloat32Workload(invalidData, invalidInfo), armnn::InvalidArgumentException); + BOOST_CHECK_THROW(RefAdditionWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException); } BOOST_AUTO_TEST_CASE(AdditionQueueDescriptor_Validate_InputShapes) @@ -347,7 +347,7 @@ BOOST_AUTO_TEST_CASE(AdditionQueueDescriptor_Validate_InputShapes) AddInputToWorkload(invalidData, invalidInfo, input2TensorInfo, nullptr); AddOutputToWorkload(invalidData, invalidInfo, outputTensorInfo, nullptr); - BOOST_CHECK_THROW(RefAdditionFloat32Workload(invalidData, invalidInfo), armnn::InvalidArgumentException); + BOOST_CHECK_THROW(RefAdditionWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException); } // Output size not compatible with input sizes. @@ -364,7 +364,7 @@ BOOST_AUTO_TEST_CASE(AdditionQueueDescriptor_Validate_InputShapes) AddOutputToWorkload(invalidData, invalidInfo, outputTensorInfo, nullptr); // Output differs. - BOOST_CHECK_THROW(RefAdditionFloat32Workload(invalidData, invalidInfo), armnn::InvalidArgumentException); + BOOST_CHECK_THROW(RefAdditionWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException); } } @@ -399,7 +399,7 @@ BOOST_AUTO_TEST_CASE(MultiplicationQueueDescriptor_Validate_InputTensorDimension AddInputToWorkload(invalidData, invalidInfo, input0TensorInfo, nullptr); AddInputToWorkload(invalidData, invalidInfo, input1TensorInfo, nullptr); - BOOST_CHECK_THROW(RefMultiplicationFloat32Workload(invalidData, invalidInfo), armnn::InvalidArgumentException); + BOOST_CHECK_THROW(RefMultiplicationWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException); } // Checks dimension consistency for input and output tensors. @@ -424,7 +424,7 @@ BOOST_AUTO_TEST_CASE(MultiplicationQueueDescriptor_Validate_InputTensorDimension AddInputToWorkload(invalidData, invalidInfo, input0TensorInfo, nullptr); AddInputToWorkload(invalidData, invalidInfo, input1TensorInfo, nullptr); - BOOST_CHECK_THROW(RefMultiplicationFloat32Workload(invalidData, invalidInfo), armnn::InvalidArgumentException); + BOOST_CHECK_THROW(RefMultiplicationWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException); } } |