aboutsummaryrefslogtreecommitdiff
path: root/src/armnnOnnxParser/test/BatchNorm.cpp
diff options
context:
space:
mode:
authortelsoa01 <telmo.soares@arm.com>2018-08-31 09:22:23 +0100
committertelsoa01 <telmo.soares@arm.com>2018-08-31 09:22:23 +0100
commitc577f2c6a3b4ddb6ba87a882723c53a248afbeba (patch)
treebd7d4c148df27f8be6649d313efb24f536b7cf34 /src/armnnOnnxParser/test/BatchNorm.cpp
parent4c7098bfeab1ffe1cdc77f6c15548d3e73274746 (diff)
downloadarmnn-c577f2c6a3b4ddb6ba87a882723c53a248afbeba.tar.gz
Release 18.08
Diffstat (limited to 'src/armnnOnnxParser/test/BatchNorm.cpp')
-rw-r--r--src/armnnOnnxParser/test/BatchNorm.cpp342
1 files changed, 342 insertions, 0 deletions
diff --git a/src/armnnOnnxParser/test/BatchNorm.cpp b/src/armnnOnnxParser/test/BatchNorm.cpp
new file mode 100644
index 0000000000..b708770895
--- /dev/null
+++ b/src/armnnOnnxParser/test/BatchNorm.cpp
@@ -0,0 +1,342 @@
+//
+// Copyright © 2017 Arm Ltd. All rights reserved.
+// See LICENSE file in the project root for full license information.
+//
+
+#include <boost/test/unit_test.hpp>
+#include "armnnOnnxParser/IOnnxParser.hpp"
+#include "ParserPrototxtFixture.hpp"
+
+BOOST_AUTO_TEST_SUITE(OnnxParser)
+
+struct BatchNormalizationMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
+{
+ BatchNormalizationMainFixture()
+ {
+ m_Prototext = R"(
+ ir_version: 3
+ producer_name: "CNTK"
+ producer_version: "2.5.1"
+ domain: "ai.cntk"
+ model_version: 1
+ graph {
+ name: "CNTKGraph"
+ input {
+ name: "Input"
+ type {
+ tensor_type {
+ elem_type: FLOAT
+ shape {
+ dim {
+ dim_value: 1
+ }
+ dim {
+ dim_value: 1
+ }
+ dim {
+ dim_value: 3
+ }
+ dim {
+ dim_value: 3
+ }
+ }
+ }
+ }
+ }
+ input {
+ name: "mean"
+ type {
+ tensor_type {
+ elem_type: FLOAT
+ shape {
+ dim {
+ dim_value: 1
+ }
+ }
+ }
+ }
+ }
+ input {
+ name: "var"
+ type {
+ tensor_type {
+ elem_type: FLOAT
+ shape {
+ dim {
+ dim_value: 1
+ }
+ }
+ }
+ }
+ }
+ input {
+ name: "scale"
+ type {
+ tensor_type {
+ elem_type: FLOAT
+ shape {
+ dim {
+ dim_value: 1
+ }
+ }
+ }
+ }
+ }
+ input {
+ name: "bias"
+ type {
+ tensor_type {
+ elem_type: FLOAT
+ shape {
+ dim {
+ dim_value: 1
+ }
+ }
+ }
+ }
+ }
+ node {
+ input: "Input"
+ input: "scale"
+ input: "bias"
+ input: "mean"
+ input: "var"
+ output: "Output"
+ name: "batchNorm"
+ op_type: "BatchNormalization"
+ attribute {
+ name: "epsilon"
+ f: 0.0010000000475
+ type: FLOAT
+ }
+ }
+ initializer {
+ dims: 1
+ data_type: FLOAT
+ float_data: 5.0
+ name: "mean"
+ }
+ initializer {
+ dims: 1
+ data_type: FLOAT
+ float_data: 2.0
+ name: "var"
+ }
+ initializer {
+ dims: 1
+ data_type: FLOAT
+ float_data: 0.0
+ name: "bias"
+ }
+ initializer {
+ dims: 1
+ data_type: FLOAT
+ float_data: 1.0
+ name: "scale"
+ }
+ output {
+ name: "Output"
+ type {
+ tensor_type {
+ elem_type: FLOAT
+ shape {
+ dim {
+ dim_value: 1
+ }
+ dim {
+ dim_value: 1
+ }
+ dim {
+ dim_value: 3
+ }
+ dim {
+ dim_value: 3
+ }
+ }
+ }
+ }
+ }
+ }
+ opset_import {
+ version: 7
+ })";
+ Setup();
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(ValidBatchNormalizationTest, BatchNormalizationMainFixture)
+{
+ RunTest<4>({{"Input", {1, 2, 3, 4, 5, 6, 7, 8, 9}}}, // Input data.
+ {{"Output", {-2.8277204f, -2.12079024f, -1.4138602f,
+ -0.7069301f, 0.0f, 0.7069301f,
+ 1.4138602f, 2.12079024f, 2.8277204f}}}); // Expected output data.
+}
+
+
+struct BatchNormalizationBisFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
+{
+ BatchNormalizationBisFixture()
+ {
+ m_Prototext = R"(
+ ir_version: 3
+ producer_name: "CNTK"
+ producer_version: "2.5.1"
+ domain: "ai.cntk"
+ model_version: 1
+ graph {
+ name: "CNTKGraph"
+ input {
+ name: "Input"
+ type {
+ tensor_type {
+ elem_type: FLOAT
+ shape {
+ dim {
+ dim_value: 1
+ }
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_value: 1
+ }
+ dim {
+ dim_value: 3
+ }
+ }
+ }
+ }
+ }
+ input {
+ name: "mean"
+ type {
+ tensor_type {
+ elem_type: FLOAT
+ shape {
+ dim {
+ dim_value: 2
+ }
+ }
+ }
+ }
+ }
+ input {
+ name: "var"
+ type {
+ tensor_type {
+ elem_type: FLOAT
+ shape {
+ dim {
+ dim_value: 2
+ }
+ }
+ }
+ }
+ }
+ input {
+ name: "scale"
+ type {
+ tensor_type {
+ elem_type: FLOAT
+ shape {
+ dim {
+ dim_value: 2
+ }
+ }
+ }
+ }
+ }
+ input {
+ name: "bias"
+ type {
+ tensor_type {
+ elem_type: FLOAT
+ shape {
+ dim {
+ dim_value: 2
+ }
+ }
+ }
+ }
+ }
+ node {
+ input: "Input"
+ input: "scale"
+ input: "bias"
+ input: "mean"
+ input: "var"
+ output: "Output"
+ name: "batchNorm"
+ op_type: "BatchNormalization"
+ attribute {
+ name: "epsilon"
+ f: 0.00001
+ type: FLOAT
+ }
+ }
+ initializer {
+ dims: 2
+ data_type: FLOAT
+ float_data: 0.0
+ float_data: 3.0
+ name: "mean"
+ }
+ initializer {
+ dims: 2
+ data_type: FLOAT
+ float_data: 1.0
+ float_data: 1.5
+ name: "var"
+ }
+ initializer {
+ dims: 2
+ data_type: FLOAT
+ float_data: 0.0
+ float_data: 1.0
+ name: "bias"
+ }
+ initializer {
+ dims: 2
+ data_type: FLOAT
+ float_data: 1.0
+ float_data: 1.5
+ name: "scale"
+ }
+ output {
+ name: "Output"
+ type {
+ tensor_type {
+ elem_type: FLOAT
+ shape {
+ dim {
+ dim_value: 1
+ }
+ dim {
+ dim_value: 2
+ }
+ dim {
+ dim_value: 1
+ }
+ dim {
+ dim_value: 3
+ }
+ }
+ }
+ }
+ }
+ }
+ opset_import {
+ version: 7
+ })";
+ Setup();
+ }
+};
+
+BOOST_FIXTURE_TEST_CASE(ValidBatchNormalizationBisTest, BatchNormalizationBisFixture)
+{
+ RunTest<4>({{"Input", {-1, 0.0, 1, 2, 3.0, 4.0}}}, // Input data.
+ {{"Output", {-0.999995f, 0.0, 0.999995f,
+ -0.22474074f, 1.0f, 2.2247407f}}}); // Expected output data.
+}
+
+BOOST_AUTO_TEST_SUITE_END()