aboutsummaryrefslogtreecommitdiff
path: root/delegate/classic/src
diff options
context:
space:
mode:
authorPatryk Kaiser <patryk.kaiser@arm.com>2024-06-27 12:29:48 +0100
committerColm Donelan <colm.donelan@arm.com>2024-07-01 16:55:55 +0000
commit656285153399d96ead5925db907d0ec1961dfd76 (patch)
tree74ae6dbd3199c0297c0321d5998cfd546537af26 /delegate/classic/src
parent443804adee542d4330713e8dda6357b9495856fa (diff)
downloadarmnn-656285153399d96ead5925db907d0ec1961dfd76.tar.gz
IVGCVSW-8139 Fixing Broadcast OP DTS tests
* In ref broadcast layer added broadcast support to expand tensor shapes * Added function to check for zero dimension tensors * Added check for unsupported zero dimension tensors during broadcast * Added DelegateUtils unit test file with unit tests for the new function Signed-off-by: Patryk Kaiser <patryk.kaiser@arm.com> Change-Id: If4e786f7ba580399e781c48335888e8da8458019
Diffstat (limited to 'delegate/classic/src')
-rw-r--r--delegate/classic/src/BroadcastTo.hpp12
1 files changed, 11 insertions, 1 deletions
diff --git a/delegate/classic/src/BroadcastTo.hpp b/delegate/classic/src/BroadcastTo.hpp
index 92aed79982..2e2b3ab155 100644
--- a/delegate/classic/src/BroadcastTo.hpp
+++ b/delegate/classic/src/BroadcastTo.hpp
@@ -1,11 +1,12 @@
//
-// Copyright © 2023 Arm Ltd and Contributors. All rights reserved.
+// Copyright © 2023-2024 Arm Ltd and Contributors. All rights reserved.
// SPDX-License-Identifier: MIT
//
#pragma once
#include <armnn/utility/IgnoreUnused.hpp>
+#include <DelegateUtils.hpp>
#include <tensorflow/lite/builtin_ops.h>
#include <tensorflow/lite/c/builtin_op_data.h>
@@ -83,6 +84,15 @@ namespace armnnDelegate
const armnn::TensorInfo& inputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteInputTensor);
const armnn::TensorInfo& outputTensorInfo = GetTensorInfoForTfLiteTensor(tfLiteOutputTensor);
+ if (ZeroDimPresent({inputTensorInfo, outputTensorInfo}))
+ {
+ TF_LITE_MAYBE_KERNEL_LOG(
+ tfLiteContext,
+ "TfLiteArmnnDelegate: Zero dimension tensors are not supported in operator #%d node #%d: ",
+ broadcastToOperatorCode, nodeIndex);
+ return kTfLiteError;
+ }
+
auto* shapeData = tflite::GetTensorData<int32_t>(&tfLiteShapeTensor);
auto shapeTensorNum = tfLiteShapeTensor.dims->data[0];