diff options
author | Matthew Sloyan <matthew.sloyan@arm.com> | 2023-05-03 13:53:02 +0100 |
---|---|---|
committer | Matthew Sloyan <matthew.sloyan@arm.com> | 2023-05-04 14:14:09 +0000 |
commit | 3504e425ac99467c80919768c4a1361c44b30353 (patch) | |
tree | 1bef47611d03d61d5ed9083733d0e5654dec3ad8 /delegate/common | |
parent | c833cef6240abb941725a667042b84b936f1e86f (diff) | |
download | armnn-3504e425ac99467c80919768c4a1361c44b30353.tar.gz |
IVGCVSW-7605 IVGCVSW-7604 Implement Squeeze and ExpandDims operators for Classic and Opaque Delegate
* Implemented unsupported operators in Classic Delegate.
* Added unit tests.
Signed-off-by: Matthew Sloyan <matthew.sloyan@arm.com>
Change-Id: Ib39eeea53c114b15943e8dc2e796ce64c40cb3a5
Diffstat (limited to 'delegate/common')
-rw-r--r-- | delegate/common/src/DelegateUtils.hpp | 52 |
1 files changed, 52 insertions, 0 deletions
diff --git a/delegate/common/src/DelegateUtils.hpp b/delegate/common/src/DelegateUtils.hpp index 37fe9b5b84..1671a4c8cf 100644 --- a/delegate/common/src/DelegateUtils.hpp +++ b/delegate/common/src/DelegateUtils.hpp @@ -169,4 +169,56 @@ TfLiteStatus CreateOutputTensorShape(const armnn::TensorInfo& inputTensorInfo, return kTfLiteOk; } +armnn::TensorInfo OutputShapeOfSqueeze(std::vector<uint32_t> squeezeDims, + const armnn::TensorInfo& inputTensorInfo) +{ + static const uint32_t dimensionSequence[] = { 0, 1, 2, 3 }; + + if (inputTensorInfo.GetNumDimensions() > 4) + { + std::stringstream ss; + ss << "Input tensor has unexpected number of dimensions:" + << inputTensorInfo.GetNumDimensions() + << " shape:" << inputTensorInfo.GetShape() + << " " + << CHECK_LOCATION().AsString(); + throw armnn::ParseException(ss.str()); + } + + if (squeezeDims.empty()) + { + squeezeDims.assign(dimensionSequence, dimensionSequence + inputTensorInfo.GetNumDimensions()); + } + + std::vector<uint32_t> outputDims; + for(unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); i++) + { + bool skipSqueeze = (std::find(squeezeDims.begin(), squeezeDims.end(), i) == squeezeDims.end()); + auto currentDimension = inputTensorInfo.GetShape()[i]; + if (skipSqueeze || currentDimension != 1) + { + outputDims.push_back(currentDimension); + } + } + + if (outputDims.size() > 4) + { + std::stringstream ss; + ss << "Output tensor has unexpected number of dimensions:" + << inputTensorInfo.GetNumDimensions() + << " shape:" << inputTensorInfo.GetShape() + << " " + << CHECK_LOCATION().AsString(); + throw armnn::ParseException(ss.str()); + } + + armnn::TensorShape outShape = armnn::TensorShape(static_cast<unsigned int>(outputDims.size()), outputDims.data()); + + // We need to preserve the tensor type and the quantization data as well + armnn::TensorInfo outTensorInfo = inputTensorInfo; + outTensorInfo.SetShape(outShape); + + return outTensorInfo; +} + } // namespace anonymous |