From 3504e425ac99467c80919768c4a1361c44b30353 Mon Sep 17 00:00:00 2001 From: Matthew Sloyan Date: Wed, 3 May 2023 13:53:02 +0100 Subject: IVGCVSW-7605 IVGCVSW-7604 Implement Squeeze and ExpandDims operators for Classic and Opaque Delegate * Implemented unsupported operators in Classic Delegate. * Added unit tests. Signed-off-by: Matthew Sloyan Change-Id: Ib39eeea53c114b15943e8dc2e796ce64c40cb3a5 --- delegate/common/src/DelegateUtils.hpp | 52 +++++++++++++++++++++++++++++++++++ 1 file changed, 52 insertions(+) (limited to 'delegate/common') diff --git a/delegate/common/src/DelegateUtils.hpp b/delegate/common/src/DelegateUtils.hpp index 37fe9b5b84..1671a4c8cf 100644 --- a/delegate/common/src/DelegateUtils.hpp +++ b/delegate/common/src/DelegateUtils.hpp @@ -169,4 +169,56 @@ TfLiteStatus CreateOutputTensorShape(const armnn::TensorInfo& inputTensorInfo, return kTfLiteOk; } +armnn::TensorInfo OutputShapeOfSqueeze(std::vector squeezeDims, + const armnn::TensorInfo& inputTensorInfo) +{ + static const uint32_t dimensionSequence[] = { 0, 1, 2, 3 }; + + if (inputTensorInfo.GetNumDimensions() > 4) + { + std::stringstream ss; + ss << "Input tensor has unexpected number of dimensions:" + << inputTensorInfo.GetNumDimensions() + << " shape:" << inputTensorInfo.GetShape() + << " " + << CHECK_LOCATION().AsString(); + throw armnn::ParseException(ss.str()); + } + + if (squeezeDims.empty()) + { + squeezeDims.assign(dimensionSequence, dimensionSequence + inputTensorInfo.GetNumDimensions()); + } + + std::vector outputDims; + for(unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); i++) + { + bool skipSqueeze = (std::find(squeezeDims.begin(), squeezeDims.end(), i) == squeezeDims.end()); + auto currentDimension = inputTensorInfo.GetShape()[i]; + if (skipSqueeze || currentDimension != 1) + { + outputDims.push_back(currentDimension); + } + } + + if (outputDims.size() > 4) + { + std::stringstream ss; + ss << "Output tensor has unexpected number of dimensions:" + << inputTensorInfo.GetNumDimensions() + << " shape:" << inputTensorInfo.GetShape() + << " " + << CHECK_LOCATION().AsString(); + throw armnn::ParseException(ss.str()); + } + + armnn::TensorShape outShape = armnn::TensorShape(static_cast(outputDims.size()), outputDims.data()); + + // We need to preserve the tensor type and the quantization data as well + armnn::TensorInfo outTensorInfo = inputTensorInfo; + outTensorInfo.SetShape(outShape); + + return outTensorInfo; +} + } // namespace anonymous -- cgit v1.2.1