aboutsummaryrefslogtreecommitdiff
path: root/src/armnnTfParser/test/RealDiv.cpp
blob: 952590e0018a34786d2b96c631af815a6d3f8e4e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
//
// Copyright © 2017 Arm Ltd. All rights reserved.
// SPDX-License-Identifier: MIT
//

#include <boost/test/unit_test.hpp>
#include "armnnTfParser/ITfParser.hpp"
#include "ParserPrototxtFixture.hpp"

BOOST_AUTO_TEST_SUITE(TensorflowParser)

struct DivisionFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
{
    DivisionFixture()
    {
        m_Prototext = "node { \n"
                      "    name: \"graphInput\" \n"
                      "    op: \"Placeholder\" \n"
                      "    attr { \n"
                      "      key: \"dtype\" \n"
                      "      value { \n"
                      "        type: DT_FLOAT \n"
                      "      } \n"
                      "    } \n"
                      "    attr { \n"
                      "      key: \"shape\" \n"
                      "      value { \n"
                      "        shape { \n"
                      "        } \n"
                      "      } \n"
                      "    } \n"
                      "  } \n"
                      "  node { \n"
                      "    name: \"softmax1\" \n"
                      "    op: \"Softmax\" \n"
                      "    input: \"graphInput\" \n"
                      "    attr { \n"
                      "      key: \"T\" \n"
                      "      value { \n"
                      "        type: DT_FLOAT \n"
                      "      } \n"
                      "    } \n"
                      "  }\n"
                      "  node {\n"
                      "    name: \"softmax2\"\n"
                      "    op : \"Softmax\"\n"
                      "    input: \"graphInput\"\n"
                      "    attr { \n"
                      "      key: \"T\" \n"
                      "      value { \n"
                      "        type: DT_FLOAT \n"
                      "      } \n"
                      "    } \n"
                      "  }\n"
                      "  node {\n"
                      "    name: \"division\"\n"
                      "    op : \"RealDiv\"\n"
                      "    input: \"softmax1\"\n"
                      "    input: \"softmax2\"\n"
                      "    attr { \n"
                      "      key: \"T\" \n"
                      "      value { \n"
                      "        type: DT_FLOAT \n"
                      "      } \n"
                      "    } \n"
                      "  }\n";

        SetupSingleInputSingleOutput({ 4, 1 }, "graphInput", "division");
    }
};

BOOST_FIXTURE_TEST_CASE(ParseDivision, DivisionFixture)
{
    RunTest<2>({ 2, 1.0f, 3, 1 }, { 1, 1.0f, 1, 1});
}

struct DivisionBroadcastFixture : public armnnUtils::ParserPrototxtFixture<armnnTfParser::ITfParser>
{
    DivisionBroadcastFixture(const armnn::TensorShape& inputShape0, const armnn::TensorShape& inputShape1)
    {
        m_Prototext = R"(
                 node {
                   name: "input0"
                   op: "Placeholder"
                   attr {
                     key: "dtype"
                     value {
                       type: DT_FLOAT
                     }
                   }
                   attr {
                     key: "shape"
                     value {
                       shape {
                       }
                     }
                   }
                 }
                 node {
                   name: "input1"
                   op: "Placeholder"
                   attr {
                     key: "dtype"
                     value {
                       type: DT_FLOAT
                     }
                   }
                   attr {
                     key: "shape"
                     value {
                       shape {
                       }
                     }
                   }
                 }
                 node {
                   name: "output"
                   op: "RealDiv"
                   input: "input0"
                   input: "input1"
                   attr {
                     key: "T"
                     value {
                       type: DT_FLOAT
                     }
                   }
                 }
                 )";

        Setup({ { "input0", inputShape0 },
                { "input1", inputShape1 } },
              { "output" });
    }
};
struct DivisionBroadcastFixture4D1D : public DivisionBroadcastFixture
{
    DivisionBroadcastFixture4D1D() : DivisionBroadcastFixture({ 1, 2, 2, 3 }, { 1 }) {}
};

BOOST_FIXTURE_TEST_CASE(ParseDivisionBroadcast4D1D, DivisionBroadcastFixture4D1D)
{
    RunTest<4>({ { "input0", { 0.0f, 100.0f, 2.0f,
                               3.0f, 250.0f, 15.0f,
                               33.0f, 60.0f, 5.0f,
                               35.0f, 10.0f, 55.0f } },
                 { "input1", { 5.0f } } },
               { { "output", { 0, 20.0f, 0.4f,
                               0.6f, 50.0f, 3.0f,
                               6.6f, 12.0f, 1.0f,
                               7.0f, 2.0f, 11.0f } } });
}

BOOST_FIXTURE_TEST_CASE(ParseDivideByZeroBroadcast4D1D, DivisionBroadcastFixture4D1D)
{
    float Inf = std::numeric_limits<float>::infinity();
    float NaN = std::numeric_limits<float>::quiet_NaN();

    RunTest<4>({ { "input0", { 0.0f,  -100.0f,  2.0f,
                               3.0f,  -250.0f,  15.0f,
                               33.0f,  -0,  5.0f,
                               35.0f, -10.0f, 55.0f } },
                 { "input1", { 0 } } },
               { { "output", { NaN, -Inf, Inf,
                               Inf, -Inf, Inf,
                               Inf, NaN, Inf,
                               Inf, -Inf, Inf } } });
}

BOOST_AUTO_TEST_SUITE_END()