diff options
author | Patryk Kaiser <patryk.kaiser@arm.com> | 2024-06-27 12:29:48 +0100 |
---|---|---|
committer | Colm Donelan <colm.donelan@arm.com> | 2024-07-01 16:55:55 +0000 |
commit | 656285153399d96ead5925db907d0ec1961dfd76 (patch) | |
tree | 74ae6dbd3199c0297c0321d5998cfd546537af26 /delegate/classic/src | |
parent | 443804adee542d4330713e8dda6357b9495856fa (diff) | |
download | armnn-656285153399d96ead5925db907d0ec1961dfd76.tar.gz |
IVGCVSW-8139 Fixing Broadcast OP DTS tests
* In ref broadcast layer added broadcast support to expand tensor shapes
* Added function to check for zero dimension tensors
* Added check for unsupported zero dimension tensors during broadcast
* Added DelegateUtils unit test file with unit tests for the new function
Signed-off-by: Patryk Kaiser <patryk.kaiser@arm.com>
Change-Id: If4e786f7ba580399e781c48335888e8da8458019
Diffstat (limited to 'delegate/classic/src')
-rw-r--r-- | delegate/classic/src/BroadcastTo.hpp | 12 |
1 files changed, 11 insertions, 1 deletions
diff --git a/delegate/classic/src/BroadcastTo.hpp b/delegate/classic/src/BroadcastTo.hpp index 92aed79982..2e2b3ab155 100644 --- a/delegate/classic/src/BroadcastTo.hpp +++ b/delegate/classic/src/BroadcastTo.hpp @@ -1,11 +1,12 @@ // -// Copyright © 2023 Arm Ltd and Contributors. All rights reserved. +// Copyright © 2023-2024 Arm Ltd and Contributors. All rights reserved. // SPDX-License-Identifier: MIT // #pragma once #include <armnn/utility/IgnoreUnused.hpp> +#include <DelegateUtils.hpp> #include <tensorflow/lite/builtin_ops.h> #include <tensorflow/lite/c/builtin_op_data.h> @@ -83,6 +84,15 @@ namespace armnnDelegate const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor); const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor); + if (ZeroDimPresent({inputTensorInfo, outputTensorInfo})) + { + TF_LITE_MAYBE_KERNEL_LOG( + tfLiteContext, + "TfLiteArmnnDelegate: Zero dimension tensors are not supported in operator #%d node #%d: ", + broadcastToOperatorCode, nodeIndex); + return kTfLiteError; + } + auto* shapeData = tflite::GetTensorData<int32_t>(&tfLiteShapeTensor); auto shapeTensorNum = tfLiteShapeTensor.dims->data[0]; |