aboutsummaryrefslogtreecommitdiff
path: root/src/armnnUtils/ParserHelper.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnUtils/ParserHelper.cpp')
-rw-r--r--src/armnnUtils/ParserHelper.cpp49
1 files changed, 49 insertions, 0 deletions
diff --git a/src/armnnUtils/ParserHelper.cpp b/src/armnnUtils/ParserHelper.cpp
index bf5ffdf0ad..9d633cfc42 100644
--- a/src/armnnUtils/ParserHelper.cpp
+++ b/src/armnnUtils/ParserHelper.cpp
@@ -61,4 +61,53 @@ void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo, armnn::Ori
}
}
+void CalculateReducedOutputTensoInfo(const armnn::TensorInfo& inputTensorInfo, const armnn::TensorInfo& axisTensorInfo,
+ const std::set<unsigned int>& axisSet, bool keepDims,
+ armnn::TensorInfo& outputTensorInfo)
+{
+ std::vector<unsigned int> outputShapeVector;
+ bool dimensionFound = false;
+ unsigned int size = 1;
+
+ for (unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); ++i)
+ {
+ dimensionFound = false;
+ for (unsigned int axis: axisSet)
+ {
+ if (axis == i)
+ {
+ dimensionFound = true;
+ break;
+ }
+ }
+
+ if (!dimensionFound)
+ {
+ size *= inputTensorInfo.GetShape()[i];
+
+ if (keepDims)
+ {
+ outputShapeVector.push_back(inputTensorInfo.GetShape()[i]);
+ }
+ }
+ else
+ {
+ if (keepDims)
+ {
+ outputShapeVector.push_back(1);
+ }
+ }
+ }
+
+ if (keepDims)
+ {
+ armnn::TensorShape outputTensorShape(inputTensorInfo.GetNumDimensions(), &outputShapeVector[0]);
+ outputTensorInfo = armnn::TensorInfo(outputTensorShape, inputTensorInfo.GetDataType());
+ }
+ else
+ {
+ outputTensorInfo = armnn::TensorInfo({size}, inputTensorInfo.GetDataType());
+ }
+}
+
} // namespace armnnUtils