diff options
Diffstat (limited to 'src/armnnTfParser/test')
-rw-r--r-- | src/armnnTfParser/test/Mean.cpp | 175 |
1 files changed, 175 insertions, 0 deletions
diff --git a/src/armnnTfParser/test/Mean.cpp b/src/armnnTfParser/test/Mean.cpp new file mode 100644 index 0000000000..13041629b5 --- /dev/null +++ b/src/armnnTfParser/test/Mean.cpp @@ -0,0 +1,175 @@ +// +// Copyright © 2017 Arm Ltd. All rights reserved. +// SPDX-License-Identifier: MIT +// + +#include <boost/test/unit_test.hpp> +#include "armnnTfParser/ITfParser.hpp" +#include "ParserPrototxtFixture.hpp" + +BOOST_AUTO_TEST_SUITE(TensorflowParser) + +struct MeanFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser> +{ + explicit MeanFixture(const armnn::TensorShape& inputShape, const armnn::TensorShape& outputShape, + const std::vector<unsigned int>& axis, bool keepDims) + { + std::string protobufAxisString; + std::vector<unsigned int> protobufAxis(axis); + + // If no axis range is specified, the reduction is applied to + // all dimensions of the input tensor + if (protobufAxis.size() == 0) + { + for (unsigned int i = 0; i < inputShape.GetNumDimensions(); ++i) + { + protobufAxis.push_back(i); + } + } + + for (unsigned int i = 0; i < protobufAxis.size(); ++i) + { + protobufAxisString.append(ConvertInt32ToOctalString(static_cast<int>(protobufAxis[i]))); + } + + m_Prototext = R"(node { + name: "input" + op: "Placeholder" + attr { + key: "dtype" + value { + type: DT_FLOAT + } + } + attr { + key: "shape" + value { + shape { + } + } + } + } + node { + name: "Const" + op: "Const" + attr { + key: "dtype" + value { + type: DT_INT32 + } + } + attr { + key: "value" + value { )"; + + if (axis.size() == 1) + { + m_Prototext.append(R"( tensor { + dtype: DT_INT32 + tensor_shape { + } + int_val: )").append(std::to_string(protobufAxis[0])).append(R"( + } )"); + } + else + { + m_Prototext.append(R"( tensor { + dtype: DT_INT32 + tensor_shape { + dim { + size: 2 + } + } + tensor_content: ")").append(protobufAxisString).append(R"(" + } )"); + } + + m_Prototext.append(R"( } + } + } + node { + name: "output" + op: "Mean" + input: "input" + input: "Const" + attr { + key: "T" + value { + type: DT_FLOAT + } + } + attr { + key: "Tidx" + value { + type: DT_INT32 + } + } + attr { + key: "keep_dims" + value { + b: )").append(keepDims ? "true" : "false").append(R"( + } + } + })"); + + SetupSingleInputSingleOutput(inputShape, outputShape, "input", "output"); + } +}; + +struct MeanNoAxisNoKeepDimsFixture: MeanFixture +{ + MeanNoAxisNoKeepDimsFixture() : MeanFixture({ 2, 3 }, { 1 }, {}, false) {} +}; + +struct MeanWithAxis0NoKeepDimsFixture: MeanFixture +{ + MeanWithAxis0NoKeepDimsFixture() : MeanFixture({ 2, 3 }, { 3 }, { 0 }, false) {} +}; + +struct MeanWithAxis1NoKeepDimsFixture: MeanFixture +{ + MeanWithAxis1NoKeepDimsFixture() : MeanFixture({ 2, 3 }, { 2 }, { 1 }, false) {} +}; + +struct MeanWithAxis0KeepDimsFixture: MeanFixture +{ + MeanWithAxis0KeepDimsFixture() : MeanFixture({ 2, 3 }, { 1, 3 }, { 0 }, true) {} +}; + +struct MeanWithAxis1KeepDimsFixture: MeanFixture +{ + MeanWithAxis1KeepDimsFixture() : MeanFixture({ 2, 3 }, { 2, 1 }, { 1 }, true) {} +}; + + +BOOST_FIXTURE_TEST_CASE(MeanNoAxisNoKeepDims, MeanNoAxisNoKeepDimsFixture) +{ + RunTest<1>({ { "input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f } } }, + { { "output", { 1.5f } } }); +} + +BOOST_FIXTURE_TEST_CASE(MeanWithAxis0NoKeepDims, MeanWithAxis0NoKeepDimsFixture) +{ + RunTest<1>({ { "input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f } } }, + { { "output", { 1.5f, 1.5f, 1.5f } } }); +} + +BOOST_FIXTURE_TEST_CASE(MeanWithAxis1NoKeepDims, MeanWithAxis1NoKeepDimsFixture) +{ + RunTest<1>({ { "input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f } } }, + { { "output", { 1.f, 2.f } } }); +} + +BOOST_FIXTURE_TEST_CASE(MeanWithAxis0KeepDims, MeanWithAxis0KeepDimsFixture) +{ + RunTest<2>({ { "input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f } } }, + { { "output", { 1.5f, 1.5f, 1.5f } } }); +} + +BOOST_FIXTURE_TEST_CASE(MeanWithAxis1KeepDims, MeanWithAxis1KeepDimsFixture) +{ + RunTest<2>({ { "input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f } } }, + { { "output", { 1.f, 2.f } } }); +} + +BOOST_AUTO_TEST_SUITE_END() |