aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon
diff options
context:
space:
mode:
authorSadik Armagan <sadik.armagan@arm.com>2019-04-03 17:48:18 +0100
committerSadik Armagan <sadik.armagan@arm.com>2019-04-08 15:48:28 +0000
commit2e6dc3a1c5d47825535db7993ba77eb1596ae99b (patch)
tree48e73fa1862d17534804d1699bedb76120e88c9f /src/backends/backendsCommon
parent0324f48e64edb99a5c8d819394545d97e0c2ae97 (diff)
downloadarmnn-2e6dc3a1c5d47825535db7993ba77eb1596ae99b.tar.gz
IVGCVSW-2861 Refactor the Reference Elementwise workload
* Refactor Reference Comparison workload * Removed templating based on the DataType * Implemented BaseIterator to do decode/encode Change-Id: I18f299f47ee23772f90152c1146b42f07465e105 Signed-off-by: Sadik Armagan <sadik.armagan@arm.com> Signed-off-by: Kevin May <kevin.may@arm.com>
Diffstat (limited to 'src/backends/backendsCommon')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp103
-rw-r--r--src/backends/backendsCommon/test/WorkloadDataValidation.cpp14
2 files changed, 109 insertions, 8 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index b850a65acf..1360ac5d0c 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -491,13 +491,29 @@ void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
ValidateNumInputs(workloadInfo, "AdditionQueueDescriptor", 2);
ValidateNumOutputs(workloadInfo, "AdditionQueueDescriptor", 1);
+ std::vector<DataType> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8
+ };
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
+ supportedTypes,
+ "AdditionQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[1],
+ supportedTypes,
+ "AdditionQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0],
+ supportedTypes,
+ "AdditionQueueDescriptor");
+
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
workloadInfo.m_OutputTensorInfos[0],
"AdditionQueueDescriptor",
"first input",
"second input");
-
}
//---------------------------------------------------------------
@@ -506,6 +522,23 @@ void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
ValidateNumInputs(workloadInfo, "MultiplicationQueueDescriptor", 2);
ValidateNumOutputs(workloadInfo, "MultiplicationQueueDescriptor", 1);
+ std::vector<DataType> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8
+ };
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
+ supportedTypes,
+ "MultiplicationQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[1],
+ supportedTypes,
+ "MultiplicationQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0],
+ supportedTypes,
+ "MultiplicationQueueDescriptor");
+
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
workloadInfo.m_OutputTensorInfos[0],
@@ -857,6 +890,23 @@ void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
ValidateNumInputs(workloadInfo, "DivisionQueueDescriptor", 2);
ValidateNumOutputs(workloadInfo, "DivisionQueueDescriptor", 1);
+ std::vector<DataType> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8
+ };
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
+ supportedTypes,
+ "DivisionQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[1],
+ supportedTypes,
+ "DivisionQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0],
+ supportedTypes,
+ "DivisionQueueDescriptor");
+
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
workloadInfo.m_OutputTensorInfos[0],
@@ -870,6 +920,23 @@ void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons
ValidateNumInputs(workloadInfo, "SubtractionQueueDescriptor", 2);
ValidateNumOutputs(workloadInfo, "SubtractionQueueDescriptor", 1);
+ std::vector<DataType> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8
+ };
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
+ supportedTypes,
+ "SubtractionQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[1],
+ supportedTypes,
+ "SubtractionQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0],
+ supportedTypes,
+ "SubtractionQueueDescriptor");
+
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
workloadInfo.m_OutputTensorInfos[0],
@@ -883,6 +950,23 @@ void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
ValidateNumInputs(workloadInfo, "MaximumQueueDescriptor", 2);
ValidateNumOutputs(workloadInfo, "MaximumQueueDescriptor", 1);
+ std::vector<DataType> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8
+ };
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
+ supportedTypes,
+ "MaximumQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[1],
+ supportedTypes,
+ "MaximumQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0],
+ supportedTypes,
+ "MaximumQueueDescriptor");
+
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
workloadInfo.m_OutputTensorInfos[0],
@@ -1008,6 +1092,23 @@ void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
ValidateNumInputs(workloadInfo, "MinimumQueueDescriptor", 2);
ValidateNumOutputs(workloadInfo, "MinimumQueueDescriptor", 1);
+ std::vector<DataType> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8
+ };
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
+ supportedTypes,
+ "MinimumQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[1],
+ supportedTypes,
+ "MinimumQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0],
+ supportedTypes,
+ "MinimumQueueDescriptor");
+
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
workloadInfo.m_OutputTensorInfos[0],
diff --git a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
index 3664d56c28..d37cc74c66 100644
--- a/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
+++ b/src/backends/backendsCommon/test/WorkloadDataValidation.cpp
@@ -312,17 +312,17 @@ BOOST_AUTO_TEST_CASE(AdditionQueueDescriptor_Validate_InputNumbers)
AddOutputToWorkload(invalidData, invalidInfo, outputTensorInfo, nullptr);
// Too few inputs.
- BOOST_CHECK_THROW(RefAdditionFloat32Workload(invalidData, invalidInfo), armnn::InvalidArgumentException);
+ BOOST_CHECK_THROW(RefAdditionWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException);
AddInputToWorkload(invalidData, invalidInfo, input2TensorInfo, nullptr);
// Correct.
- BOOST_CHECK_NO_THROW(RefAdditionFloat32Workload(invalidData, invalidInfo));
+ BOOST_CHECK_NO_THROW(RefAdditionWorkload(invalidData, invalidInfo));
AddInputToWorkload(invalidData, invalidInfo, input3TensorInfo, nullptr);
// Too many inputs.
- BOOST_CHECK_THROW(RefAdditionFloat32Workload(invalidData, invalidInfo), armnn::InvalidArgumentException);
+ BOOST_CHECK_THROW(RefAdditionWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException);
}
BOOST_AUTO_TEST_CASE(AdditionQueueDescriptor_Validate_InputShapes)
@@ -347,7 +347,7 @@ BOOST_AUTO_TEST_CASE(AdditionQueueDescriptor_Validate_InputShapes)
AddInputToWorkload(invalidData, invalidInfo, input2TensorInfo, nullptr);
AddOutputToWorkload(invalidData, invalidInfo, outputTensorInfo, nullptr);
- BOOST_CHECK_THROW(RefAdditionFloat32Workload(invalidData, invalidInfo), armnn::InvalidArgumentException);
+ BOOST_CHECK_THROW(RefAdditionWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException);
}
// Output size not compatible with input sizes.
@@ -364,7 +364,7 @@ BOOST_AUTO_TEST_CASE(AdditionQueueDescriptor_Validate_InputShapes)
AddOutputToWorkload(invalidData, invalidInfo, outputTensorInfo, nullptr);
// Output differs.
- BOOST_CHECK_THROW(RefAdditionFloat32Workload(invalidData, invalidInfo), armnn::InvalidArgumentException);
+ BOOST_CHECK_THROW(RefAdditionWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException);
}
}
@@ -399,7 +399,7 @@ BOOST_AUTO_TEST_CASE(MultiplicationQueueDescriptor_Validate_InputTensorDimension
AddInputToWorkload(invalidData, invalidInfo, input0TensorInfo, nullptr);
AddInputToWorkload(invalidData, invalidInfo, input1TensorInfo, nullptr);
- BOOST_CHECK_THROW(RefMultiplicationFloat32Workload(invalidData, invalidInfo), armnn::InvalidArgumentException);
+ BOOST_CHECK_THROW(RefMultiplicationWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException);
}
// Checks dimension consistency for input and output tensors.
@@ -424,7 +424,7 @@ BOOST_AUTO_TEST_CASE(MultiplicationQueueDescriptor_Validate_InputTensorDimension
AddInputToWorkload(invalidData, invalidInfo, input0TensorInfo, nullptr);
AddInputToWorkload(invalidData, invalidInfo, input1TensorInfo, nullptr);
- BOOST_CHECK_THROW(RefMultiplicationFloat32Workload(invalidData, invalidInfo), armnn::InvalidArgumentException);
+ BOOST_CHECK_THROW(RefMultiplicationWorkload(invalidData, invalidInfo), armnn::InvalidArgumentException);
}
}