aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/test/Gather.cpp
diff options
context:
space:
mode:
authorFrancisMurtagh <francis.murtagh@arm.com>2019-01-24 10:53:39 +0000
committerFrancis Murtagh <francis.murtagh@arm.com>2019-01-24 12:43:25 +0000
commit94412aff782472be54dce4328e2ecee0225b3e97 (patch)
tree42de97582e1948ecd91bdb24de00a966d180f6e7 /src/armnnTfParser/test/Gather.cpp
parent5f4e41ff596170dad9c073b007b3f53783a9e1f3 (diff)
downloadarmnn-94412aff782472be54dce4328e2ecee0225b3e97.tar.gz
IVGCVSW-2512 Add Gather operator parser to TfParser
* Add ParseGather to TFParser * Add Unit tests for Gather Operator !armnn:562 Change-Id: Idff45c2d3d8d683aa9eb2c4a63123c8d6054609e
Diffstat (limited to 'src/armnnTfParser/test/Gather.cpp')
-rw-r--r--src/armnnTfParser/test/Gather.cpp167
1 files changed, 167 insertions, 0 deletions
diff --git a/src/armnnTfParser/test/Gather.cpp b/src/armnnTfParser/test/Gather.cpp
new file mode 100644
index 0000000000..f40dc57556
--- /dev/null
+++ b/src/armnnTfParser/test/Gather.cpp
@@ -0,0 +1,167 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// SPDX-License-Identifier: MIT
+//
+
+#include "armnnTfParser/ITfParser.hpp"
+
+#include "ParserPrototxtFixture.hpp"
+#include <PrototxtConversions.hpp>
+
+#include <boost/test/unit_test.hpp>
+
+BOOST_AUTO_TEST_SUITE(TensorflowParser)
+
+// helper for setting the dimensions in prototxt
+void dimsHelper(const std::vector<int>& dims, std::string& text){
+ for(u_int i=0; i<dims.size(); ++i){
+ text.append(R"(dim {
+ size: )");
+ text.append(std::to_string(dims[i]));
+ text.append(R"(
+ })");
+ }
+}
+
+// helper for converting from integer to octal representation
+void octalHelper(const std::vector<int>& indicesContent, std::string& text){
+ for (unsigned int i = 0; i < indicesContent.size(); ++i)
+ {
+ text.append(armnnUtils::ConvertInt32ToOctalString(static_cast<int>(indicesContent[i])));
+ }
+}
+
+struct GatherFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
+{
+ GatherFixture(const armnn::TensorShape& inputShape0,
+ const armnn::TensorShape& inputShape1,
+ const std::vector<int>& input1Content,
+ const std::vector<int>& input0Dims,
+ const std::vector<int>& input1Dims)
+ {
+ m_Prototext = R"(
+node {
+ name: "input0"
+ op: "Placeholder"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "shape"
+ value {
+ shape {
+)";
+ dimsHelper(input0Dims, m_Prototext);
+ m_Prototext.append(R"(
+ }
+ }
+ }
+}
+node {
+ name: "input1"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_INT32
+ tensor_shape {
+)");
+ dimsHelper(input1Dims, m_Prototext);
+ m_Prototext.append(R"(
+ }
+ tensor_content: ")");
+ octalHelper(input1Content, m_Prototext);
+ m_Prototext.append(R"("
+ }
+ }
+ }
+}
+node {
+ name: "output"
+ op: "Gather"
+ input: "input0"
+ input: "input1"
+ attr {
+ key: "Tindices"
+ value {
+ type: DT_INT32
+ }
+ }
+ attr {
+ key: "Tparams"
+ value {
+ type: DT_FLOAT
+ }
+ }
+}
+ )");
+ Setup({ { "input0", inputShape0 },
+ { "input1", inputShape1 } },
+ { "output" });
+
+ }
+};
+
+
+struct GatherFixture1DParams1DIndices : public GatherFixture
+{
+ GatherFixture1DParams1DIndices() : GatherFixture(
+ { 4, 1, 1, 1 },
+ { 4, 0, 0, 0 },
+ { 0, 2, 1, 3 },
+ { 4 },
+ { 4 }) {}
+};
+
+struct GatherFixture1DParamsMultiDimIndices : public GatherFixture
+{
+ GatherFixture1DParamsMultiDimIndices() : GatherFixture(
+ { 4, 1, 1 },
+ { 2, 2, 1, 1 },
+ { 0, 1, 1, 3 },
+ { 4 },
+ { 2, 2 }) {}
+};
+
+struct GatherFixtureMultiDimParamMultiDimIndices : public GatherFixture
+{
+ GatherFixtureMultiDimParamMultiDimIndices() : GatherFixture(
+ { 5, 2, 1 },
+ { 2, 1, 4 },
+ { 1, 3, 0, 2 },
+ { 5, 2 },
+ { 2, 2 }) {}
+};
+
+BOOST_FIXTURE_TEST_CASE(ParseGather1DParams1DIndices, GatherFixture1DParams1DIndices)
+{
+ RunTest<4>({ { "input0", { 1, 2, 3, 4 } } },
+
+ { { "output", { 1, 3, 2, 4 } } });
+}
+
+BOOST_FIXTURE_TEST_CASE(ParseGather1DParamsMultiDimIndices, GatherFixture1DParamsMultiDimIndices)
+{
+ RunTest<4>({ { "input0", { 1, 2, 3, 4 } } },
+
+ { { "output", { 1, 2, 2, 4 } } });
+}
+
+BOOST_FIXTURE_TEST_CASE(ParseGatherMultiDimParamMultiDimIndices, GatherFixtureMultiDimParamMultiDimIndices)
+{
+ RunTest<4>({ { "input0", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 } } },
+
+ { { "output", { 3, 4, 7, 8, 1, 2, 5, 6} } });
+}
+
+BOOST_AUTO_TEST_SUITE_END()