diff options
Diffstat (limited to 'delegate/common/src/DelegateUtils.hpp')
-rw-r--r-- | delegate/common/src/DelegateUtils.hpp | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/delegate/common/src/DelegateUtils.hpp b/delegate/common/src/DelegateUtils.hpp index 37fe9b5b84..1671a4c8cf 100644 --- a/delegate/common/src/DelegateUtils.hpp +++ b/delegate/common/src/DelegateUtils.hpp @@ -169,4 +169,56 @@ TfLiteStatus CreateOutputTensorShape(const armnn::TensorInfo& inputTensorInfo, return kTfLiteOk; } +armnn::TensorInfo OutputShapeOfSqueeze(std::vector<uint32_t> squeezeDims, + const armnn::TensorInfo& inputTensorInfo) +{ + static const uint32_t dimensionSequence[] = { 0, 1, 2, 3 }; + + if (inputTensorInfo.GetNumDimensions() > 4) + { + std::stringstream ss; + ss << "Input tensor has unexpected number of dimensions:" + << inputTensorInfo.GetNumDimensions() + << " shape:" << inputTensorInfo.GetShape() + << " " + << CHECK_LOCATION().AsString(); + throw armnn::ParseException(ss.str()); + } + + if (squeezeDims.empty()) + { + squeezeDims.assign(dimensionSequence, dimensionSequence + inputTensorInfo.GetNumDimensions()); + } + + std::vector<uint32_t> outputDims; + for(unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); i++) + { + bool skipSqueeze = (std::find(squeezeDims.begin(), squeezeDims.end(), i) == squeezeDims.end()); + auto currentDimension = inputTensorInfo.GetShape()[i]; + if (skipSqueeze || currentDimension != 1) + { + outputDims.push_back(currentDimension); + } + } + + if (outputDims.size() > 4) + { + std::stringstream ss; + ss << "Output tensor has unexpected number of dimensions:" + << inputTensorInfo.GetNumDimensions() + << " shape:" << inputTensorInfo.GetShape() + << " " + << CHECK_LOCATION().AsString(); + throw armnn::ParseException(ss.str()); + } + + armnn::TensorShape outShape = armnn::TensorShape(static_cast<unsigned int>(outputDims.size()), outputDims.data()); + + // We need to preserve the tensor type and the quantization data as well + armnn::TensorInfo outTensorInfo = inputTensorInfo; + outTensorInfo.SetShape(outShape); + + return outTensorInfo; +} + } // namespace anonymous |