diff options
author | Ferran Balaguer <ferran.balaguer@arm.com> | 2019-01-11 19:29:18 +0000 |
---|---|---|
committer | Ferran Balaguer Arm <ferran.balaguer@arm.com> | 2019-01-14 09:45:15 +0000 |
commit | 51dd62f5725e8a97f3f6957fbc2b899493eb7bb3 (patch) | |
tree | f8cce612850d49d798686cce5ad2ab7545b6e0b7 /src/armnnUtils/ParserHelper.cpp | |
parent | 992d6dc57d8463729910b688f0fb5825d0d3ccf2 (diff) | |
download | armnn-51dd62f5725e8a97f3f6957fbc2b899493eb7bb3.tar.gz |
IVGCVSW-1656 Add Mean support to Tf Parser
Change-Id: I3d31d6b72be1984acdb51fd9e7b5488a7aa5d832
Diffstat (limited to 'src/armnnUtils/ParserHelper.cpp')
-rw-r--r-- | src/armnnUtils/ParserHelper.cpp | 49 |
1 files changed, 49 insertions, 0 deletions
diff --git a/src/armnnUtils/ParserHelper.cpp b/src/armnnUtils/ParserHelper.cpp index bf5ffdf0ad..9d633cfc42 100644 --- a/src/armnnUtils/ParserHelper.cpp +++ b/src/armnnUtils/ParserHelper.cpp @@ -61,4 +61,53 @@ void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo, armnn::Ori } } +void CalculateReducedOutputTensoInfo(const armnn::TensorInfo& inputTensorInfo, const armnn::TensorInfo& axisTensorInfo, + const std::set<unsigned int>& axisSet, bool keepDims, + armnn::TensorInfo& outputTensorInfo) +{ + std::vector<unsigned int> outputShapeVector; + bool dimensionFound = false; + unsigned int size = 1; + + for (unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); ++i) + { + dimensionFound = false; + for (unsigned int axis: axisSet) + { + if (axis == i) + { + dimensionFound = true; + break; + } + } + + if (!dimensionFound) + { + size *= inputTensorInfo.GetShape()[i]; + + if (keepDims) + { + outputShapeVector.push_back(inputTensorInfo.GetShape()[i]); + } + } + else + { + if (keepDims) + { + outputShapeVector.push_back(1); + } + } + } + + if (keepDims) + { + armnn::TensorShape outputTensorShape(inputTensorInfo.GetNumDimensions(), &outputShapeVector[0]); + outputTensorInfo = armnn::TensorInfo(outputTensorShape, inputTensorInfo.GetDataType()); + } + else + { + outputTensorInfo = armnn::TensorInfo({size}, inputTensorInfo.GetDataType()); + } +} + } // namespace armnnUtils |