diff options
author | Tianle Cheng <tianle.cheng@arm.com> | 2024-02-23 17:56:54 +0000 |
---|---|---|
committer | Kevin May <kevin.may@arm.com> | 2024-02-28 16:12:34 +0000 |
commit | 282881877522d3e94752dfc0839de9bfa0aa5a81 (patch) | |
tree | 9cd11c96eb4c179e76f2e586d5a9d9b416dd85a0 /src/backends/backendsCommon/WorkloadData.cpp | |
parent | 2883a86c5a167aea3c736529bff5921ab6cbc99c (diff) | |
download | armnn-282881877522d3e94752dfc0839de9bfa0aa5a81.tar.gz |
IVGCVSW-8229 & IVGCVSW-8237 ScatterNd: Front end and reference implementation
(scatter_nd, scatter_nd_add, and scatter_nd_update, scatter_nd_sub, scatter_nd_min, scatter_nd_max, scatter_nd_mul)
* Front end support for ScatterNd added.
* Reference implementation for ScatterNd added.
* Unit tests added.
Signed-off-by: Tianle Cheng <tianle.cheng@arm.com>
Change-Id: I30da9056d9b03ca9b5fb8d09987341128badbcf4
Diffstat (limited to 'src/backends/backendsCommon/WorkloadData.cpp')
-rw-r--r-- | src/backends/backendsCommon/WorkloadData.cpp | 42 |
1 files changed, 42 insertions, 0 deletions
diff --git a/src/backends/backendsCommon/WorkloadData.cpp b/src/backends/backendsCommon/WorkloadData.cpp index 0ddb4291f1..de985ec28d 100644 --- a/src/backends/backendsCommon/WorkloadData.cpp +++ b/src/backends/backendsCommon/WorkloadData.cpp @@ -4443,4 +4443,46 @@ void BroadcastToQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) cons ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName); } +void ScatterNdQueueDescriptor::Validate(const WorkloadInfo& workloadInfo) const +{ + const std::string& descriptorName{"ScatterQueueDescriptor"}; + + ValidateNumInputs(workloadInfo, descriptorName, 3); + ValidateNumOutputs(workloadInfo, descriptorName, 1); + + const TensorInfo& inputTensorInfo0 = workloadInfo.m_InputTensorInfos[0]; + const TensorInfo& inputTensorInfo1 = workloadInfo.m_InputTensorInfos[1]; + const TensorInfo& inputTensorInfo2 = workloadInfo.m_InputTensorInfos[2]; + const TensorInfo& outputTensorInfo = workloadInfo.m_OutputTensorInfos[0]; + + std::vector<DataType> supportedTypes = + { + DataType::Float32, + DataType::Float16, + DataType::QAsymmS8, + DataType::QAsymmU8, + DataType::QSymmS8, + DataType::QSymmS16, + DataType::Signed32 + }; + + std::vector<DataType> indicesSupportedTypes = + { + DataType::Signed32 + }; + + if (m_Parameters.m_InputEnabled) + { + ValidateDataTypes(inputTensorInfo0, supportedTypes, descriptorName); + } + else + { + ValidateDataTypes(inputTensorInfo0, indicesSupportedTypes, descriptorName); + } + + ValidateDataTypes(inputTensorInfo1, indicesSupportedTypes, descriptorName); + ValidateDataTypes(inputTensorInfo2, supportedTypes, descriptorName); + ValidateDataTypes(outputTensorInfo, supportedTypes, descriptorName); +} + } // namespace armnn
\ No newline at end of file |