aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/test/Mean.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'src/armnnTfParser/test/Mean.cpp')
-rw-r--r--src/armnnTfParser/test/Mean.cpp175
1 files changed, 175 insertions, 0 deletions
diff --git a/src/armnnTfParser/test/Mean.cpp b/src/armnnTfParser/test/Mean.cpp
new file mode 100644
index 0000000000..13041629b5
--- /dev/null
+++ b/src/armnnTfParser/test/Mean.cpp
@@ -0,0 +1,175 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include <boost/test/unit_test.hpp>
+#include "armnnTfParser/ITfParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+struct MeanFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ explicit MeanFixture(const armnn::TensorShape& inputShape, const armnn::TensorShape& outputShape,
+ const std::vector<unsigned int>& axis, bool keepDims)
+ {
+ std::string protobufAxisString;
+ std::vector<unsigned int> protobufAxis(axis);
+
+ // If no axis range is specified, the reduction is applied to
+ // all dimensions of the input tensor
+ if (protobufAxis.size() == 0)
+ {
+ for (unsigned int i = 0; i < inputShape.GetNumDimensions(); ++i)
+ {
+ protobufAxis.push_back(i);
+ }
+ }
+
+ for (unsigned int i = 0; i < protobufAxis.size(); ++i)
+ {
+ protobufAxisString.append(ConvertInt32ToOctalString(static_cast<int>(protobufAxis[i])));
+ }
+
+ m_Prototext = R"(node {
+ name: "input"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+ }
+ }
+ }
+ }
+ node {
+ name: "Const"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value { )";
+
+ if (axis.size() == 1)
+ {
+ m_Prototext.append(R"( tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ }
+ int_val: )").append(std::to_string(protobufAxis[0])).append(R"(
+ } )");
+ }
+ else
+ {
+ m_Prototext.append(R"( tensor {
+ dtype: DT_INT32
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ }
+ tensor_content: ")").append(protobufAxisString).append(R"("
+ } )");
+ }
+
+ m_Prototext.append(R"( }
+ }
+ }
+ node {
+ name: "output"
+ op: "Mean"
+ input: "input"
+ input: "Const"
+ attr {
+ key: "T"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "Tidx"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "keep_dims"
+ value {
+ b: )").append(keepDims ? "true" : "false").append(R"(
+ }
+ }
+ })");
+
+ SetupSingleInputSingleOutput(inputShape, outputShape, "input", "output");
+ }
+};
+
+struct MeanNoAxisNoKeepDimsFixture: MeanFixture
+{
+ MeanNoAxisNoKeepDimsFixture() : MeanFixture({ 2, 3 }, { 1 }, {}, false) {}
+};
+
+struct MeanWithAxis0NoKeepDimsFixture: MeanFixture
+{
+ MeanWithAxis0NoKeepDimsFixture() : MeanFixture({ 2, 3 }, { 3 }, { 0 }, false) {}
+};
+
+struct MeanWithAxis1NoKeepDimsFixture: MeanFixture
+{
+ MeanWithAxis1NoKeepDimsFixture() : MeanFixture({ 2, 3 }, { 2 }, { 1 }, false) {}
+};
+
+struct MeanWithAxis0KeepDimsFixture: MeanFixture
+{
+ MeanWithAxis0KeepDimsFixture() : MeanFixture({ 2, 3 }, { 1, 3 }, { 0 }, true) {}
+};
+
+struct MeanWithAxis1KeepDimsFixture: MeanFixture
+{
+ MeanWithAxis1KeepDimsFixture() : MeanFixture({ 2, 3 }, { 2, 1 }, { 1 }, true) {}
+};
+
+
+BOOST_FIXTURE_TEST_CASE(MeanNoAxisNoKeepDims, MeanNoAxisNoKeepDimsFixture)
+{
+ RunTest<1>({ { "input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f } } },
+ { { "output", { 1.5f } } });
+}
+
+BOOST_FIXTURE_TEST_CASE(MeanWithAxis0NoKeepDims, MeanWithAxis0NoKeepDimsFixture)
+{
+ RunTest<1>({ { "input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f } } },
+ { { "output", { 1.5f, 1.5f, 1.5f } } });
+}
+
+BOOST_FIXTURE_TEST_CASE(MeanWithAxis1NoKeepDims, MeanWithAxis1NoKeepDimsFixture)
+{
+ RunTest<1>({ { "input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f } } },
+ { { "output", { 1.f, 2.f } } });
+}
+
+BOOST_FIXTURE_TEST_CASE(MeanWithAxis0KeepDims, MeanWithAxis0KeepDimsFixture)
+{
+ RunTest<2>({ { "input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f } } },
+ { { "output", { 1.5f, 1.5f, 1.5f } } });
+}
+
+BOOST_FIXTURE_TEST_CASE(MeanWithAxis1KeepDims, MeanWithAxis1KeepDimsFixture)
+{
+ RunTest<2>({ { "input", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f } } },
+ { { "output", { 1.f, 2.f } } });
+}
+
+BOOST_AUTO_TEST_SUITE_END()