diff options
author | Sadik Armagan <sadik.armagan@arm.com> | 2019-04-03 17:48:18 +0100 |
---|---|---|
committer | Sadik Armagan <sadik.armagan@arm.com> | 2019-04-08 15:48:28 +0000 |
commit | 2e6dc3a1c5d47825535db7993ba77eb1596ae99b (patch) | |
tree | 48e73fa1862d17534804d1699bedb76120e88c9f /src/backends/backendsCommon/WorkloadData.cpp | |
parent | 0324f48e64edb99a5c8d819394545d97e0c2ae97 (diff) | |
download | armnn-2e6dc3a1c5d47825535db7993ba77eb1596ae99b.tar.gz |
IVGCVSW-2861 Refactor the Reference Elementwise workload
* Refactor Reference Comparison workload
* Removed templating based on the DataType
* Implemented BaseIterator to do decode/encode
Change-Id: I18f299f47ee23772f90152c1146b42f07465e105
Signed-off-by: Sadik Armagan <sadik.armagan@arm.com>
Signed-off-by: Kevin May <kevin.may@arm.com>
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 103 |
1 files changed, 102 insertions, 1 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index b850a65acf..1360ac5d0c 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -491,13 +491,29 @@ void AdditionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const ValidateNumInputs(workloadInfo, "AdditionQueueDescriptor", 2); ValidateNumOutputs(workloadInfo, "AdditionQueueDescriptor", 1); + std::vector<DataType> supportedTypes = { + DataType::Float32, + DataType::QuantisedAsymm8 + }; + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], + supportedTypes, + "AdditionQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[1], + supportedTypes, + "AdditionQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], + supportedTypes, + "AdditionQueueDescriptor"); + ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], workloadInfo.m_OutputTensorInfos[0], "AdditionQueueDescriptor", "first input", "second input"); - } //--------------------------------------------------------------- @@ -506,6 +522,23 @@ void MultiplicationQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) c ValidateNumInputs(workloadInfo, "MultiplicationQueueDescriptor", 2); ValidateNumOutputs(workloadInfo, "MultiplicationQueueDescriptor", 1); + std::vector<DataType> supportedTypes = { + DataType::Float32, + DataType::QuantisedAsymm8 + }; + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], + supportedTypes, + "MultiplicationQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[1], + supportedTypes, + "MultiplicationQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], + supportedTypes, + "MultiplicationQueueDescriptor"); + ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], workloadInfo.m_OutputTensorInfos[0], @@ -857,6 +890,23 @@ void DivisionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const ValidateNumInputs(workloadInfo, "DivisionQueueDescriptor", 2); ValidateNumOutputs(workloadInfo, "DivisionQueueDescriptor", 1); + std::vector<DataType> supportedTypes = { + DataType::Float32, + DataType::QuantisedAsymm8 + }; + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], + supportedTypes, + "DivisionQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[1], + supportedTypes, + "DivisionQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], + supportedTypes, + "DivisionQueueDescriptor"); + ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], workloadInfo.m_OutputTensorInfos[0], @@ -870,6 +920,23 @@ void SubtractionQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons ValidateNumInputs(workloadInfo, "SubtractionQueueDescriptor", 2); ValidateNumOutputs(workloadInfo, "SubtractionQueueDescriptor", 1); + std::vector<DataType> supportedTypes = { + DataType::Float32, + DataType::QuantisedAsymm8 + }; + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], + supportedTypes, + "SubtractionQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[1], + supportedTypes, + "SubtractionQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], + supportedTypes, + "SubtractionQueueDescriptor"); + ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], workloadInfo.m_OutputTensorInfos[0], @@ -883,6 +950,23 @@ void MaximumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const ValidateNumInputs(workloadInfo, "MaximumQueueDescriptor", 2); ValidateNumOutputs(workloadInfo, "MaximumQueueDescriptor", 1); + std::vector<DataType> supportedTypes = { + DataType::Float32, + DataType::QuantisedAsymm8 + }; + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], + supportedTypes, + "MaximumQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[1], + supportedTypes, + "MaximumQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], + supportedTypes, + "MaximumQueueDescriptor"); + ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], workloadInfo.m_OutputTensorInfos[0], @@ -1008,6 +1092,23 @@ void MinimumQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const ValidateNumInputs(workloadInfo, "MinimumQueueDescriptor", 2); ValidateNumOutputs(workloadInfo, "MinimumQueueDescriptor", 1); + std::vector<DataType> supportedTypes = { + DataType::Float32, + DataType::QuantisedAsymm8 + }; + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[0], + supportedTypes, + "MinimumQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_InputTensorInfos[1], + supportedTypes, + "MinimumQueueDescriptor"); + + ValidateDataTypes(workloadInfo.m_OutputTensorInfos[0], + supportedTypes, + "MinimumQueueDescriptor"); + ValidateBroadcastTensorShapesMatch(workloadInfo.m_InputTensorInfos[0], workloadInfo.m_InputTensorInfos[1], workloadInfo.m_OutputTensorInfos[0], |