aboutsummaryrefslogtreecommitdiff
path: root/delegate/common
diff options
context:
space:
mode:
authorMatthew Sloyan <matthew.sloyan@arm.com>2023-05-03 13:53:02 +0100
committerMatthew Sloyan <matthew.sloyan@arm.com>2023-05-04 14:14:09 +0000
commit3504e425ac99467c80919768c4a1361c44b30353 (patch)
tree1bef47611d03d61d5ed9083733d0e5654dec3ad8 /delegate/common
parentc833cef6240abb941725a667042b84b936f1e86f (diff)
downloadarmnn-3504e425ac99467c80919768c4a1361c44b30353.tar.gz
IVGCVSW-7605 IVGCVSW-7604 Implement Squeeze and ExpandDims operators for Classic and Opaque Delegate
* Implemented unsupported operators in Classic Delegate. * Added unit tests. Signed-off-by: Matthew Sloyan <matthew.sloyan@arm.com> Change-Id: Ib39eeea53c114b15943e8dc2e796ce64c40cb3a5
Diffstat (limited to 'delegate/common')
-rw-r--r--delegate/common/src/DelegateUtils.hpp52
1 files changed, 52 insertions, 0 deletions
diff --git a/delegate/common/src/DelegateUtils.hpp b/delegate/common/src/DelegateUtils.hpp
index 37fe9b5b84..1671a4c8cf 100644
--- a/delegate/common/src/DelegateUtils.hpp
+++ b/delegate/common/src/DelegateUtils.hpp
@@ -169,4 +169,56 @@ TfLiteStatus CreateOutputTensorShape(const armnn::TensorInfo& inputTensorInfo,
return kTfLiteOk;
}
+armnn::TensorInfo OutputShapeOfSqueeze(std::vector<uint32_t> squeezeDims,
+ const armnn::TensorInfo& inputTensorInfo)
+{
+ static const uint32_t dimensionSequence[] = { 0, 1, 2, 3 };
+
+ if (inputTensorInfo.GetNumDimensions() > 4)
+ {
+ std::stringstream ss;
+ ss << "Input tensor has unexpected number of dimensions:"
+ << inputTensorInfo.GetNumDimensions()
+ << " shape:" << inputTensorInfo.GetShape()
+ << " "
+ << CHECK_LOCATION().AsString();
+ throw armnn::ParseException(ss.str());
+ }
+
+ if (squeezeDims.empty())
+ {
+ squeezeDims.assign(dimensionSequence, dimensionSequence + inputTensorInfo.GetNumDimensions());
+ }
+
+ std::vector<uint32_t> outputDims;
+ for(unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); i++)
+ {
+ bool skipSqueeze = (std::find(squeezeDims.begin(), squeezeDims.end(), i) == squeezeDims.end());
+ auto currentDimension = inputTensorInfo.GetShape()[i];
+ if (skipSqueeze || currentDimension != 1)
+ {
+ outputDims.push_back(currentDimension);
+ }
+ }
+
+ if (outputDims.size() > 4)
+ {
+ std::stringstream ss;
+ ss << "Output tensor has unexpected number of dimensions:"
+ << inputTensorInfo.GetNumDimensions()
+ << " shape:" << inputTensorInfo.GetShape()
+ << " "
+ << CHECK_LOCATION().AsString();
+ throw armnn::ParseException(ss.str());
+ }
+
+ armnn::TensorShape outShape = armnn::TensorShape(static_cast<unsigned int>(outputDims.size()), outputDims.data());
+
+ // We need to preserve the tensor type and the quantization data as well
+ armnn::TensorInfo outTensorInfo = inputTensorInfo;
+ outTensorInfo.SetShape(outShape);
+
+ return outTensorInfo;
+}
+
} // namespace anonymous