aboutsummaryrefslogtreecommitdiff
path: root/delegate/common/src/DelegateUtils.hpp
diff options
context:
space:
mode:
Diffstat (limited to 'delegate/common/src/DelegateUtils.hpp')
-rw-r--r--delegate/common/src/DelegateUtils.hpp52
1 files changed, 52 insertions, 0 deletions
diff --git a/delegate/common/src/DelegateUtils.hpp b/delegate/common/src/DelegateUtils.hpp
index 37fe9b5b84..1671a4c8cf 100644
--- a/delegate/common/src/DelegateUtils.hpp
+++ b/delegate/common/src/DelegateUtils.hpp
@@ -169,4 +169,56 @@ TfLiteStatus CreateOutputTensorShape(const armnn::TensorInfo& inputTensorInfo,
return kTfLiteOk;
}
+armnn::TensorInfo OutputShapeOfSqueeze(std::vector<uint32_t> squeezeDims,
+ const armnn::TensorInfo& inputTensorInfo)
+{
+ static const uint32_t dimensionSequence[] = { 0, 1, 2, 3 };
+
+ if (inputTensorInfo.GetNumDimensions() > 4)
+ {
+ std::stringstream ss;
+ ss << "Input tensor has unexpected number of dimensions:"
+ << inputTensorInfo.GetNumDimensions()
+ << " shape:" << inputTensorInfo.GetShape()
+ << " "
+ << CHECK_LOCATION().AsString();
+ throw armnn::ParseException(ss.str());
+ }
+
+ if (squeezeDims.empty())
+ {
+ squeezeDims.assign(dimensionSequence, dimensionSequence + inputTensorInfo.GetNumDimensions());
+ }
+
+ std::vector<uint32_t> outputDims;
+ for(unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); i++)
+ {
+ bool skipSqueeze = (std::find(squeezeDims.begin(), squeezeDims.end(), i) == squeezeDims.end());
+ auto currentDimension = inputTensorInfo.GetShape()[i];
+ if (skipSqueeze || currentDimension != 1)
+ {
+ outputDims.push_back(currentDimension);
+ }
+ }
+
+ if (outputDims.size() > 4)
+ {
+ std::stringstream ss;
+ ss << "Output tensor has unexpected number of dimensions:"
+ << inputTensorInfo.GetNumDimensions()
+ << " shape:" << inputTensorInfo.GetShape()
+ << " "
+ << CHECK_LOCATION().AsString();
+ throw armnn::ParseException(ss.str());
+ }
+
+ armnn::TensorShape outShape = armnn::TensorShape(static_cast<unsigned int>(outputDims.size()), outputDims.data());
+
+ // We need to preserve the tensor type and the quantization data as well
+ armnn::TensorInfo outTensorInfo = inputTensorInfo;
+ outTensorInfo.SetShape(outShape);
+
+ return outTensorInfo;
+}
+
} // namespace anonymous