aboutsummaryrefslogtreecommitdiff
path: root/src/backends/backendsCommon/WorkloadData.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r--src/backends/backendsCommon/WorkloadData.cpp103
1 files changed, 102 insertions, 1 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp
index b850a65acf..1360ac5d0c 100644
--- a/src/backends/backendsCommon/WorkloadData.cpp
+++ b/src/backends/backendsCommon/WorkloadData.cpp
@@ -491,13 +491,29 @@ void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
ValidateNumInputs(workloadInfo, "AdditionQueueDescriptor", 2);
ValidateNumOutputs(workloadInfo, "AdditionQueueDescriptor", 1);
+ std::vector<DataType> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8
+ };
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
+ supportedTypes,
+ "AdditionQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[1],
+ supportedTypes,
+ "AdditionQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0],
+ supportedTypes,
+ "AdditionQueueDescriptor");
+
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
workloadInfo.m_OutputTensorInfos[0],
"AdditionQueueDescriptor",
"first input",
"second input");
-
}
//---------------------------------------------------------------
@@ -506,6 +522,23 @@ void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c
ValidateNumInputs(workloadInfo, "MultiplicationQueueDescriptor", 2);
ValidateNumOutputs(workloadInfo, "MultiplicationQueueDescriptor", 1);
+ std::vector<DataType> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8
+ };
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
+ supportedTypes,
+ "MultiplicationQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[1],
+ supportedTypes,
+ "MultiplicationQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0],
+ supportedTypes,
+ "MultiplicationQueueDescriptor");
+
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
workloadInfo.m_OutputTensorInfos[0],
@@ -857,6 +890,23 @@ void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
ValidateNumInputs(workloadInfo, "DivisionQueueDescriptor", 2);
ValidateNumOutputs(workloadInfo, "DivisionQueueDescriptor", 1);
+ std::vector<DataType> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8
+ };
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
+ supportedTypes,
+ "DivisionQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[1],
+ supportedTypes,
+ "DivisionQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0],
+ supportedTypes,
+ "DivisionQueueDescriptor");
+
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
workloadInfo.m_OutputTensorInfos[0],
@@ -870,6 +920,23 @@ void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons
ValidateNumInputs(workloadInfo, "SubtractionQueueDescriptor", 2);
ValidateNumOutputs(workloadInfo, "SubtractionQueueDescriptor", 1);
+ std::vector<DataType> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8
+ };
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
+ supportedTypes,
+ "SubtractionQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[1],
+ supportedTypes,
+ "SubtractionQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0],
+ supportedTypes,
+ "SubtractionQueueDescriptor");
+
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
workloadInfo.m_OutputTensorInfos[0],
@@ -883,6 +950,23 @@ void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
ValidateNumInputs(workloadInfo, "MaximumQueueDescriptor", 2);
ValidateNumOutputs(workloadInfo, "MaximumQueueDescriptor", 1);
+ std::vector<DataType> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8
+ };
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
+ supportedTypes,
+ "MaximumQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[1],
+ supportedTypes,
+ "MaximumQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0],
+ supportedTypes,
+ "MaximumQueueDescriptor");
+
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
workloadInfo.m_OutputTensorInfos[0],
@@ -1008,6 +1092,23 @@ void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const
ValidateNumInputs(workloadInfo, "MinimumQueueDescriptor", 2);
ValidateNumOutputs(workloadInfo, "MinimumQueueDescriptor", 1);
+ std::vector<DataType> supportedTypes = {
+ DataType::Float32,
+ DataType::QuantisedAsymm8
+ };
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[0],
+ supportedTypes,
+ "MinimumQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_InputTensorInfos[1],
+ supportedTypes,
+ "MinimumQueueDescriptor");
+
+ ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0],
+ supportedTypes,
+ "MinimumQueueDescriptor");
+
ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0],
workloadInfo.m_InputTensorInfos[1],
workloadInfo.m_OutputTensorInfos[0],