From 51dd62f5725e8a97f3f6957fbc2b899493eb7bb3 Mon Sep 17 00:00:00 2001 From: Ferran Balaguer Date: Fri, 11 Jan 2019 19:29:18 +0000 Subject: IVGCVSW-1656 Add Mean support to Tf Parser Change-Id: I3d31d6b72be1984acdb51fd9e7b5488a7aa5d832 --- CMakeLists.txt | 9 +- src/armnnTfParser/TfParser.cpp | 81 +++++++++-- src/armnnTfParser/TfParser.hpp | 2 + src/armnnTfParser/test/Mean.cpp | 175 +++++++++++++++++++++++ src/armnnUtils/ParserHelper.cpp | 49 +++++++ src/armnnUtils/ParserHelper.hpp | 5 + src/armnnUtils/ParserPrototxtFixture.hpp | 88 +++++++++++- src/armnnUtils/test/ParsePrototxtFixtureTest.cpp | 42 ++++++ src/armnnUtils/test/ParserHelperTest.cpp | 98 +++++++++++++ 9 files changed, 526 insertions(+), 23 deletions(-) create mode 100644 src/armnnTfParser/test/Mean.cpp create mode 100644 src/armnnUtils/test/ParsePrototxtFixtureTest.cpp create mode 100644 src/armnnUtils/test/ParserHelperTest.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 407a51de88..616f616474 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -386,7 +386,10 @@ if(BUILD_UNIT_TESTS) src/armnn/test/TensorTest.cpp src/armnn/test/UnitTests.cpp src/armnn/test/UnitTests.hpp - src/armnn/test/UtilsTests.cpp) + src/armnn/test/UtilsTests.cpp + src/armnnUtils/test/ParsePrototxtFixtureTest.cpp + src/armnnUtils/test/ParserHelperTest.cpp + ) if(BUILD_TF_PARSER) list(APPEND unittest_sources @@ -408,6 +411,7 @@ if(BUILD_UNIT_TESTS) src/armnnTfParser/test/LocalResponseNormalization.cpp src/armnnTfParser/test/Maximum.cpp src/armnnTfParser/test/MaximumForLeakyRelu.cpp + src/armnnTfParser/test/Mean.cpp src/armnnTfParser/test/Minimum.cpp src/armnnTfParser/test/Multiplication.cpp src/armnnTfParser/test/Pad.cpp @@ -426,7 +430,8 @@ if(BUILD_UNIT_TESTS) src/armnnTfParser/test/TestMultiInputsOutputs.cpp src/armnnTfParser/test/Split.cpp src/armnnTfParser/test/Squeeze.cpp - src/armnnTfParser/test/Sub.cpp) + src/armnnTfParser/test/Sub.cpp + ) endif() if(BUILD_TF_LITE_PARSER) diff --git a/src/armnnTfParser/TfParser.cpp b/src/armnnTfParser/TfParser.cpp index 90bd992a2b..0087ef83bf 100755 --- a/src/armnnTfParser/TfParser.cpp +++ b/src/armnnTfParser/TfParser.cpp @@ -2,40 +2,27 @@ // Copyright © 2017 Arm Ltd. All rights reserved. // SPDX-License-Identifier: MIT // + #include "TfParser.hpp" -#include -#include #include -#include #include #include #include #include -#include #include #include #include #include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/framework/tensor.pb.h" -#include "tensorflow/core/framework/tensor_shape.pb.h" -#include #include #include -#include -#include #include -#include -#include #include -#include using namespace armnnUtils; using namespace armnn; @@ -141,6 +128,17 @@ int32_t ReadMandatoryNodeInt32Attribute(const tensorflow::NodeDef& nodeDef, cons return attribValue; } +bool ReadMandatoryNodeBoolAttribute(const tensorflow::NodeDef& nodeDef, const std::string& name) +{ + bool attribValue = false; + ReadMandatoryNodeAttributeImpl(nodeDef, name, tensorflow::AttrValue::kB, + [&attribValue](const tensorflow::AttrValue& attrValue) + { + attribValue = static_cast(attrValue.b()); + }); + return attribValue; +} + uint32_t ReadMandatoryNodeUint32Attribute(const tensorflow::NodeDef& nodeDef, const std::string& name) { uint32_t attribValue = 0u; @@ -338,6 +336,7 @@ const std::map TfParser::ms_Ope { "ConcatV2", &TfParser::ParseConcat }, { "LRN", &TfParser::ParseLrn }, { "MatMul", &TfParser::ParseMatMul }, + { "Mean", &TfParser::ParseMean }, { "Mul", &TfParser::ParseMul }, { "Placeholder", &TfParser::ParsePlaceholder }, { "RealDiv", &TfParser::ParseRealDiv }, @@ -2349,6 +2348,60 @@ ParsedTfOperationPtr TfParser::ParseMatMul(const tensorflow::NodeDef& nodeDef, c return std::make_unique(this, nodeDef); } +ParsedTfOperationPtr TfParser::ParseMean(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef) +{ + std::vector inputs = GetInputParsedTfOperationsChecked(nodeDef, 2); + IOutputSlot& inputSlot = inputs[0].m_IndexedValue->ResolveArmnnOutputSlot(inputs[0].m_Index); + TensorInfo inputTensorInfo = inputSlot.GetTensorInfo(); + + if (inputs.size() != 2) + { + throw ParseException( + boost::str(boost::format("Mean expects two inputs!. Got %1% for Node %2% %3%") + % inputs.size() + % nodeDef.name() + % CHECK_LOCATION().AsString())); + } + + bool keepDims = ReadMandatoryNodeBoolAttribute(nodeDef, "keep_dims"); + + ParsedConstTfOperation* axisNode = + boost::polymorphic_downcast*>(inputs[1].m_IndexedValue); + + const TensorInfo& axisTensorInfo = axisNode->GetTensorInfo(); + + ConstTensor axisTensor(axisTensorInfo, axisNode->GetStorage()); + const int* axisData = static_cast(axisTensor.GetMemoryArea()); + + TensorInfo outputTensorInfo; + MeanDescriptor meanDescriptor; + meanDescriptor.m_KeepDims = keepDims; + + // Negative axis values are supported so that the process requires + // to convert them into the corresponding positive ones. + // Duplicate values are also removed. + std::vector rawAxisVector(axisData, axisData + axisTensorInfo.GetNumElements()); + std::set positiveAxisSet; + int rank = static_cast(inputTensorInfo.GetNumDimensions()); + + std::transform(rawAxisVector.begin(), rawAxisVector.end(), + std::inserter(positiveAxisSet, positiveAxisSet.begin()), + [rank](int i) -> unsigned int { return static_cast((i + rank) % rank); }); + + CalculateReducedOutputTensoInfo(inputTensorInfo, axisTensorInfo, positiveAxisSet, keepDims, outputTensorInfo); + + if (inputTensorInfo.GetNumDimensions() > positiveAxisSet.size()) + { + meanDescriptor.m_Axis.assign(positiveAxisSet.begin(), positiveAxisSet.end()); + } + + IConnectableLayer* layer = m_Network->AddMeanLayer(meanDescriptor, nodeDef.name().c_str()); + layer->GetOutputSlot(0).SetTensorInfo(outputTensorInfo); + inputSlot.Connect(layer->GetInputSlot(0)); + + return std::make_unique(this, nodeDef, layer); +} + /// An ParsedTfOperation for a Mul node. /// Creation of the armnn Mul layer is deferred until it is actually needed, because Mul nodes /// are also used for the first part of a leaky relu activation function (Mul followed by Maximum) diff --git a/src/armnnTfParser/TfParser.hpp b/src/armnnTfParser/TfParser.hpp index 4421768fc5..f1b7205ff1 100644 --- a/src/armnnTfParser/TfParser.hpp +++ b/src/armnnTfParser/TfParser.hpp @@ -140,6 +140,7 @@ private: ParsedTfOperationPtr ParseIdentity(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); ParsedTfOperationPtr ParseLrn(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); ParsedTfOperationPtr ParseMatMul(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); + ParsedTfOperationPtr ParseMean(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); ParsedTfOperationPtr ParseMul(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); ParsedTfOperationPtr ParsePlaceholder(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); ParsedTfOperationPtr ParseRealDiv(const tensorflow::NodeDef& nodeDef, const tensorflow::GraphDef& graphDef); @@ -260,4 +261,5 @@ private: /// Maps output layer names to their corresponding ids and tensor info. std::unordered_map m_NetworkOutputsBindingInfo; }; + } diff --git a/src/armnnTfParser/test/Mean.cpp b/src/armnnTfParser/test/Mean.cpp new file mode 100644 index 0000000000..13041629b5 --- /dev/null +++ b/src/armnnTfParser/test/Mean.cpp @@ -0,0 +1,175 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include +#include "armnnTfParser/ITfParser.hpp" +#include "ParserPrototxtFixture.hpp" + +BOOST_AUTO_TEST_SUITE(TensorflowParser) + +struct MeanFixture : public armnnUtils::ParserPrototxtFixture +{ + explicit MeanFixture(const armnn::TensorShape& inputShape, const armnn::TensorShape& outputShape, + const std::vector& axis, bool keepDims) + { + std::string protobufAxisString; + std::vector protobufAxis(axis); + + // If no axis range is specified, the reduction is applied to + // all dimensions of the input tensor + if (protobufAxis.size() == 0) + { + for (unsigned int i = 0; i < inputShape.GetNumDimensions(); ++i) + { + protobufAxis.push_back(i); + } + } + + for (unsigned int i = 0; i < protobufAxis.size(); ++i) + { + protobufAxisString.append(ConvertInt32ToOctalString(static_cast(protobufAxis[i]))); + } + + m_Prototext = R"(node { + name: "input" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { )"; + + if (axis.size() == 1) + { + m_Prototext.append(R"( tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: )").append(std::to_string(protobufAxis[0])).append(R"( + } )"); + } + else + { + m_Prototext.append(R"( tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: ")").append(protobufAxisString).append(R"(" + } )"); + } + + m_Prototext.append(R"( } + } + } + node { + name: "output" + op: "Mean" + input: "input" + input: "Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "keep_dims" + value { + b: )").append(keepDims ? "true" : "false").append(R"( + } + } + })"); + + SetupSingleInputSingleOutput(inputShape, outputShape, "input", "output"); + } +}; + +struct MeanNoAxisNoKeepDimsFixture: MeanFixture +{ + MeanNoAxisNoKeepDimsFixture() : MeanFixture({ 2, 3 }, { 1 }, {}, false) {} +}; + +struct MeanWithAxis0NoKeepDimsFixture: MeanFixture +{ + MeanWithAxis0NoKeepDimsFixture() : MeanFixture({ 2, 3 }, { 3 }, { 0 }, false) {} +}; + +struct MeanWithAxis1NoKeepDimsFixture: MeanFixture +{ + MeanWithAxis1NoKeepDimsFixture() : MeanFixture({ 2, 3 }, { 2 }, { 1 }, false) {} +}; + +struct MeanWithAxis0KeepDimsFixture: MeanFixture +{ + MeanWithAxis0KeepDimsFixture() : MeanFixture({ 2, 3 }, { 1, 3 }, { 0 }, true) {} +}; + +struct MeanWithAxis1KeepDimsFixture: MeanFixture +{ + MeanWithAxis1KeepDimsFixture() : MeanFixture({ 2, 3 }, { 2, 1 }, { 1 }, true) {} +}; + + +BOOST_FIXTURE_TEST_CASE(MeanNoAxisNoKeepDims, MeanNoAxisNoKeepDimsFixture) +{ + RunTest<1>({ { "input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f } } }, + { { "output", { 1.5f } } }); +} + +BOOST_FIXTURE_TEST_CASE(MeanWithAxis0NoKeepDims, MeanWithAxis0NoKeepDimsFixture) +{ + RunTest<1>({ { "input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f } } }, + { { "output", { 1.5f, 1.5f, 1.5f } } }); +} + +BOOST_FIXTURE_TEST_CASE(MeanWithAxis1NoKeepDims, MeanWithAxis1NoKeepDimsFixture) +{ + RunTest<1>({ { "input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f } } }, + { { "output", { 1.f, 2.f } } }); +} + +BOOST_FIXTURE_TEST_CASE(MeanWithAxis0KeepDims, MeanWithAxis0KeepDimsFixture) +{ + RunTest<2>({ { "input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f } } }, + { { "output", { 1.5f, 1.5f, 1.5f } } }); +} + +BOOST_FIXTURE_TEST_CASE(MeanWithAxis1KeepDims, MeanWithAxis1KeepDimsFixture) +{ + RunTest<2>({ { "input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f } } }, + { { "output", { 1.f, 2.f } } }); +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/armnnUtils/ParserHelper.cpp b/src/armnnUtils/ParserHelper.cpp index bf5ffdf0ad..9d633cfc42 100644 --- a/src/armnnUtils/ParserHelper.cpp +++ b/src/armnnUtils/ParserHelper.cpp @@ -61,4 +61,53 @@ void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo, armnn::Ori } } +void CalculateReducedOutputTensoInfo(const armnn::TensorInfo& inputTensorInfo, const armnn::TensorInfo& axisTensorInfo, + const std::set& axisSet, bool keepDims, + armnn::TensorInfo& outputTensorInfo) +{ + std::vector outputShapeVector; + bool dimensionFound = false; + unsigned int size = 1; + + for (unsigned int i = 0; i < inputTensorInfo.GetNumDimensions(); ++i) + { + dimensionFound = false; + for (unsigned int axis: axisSet) + { + if (axis == i) + { + dimensionFound = true; + break; + } + } + + if (!dimensionFound) + { + size *= inputTensorInfo.GetShape()[i]; + + if (keepDims) + { + outputShapeVector.push_back(inputTensorInfo.GetShape()[i]); + } + } + else + { + if (keepDims) + { + outputShapeVector.push_back(1); + } + } + } + + if (keepDims) + { + armnn::TensorShape outputTensorShape(inputTensorInfo.GetNumDimensions(), &outputShapeVector[0]); + outputTensorInfo = armnn::TensorInfo(outputTensorShape, inputTensorInfo.GetDataType()); + } + else + { + outputTensorInfo = armnn::TensorInfo({size}, inputTensorInfo.GetDataType()); + } +} + } // namespace armnnUtils diff --git a/src/armnnUtils/ParserHelper.hpp b/src/armnnUtils/ParserHelper.hpp index 93dfbf9360..24369dc521 100644 --- a/src/armnnUtils/ParserHelper.hpp +++ b/src/armnnUtils/ParserHelper.hpp @@ -14,4 +14,9 @@ void ProcessConcatInputTensorInfo(armnn::TensorInfo& inputTensorInfo, armnn::Ori const unsigned int& concatAxis, unsigned int inputIndex, std::vector& mergeDimSizes, unsigned int& mergeDim); +/// Creates a tensor info after reducing the dimensions mentioned in axisData. +void CalculateReducedOutputTensoInfo(const armnn::TensorInfo& inputTensorInfo, const armnn::TensorInfo& axisTensorInfo, + const std::set& axisSet, bool keepDims, + armnn::TensorInfo& outputTensorInfo); + } // namespace armnnUtils diff --git a/src/armnnUtils/ParserPrototxtFixture.hpp b/src/armnnUtils/ParserPrototxtFixture.hpp index acb8f82c4d..154f6bec2a 100644 --- a/src/armnnUtils/ParserPrototxtFixture.hpp +++ b/src/armnnUtils/ParserPrototxtFixture.hpp @@ -16,6 +16,7 @@ #include +#include #include namespace armnnUtils @@ -37,6 +38,10 @@ struct ParserPrototxtFixture void SetupSingleInputSingleOutput(const armnn::TensorShape& inputTensorShape, const std::string& inputName, const std::string& outputName); + void SetupSingleInputSingleOutput(const armnn::TensorShape& inputTensorShape, + const armnn::TensorShape& outputTensorShape, + const std::string& inputName, + const std::string& outputName); void Setup(const std::map& inputShapes, const std::vector& requestedOutputs); void Setup(); @@ -56,6 +61,9 @@ struct ParserPrototxtFixture void RunTest(const std::map>& inputData, const std::map>& expectedOutputData); + /// Converts an int value into the Protobuf octal representation + std::string ConvertInt32ToOctalString(int value); + std::string m_Prototext; std::unique_ptr m_Parser; armnn::IRuntimePtr m_Runtime; @@ -67,6 +75,10 @@ struct ParserPrototxtFixture std::string m_SingleInputName; std::string m_SingleOutputName; /// @} + + /// This will store the output shape so it don't need to be passed to the single-input-single-output overload + /// of RunTest(). + armnn::TensorShape m_SingleOutputShape; }; template @@ -90,6 +102,20 @@ void ParserPrototxtFixture::SetupSingleInputSingleOutput(const armnn::T Setup({ { inputName, inputTensorShape } }, { outputName }); } +template +void ParserPrototxtFixture::SetupSingleInputSingleOutput(const armnn::TensorShape& inputTensorShape, + const armnn::TensorShape& outputTensorShape, + const std::string& inputName, + const std::string& outputName) +{ + // Stores the input name, the output name and the output tensor shape + // so they don't need to be passed to the single-input-single-output RunTest(). + m_SingleInputName = inputName; + m_SingleOutputName = outputName; + m_SingleOutputShape = outputTensorShape; + Setup({ { inputName, inputTensorShape } }, { outputName }); +} + template void ParserPrototxtFixture::Setup(const std::map& inputShapes, const std::vector& requestedOutputs) @@ -181,17 +207,65 @@ void ParserPrototxtFixture::RunTest(const std::map(bindingInfo.second, it.second); BOOST_TEST(CompareTensors(outputExpected, outputStorage[it.first])); } } +template +std::string ParserPrototxtFixture::ConvertInt32ToOctalString(int value) +{ + std::stringstream ss; + std::string returnString; + for (int i = 0; i < 4; ++i) + { + ss << "\\"; + ss << std::setw(3) << std::setfill('0') << std::oct << ((value >> (i * 8)) & 0xFF); + } + + ss >> returnString; + return returnString; +} + } // namespace armnnUtils diff --git a/src/armnnUtils/test/ParsePrototxtFixtureTest.cpp b/src/armnnUtils/test/ParsePrototxtFixtureTest.cpp new file mode 100644 index 0000000000..926658ed0c --- /dev/null +++ b/src/armnnUtils/test/ParsePrototxtFixtureTest.cpp @@ -0,0 +1,42 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include + +#include +#include "armnnTfParser/ITfParser.hpp" + +BOOST_AUTO_TEST_SUITE(ParsePrototxtFixtureSuite) + +using Fixture = armnnUtils::ParserPrototxtFixture; + +BOOST_FIXTURE_TEST_CASE(ConvertInt32ToOctalStringTest, Fixture) +{ + std::string octalString = ConvertInt32ToOctalString(1); + BOOST_ASSERT(octalString.compare("\\\\001\\\\000\\\\000\\\\000")); + + octalString = ConvertInt32ToOctalString(256); + BOOST_ASSERT(octalString.compare("\\\\000\\\\100\\\\000\\\\000")); + + octalString = ConvertInt32ToOctalString(65536); + BOOST_ASSERT(octalString.compare("\\\\000\\\\000\\\\100\\\\000")); + + octalString = ConvertInt32ToOctalString(16777216); + BOOST_ASSERT(octalString.compare("\\\\000\\\\000\\\\000\\\\100")); + + octalString = ConvertInt32ToOctalString(-1); + BOOST_ASSERT(octalString.compare("\\\\377\\\\377\\\\377\\\\377")); + + octalString = ConvertInt32ToOctalString(-256); + BOOST_ASSERT(octalString.compare("\\\\000\\\\377\\\\377\\\\377")); + + octalString = ConvertInt32ToOctalString(-65536); + BOOST_ASSERT(octalString.compare("\\\\000\\\\000\\\\377\\\\377")); + + octalString = ConvertInt32ToOctalString(-16777216); + BOOST_ASSERT(octalString.compare("\\\\000\\\\000\\\\000\\\\377")); +} + +BOOST_AUTO_TEST_SUITE_END() diff --git a/src/armnnUtils/test/ParserHelperTest.cpp b/src/armnnUtils/test/ParserHelperTest.cpp new file mode 100644 index 0000000000..122ad7649e --- /dev/null +++ b/src/armnnUtils/test/ParserHelperTest.cpp @@ -0,0 +1,98 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include "../ParserHelper.hpp" + +#include + +#include "armnn/Types.hpp" + +using namespace armnn; +using namespace armnnUtils; + +BOOST_AUTO_TEST_SUITE(ParserHelperSuite) + +BOOST_AUTO_TEST_CASE(CalculateReducedOutputTensoInfoTest) +{ + bool keepDims = false; + + unsigned int inputShape[] = { 2, 3, 4 }; + TensorInfo inputTensorInfo(3, &inputShape[0], DataType::Float32); + + // Reducing all dimensions results in one single output value (one dimension) + unsigned int axisShape1[] = { 3 }; + std::set axisData1 = { 0, 1, 2 }; + TensorInfo axisTensorInfo1(1, &axisShape1[0], DataType::Signed32); + + TensorInfo outputTensorInfo1; + + CalculateReducedOutputTensoInfo(inputTensorInfo, axisTensorInfo1, axisData1, + keepDims, outputTensorInfo1); + + BOOST_ASSERT(outputTensorInfo1.GetNumDimensions() == 1); + BOOST_ASSERT(outputTensorInfo1.GetShape()[0] == 1); + + // Reducing dimension 0 results in a 3x4 size tensor (one dimension) + unsigned int axisShape2[] = { 1 }; + std::set axisData2 = { 0 }; + TensorInfo axisTensorInfo2(1, &axisShape2[0], DataType::Signed32); + + TensorInfo outputTensorInfo2; + + CalculateReducedOutputTensoInfo(inputTensorInfo, axisTensorInfo2, axisData2, + keepDims, outputTensorInfo2); + + BOOST_ASSERT(outputTensorInfo2.GetNumDimensions() == 1); + BOOST_ASSERT(outputTensorInfo2.GetShape()[0] == 12); + + // Reducing dimensions 0,1 results in a 4 size tensor (one dimension) + unsigned int axisShape3[] = { 2 }; + std::set axisData3 = { 0, 1 }; + TensorInfo axisTensorInfo3(1, &axisShape3[0], DataType::Signed32); + + TensorInfo outputTensorInfo3; + + CalculateReducedOutputTensoInfo(inputTensorInfo, axisTensorInfo3, axisData3, + keepDims, outputTensorInfo3); + + BOOST_ASSERT(outputTensorInfo3.GetNumDimensions() == 1); + BOOST_ASSERT(outputTensorInfo3.GetShape()[0] == 4); + + // Reducing dimension 0 results in a { 1, 3, 4 } dimension tensor + keepDims = true; + unsigned int axisShape4[] = { 1 }; + std::set axisData4 = { 0 }; + TensorInfo axisTensorInfo4(1, &axisShape4[0], DataType::Signed32); + + TensorInfo outputTensorInfo4; + + CalculateReducedOutputTensoInfo(inputTensorInfo, axisTensorInfo4, axisData4, + keepDims, outputTensorInfo4); + + BOOST_ASSERT(outputTensorInfo4.GetNumDimensions() == 3); + BOOST_ASSERT(outputTensorInfo4.GetShape()[0] == 1); + BOOST_ASSERT(outputTensorInfo4.GetShape()[1] == 3); + BOOST_ASSERT(outputTensorInfo4.GetShape()[2] == 4); + + // Reducing dimension 1, 2 results in a { 2, 1, 1 } dimension tensor + keepDims = true; + unsigned int axisShape5[] = { 2 }; + std::set axisData5 = { 1, 2 }; + TensorInfo axisTensorInfo5(1, &axisShape5[0], DataType::Signed32); + + TensorInfo outputTensorInfo5; + + CalculateReducedOutputTensoInfo(inputTensorInfo, axisTensorInfo5, axisData5, + keepDims, outputTensorInfo5); + + BOOST_ASSERT(outputTensorInfo5.GetNumDimensions() == 3); + BOOST_ASSERT(outputTensorInfo5.GetShape()[0] == 2); + BOOST_ASSERT(outputTensorInfo5.GetShape()[1] == 1); + BOOST_ASSERT(outputTensorInfo5.GetShape()[2] == 1); + +} + +BOOST_AUTO_TEST_SUITE_END() + -- cgit v1.2.1