aboutsummaryrefslogtreecommitdiff
path: root/reference_model
diff options
context:
space:
mode:
Diffstat (limited to 'reference_model')
-rw-r--r--reference_model/CMakeLists.txt76
-rw-r--r--reference_model/src/arith_util.h194
-rw-r--r--reference_model/src/debug_modes.def20
-rw-r--r--reference_model/src/debug_types.h57
-rw-r--r--reference_model/src/func_config.cc632
-rw-r--r--reference_model/src/func_config.def90
-rw-r--r--reference_model/src/func_config.h55
-rw-r--r--reference_model/src/func_debug.cc436
-rw-r--r--reference_model/src/func_debug.h255
-rw-r--r--reference_model/src/graph_node.cc226
-rw-r--r--reference_model/src/graph_node.h354
-rw-r--r--reference_model/src/main.cpp295
-rw-r--r--reference_model/src/model_common.h28
-rw-r--r--reference_model/src/ops/activation_funcs.cc118
-rw-r--r--reference_model/src/ops/activation_funcs.h101
-rw-r--r--reference_model/src/ops/comparison.cc81
-rw-r--r--reference_model/src/ops/comparison.h71
-rw-r--r--reference_model/src/ops/control_flow.cc353
-rw-r--r--reference_model/src/ops/control_flow.h72
-rw-r--r--reference_model/src/ops/custom.cc40
-rw-r--r--reference_model/src/ops/custom.h38
-rw-r--r--reference_model/src/ops/data_layout.cc644
-rw-r--r--reference_model/src/ops/data_layout.h216
-rw-r--r--reference_model/src/ops/data_nodes.cc172
-rw-r--r--reference_model/src/ops/data_nodes.h86
-rw-r--r--reference_model/src/ops/ewise_binary.cc586
-rw-r--r--reference_model/src/ops/ewise_binary.h195
-rw-r--r--reference_model/src/ops/ewise_ternary.cc115
-rw-r--r--reference_model/src/ops/ewise_ternary.h83
-rw-r--r--reference_model/src/ops/ewise_unary.cc302
-rw-r--r--reference_model/src/ops/ewise_unary.h102
-rw-r--r--reference_model/src/ops/image.cc169
-rw-r--r--reference_model/src/ops/image.h53
-rw-r--r--reference_model/src/ops/op_factory.cc432
-rw-r--r--reference_model/src/ops/op_factory.h294
-rw-r--r--reference_model/src/ops/reduction.cc139
-rw-r--r--reference_model/src/ops/reduction.h109
-rw-r--r--reference_model/src/ops/scatter_gather.cc120
-rw-r--r--reference_model/src/ops/scatter_gather.h54
-rw-r--r--reference_model/src/ops/template_types.h277
-rw-r--r--reference_model/src/ops/tensor_ops.cc1229
-rw-r--r--reference_model/src/ops/tensor_ops.h253
-rw-r--r--reference_model/src/ops/type_conversion.cc299
-rw-r--r--reference_model/src/ops/type_conversion.h162
-rw-r--r--reference_model/src/quant_util.h103
-rw-r--r--reference_model/src/subgraph_traverser.cc649
-rw-r--r--reference_model/src/subgraph_traverser.h90
-rw-r--r--reference_model/src/tensor.cc3008
-rw-r--r--reference_model/src/tensor.h815
49 files changed, 14348 insertions, 0 deletions
diff --git a/reference_model/CMakeLists.txt b/reference_model/CMakeLists.txt
new file mode 100644
index 0000000..0ba8afb
--- /dev/null
+++ b/reference_model/CMakeLists.txt
@@ -0,0 +1,76 @@
+cmake_minimum_required (VERSION 3.4)
+
+# Copyright (c) 2020, ARM Limited.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+project(tosa_reference_model LANGUAGES CXX)
+
+set(CMAKE_CXX_STANDARD 17)
+set(CMAKE_CXX_STANDARD_REQUIRED ON)
+
+if(CMAKE_CXX_COMPILER_ID STREQUAL GNU)
+ set(CMAKE_CXX_FLAGS "-Wall -Wno-ignored-attributes -Wno-format-truncation")
+else()
+ set(CMAKE_CXX_FLAGS "-Wall -Wno-ignored-attributes")
+endif()
+
+set(FLATBUFFERS_DIR "../thirdparty/flatbuffers/")
+set(SERIALIZATION_DIR "../serialization")
+
+set (CXX_SOURCE
+ src/main.cpp
+ src/tensor.cc
+ src/graph_node.cc
+ src/subgraph_traverser.cc
+ src/func_debug.cc
+ src/func_config.cc
+ src/ops/op_factory.cc
+ src/ops/tensor_ops.cc
+ src/ops/activation_funcs.cc
+ src/ops/ewise_binary.cc
+ src/ops/ewise_unary.cc
+ src/ops/ewise_ternary.cc
+ src/ops/comparison.cc
+ src/ops/reduction.cc
+ src/ops/data_layout.cc
+ src/ops/scatter_gather.cc
+ src/ops/image.cc
+ src/ops/type_conversion.cc
+ src/ops/data_nodes.cc
+ src/ops/custom.cc
+ src/ops/control_flow.cc
+)
+
+add_executable(tosa_reference_model ${CXX_SOURCE})
+
+target_include_directories(tosa_reference_model
+ PUBLIC
+ $<INSTALL_INTERFACE:include>
+ $<BUILD_INTERFACE:${CMAKE_CURRENT_SRC_DIR}/include>
+ PRIVATE
+ ${CMAKE_CURRENT_SOURCE_DIR}/src
+ ${FLATBUFFERS_DIR}/include
+ ../thirdparty/eigen/
+ ../thirdparty/eigen/unsupported/
+ ${SERIALIZATION_DIR}
+)
+
+target_link_libraries(tosa_reference_model
+ PRIVATE
+ flatbuffers
+ tosa_serialization
+)
+
+install (TARGETS tosa_reference_model DESTINATION bin)
diff --git a/reference_model/src/arith_util.h b/reference_model/src/arith_util.h
new file mode 100644
index 0000000..554a7a2
--- /dev/null
+++ b/reference_model/src/arith_util.h
@@ -0,0 +1,194 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+/*
+ * Filename: src/arith_util.h
+ * Description:
+ * arithmetic utility macro, include:
+ * fp16 (float16_t ) type alias
+ * bitwise operation
+ * fix point arithmetic
+ * fp16 type conversion(in binary translation)
+ * fp16 arithmetic (disguised with fp32 now)
+ */
+
+#ifndef ARITH_UTIL_H
+#define ARITH_UTIL_H
+
+#include <fenv.h>
+#include <math.h>
+#define __STDC_LIMIT_MACROS //enable min/max of plain data type
+#include "func_debug.h"
+#include "inttypes.h"
+#include <cassert>
+#include <iostream>
+#include <limits>
+#include <stdint.h>
+#include <typeinfo>
+
+using namespace std;
+
+inline size_t _count_one(uint64_t val)
+{
+ size_t count = 0;
+ for (; val; count++)
+ {
+ val &= val - 1;
+ }
+ return count;
+}
+
+template <typename T>
+inline size_t _integer_log2(T val)
+{
+ size_t result = 0;
+ while (val >>= 1)
+ {
+ ++result;
+ }
+ return result;
+}
+
+template <typename T>
+inline size_t _count_leading_zeros(T val)
+{
+ size_t size = sizeof(T) * 8;
+ size_t count = 0;
+ T msb = static_cast<T>(1) << (size - 1);
+ for (size_t i = 0; i < size; i++)
+ {
+ if (!((val << i) & msb))
+ count++;
+ else
+ break;
+ }
+ return count;
+}
+
+template <typename T>
+inline size_t _count_leading_ones(T val)
+{
+ size_t size = sizeof(T) * 8;
+ size_t count = 0;
+ T msb = static_cast<T>(1) << (size - 1);
+ for (size_t i = 0; i < size; i++)
+ {
+ if ((val << i) & msb)
+ count++;
+ else
+ break;
+ }
+ return count;
+}
+
+#define MAX(a, b) ((a) > (b) ? (a) : (b))
+#define MIN(a, b) ((a) < (b) ? (a) : (b))
+// Compute ceiling of (a/b)
+#define DIV_CEIL(a, b) ((a) % (b) ? ((a) / (b) + 1) : ((a) / (b)))
+
+// Returns a mask of 1's of this size
+#define ONES_MASK(SIZE) ((uint64_t)((SIZE) >= 64 ? 0xffffffffffffffffULL : ((uint64_t)(1) << (SIZE)) - 1))
+
+// Returns a field of bits from HIGH_BIT to LOW_BIT, right-shifted
+// include both side, equivalent VAL[LOW_BIT:HIGH_BIT] in verilog
+
+#define BIT_FIELD(HIGH_BIT, LOW_BIT, VAL) (((uint64_t)(VAL) >> (LOW_BIT)) & ONES_MASK((HIGH_BIT) + 1 - (LOW_BIT)))
+
+// Returns a bit at a particular position
+#define BIT_EXTRACT(POS, VAL) (((uint64_t)(VAL) >> (POS)) & (1))
+
+// Use Brian Kernigahan's way: https://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetKernighan
+// Does this need to support floating point type?
+// Not sure if static_cast is the right thing to do, try to be type safe first
+#define ONES_COUNT(VAL) (_count_one((uint64_t)(VAL)))
+
+#define SHIFT(SHF, VAL) (((SHF) > 0) ? ((VAL) << (SHF)) : ((SHF < 0) ? ((VAL) >> (-(SHF))) : (VAL)))
+#define ROUNDTO(A, B) ((A) % (B) == 0 ? (A) : ((A) / (B) + 1) * (B))
+#define ROUNDTOLOWER(A, B) (((A) / (B)) * (B))
+#define BIDIRECTIONAL_SHIFT(VAL, SHIFT) (((SHIFT) >= 0) ? ((VAL) << (SHIFT)) : ((VAL) >> (-(SHIFT))))
+#define ILOG2(VAL) (_integer_log2(VAL))
+
+// Get negative value (2's complement)
+#define NEGATIVE_8(VAL) ((uint8_t)(~(VAL) + 1))
+#define NEGATIVE_16(VAL) ((uint16_t)(~(VAL) + 1))
+#define NEGATIVE_32(VAL) ((uint32_t)(~(VAL) + 1))
+#define NEGATIVE_64(VAL) ((uint64_t)(~(VAL) + 1))
+// Convert a bit quanity to the minimum bytes required to hold those bits
+#define BITS_TO_BYTES(BITS) (ROUNDTO((BITS), 8) / 8)
+
+// Count leading zeros/ones for 8/16/32/64-bit operands
+// (I don't see an obvious way to collapse this into a size-independent set)
+// treated as unsigned
+#define LEADING_ZEROS_64(VAL) (_count_leading_zeros((uint64_t)(VAL)))
+#define LEADING_ZEROS_32(VAL) (_count_leading_zeros((uint32_t)(VAL)))
+#define LEADING_ZEROS_16(VAL) (_count_leading_zeros((uint16_t)(VAL)))
+#define LEADING_ZEROS_8(VAL) (_count_leading_zeros((uint8_t)(VAL)))
+#define LEADING_ZEROS(VAL) (_count_leading_zeros(VAL))
+
+#define LEADING_ONES_64(VAL) _count_leading_ones((uint64_t)(VAL))
+#define LEADING_ONES_32(VAL) _count_leading_ones((uint32_t)(VAL))
+#define LEADING_ONES_16(VAL) _count_leading_ones((uint16_t)(VAL))
+#define LEADING_ONES_8(VAL) _count_leading_ones((uint8_t)(VAL))
+#define LEADING_ONES(VAL) _count_leading_ones(VAL)
+// math operation
+// sign-extended for signed version
+// extend different return type (8, 16, 32) + (S, U)
+// Saturate a value at a certain bitwidth, signed and unsigned versions
+// Format is as followed: SATURATE_VAL_{saturation_sign}_{return_type}
+// for example
+// SATURATE_VAL_U_8U(8,300) will return uint8_t with value of 255(0xff)
+// SATURATE_VAL_S_32S(5,-48) will return int32_t with value of -16(0x10)
+// note that negative value can cast to unsigned return type using native uint(int) cast
+// so SATURATE_VAL_S_8U(5,-40) will have value 0'b1110000 which is in turn 224 in uint8_t
+
+template <typename T>
+constexpr T bitmask(const uint32_t width)
+{
+ ASSERT(width <= sizeof(T) * 8);
+ return width == sizeof(T) * 8 ? static_cast<T>(std::numeric_limits<uintmax_t>::max())
+ : (static_cast<T>(1) << width) - 1;
+}
+
+template <typename T>
+constexpr T minval(const uint32_t width)
+{
+ ASSERT(width <= sizeof(T) * 8);
+ return std::is_signed<T>::value ? -(static_cast<T>(1) << (width - 1)) : 0;
+}
+
+template <typename T>
+constexpr T maxval(const uint32_t width)
+{
+ ASSERT(width <= sizeof(T) * 8);
+ return bitmask<T>(width - std::is_signed<T>::value);
+}
+
+template <typename T>
+constexpr T saturate(const uint32_t width, const intmax_t value)
+{
+ // clang-format off
+ return static_cast<T>(
+ std::min(
+ std::max(
+ value,
+ static_cast<intmax_t>(minval<T>(width))
+ ),
+ static_cast<intmax_t>(maxval<T>(width))
+ )
+ );
+ // clang-format on
+}
+
+#endif /* _ARITH_UTIL_H */
diff --git a/reference_model/src/debug_modes.def b/reference_model/src/debug_modes.def
new file mode 100644
index 0000000..51b151d
--- /dev/null
+++ b/reference_model/src/debug_modes.def
@@ -0,0 +1,20 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Defines the debugging printing modes
+
+DEBUG_MODE(CONFIG,0) // Configuration parsing/initialization
+DEBUG_MODE(GT,1) // Graph traverser
+DEBUG_MODE(OP,2) // Operation
diff --git a/reference_model/src/debug_types.h b/reference_model/src/debug_types.h
new file mode 100644
index 0000000..bd93f19
--- /dev/null
+++ b/reference_model/src/debug_types.h
@@ -0,0 +1,57 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+/*
+ * Filename: src/debug_types.h
+ * Description:
+ * Defines fundamental debugger datatypes for the functional model
+ */
+
+#ifndef DEBUG_TYPES_H_
+#define DEBUG_TYPES_H_
+
+#ifdef __cplusplus
+extern "C"
+{
+#endif
+
+ // Debug verbosity mask
+ typedef enum func_debug_verbosity_e
+ {
+ DEBUG_VERB_NONE = 0x00,
+ DEBUG_VERB_INFO = 0x01, // Informational debugging messages
+ DEBUG_VERB_IFACE = 0x02, // Interface debugging support
+ DEBUG_VERB_LOW = 0x04, // Low, medium, and high levels of debug printout
+ DEBUG_VERB_MED = 0x08,
+ DEBUG_VERB_HIGH = 0x10
+ } func_debug_verbosity_e;
+
+ // Generated debug modes enumeration
+ typedef enum func_debug_mode_e
+ {
+ DEBUG_NONE = 0x0,
+#define DEBUG_MODE(NAME, BIT) DEBUG_##NAME = (1UL << BIT),
+#include "debug_modes.def"
+#undef DEBUG_MODE
+ DEBUG_ALL = 0xffffffffffffffffUL
+ } func_debug_mode_e;
+
+#define DEBUG_INST_ALL 0xffffffffffffffffUL
+
+#ifdef __cplusplus
+}
+#endif
+
+#endif
diff --git a/reference_model/src/func_config.cc b/reference_model/src/func_config.cc
new file mode 100644
index 0000000..bd1ce32
--- /dev/null
+++ b/reference_model/src/func_config.cc
@@ -0,0 +1,632 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <ctype.h>
+#include <signal.h>
+#include <stdarg.h>
+#include <stdint.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/types.h>
+
+#include "func_config.h"
+#include "func_debug.h"
+
+#define MAX_NAME_LEN 128
+#define MAX_DESC_LEN 128
+
+#ifndef ARG_ERROR
+#define ARG_ERROR(...) \
+ fprintf(stderr, "ERROR: "); \
+ fprintf(stderr, __VA_ARGS__); \
+ fprintf(stderr, "\n"); \
+ return 1;
+#endif
+
+// Parameter base name string table
+const char* config_base_name_table[] = {
+#define DEF_UNIT_START(UNIT)
+#define DEF_UNIT_END(UNIT)
+#define DEF_OPTION(NAME, DESC, TYPE, FMT, DEFAULT) #NAME,
+#define DEF_OPTION_STR(NAME, DESC, LEN, DEFAULT) #NAME,
+#define DEF_UNIT_OPTION(UNIT, NAME, DESC, TYPE, FMT, DEFAULT) #NAME,
+#define DEF_UNIT_OPTION_STR(UNIT, NAME, DESC, LEN, DEFAULT) #NAME,
+#include "func_config.def"
+#undef DEF_UNIT_START
+#undef DEF_UNIT_END
+#undef DEF_OPTION
+#undef DEF_OPTION_STR
+#undef DEF_UNIT_OPTION_STR
+#undef DEF_UNIT_OPTION
+};
+
+// Parameter description table
+const char* config_param_desc_table[] = {
+#define DEF_UNIT_START(UNIT)
+#define DEF_UNIT_END(UNIT)
+#define DEF_OPTION(NAME, DESC, TYPE, FMT, DEFAULT) #DESC,
+#define DEF_OPTION_STR(NAME, DESC, LEN, DEFAULT) #DESC,
+#define DEF_UNIT_OPTION(UNIT, NAME, DESC, TYPE, FMT, DEFAULT) #DESC,
+#define DEF_UNIT_OPTION_STR(UNIT, NAME, DESC, LEN, DEFAULT) #DESC,
+#include "func_config.def"
+#undef DEF_UNIT_START
+#undef DEF_UNIT_END
+#undef DEF_OPTION
+#undef DEF_UNIT_OPTION
+#undef DEF_OPTION_STR
+#undef DEF_UNIT_OPTION_STR
+};
+
+// String table and enum for the option hierarchy level/sub-levels
+// (no leaf options). Attribute at the top level have "BASE" as their
+// enum value and an empty string for the value.
+const char* config_hier_str_table[] = {
+ "",
+#define DEF_UNIT_START(UNIT) #UNIT,
+#define DEF_UNIT_END(UNIT) /**/
+#define DEF_OPTION(NAME, DESC, TYPE, FMT, DEFAULT) /**/
+#define DEF_OPTION_STR(NAME, DESC, LEN, DEFAULT) /**/
+#define DEF_UNIT_OPTION(UNIT, NAME, DESC, TYPE, FMT, DEFAULT) /**/
+#define DEF_UNIT_OPTION_STR(UNIT, NAME, DESC, LEN, DEFAULT) /**/
+#include "func_config.def"
+#undef DEF_UNIT_START
+#undef DEF_UNIT_END
+#undef DEF_OPTION
+#undef DEF_UNIT_OPTION
+#undef DEF_OPTION_STR
+#undef DEF_UNIT_OPTION_STR
+};
+
+typedef enum config_hier_enum_t
+{
+ BASE,
+#define DEF_UNIT_START(UNIT) CURRENT_UNIT,
+#define DEF_UNIT_END(UNIT) /**/
+#define DEF_OPTION(NAME, DESC, TYPE, FMT, DEFAULT) /**/
+#define DEF_OPTION_STR(NAME, DESC, LEN, DEFAULT) /**/
+#define DEF_UNIT_OPTION(UNIT, NAME, DESC, TYPE, FMT, DEFAULT) /**/
+#define DEF_UNIT_OPTION_STR(UNIT, NAME, DESC, LEN, DEFAULT) /**/
+#include "func_config.def"
+#undef DEF_UNIT_START
+#undef DEF_UNIT_END
+#undef DEF_OPTION
+#undef DEF_UNIT_OPTION
+#undef DEF_OPTION_STR
+#undef DEF_UNIT_OPTION_STR
+
+ MAX_CONFIG_HIER
+} config_hier_enum_t;
+
+// Mapping from a leaf parameter index to the
+// position in the hierarchy.
+config_hier_enum_t config_hierarchy_map[] = {
+#define DEF_UNIT_START(UNIT)
+#define DEF_UNIT_END(UNIT)
+#define DEF_OPTION(NAME, DESC, TYPE, FMT, DEFAULT) BASE,
+#define DEF_OPTION_STR(NAME, DESC, LEN, DEFAULT) BASE,
+#define DEF_UNIT_OPTION(UNIT, NAME, DESC, TYPE, FMT, DEFAULT) CURRENT_UNIT,
+#define DEF_UNIT_OPTION_STR(UNIT, NAME, DESC, LEN, DEFAULT) CURRENT_UNIT,
+#include "func_config.def"
+#undef DEF_UNIT_START
+#undef DEF_UNIT_END
+#undef DEF_OPTION
+#undef DEF_UNIT_OPTION
+#undef DEF_OPTION_STR
+#undef DEF_UNIT_OPTION_STR
+};
+
+#define CONFIG_PARAMETER_COUNT (sizeof(config_hierarchy_map) / sizeof(config_hier_enum_t))
+
+// Dynamically generated at initialization
+char** config_param_str_table = nullptr;
+
+// Initialize the configuration data structures
+int func_model_init_config()
+{
+ // Initialize string table (builds the hierarchical names)
+ config_param_str_table = (char**)calloc(CONFIG_PARAMETER_COUNT, sizeof(char*));
+ ASSERT_MEM(config_param_str_table);
+
+ for (uint32_t i = 0; i < CONFIG_PARAMETER_COUNT; i++)
+ {
+ size_t len = strlen(config_base_name_table[i]) + 1;
+ if (config_hierarchy_map[i] != BASE)
+ {
+ ASSERT_MSG(config_hierarchy_map[i] <= MAX_CONFIG_HIER,
+ "Configuration parameter\'s hierarchy is out of bounds");
+ len += strlen(config_hier_str_table[config_hierarchy_map[i]]) + 1;
+ }
+ config_param_str_table[i] = (char*)calloc(len, 1);
+ ASSERT_MEM(config_param_str_table[i]);
+ ASSERT_MSG(len < MAX_NAME_LEN, "option expanded name is too long: %s", config_base_name_table[i]);
+
+ if (config_hierarchy_map[i] != BASE)
+ {
+ snprintf(config_param_str_table[i], len, "%s.%s", config_hier_str_table[config_hierarchy_map[i]],
+ config_base_name_table[i]);
+ }
+ else
+ {
+ snprintf(config_param_str_table[i], len, "%s", config_base_name_table[i]);
+ }
+ }
+
+ return 0;
+}
+
+int func_model_set_default_config(func_config_t* func_config)
+{
+ // Set default values in the global configuration data structure
+ bzero(func_config, sizeof(*func_config));
+
+#define DEF_UNIT_START(UNIT)
+#define DEF_UNIT_END(UNIT)
+#define DEF_OPTION(NAME, DESC, TYPE, FMT, DEFAULT) func_config->NAME = (DEFAULT);
+#define DEF_OPTION_STR(NAME, DESC, LEN, DEFAULT) strncpy(func_config->NAME, (DEFAULT), (LEN)-1);
+#define DEF_UNIT_OPTION(UNIT, NAME, DESC, TYPE, FMT, DEFAULT) func_config->UNIT.NAME = (DEFAULT);
+#define DEF_UNIT_OPTION_STR(UNIT, NAME, DESC, LEN, DEFAULT) strncpy(func_config->UNIT.NAME, (DEFAULT), (LEN)-1);
+#include "func_config.def"
+#undef DEF_UNIT_START
+#undef DEF_UNIT_END
+#undef DEF_OPTION
+#undef DEF_UNIT_OPTION
+#undef DEF_OPTION_STR
+#undef DEF_UNIT_OPTION_STR
+
+ return 0;
+}
+
+int func_model_config_cleanup()
+{
+ uint32_t i;
+
+ if (!config_param_str_table)
+ return 1;
+
+ for (i = 0; i < CONFIG_PARAMETER_COUNT; i++)
+ {
+ free(config_param_str_table[i]);
+ }
+
+ free(config_param_str_table);
+ config_param_str_table = nullptr;
+
+ return 0;
+}
+
+int func_model_config_set_option(func_config_t* func_config, const char* name, const char* value)
+{
+ // Increment an index variable on each parameter position
+ // so that we can index both the position struct through the macro and the
+ // array of parameter names through a simple array of strings.
+ int param_idx = 0;
+ char* endptr;
+
+ // TODO: does not handle strings yet. Can set magic values on FMT to
+ // choose a string copy vs strtoull
+#define DEF_UNIT_START(UNIT)
+#define DEF_UNIT_END(UNIT)
+#define DEF_OPTION(NAME, DESC, TYPE, FMT, DEFAULT) \
+ if (!strcmp(config_param_str_table[param_idx], name)) \
+ { \
+ func_config->NAME = (uint64_t)strtoll(value, &endptr, 0); \
+ if (endptr == value) \
+ { \
+ ARG_ERROR("Cannot parse option: %s = %s", name, value); \
+ } \
+ return 0; \
+ } \
+ param_idx++;
+
+#define DEF_OPTION_STR(NAME, DESC, LEN, DEFAULT) \
+ if (!strcmp(config_param_str_table[param_idx], name)) \
+ { \
+ if (strlen(value) >= LEN) \
+ { \
+ ARG_ERROR("Option value is too long: %s = %s", name, value); \
+ } \
+ strncpy(func_config->NAME, value, (LEN)-1); \
+ return 0; \
+ } \
+ param_idx++;
+
+#define DEF_UNIT_OPTION(UNIT, NAME, DESC, TYPE, FMT, DEFAULT) \
+ if (!strcmp(config_param_str_table[param_idx], name)) \
+ { \
+ func_config->UNIT.NAME = (uint64_t)strtoll(value, &endptr, 0); \
+ if (endptr == value) \
+ { \
+ ARG_ERROR("Cannot parse option: %s = %s", name, value); \
+ } \
+ return 0; \
+ } \
+ param_idx++;
+
+#define DEF_UNIT_OPTION_STR(UNIT, NAME, DESC, LEN, DEFAULT) \
+ if (!strcmp(config_param_str_table[param_idx], name)) \
+ { \
+ if (strlen(value) >= LEN) \
+ { \
+ ARG_ERROR("Option value is too long: %s = %s", name, value); \
+ } \
+ strncpy(func_config->UNIT.NAME, value, (LEN)-1); \
+ return 0; \
+ } \
+ param_idx++;
+
+#include "func_config.def"
+#undef DEF_UNIT_START
+#undef DEF_UNIT_END
+#undef DEF_OPTION
+#undef DEF_UNIT_OPTION
+#undef DEF_OPTION_STR
+#undef DEF_UNIT_OPTION_STR
+
+ // No match!
+ ARG_ERROR("Cannot find option: %s", name);
+
+ return 1;
+}
+
+int func_model_config_get_option_by_name(func_config_t* func_config, const char* name, uint64_t* val)
+{
+ // Increment an index variable on each parameter position
+ // so that we can index both the position struct through the macro and the
+ // array of parameter names through a simple array of strings.
+ int param_idx = 0;
+
+#define DEF_UNIT_START(UNIT)
+#define DEF_UNIT_END(UNIT)
+
+#define DEF_OPTION_STR(NAME, DESC, LEN, DEFAULT) param_idx++;
+
+#define DEF_UNIT_OPTION_STR(UNIT, NAME, DESC, FMT, DEFAULT) param_idx++;
+
+#define DEF_OPTION(NAME, DESC, TYPE, FMT, DEFAULT) \
+ if (!strcmp(config_param_str_table[param_idx], name)) \
+ { \
+ *val = func_config->NAME; \
+ return 0; \
+ } \
+ param_idx++;
+
+#define DEF_UNIT_OPTION(UNIT, NAME, DESC, TYPE, FMT, DEFAULT) \
+ if (!strcmp(config_param_str_table[param_idx], name)) \
+ { \
+ *val = func_config->UNIT.NAME; \
+ return 0; \
+ } \
+ param_idx++;
+
+#include "func_config.def"
+#undef DEF_UNIT_START
+#undef DEF_UNIT_END
+#undef DEF_OPTION
+#undef DEF_UNIT_OPTION
+#undef DEF_OPTION_STR
+#undef DEF_UNIT_OPTION_STR
+ // No match!
+ return 1;
+}
+int func_model_config_get_str_option_by_name(func_config_t* func_config,
+ const char* name,
+ char* value,
+ const uint32_t len)
+{
+ // Increment an index variable on each parameter position
+ // so that we can index both the position struct through the macro and the
+ // array of parameter names through a simple array of strings.
+ int param_idx = 0;
+
+#define DEF_UNIT_START(UNIT)
+#define DEF_UNIT_END(UNIT)
+#define DEF_OPTION_STR(NAME, DESC, LEN, DEFAULT) \
+ if (!strcmp(config_param_str_table[param_idx], name)) \
+ { \
+ strncpy(value, func_config->NAME, len - 1); \
+ return 0; \
+ } \
+ param_idx++;
+
+#define DEF_UNIT_OPTION_STR(UNIT, NAME, DESC, LEN, DEFAULT) \
+ if (!strcmp(config_param_str_table[param_idx], name)) \
+ { \
+ strncpy(value, func_config->UNIT.NAME, len - 1); \
+ return 0; \
+ } \
+ param_idx++;
+
+#define DEF_OPTION(NAME, DESC, TYPE, FMT, DEFAULT) param_idx++;
+
+#define DEF_UNIT_OPTION(UNIT, NAME, DESC, TYPE, FMT, DEFAULT) param_idx++;
+
+#include "func_config.def"
+#undef DEF_UNIT_START
+#undef DEF_UNIT_END
+#undef DEF_OPTION
+#undef DEF_UNIT_OPTION
+#undef DEF_OPTION_STR
+#undef DEF_UNIT_OPTION_STR
+ // No match!
+ return 1;
+}
+
+int func_config_print_config_help(FILE* out)
+{
+ fprintf(out, "%-40s %s\n", "Option", "Description");
+ fprintf(out, "%-40s %s\n", "------", "-----------");
+
+ for (uint32_t i = 0; i < CONFIG_PARAMETER_COUNT; i++)
+ {
+ fprintf(out, "-C%-40s %s\n", config_param_str_table[i], config_param_desc_table[i]);
+ }
+
+ fprintf(out, "\n");
+
+ return 0;
+}
+
+int func_model_print_config(func_config_t* func_config, FILE* out)
+{
+#define DEF_UNIT_START(UNIT)
+#define DEF_UNIT_END(UNIT)
+#define DEF_OPTION(NAME, DESC, TYPE, FMT, DEFAULT) fprintf(out, "%-40s = " FMT "\n", #NAME, func_config->NAME);
+#define DEF_UNIT_OPTION(UNIT, NAME, DESC, TYPE, FMT, DEFAULT) \
+ fprintf(out, "%-40s = " FMT "\n", #UNIT "." #NAME, func_config->UNIT.NAME);
+#define DEF_OPTION_STR(NAME, DESC, LEN, DEFAULT) fprintf(out, "%-40s = %s\n", #NAME, func_config->NAME);
+#define DEF_UNIT_OPTION_STR(UNIT, NAME, DESC, LEN, DEFAULT) \
+ fprintf(out, "%-40s = %s\n", #UNIT "." #NAME, func_config->UNIT.NAME);
+
+#define FOF_HEX "0x%llx"
+#define FOF_DEC "%" PRIu32
+#define FOF_DECU64 "%" PRIu64
+
+#include "func_config.def"
+#undef DEF_UNIT_START
+#undef DEF_UNIT_END
+#undef DEF_OPTION
+#undef DEF_UNIT_OPTION
+#undef DEF_OPTION_STR
+#undef DEF_UNIT_OPTION_STR
+
+ return 0;
+}
+
+static const char* programname;
+
+void func_model_print_debug_masks(FILE* out)
+{
+ fprintf(out, "\t List of components:\n");
+#define DEBUG_MODE(string, value) fprintf(out, "\t\t" #string "\n");
+#include "debug_modes.def"
+#undef DEBUG_MODE
+}
+
+int func_model_print_help(FILE* out)
+{
+ fprintf(out, "TOSA Reference Model help\n\n");
+
+ fprintf(out,
+ "Usage: %s [-c] [-C <name=value>] [-d <Debug Mask>] [-h] [-i <uscriptfile>] [-l <verbosity>] [-F "
+ "<flatconfig>]\n",
+ programname);
+ fprintf(out, "\t-c - Print list of config options\n");
+ fprintf(out, "\t-C <name=value> - modify config option <name> to <value>\n");
+ fprintf(out, "\t-d <Debug Mask - set component debug mask\n");
+ func_model_print_debug_masks(out);
+ fprintf(out, "\t-F <flatconfig> - parse <flatconfig> as file of config options\n");
+ fprintf(out, "\t-h - show this help message and exit\n");
+ fprintf(
+ out,
+ "\t-i <input_tensor_name>,<filename> - set input tensor <input_tensor_name> to the values from <filename>\n");
+ fprintf(out, "\t-l <verbosity> - set log verbosity\n");
+ fprintf(out, "\t-o <debuglog> - set debug log file\n");
+ fprintf(out, "\n");
+
+ func_config_print_config_help(stdout);
+
+ return 0;
+}
+
+static const char* get_arg_text(int& index, const int argc, const char** argv)
+{
+ if (strlen(argv[index]) > 2)
+ {
+ return argv[index] + 2;
+ }
+
+ if ((index + 1 == argc) || (argv[index + 1][0] == '-'))
+ {
+ fprintf(stderr, "No option value found for option %s\n", argv[index]);
+ return "";
+ }
+
+ index++;
+ return argv[index];
+}
+
+// Read the command line arguments
+int func_model_parse_cmd_line(func_config_t* func_config, func_debug_t* func_debug, const int argc, const char** argv)
+{
+ int i;
+ programname = argv[0];
+ for (i = 1; i < argc; i++)
+ {
+ // All command line arguments must begin with -X where X is a recognized character
+ if (strlen(argv[i]) < 2 || argv[i][0] != '-')
+ {
+ func_model_print_help(stderr);
+ ARG_ERROR("Command line argument at position %d not valid: %s", i, argv[i]);
+ }
+
+ switch (argv[i][1])
+ {
+ // Model parameters may be overridden with the -Cname=value switch
+ case 'c':
+ func_config_print_config_help(stderr);
+ return 1;
+
+ case 'C':
+ {
+ const char *name = nullptr, *value = nullptr;
+
+ // Break the string into name and value parts
+ name = get_arg_text(i, argc, argv);
+ value = strchr(name, '=');
+
+ if (value == nullptr)
+ {
+ func_model_print_help(stderr);
+ ARG_ERROR("Cannot parse -C argument at position %d: %s", i, argv[i]);
+ }
+
+ *const_cast<char*>(value) = 0;
+
+ if (func_model_config_set_option(func_config, name, value + 1))
+ {
+ func_model_print_help(stderr);
+ ARG_ERROR("Cannot parse -C argument at position %d: %s", i, argv[i]);
+ }
+ break;
+ }
+
+ case 'd':
+ case 'D':
+ {
+ func_debug_set_mask(func_debug, get_arg_text(i, argc, argv));
+ break;
+ }
+ case 'F':
+ {
+ // Read a flat configuration file
+ if (func_model_parse_flat_config_file(func_config, get_arg_text(i, argc, argv)))
+ return 1;
+
+ break;
+ }
+ case 'h':
+ func_model_print_help(stderr);
+ return 1;
+
+ case 'i':
+ {
+ // shortcut for '-Cinput_tensor='
+ if (func_model_config_set_option(func_config, "input_tensor", get_arg_text(i, argc, argv)))
+ {
+ func_model_print_help(stderr);
+ ARG_ERROR("Cannot set input tensor config value");
+ }
+ break;
+ }
+ case 'l':
+ {
+ // Debug verbosity/logging level
+ func_debug_set_verbosity(func_debug, get_arg_text(i, argc, argv));
+ break;
+ }
+ case 'o':
+ {
+ func_debug_set_file(func_debug, get_arg_text(i, argc, argv));
+ break;
+ }
+ default:
+ func_model_print_help(stderr);
+ ARG_ERROR("Unrecognized argument at position %d: %s", i, argv[i]);
+ }
+ }
+
+ return 0;
+}
+
+int func_model_parse_flat_config_file(func_config_t* func_config, const char* filename)
+{
+ const int MAX_LINE_LEN = 1024;
+
+ FILE* infile = nullptr;
+ char line_buf[MAX_LINE_LEN];
+ int line = 1;
+
+ infile = fopen(filename, "r");
+
+ if (infile == nullptr)
+ {
+ ARG_ERROR("Cannot open config file: %s\n", filename);
+ }
+
+ while (fgets(line_buf, MAX_LINE_LEN - 1, infile) != nullptr)
+ {
+ char *name = line_buf, *value = nullptr, *comment = nullptr, *ptr = nullptr;
+
+ // Remove comments
+ comment = strchr(line_buf, '#');
+
+ if (comment)
+ *comment = 0;
+
+ // Break the string into name and value parts
+ name = line_buf;
+
+ // Remove leading whitespace
+ while (*name && isspace(*name))
+ name++;
+
+ // Empty line?
+ if (*name == 0)
+ {
+ line++;
+ continue;
+ }
+
+ value = strchr(name, '=');
+
+ // Missing value
+ if (value == nullptr)
+ {
+ ARG_ERROR("Cannot parse parameter in %s at line %d: %s", filename, line, line_buf);
+ }
+
+ // Remove the =
+ *value = 0;
+ value++;
+
+ // Trim off any whitespace at the end of the value
+ ptr = value;
+ while (*ptr != 0 && !isspace(*ptr))
+ ptr++;
+ *ptr = 0;
+
+ // Include a nested file
+ if (!strcmp(name, "include"))
+ {
+ if (func_model_parse_flat_config_file(func_config, value))
+ return 1;
+ line++;
+ continue;
+ }
+
+ if (func_model_config_set_option(func_config, name, value))
+ {
+ func_model_print_help(stderr);
+ ARG_ERROR("Cannot set parameter in %s at line %d: %s", filename, line, line_buf)
+ }
+
+ line++;
+ }
+
+ fclose(infile);
+
+ return 0;
+}
diff --git a/reference_model/src/func_config.def b/reference_model/src/func_config.def
new file mode 100644
index 0000000..004cf36
--- /dev/null
+++ b/reference_model/src/func_config.def
@@ -0,0 +1,90 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+/*
+ * Filename: src/func_config.def
+ * Description:
+ * Defines the model parameters/options for the functional model.
+ */
+
+// Placeholder values for the Functional model Option Formatting (FOF) fields
+//
+// FOF_DEC is decimal
+// FOF_HEX is hexidecimal
+//
+// Floating point values are not supported yet, but there is no fundamental reason
+// why we can't have them.
+#ifndef FOF_DEC
+#define FOF_DEC 1
+#endif
+
+#ifndef FOF_HEX
+#define FOF_HEX 1
+#endif
+
+#ifndef FOF_STR_LEN
+#define FOF_STR_LEN 1024
+#endif
+
+// Options are defined as follows:
+// DEF_OPTION() defines a top-level option
+// Arguments:
+// option_field_name: a C-syntax field name in the struct
+// description: a short string that describes the purpose of the option (printed out with help)
+// C type: the type of the option (typically a uint64_t, uint32_t, etc)
+// Format field: the FOF_* type used to figure out how to format/print the option
+// Default value: the default value assigned to the option, if it isn't assigned by an configuration file
+// or command line override
+
+// For defining hierarchical options (example hierarchy is 'cle', use the following formula).
+// All options within the hierarchical space must be grouped together:
+//
+
+// #define CURRENT_UNIT cle
+// DEF_UNIT_START(CURRENT_UNIT)
+// DEF_UNIT_OPTION(CURRENT_UNIT,...)
+// ...
+// DEF_UNIT_END(CURRENT_UNIT)
+// #undef CURRENT_UNIT
+//
+// The CURRENT_UNIT argument is required as a parameter in these definitions because
+// macro processing rules only allow stringification of macro parameters. Unfortunately,
+// Other tokens that are NOT passed in as macro parameters cannot be stringified.
+
+DEF_OPTION_STR(operator_fbs, "Flat buffer syntax file", FOF_STR_LEN, "../serialization/tosa.fbs")
+DEF_OPTION_STR(subgraph_dir, "Subgraph directory to load", FOF_STR_LEN, ".")
+DEF_OPTION_STR(subgraph_file, "Subgraph file to load", FOF_STR_LEN, "")
+DEF_OPTION_STR(input_dir, "Input directory path for dumps/files", FOF_STR_LEN, ".")
+DEF_OPTION_STR(input_tensor, "A list of pairs <name0>:<npy0>,<name1>:<npy1>", FOF_STR_LEN, "")
+DEF_OPTION_STR(output_dir, "Output directory path for output dumps/files", FOF_STR_LEN, ".")
+DEF_OPTION(eval, "Evaluate the network (0/1)", uint32_t, FOF_DEC, 1)
+DEF_OPTION(validate_only, "Validate the network, but do not read inputs or evaluate (0/1)", uint32_t, FOF_DEC, 0)
+DEF_OPTION(output_tensors, "Output tensors to a file (0/1)", uint32_t, FOF_DEC, 1)
+DEF_OPTION(tosa_profile, "Set TOSA profile (0 = Base Inference, 1 = Main Inference, 2 = Main Training)", uint32_t, FOF_DEC, 1)
+DEF_OPTION_STR(output_tensor_prefix, "Optional output tensor prefix", FOF_STR_LEN, "output_")
+DEF_OPTION(dump_intermediates, "Dump intermediate tensors (0/1)", uint32_t, FOF_DEC, 0)
+DEF_OPTION_STR(fp_format, "Floating-point number dump format string (printf-style format, e.g. 0.5)", FOF_STR_LEN, "0.5")
+// Example of a hierarchical option
+//#define CURRENT_UNIT arch
+//DEF_UNIT_START(arch)
+//DEF_UNIT_OPTION(arch, ifm_width, "input feature map width(x dim)", uint32_t, FOF_DEC, 10)
+//DEF_UNIT_END(CURRENT_UNIT)
+///#undef CURRENT_UNIT
+
+// START Do not delete
+// Required for keeping the FOFs clean
+#undef FOF_DEC
+#undef FOF_HEX
+// END Do not delete^^
diff --git a/reference_model/src/func_config.h b/reference_model/src/func_config.h
new file mode 100644
index 0000000..f941300
--- /dev/null
+++ b/reference_model/src/func_config.h
@@ -0,0 +1,55 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef FUNC_CONFIG_H_
+#define FUNC_CONFIG_H_
+
+// Parameter value structure
+#define DEF_UNIT_START(UNIT) \
+ struct UNIT##_t \
+ {
+#define DEF_UNIT_END(UNIT) \
+ } \
+ UNIT;
+#define DEF_OPTION(NAME, DESC, TYPE, FMT, DEFAULT) TYPE NAME;
+#define DEF_OPTION_STR(NAME, DESC, LEN, DEFAULT) char NAME[LEN];
+#define DEF_UNIT_OPTION(UNIT, NAME, DESC, TYPE, FMT, DEFAULT) TYPE NAME;
+#define DEF_UNIT_OPTION_STR(UNIT, NAME, DESC, LEN, DEFAULT) char NAME[LEN];
+struct func_config_t
+{
+#include "func_config.def"
+#undef DEF_UNIT_START
+#undef DEF_UNIT_END
+#undef DEF_OPTION
+#undef DEF_OPTION_STR
+#undef DEF_UNIT_OPTION
+#undef DEF_UNIT_OPTION_STR
+};
+
+// Forward declaration
+struct func_debug_t;
+
+int func_model_init_config();
+int func_model_set_default_config(func_config_t*);
+int func_model_config_set_option(func_config_t*, const char* name, const char* value);
+int func_model_print_config(func_config_t*, FILE* out);
+int func_model_parse_cmd_line(func_config_t*, func_debug_t* func_debug, const int argc, const char** argv);
+int func_model_parse_flat_config_file(func_config_t*, const char* filename);
+int func_model_config_cleanup();
+int func_model_config_get_str_option_by_name(func_config_t*, const char* name, char* value, const uint32_t len);
+int func_model_config_get_option_by_name(func_config_t*, const char* name, uint64_t* val);
+int func_model_print_help(FILE* out);
+
+#endif
diff --git a/reference_model/src/func_debug.cc b/reference_model/src/func_debug.cc
new file mode 100644
index 0000000..f5f045e
--- /dev/null
+++ b/reference_model/src/func_debug.cc
@@ -0,0 +1,436 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <ctype.h>
+#include <signal.h>
+#include <stdarg.h>
+#include <stdio.h>
+#include <stdlib.h>
+#include <string.h>
+#include <sys/types.h>
+
+#ifndef _MSC_VER
+#include <execinfo.h>
+#include <sys/prctl.h>
+#include <sys/ptrace.h>
+#include <sys/wait.h>
+#include <unistd.h>
+#endif
+
+#include "func_debug.h"
+
+#define MAX_FRAMES 100
+
+#ifndef _MSC_VER
+pid_t func_print_backtrace_helper(int num_tries, int sig);
+#endif
+
+void func_print_backtrace(FILE* out, int sig)
+{
+#ifndef _MSC_VER
+ for (int i = 0; i < 2; i++)
+ {
+ const pid_t child_pid = func_print_backtrace_helper(i, sig);
+ if (child_pid < 0)
+ {
+ perror("Backtrace generation failed on fork");
+ break;
+ }
+
+ int status = 0;
+ waitpid(child_pid, &status, 0);
+ if (WEXITSTATUS(status) == 0)
+ {
+ break;
+ }
+ }
+#endif
+}
+
+#ifndef _MSC_VER
+pid_t func_print_backtrace_helper(int num_tries, int sig)
+{
+ const pid_t child_pid = fork();
+
+ if (child_pid)
+ {
+ return 0;
+ }
+
+ const pid_t ppid = getppid();
+
+ printf("Attaching debugger to pid %d\n", ppid);
+ // Check if we're in a debugger
+ if (ptrace(PTRACE_ATTACH, ppid, 0, 0) == 0)
+ {
+ // If we reach this point, no debugger is present
+ // Undo effects of PTRACE_ATTACH
+ waitpid(ppid, NULL, 0);
+ ptrace(PTRACE_CONT, 0, 0, 0);
+ ptrace(PTRACE_DETACH, ppid, 0, 0);
+
+ dup2(STDERR_FILENO, STDOUT_FILENO);
+
+ char parent_pid[20];
+ snprintf(parent_pid, sizeof(parent_pid), "attach %d", ppid);
+ fprintf(stdout, "Caught signal %d (%s)\n", sig, strsignal(sig));
+
+ execlp("gdb", "gdb", "--batch", "-n", "-ex",
+ // Don't print startup messages for each thread
+ "-ex", "set print thread-events off", "-ex", parent_pid,
+ // Turn off pagination
+ "-ex", "set height 0",
+ // Print a backtrace for the current thread
+ "-ex", "thread $_thread", "-ex", "bt",
+ // Print a backtrace for the main thread (uncomment the next two lines, if desired)
+ //"-ex", "thread 1",
+ //"-ex", "bt",
+ // Print a backtrace for all thread (TMI)
+ //"-ex", "thread apply all bt",
+ NULL);
+
+ // If we reach this point, it is bad. Attempt to print an error before exiting.
+ perror("Backtrace generation failed to invoke gdb");
+ exit(1);
+ }
+
+ // Debugger present. Exit here.
+ exit(0);
+
+ return 0;
+}
+#endif
+
+void func_backtrace_signal_handler(int sig)
+{
+ func_print_backtrace(NULL, sig);
+ exit(1);
+}
+
+// Note: this overwrites other signal handlers. May want to make this
+// more friendly sometime
+void func_enable_signal_handlers()
+{
+ static const int sig_list[] = { SIGABRT, SIGSEGV, SIGILL, SIGFPE };
+
+ if (getenv("FUNC_NO_SIG_HANDLERS"))
+ {
+ return;
+ }
+
+ for (size_t i = 0; i < sizeof(sig_list) / sizeof(int); i++)
+ {
+ struct sigaction act;
+
+ bzero(&act, sizeof(act));
+ act.sa_handler = func_backtrace_signal_handler;
+
+ if (sigaction(sig_list[i], &act, NULL))
+ {
+ perror("Error calling sigaction");
+ }
+ }
+}
+
+const char* func_debug_mode_str_table[] = {
+#define DEBUG_MODE(NAME, BIT) #NAME,
+#include "debug_modes.def"
+#undef DEBUG_MODE
+};
+
+#define DEBUG_MASK_COUNT (sizeof(func_debug_mode_str_table) / sizeof(const char*))
+
+const char* func_debug_verbosity_str_table[] = { "NONE", "INFO", "IFACE", "LOW", "MED", "HIGH" };
+
+const uint32_t func_debug_verbosity_mask_table[] = { DEBUG_VERB_NONE, DEBUG_VERB_INFO, DEBUG_VERB_IFACE,
+ DEBUG_VERB_LOW, DEBUG_VERB_MED, DEBUG_VERB_HIGH };
+
+#define DEBUG_VERBOSITY_COUNT (sizeof(func_debug_verbosity_str_table) / sizeof(const char*))
+
+// Initialize the debug mode
+int func_init_debug(func_debug_t* func_debug, uint64_t inst_id)
+{
+ // Set the default debug settings
+ bzero(func_debug, sizeof(func_debug_t));
+ func_debug_set_mask(func_debug, DEBUG_NONE);
+ func_debug_set_verbosity(func_debug, DEBUG_VERB_NONE);
+ func_debug_set_inst_mask(func_debug, DEBUG_INST_ALL);
+ func_debug->func_debug_file = stderr;
+ func_debug_set_captured_warnings(func_debug, 0);
+ func_debug_set_output_unbuffered(func_debug, false);
+ func_debug->inst_id = inst_id;
+
+ return 0;
+}
+
+int func_fini_debug(func_debug_t* func_debug)
+{
+ if (func_debug->record_warnings)
+ {
+ func_debug_set_captured_warnings(func_debug, 0);
+ }
+
+#ifndef _FUNC_INCLUDE_WINDOWS_SUPPORT_H
+ if (func_debug->is_gzip && func_debug->func_debug_file)
+ {
+ pclose(func_debug->func_debug_file);
+ func_debug->func_debug_file = NULL;
+ }
+#endif
+
+ return 0;
+}
+
+int func_debug_set_file(func_debug_t* func_debug, const char* filename)
+{
+ int filenameLen = strlen(filename);
+
+ // Open the debug output file
+ ASSERT(filename != NULL);
+#ifndef _FUNC_INCLUDE_WINDOWS_SUPPORT_H
+ if (filenameLen > 3 && strcmp(filename + filenameLen - 3, ".gz") == 0)
+ {
+ char cmd[256];
+
+ snprintf(cmd, sizeof(cmd), "gzip > %s", filename);
+ func_debug->func_debug_file = popen(cmd, "w");
+ func_debug->is_gzip = 1;
+ }
+ else
+ {
+#else
+ {
+#endif
+ func_debug->func_debug_file = fopen(filename, "w");
+ }
+
+ if (!func_debug->func_debug_file)
+ {
+ perror(NULL);
+ FATAL_ERROR("Cannot open debug output file: %s\n", filename);
+ return 1;
+ }
+ if (func_debug->is_output_unbuffered)
+ {
+ setvbuf(func_debug->func_debug_file, nullptr, _IONBF, 0);
+ }
+
+ return 0;
+}
+
+void func_debug_set_verbosity(func_debug_t* func_debug, const char* str)
+{
+ if (!strcasecmp(str, "RESET"))
+ {
+ func_debug_set_verbosity(func_debug, DEBUG_VERB_NONE);
+ return;
+ }
+
+ for (size_t i = 0; i < DEBUG_VERBOSITY_COUNT; i++)
+ {
+ if (!strcasecmp(str, func_debug_verbosity_str_table[i]))
+ {
+ func_debug_set_verbosity(func_debug, func_debug_verbosity_mask_table[i]);
+ return;
+ }
+ }
+
+ FATAL_ERROR("Invalid debug verbosity: %s", str);
+}
+
+void func_debug_set_verbosity(func_debug_t* func_debug, const uint32_t verb)
+{
+ uint32_t new_mask = verb;
+
+ switch (verb)
+ {
+ case DEBUG_VERB_NONE:
+ new_mask = DEBUG_VERB_NONE;
+ break;
+ case DEBUG_VERB_INFO:
+ new_mask = DEBUG_VERB_INFO;
+ break;
+ case DEBUG_VERB_IFACE:
+ new_mask = DEBUG_VERB_IFACE;
+ break;
+ case DEBUG_VERB_HIGH:
+ new_mask |= DEBUG_VERB_HIGH;
+ // Intentional fallthrough
+ case DEBUG_VERB_MED:
+ new_mask |= DEBUG_VERB_MED;
+ // Intentional fallthrough
+ case DEBUG_VERB_LOW:
+ new_mask |= DEBUG_VERB_LOW;
+ new_mask |= DEBUG_VERB_INFO;
+ new_mask |= DEBUG_VERB_IFACE;
+ break;
+ }
+
+ func_debug->func_debug_verbosity = new_mask;
+}
+
+void func_debug_set_suppress_arch_error_mask(func_debug_t* func_debug, const uint32_t suppress)
+{
+ func_debug->func_suppress_arch_error_mask = suppress;
+}
+
+void func_debug_set_mask(func_debug_t* func_debug, const uint64_t mask)
+{
+ if (mask == DEBUG_NONE)
+ func_debug->func_debug_mask = mask;
+ else
+ func_debug->func_debug_mask |= mask;
+
+ // Set a minimum verbosity level
+ if (func_debug->func_debug_verbosity == DEBUG_VERB_NONE)
+ func_debug->func_debug_verbosity = DEBUG_VERB_INFO;
+}
+
+void func_debug_set_inst_mask(func_debug_t* func_debug, const char* mask)
+{
+ uint64_t val;
+
+ val = strtoul(mask, NULL, 0);
+
+ return func_debug_set_inst_mask(func_debug, val);
+}
+
+void func_debug_set_inst_mask(func_debug_t* func_debug, const uint64_t mask)
+{
+ if (mask == 0)
+ func_debug->func_debug_inst_mask = DEBUG_INST_ALL;
+ else
+ func_debug->func_debug_inst_mask = mask;
+}
+
+void func_debug_set_mask(func_debug_t* func_debug, const char* str)
+{
+ if (!strcasecmp(str, "all"))
+ {
+ func_debug_set_mask(func_debug, UINT64_MAX - 1);
+ return;
+ }
+
+ size_t i;
+ for (i = 0; i < DEBUG_MASK_COUNT; i++)
+ {
+ if (!strcasecmp(str, func_debug_mode_str_table[i]))
+ {
+ func_debug_set_mask(func_debug, 1ULL << i);
+ return;
+ }
+ }
+
+ func_debug_print_masks(stderr);
+
+ FATAL_ERROR("Invalid debug mask: %s", str);
+}
+
+void func_debug_print_masks(FILE* out)
+{
+ uint32_t i;
+
+ fprintf(out, "Available debug masks:\n");
+
+ for (i = 0; i < DEBUG_MASK_COUNT; i++)
+ {
+ fprintf(out, "[%d] %s\n", i, func_debug_mode_str_table[i]);
+ }
+}
+
+void func_debug_set_output_unbuffered(func_debug_t* func_debug, const bool is_unbuffered)
+{
+ func_debug->is_output_unbuffered = is_unbuffered;
+}
+
+// Print warnings to the debug file or optionally store them in a buffer instead
+// Note that the buffer is circular and can be overwritten if enough messages are
+// written before removing a warning from the front.
+void func_debug_warning(
+ func_debug_t* func_debug, const char* file, const char* func, const int line, const char* fmt, ...)
+{
+ va_list args;
+ va_start(args, fmt);
+
+ if (func_debug->record_warnings)
+ {
+ // Record to the circular buffer
+ uint32_t len;
+
+ len = snprintf(func_debug->warning_buffer[func_debug->warning_buffer_tail], WARNING_BUFFER_ENTRY_LENGTH,
+ "WARNING AT %s:%d %s(): ", file, line, func);
+ vsnprintf(func_debug->warning_buffer[func_debug->warning_buffer_tail] + len, WARNING_BUFFER_ENTRY_LENGTH - len,
+ fmt, args);
+ func_debug->warning_buffer_tail = (func_debug->warning_buffer_tail + 1) % WARNING_BUFFER_SIZE;
+ }
+ else
+ {
+ // Print to the debug file (e.g., stderr)
+ fprintf(func_debug->func_debug_file, "WARNING AT %s:%d %s():\n", file, line, func);
+ vfprintf(func_debug->func_debug_file, fmt, args);
+ fprintf(func_debug->func_debug_file, "\n");
+ }
+ va_end(args);
+}
+
+// Initialize the warning buffer capture
+int func_debug_set_captured_warnings(func_debug_t* func_debug, uint32_t capture)
+{
+ uint32_t i;
+ func_debug->record_warnings = capture;
+ if (capture)
+ {
+ func_debug->warning_buffer_head = 0;
+ func_debug->warning_buffer_tail = 0;
+
+ for (i = 0; i < WARNING_BUFFER_SIZE; i++)
+ {
+ func_debug->warning_buffer[i] = (char*)calloc(1, WARNING_BUFFER_ENTRY_LENGTH);
+ }
+ }
+ else
+ {
+ for (i = 0; i < WARNING_BUFFER_SIZE; i++)
+ {
+ if (func_debug->warning_buffer[i])
+ {
+ free(func_debug->warning_buffer[i]);
+ func_debug->warning_buffer[i] = NULL;
+ }
+ }
+ }
+
+ return 0;
+}
+
+int func_debug_has_captured_warning(func_debug_t* func_debug)
+{
+ if (func_debug->record_warnings && func_debug->warning_buffer_head != func_debug->warning_buffer_tail)
+ return 1;
+ else
+ return 0;
+}
+
+int func_debug_get_captured_warning(func_debug_t* func_debug, char* buf_ptr, const uint32_t buf_len)
+{
+ if (!func_debug_has_captured_warning(func_debug))
+ return 1;
+
+ strncpy(buf_ptr, func_debug->warning_buffer[func_debug->warning_buffer_head], buf_len);
+
+ func_debug->warning_buffer_head = (func_debug->warning_buffer_head + 1) % WARNING_BUFFER_SIZE;
+
+ return 0;
+}
diff --git a/reference_model/src/func_debug.h b/reference_model/src/func_debug.h
new file mode 100644
index 0000000..2d47462
--- /dev/null
+++ b/reference_model/src/func_debug.h
@@ -0,0 +1,255 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef FUNC_DEBUG_H
+#define FUNC_DEBUG_H
+
+#include "debug_types.h"
+#include <assert.h>
+#include <cinttypes>
+#include <signal.h>
+#include <stdio.h>
+
+void func_print_backtrace(FILE* out, int sig = SIGABRT);
+
+void func_enable_signal_handlers();
+
+// Debug content container
+#define WARNING_BUFFER_SIZE 16
+#define WARNING_BUFFER_ENTRY_LENGTH 1024
+
+// STRINGIFY2 is needed expand expression passed to STRINGIFY
+#define STRINGIFY2(s) #s
+#define STRINGIFY(s) STRINGIFY2(s)
+
+// If TRACED_LOG is defined, add file:line to log messages
+#if defined(TRACED_LOG)
+#define WHERE "@" __FILE__ ":" STRINGIFY(__LINE__)
+#else
+#define WHERE
+#endif
+
+#if defined(COLORIZED_LOG)
+#define COL(col, fmt) "\x1b[3" col "m" fmt "\x1b[0m"
+#define COL_FATAL(fmt) COL("1;41", fmt)
+#define COL_WARN(fmt) COL("1;43", fmt)
+#define COL_INFO(fmt) COL("2", fmt)
+#define COL_IFACE(fmt) fmt
+#define COL_LOW(fmt) COL("35", fmt)
+#define COL_MED(fmt) COL("2;33", fmt)
+#define COL_HIGH(fmt) COL("2;32", fmt)
+#else
+#define COL_FATAL(fmt) fmt
+#define COL_WARN(fmt) fmt
+#define COL_INFO(fmt) fmt
+#define COL_IFACE(fmt) fmt
+#define COL_LOW(fmt) fmt
+#define COL_MED(fmt) fmt
+#define COL_HIGH(fmt) fmt
+#endif
+
+struct func_debug_t
+{
+ uint32_t func_debug_verbosity; // What verbosity level is set? (bitmask)
+ uint64_t func_debug_mask; // Which units have debugging enabled? (bitmask)
+ uint64_t func_debug_inst_mask; // Which instances have debugging enabled (bitmask)
+ uint64_t inst_id; // The instance id for multiple model instances
+ uint32_t func_suppress_arch_error_mask; // Which architecture error should be suppressed? (bitmask)
+ FILE* func_debug_file; // Output file
+ uint32_t record_warnings;
+ char* warning_buffer[WARNING_BUFFER_SIZE];
+ uint32_t warning_buffer_head; // next unread message
+ uint32_t warning_buffer_tail; // next message to write
+ uint32_t is_gzip;
+ bool is_output_unbuffered; // should log files be opened with unbuffered I/O.
+};
+
+#ifndef ASSERT
+#define ASSERT(COND) \
+ if (!(COND)) \
+ { \
+ fprintf(stderr, COL_FATAL("ASSERTION AT %s:%d %s(): (%s)\n"), __FILE__, __LINE__, __func__, #COND); \
+ func_print_backtrace(stderr); \
+ assert(COND); \
+ }
+#endif
+
+#ifndef ASSERT_MSG
+#define ASSERT_MSG(COND, fmt, ...) \
+ if (!(COND)) \
+ { \
+ fprintf(stderr, COL_FATAL("ASSERTION AT %s:%d %s(): (%s)\n"), __FILE__, __LINE__, __func__, #COND); \
+ fprintf(stderr, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \
+ func_print_backtrace(stderr); \
+ assert(COND); \
+ }
+#endif
+
+#ifndef ASSERT_MSG_NODE
+#define ASSERT_MSG_NODE(COND, fmt, ...) \
+ if (!(COND)) \
+ { \
+ fprintf(g_func_debug.func_debug_file, COL_FATAL("ASSERTION AT %s:%d %s(): (%s)\n"), __FILE__, __LINE__, \
+ __func__, #COND); \
+ fprintf(g_func_debug.func_debug_file, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \
+ this->dumpNode(g_func_debug.func_debug_file); \
+ func_print_backtrace(g_func_debug.func_debug_file); \
+ assert(COND); \
+ }
+#endif
+
+// Assertion specific to allocating memory
+#ifndef ASSERT_MEM
+#define ASSERT_MEM(OBJ) \
+ if (!(OBJ)) \
+ { \
+ fprintf(stderr, COL_FATAL("ASSERTION AT %s:%d %s(): (" #OBJ "): out of memory\n"), __FILE__, __LINE__, \
+ __func__); \
+ func_print_backtrace(stderr); \
+ assert(OBJ); \
+ }
+#endif
+
+#ifndef FATAL_ERROR
+#define FATAL_ERROR(fmt, ...) \
+ fprintf(stderr, COL_FATAL("FATAL ERROR AT %s:%d %s():\n"), __FILE__, __LINE__, __func__); \
+ fprintf(stderr, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \
+ func_print_backtrace(stderr); \
+ abort();
+#endif
+
+#ifndef FATAL_ERROR_NODE
+#define FATAL_ERROR_NODE(fmt, ...) \
+ fprintf(g_func_debug.func_debug_file, COL_FATAL("FATAL ERROR AT %s:%d %s():\n"), __FILE__, __LINE__, __func__); \
+ fprintf(g_func_debug.func_debug_file, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \
+ this->dumpNode(g_func_debug.func_debug_file); \
+ func_print_backtrace(g_func_debug.func_debug_file); \
+ abort();
+#endif
+#ifndef SIMPLE_FATAL_ERROR
+#define SIMPLE_FATAL_ERROR(fmt, ...) \
+ fprintf(stderr, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \
+ exit(1);
+#endif
+
+void func_debug_warning(
+ func_debug_t* func_debug, const char* file, const char* func, const int line, const char* fmt, ...);
+#ifndef WARNING
+#define WARNING(...) func_debug_warning(&g_func_debug, __FILE__, __func__, __LINE__, __VA_ARGS__)
+#endif
+
+#ifndef WARNING_STDERR
+#define WARNING_STDERR(fmt, ...) \
+ fprintf(stderr, COL_WARN("WARNING AT %s:%d %s():\n"), __FILE__, __LINE__, __func__); \
+ fprintf(stderr, COL_WARN(fmt) "\n", ##__VA_ARGS__);
+#endif
+
+int func_debug_set_captured_warnings(func_debug_t* func_debug, uint32_t capture);
+
+int func_debug_has_captured_warning(func_debug_t* func_debug);
+
+int func_debug_get_captured_warning(func_debug_t* func_debug, char* buf_ptr, const uint32_t buf_len);
+
+// Is this debug verbosity and unit level enabled?
+// Provide compiler hints that this is unlikely
+// Two versions, depending on whether DEBUG_INSTANCE_EXPR is defined in a file or not
+//
+// For .cpp files whose units have discrete instance IDs, define DEBUG_INSTANCE_EXPR to evalute
+// to the instance ID variable. The use of this define in header files is discouraged.
+
+#ifdef DEBUG_INSTANCE_EXPR
+// Expression for whether the debugging verbosity + debugging unit is enabled for free-form printouts
+#ifdef DEBUG_INSTANCE_EXPR_2
+#define DEBUG_ENABLED(VERB, LEVEL) \
+ (__builtin_expect((g_func_debug.func_debug_mask == DEBUG_ALL || g_func_debug.func_debug_mask & (DEBUG_##LEVEL)) && \
+ (g_func_debug.func_debug_inst_mask & (uint64_t(1) << (DEBUG_INSTANCE_EXPR))) && \
+ (g_func_debug.func_debug_verbosity & (VERB)), \
+ 0))
+// Debug printing macro
+#define DEBUG(VERB, LEVEL, FMT, ...) \
+ if (DEBUG_ENABLED(VERB, LEVEL)) \
+ { \
+ fprintf(g_func_debug.func_debug_file, "[%d:" #LEVEL "_%02d_%02d" WHERE "]: " FMT "\n", \
+ (int)g_func_debug.inst_id, (int)(DEBUG_INSTANCE_EXPR), (int)(DEBUG_INSTANCE_EXPR_2), ##__VA_ARGS__); \
+ }
+
+// Prints just the debugging prefix for properly marking free-form printouts
+#define DEBUG_PREFIX(LEVEL) \
+ fprintf(g_func_debug.func_debug_file, "[%d" #LEVEL "_%02d_%02d" WHERE "]: ", (int)g_func_debug.inst_id, \
+ (int)(DEBUG_INSTANCE_EXPR), (int)(DEBUG_INSTANCE_EXPR_2))
+
+#else // !DEBUG_INSTANCE_EXPR_2
+
+#define DEBUG_ENABLED(VERB, LEVEL) \
+ (__builtin_expect((g_func_debug.func_debug_mask == DEBUG_ALL || g_func_debug.func_debug_mask & (DEBUG_##LEVEL)) && \
+ (g_func_debug.func_debug_inst_mask & (uint64_t(1) << (DEBUG_INSTANCE_EXPR))) && \
+ (g_func_debug.func_debug_verbosity & (VERB)), \
+ 0))
+// Debug printing macro
+#define DEBUG(VERB, LEVEL, FMT, ...) \
+ if (DEBUG_ENABLED(VERB, LEVEL)) \
+ { \
+ fprintf(g_func_debug.func_debug_file, "[%d:" #LEVEL "_%02d" WHERE "]: " FMT "\n", (int)g_func_debug.inst_id, \
+ (int)(DEBUG_INSTANCE_EXPR), ##__VA_ARGS__); \
+ }
+
+// Prints just the debugging prefix for properly marking free-form printouts
+#define DEBUG_PREFIX(LEVEL) \
+ fprintf(g_func_debug.func_debug_file, "[%d:" #LEVEL "_%02d" WHERE "]: ", (int)g_func_debug.inst_id, \
+ (int)(DEBUG_INSTANCE_EXPR))
+
+#endif // DEBUG_INSTANCE_EXPR_2
+
+#else // !DEBUG_INSTANCE_EXPR
+
+// Expression for whether the debugging verbosity + debugging unit is enabled for free-form printouts
+#define DEBUG_ENABLED(VERB, LEVEL) \
+ (__builtin_expect((g_func_debug.func_debug_mask == DEBUG_ALL || g_func_debug.func_debug_mask & (DEBUG_##LEVEL)) && \
+ (g_func_debug.func_debug_verbosity & (VERB)), \
+ 0))
+// Debug printing macro
+#define DEBUG(VERB, LEVEL, FMT, ...) \
+ if (DEBUG_ENABLED(VERB, LEVEL)) \
+ { \
+ fprintf(g_func_debug.func_debug_file, "[%d:" #LEVEL WHERE "]: " FMT "\n", (int)g_func_debug.inst_id, \
+ ##__VA_ARGS__); \
+ }
+
+// Prints just the debugging prefix for properly marking free-form printouts
+#define DEBUG_PREFIX(LEVEL) fprintf(g_func_debug.func_debug_file, "[" #LEVEL WHERE "]: ")
+
+#endif
+
+// Macros for different verbosity levels
+#define DEBUG_INFO(LEVEL, FMT, ...) DEBUG(DEBUG_VERB_INFO, LEVEL, COL_INFO(FMT), ##__VA_ARGS__)
+#define DEBUG_IFACE(LEVEL, FMT, ...) DEBUG(DEBUG_VERB_IFACE, LEVEL, COL_IFACE(FMT), ##__VA_ARGS__)
+#define DEBUG_LOW(LEVEL, FMT, ...) DEBUG(DEBUG_VERB_LOW, LEVEL, COL_LOW(FMT), ##__VA_ARGS__)
+#define DEBUG_MED(LEVEL, FMT, ...) DEBUG(DEBUG_VERB_MED, LEVEL, COL_MED(FMT), ##__VA_ARGS__)
+#define DEBUG_HIGH(LEVEL, FMT, ...) DEBUG(DEBUG_VERB_HIGH, LEVEL, COL_HIGH(FMT), ##__VA_ARGS__)
+
+int func_init_debug(func_debug_t*, uint64_t inst_id);
+int func_fini_debug(func_debug_t*);
+int func_debug_set_file(func_debug_t*, const char* filename);
+void func_debug_set_mask(func_debug_t*, const char* str);
+void func_debug_set_mask(func_debug_t*, const uint64_t mask);
+void func_debug_print_masks(FILE* out);
+void func_debug_set_verbosity(func_debug_t*, const char* str);
+void func_debug_set_verbosity(func_debug_t*, const uint32_t verb);
+void func_debug_set_suppress_arch_error_mask(func_debug_t*, const uint32_t suppress);
+void func_debug_set_inst_mask(func_debug_t*, const char* mask);
+void func_debug_set_inst_mask(func_debug_t*, const uint64_t mask);
+void func_debug_set_output_unbuffered(func_debug_t*, const bool is_unbuffered);
+
+#endif
diff --git a/reference_model/src/graph_node.cc b/reference_model/src/graph_node.cc
new file mode 100644
index 0000000..b57b9dd
--- /dev/null
+++ b/reference_model/src/graph_node.cc
@@ -0,0 +1,226 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "graph_node.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+GraphNode::GraphNode(const Op& nodeType_, const uint64_t id_)
+{
+ nodeType = nodeType_;
+ nodeId = id_;
+ inputs.clear();
+ outputs.clear();
+ inputNames.clear();
+ outputNames.clear();
+ clearNodeMarked();
+ evalCount = 0;
+ clearOnNextNodeList();
+ setRequiredOperands(-1, -1);
+ setRequiredRank(-1);
+}
+
+GraphNode::~GraphNode()
+{}
+
+int GraphNode::addInputName(std::string& name)
+{
+ inputNames.push_back(name);
+ return 0;
+}
+
+int GraphNode::addOutputName(std::string& name)
+{
+ outputNames.push_back(name);
+ return 0;
+}
+
+int GraphNode::addInputTensor(Tensor* tens)
+{
+ ASSERT_MSG(tens, "GraphNode::addInputTensor: no tensor provided");
+ inputs.push_back(tens);
+ return 0;
+}
+
+int GraphNode::addOutputTensor(Tensor* tens)
+{
+ ASSERT_MSG(tens, "GraphNode::addOutputTensor: no tensor provided");
+ outputs.push_back(tens);
+ return 0;
+}
+
+int GraphNode::checkTensorAttributes()
+{
+ // Placeholder
+ return 0;
+}
+
+int GraphNode::eval()
+{
+ // Placeholder evaluation function
+ evalCount++;
+
+ // this should be set by derived op
+ for (auto ct : getOutputs())
+ {
+ ct->setIsValid();
+ }
+
+ return 0;
+}
+
+int GraphNode::hasAllInputsReady() const
+{
+ for (size_t i = 0; i < inputs.size(); i++)
+ {
+ if (!inputs[i]->getIsValid())
+ return false;
+ }
+
+ return true;
+}
+
+int GraphNode::hasAllOutputsReady() const
+{
+ for (size_t i = 0; i < outputs.size(); i++)
+ {
+ if (!outputs[i]->getIsValid())
+ return false;
+ }
+
+ return true;
+}
+
+int GraphNode::dumpNode(FILE* out)
+{
+ int i;
+ fprintf(out, "Node type: %s ID: %lu Eval Count: %d On next node list: %d Is marked: %d\n", EnumNamesOp()[nodeType],
+ nodeId, evalCount, onNextNodeList, isMarked);
+
+ i = 0;
+ for (Tensor* ins : inputs)
+ {
+ fprintf(out, " Input[%d] ", i++);
+ ins->dumpTensorParams(out);
+ }
+
+ i = 0;
+ for (Tensor* outs : outputs)
+ {
+ fprintf(out, " Output[%d] ", i++);
+ outs->dumpTensorParams(out);
+ }
+
+ return 0;
+}
+
+int GraphNode::dumpNode(std::ostream& out)
+{
+ int i;
+
+ out << "Node type: " << EnumNamesOp()[nodeType] << " ID: " << nodeId << " Eval count: " << evalCount
+ << " On next node list: " << onNextNodeList << " Is marked: " << isMarked << std::endl;
+
+ out << " Inputs:";
+ for (std::string& name : inputNames)
+ {
+ out << " " << name;
+ }
+ out << std::endl;
+
+ i = 0;
+ for (Tensor* ins : inputs)
+ {
+ out << " Input[" << i++ << "]: ";
+ ins->dumpTensorParams(out);
+ }
+
+ out << " Outputs:";
+ for (std::string& name : outputNames)
+ {
+ out << " " << name;
+ }
+ out << std::endl;
+
+ i = 0;
+ for (Tensor* outs : outputs)
+ {
+ out << " Output[" << i++ << "]: ";
+ outs->dumpTensorParams(out);
+ }
+ return 0;
+}
+
+int GraphNode::printNodeValidationError(const std::string& msg)
+{
+ std::cout << "Operator validation error: " << msg << std::endl;
+ ;
+ dumpNode(std::cout);
+
+ return 0;
+}
+
+int GraphNode::validateRequiredOperands()
+{
+ if (requiredInputCount >= 0 && inputs.size() != (size_t)requiredInputCount)
+ {
+ printNodeValidationError(std::string(EnumNamesOp()[nodeType]) + " operator must have " +
+ std::to_string(requiredInputCount) + " input(s)");
+ return 1;
+ }
+
+ if (requiredOutputCount >= 0 && outputs.size() != (size_t)requiredOutputCount)
+ {
+ printNodeValidationError(std::string(EnumNamesOp()[nodeType]) + " operator output must have exactly " +
+ std::to_string(requiredOutputCount) + " output(s)");
+ return 1;
+ }
+
+ return 0;
+}
+
+int GraphNode::validateRequiredRank(const Tensor* t)
+{
+ if (requiredRankMin >= 0 && requiredRankMax >= 0)
+ {
+ if (t->checkRequiredRank(requiredRankMin, requiredRankMax))
+ {
+ printNodeValidationError(std::string(EnumNamesOp()[nodeType]) +
+ " operand has illegal rank=" + std::to_string(t->getRank()) + " not in range [" +
+ std::to_string(requiredRankMin) + "," + std::to_string(requiredRankMax) +
+ "]. tensorName: " + t->getName());
+ return 1;
+ }
+ else
+ {
+ return 0;
+ }
+ }
+
+ if (requiredRankMin >= 0)
+ {
+ if (t->checkRequiredRank(requiredRankMin))
+ {
+ printNodeValidationError(std::string(EnumNamesOp()[nodeType]) +
+ " operand has illegal rank=" + std::to_string(t->getRank()) + " not equal to " +
+ std::to_string(requiredRankMin) + ". tensorName: " + t->getName());
+ return 1;
+ }
+ }
+
+ return 0;
+}
diff --git a/reference_model/src/graph_node.h b/reference_model/src/graph_node.h
new file mode 100644
index 0000000..5b4a767
--- /dev/null
+++ b/reference_model/src/graph_node.h
@@ -0,0 +1,354 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GRAPH_NODE_H
+#define GRAPH_NODE_H
+
+#include "attribute.h"
+#include "quant_info.h"
+#include "tensor.h"
+#include "tosa_generated.h"
+#include <iostream>
+
+#define DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) template class TosaReference::OP<RANK, DType_##DTYPE>;
+
+#define DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, RANK, DTYPE1, DTYPE2) \
+ template class TosaReference::OP<RANK, DType_##DTYPE1, DType_##DTYPE2>;
+
+#define DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, RANK1, RANK2, DTYPE) \
+ template class TosaReference::OP<RANK1, RANK2, DType_##DTYPE>;
+
+#define DEF_INSTANTIATE_TWO_RANK_TWO_TYPE(OP, RANK1, RANK2, DTYPE1, DTYPE2) \
+ template class TosaReference::OP<RANK1, RANK2, DType_##DTYPE1, DType_##DTYPE2>;
+
+#define DEF_INSTANTIATE_ONE_RANK_0_6(OP) \
+ template class TosaReference::OP<0>; \
+ template class TosaReference::OP<1>; \
+ template class TosaReference::OP<2>; \
+ template class TosaReference::OP<3>; \
+ template class TosaReference::OP<4>; \
+ template class TosaReference::OP<5>; \
+ template class TosaReference::OP<6>;
+
+#define DEF_INSTANTIATE_ONE_TYPE(OP, DTYPE) template class TosaReference::OP<DType_##DTYPE>;
+
+#define DEF_INSTANTIATE_TWO_TYPE(OP, DTYPE1, DTYPE2) template class TosaReference::OP<DType_##DTYPE1, DType_##DTYPE2>;
+
+#define DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \
+ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 0, DTYPE) \
+ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 1, DTYPE) \
+ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 2, DTYPE) \
+ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 3, DTYPE) \
+ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 4, DTYPE) \
+ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 5, DTYPE) \
+ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 6, DTYPE)
+
+#define DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \
+ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 1, DTYPE) \
+ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 2, DTYPE) \
+ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 3, DTYPE) \
+ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 4, DTYPE) \
+ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 5, DTYPE) \
+ DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 6, DTYPE)
+
+#define DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OP, DTYPE1, DTYPE2) \
+ DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 0, DTYPE1, DTYPE2) \
+ DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 1, DTYPE1, DTYPE2) \
+ DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 2, DTYPE1, DTYPE2) \
+ DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 3, DTYPE1, DTYPE2) \
+ DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 4, DTYPE1, DTYPE2) \
+ DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 5, DTYPE1, DTYPE2) \
+ DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 6, DTYPE1, DTYPE2)
+
+#define DEF_INSTANTIATE_RESHAPE(OP, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 0, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 1, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 2, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 3, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 4, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 5, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 6, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 0, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 1, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 2, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 3, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 4, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 5, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 6, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 0, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 1, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 2, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 3, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 4, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 5, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 6, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 0, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 1, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 2, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 3, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 4, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 5, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 6, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 0, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 1, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 2, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 3, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 4, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 5, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 6, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 0, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 1, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 2, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 3, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 4, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 5, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 6, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 0, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 1, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 2, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 3, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 4, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 5, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 6, DTYPE)
+
+#define DEF_INSTANTIATE_GATHER(OP, DTYPE) \
+ /* gather op takes input and index rank as template argument */ \
+ /* note output rank = input rank - 1 + index rank */ \
+ /* and max rank allowed in tosa_reference is 6 */ \
+ /* so only specific input and index pair is instantiated */ \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 1, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 2, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 3, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 4, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 5, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 6, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 1, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 2, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 3, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 4, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 5, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 1, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 2, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 3, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 4, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 1, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 2, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 3, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 1, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 2, DTYPE) \
+ DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 1, DTYPE)
+
+#define INIT_ATTRIBUTE(ATTRIBUTE_NAME) \
+ if (auto p = dynamic_cast<Tosa##ATTRIBUTE_NAME##Attribute*>(attribute_)) \
+ { \
+ attribute = new Tosa##ATTRIBUTE_NAME##Attribute(p); \
+ ASSERT_MEM(attribute); \
+ } \
+ else \
+ { \
+ FATAL_ERROR("Can't initialize Tosa" #ATTRIBUTE_NAME "Attribute"); \
+ }
+
+#define INIT_QINFO(QINFO_NAME) \
+ if (auto p = dynamic_cast<Tosa##QINFO_NAME##QuantInfo*>(qinfo_)) \
+ { \
+ qinfo = new Tosa##QINFO_NAME##QuantInfo(p); \
+ ASSERT_MEM(qinfo); \
+ } \
+ else \
+ { \
+ qinfo = nullptr; \
+ }
+
+namespace TosaReference
+{
+
+// Nodes in the graph (e.g., tosa operators) are defined with this base
+// class.
+class GraphNode
+{
+public:
+ GraphNode(const tosa::Op& nodeType, const uint64_t id_);
+ virtual ~GraphNode();
+
+ int addInputName(std::string& name);
+ int addOutputName(std::string& name);
+
+ int addInputTensor(Tensor* tens);
+ int addOutputTensor(Tensor* tens);
+
+ // Validate that the input tensors match properly
+ // in their types, attributes, rank, etc well enough to be
+ // processed.
+ //
+ // This function should be pure virtual (eventually) in order to force
+ // derivative operators to implement the check, but we'll initially
+ // provide a default function so that GraphNode can be instantiated
+ // directly for testing purposes.
+ virtual int checkTensorAttributes();
+
+ // Evalute the node/operator
+ virtual int eval();
+
+ int hasAllInputsReady() const;
+ int hasAllOutputsReady() const;
+
+ int dumpNode(FILE* out);
+ int dumpNode(std::ostream& out);
+
+ int setNodeMarked()
+ {
+ isMarked = true;
+ return 0;
+ }
+
+ int getNodeMarked() const
+ {
+ return isMarked;
+ }
+
+ int clearNodeMarked()
+ {
+ isMarked = false;
+ return 0;
+ }
+
+ int getEvalCount() const
+ {
+ return evalCount;
+ }
+
+ uint64_t getID() const
+ {
+ return nodeId;
+ }
+
+ std::vector<std::string>& getInputNames()
+ {
+ return inputNames;
+ }
+
+ std::vector<std::string>& getOutputNames()
+ {
+ return outputNames;
+ }
+
+ std::vector<Tensor*>& getOutputs()
+ {
+ return outputs;
+ }
+
+ std::vector<Tensor*>& getInputs()
+ {
+ return inputs;
+ }
+
+ int getOnNextNodeList() const
+ {
+ return onNextNodeList;
+ }
+
+ int setOnNextNodeList()
+ {
+ onNextNodeList = true;
+ return 0;
+ }
+
+ int clearOnNextNodeList()
+ {
+ onNextNodeList = false;
+ return 0;
+ }
+
+ tosa::Op getOp() const
+ {
+ return nodeType;
+ }
+
+protected:
+ // Print out a node validation error
+ int printNodeValidationError(const std::string& msg);
+
+ int setRequiredOperands(const int in, const int out)
+ {
+ requiredInputCount = in;
+ requiredOutputCount = out;
+ return 0;
+ }
+
+ int setRequiredRank(const int min, const int max = -1)
+ {
+ if (max == -1)
+ {
+ requiredRankMin = requiredRankMax = min;
+ }
+ else
+ {
+ requiredRankMin = min;
+ requiredRankMax = max;
+ }
+
+ ASSERT_MSG(requiredRankMin <= requiredRankMax,
+ "GraphNode::setRequiredRank: requiredRankMin %d must be <= requiredRankMax %d", requiredRankMin,
+ requiredRankMax);
+
+ return 0;
+ }
+
+ int validateRequiredOperands();
+ int validateRequiredRank(const Tensor* t);
+
+ // Description of the node type (e.g., CONST, CONV2D, etc...)
+ tosa::Op nodeType;
+
+ // A list of input tensor names
+ std::vector<std::string> inputNames;
+
+ // A list of the output tensor names
+ std::vector<std::string> outputNames;
+
+ // A list of the input tensors (after names have been matched up)
+ std::vector<Tensor*> inputs;
+
+ // A list of the output tensors (after names have been matched up)
+ std::vector<Tensor*> outputs;
+
+ // Unique node ID for debugging
+ uint64_t nodeId;
+
+ // Flag used for graph analysis
+ int isMarked;
+
+ // Number of times eval() has been called for this node
+ int evalCount;
+
+ // Flag indicating that this node is ready and is on the
+ // next-node list.
+ int onNextNodeList;
+
+ // Required input/output tensor counts for node validation
+ // -1 means any number is allowed
+ int requiredInputCount;
+ int requiredOutputCount;
+
+ // Required rank ranges for input/output tensors
+ // -1 means n/a
+ int requiredRankMin;
+ int requiredRankMax;
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp
new file mode 100644
index 0000000..ec2fdc9
--- /dev/null
+++ b/reference_model/src/main.cpp
@@ -0,0 +1,295 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <stdio.h>
+
+#include "flatbuffers/idl.h"
+#include "flatbuffers/util.h"
+#include "model_common.h"
+#include "ops/op_factory.h"
+#include "subgraph_traverser.h"
+#include "tosa_serialization_handler.h"
+#include <Eigen/CXX11/Tensor>
+#include <iostream>
+
+using namespace TosaReference;
+using namespace tosa;
+
+// Global instantiation of configuration and debug objects
+func_config_t g_func_config;
+func_debug_t g_func_debug;
+
+int readInputTensors(SubgraphTraverser& gt);
+int writeFinalTensors(SubgraphTraverser& gt);
+int loadGraph(TosaSerializationHandler& tsh);
+
+int main(int argc, const char** argv)
+{
+ // Initialize configuration and debug subsystems
+ func_model_init_config();
+ func_model_set_default_config(&g_func_config);
+ func_init_debug(&g_func_debug, 0);
+ TosaSerializationHandler tsh;
+
+ if (func_model_parse_cmd_line(&g_func_config, &g_func_debug, argc, argv))
+ {
+ return 1;
+ }
+
+ if (loadGraph(tsh))
+ {
+ SIMPLE_FATAL_ERROR("Unable to load graph");
+ }
+
+ // load json first since it's easier debugging
+ SubgraphTraverser main_gt(tsh.GetMainBlock(), &tsh);
+
+ if (main_gt.initializeGraph())
+ {
+ SIMPLE_FATAL_ERROR("Unable to initialize graph traverser: \"main\"");
+ }
+
+ if (main_gt.linkTensorsAndNodes())
+ {
+ SIMPLE_FATAL_ERROR("Failed to link tensors and nodes");
+ }
+
+ if (main_gt.validateGraph())
+ {
+ SIMPLE_FATAL_ERROR("Failed to validate graph");
+ }
+
+ if (g_func_config.validate_only)
+ {
+ goto done;
+ }
+
+ if (readInputTensors(main_gt))
+ {
+ SIMPLE_FATAL_ERROR("Unable to read input tensors");
+ }
+
+ if (g_func_config.eval)
+ {
+
+ if (main_gt.evaluateAll())
+ {
+ SIMPLE_FATAL_ERROR("Error evaluating network. Giving up.");
+ }
+
+ // make sure output tensor is evaluated and show its value
+ int num_output_tensors = main_gt.getNumOutputTensors();
+ bool all_output_valid = true;
+ for (int i = 0; i < num_output_tensors; i++)
+ {
+ const Tensor* ct = main_gt.getOutputTensor(i);
+ ASSERT_MEM(ct);
+ if (!ct->getIsValid())
+ {
+ ct->dumpTensorParams(g_func_debug.func_debug_file);
+ if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
+ {
+ ct->dumpTensor(g_func_debug.func_debug_file);
+ }
+ all_output_valid = false;
+ }
+ }
+ if (!all_output_valid)
+ {
+ main_gt.dumpGraph(g_func_debug.func_debug_file);
+ SIMPLE_FATAL_ERROR(
+ "SubgraphTraverser \"main\" error: Output tensors are not all valid at the end of evaluation.");
+ }
+
+ if (g_func_config.output_tensors)
+ {
+ if (writeFinalTensors(main_gt))
+ {
+ WARNING("Errors encountered in saving output tensors");
+ }
+ }
+ }
+
+done:
+ func_fini_debug(&g_func_debug);
+ func_model_config_cleanup();
+
+ return 0;
+}
+
+int loadGraph(TosaSerializationHandler& tsh)
+{
+ char graph_fullname[1024];
+
+ snprintf(graph_fullname, sizeof(graph_fullname), "%s/%s", g_func_config.subgraph_dir, g_func_config.subgraph_file);
+
+ if (strlen(graph_fullname) <= 2)
+ {
+ func_model_print_help(stderr);
+ SIMPLE_FATAL_ERROR("Missing required argument: Check -Csubgraph_file=");
+ }
+
+ const char JSON_EXT[] = ".json";
+ int is_json = 0;
+ {
+ // look for JSON file extension
+ size_t suffix_len = strlen(JSON_EXT);
+ size_t str_len = strlen(graph_fullname);
+
+ if (str_len > suffix_len && strncasecmp(graph_fullname + (str_len - suffix_len), JSON_EXT, suffix_len) == 0)
+ {
+ is_json = 1;
+ }
+ }
+
+ if (is_json)
+ {
+ if (tsh.LoadFileSchema(g_func_config.operator_fbs))
+ {
+ SIMPLE_FATAL_ERROR(
+ "\nJSON file detected. Unable to load TOSA flatbuffer schema from: %s\nCheck -Coperator_fbs=",
+ g_func_config.operator_fbs);
+ }
+
+ if (tsh.LoadFileJson(graph_fullname))
+ {
+ SIMPLE_FATAL_ERROR("\nError loading JSON graph file: %s\nCheck -Csubgraph_file= and -Csubgraph_dir=",
+ graph_fullname);
+ }
+ }
+ else
+ {
+ if (tsh.LoadFileTosaFlatbuffer(graph_fullname))
+ {
+ SIMPLE_FATAL_ERROR("\nError loading TOSA flatbuffer file: %s\nCheck -Csubgraph_file= and -Csubgraph_dir=",
+ graph_fullname);
+ }
+ }
+
+ return 0;
+}
+
+int readInputTensors(SubgraphTraverser& gt)
+{
+ int tensorCount = gt.getNumInputTensors();
+ Tensor* tensor;
+ char filename[1024];
+
+ // assuming filename doesn't have colons(:)
+ std::map<std::string, std::string> input_tensor_map;
+ std::string raw_str(g_func_config.input_tensor);
+ std::string name, npy;
+ bool last_pair = false;
+
+ std::string::size_type pair_start = 0, pair_end, colons_pos;
+ do
+ {
+ pair_end = raw_str.find(',', pair_start);
+ if (pair_end == std::string::npos)
+ last_pair = true;
+
+ colons_pos = raw_str.find(':', pair_start);
+
+ name = raw_str.substr(pair_start, colons_pos - pair_start);
+ npy = raw_str.substr(colons_pos + 1, pair_end - colons_pos - 1);
+
+ // Empty strings can make it to here
+ if (name.length() == 0 || npy.length() == 0)
+ break;
+
+ input_tensor_map[name] = npy;
+
+ pair_start = pair_end + 1; // skip colons
+ } while (!last_pair);
+
+ if ((size_t)tensorCount != input_tensor_map.size())
+ {
+ WARNING("graph has %lu input placeholders, but %lu initialized", tensorCount, input_tensor_map.size());
+ return 1;
+ }
+
+ for (auto& tensor_pair : input_tensor_map)
+ {
+ tensor = gt.getInputTensorByName(tensor_pair.first);
+ if (!tensor)
+ {
+ WARNING("Unable to find input tensor %s", tensor_pair.first.c_str());
+ return 1;
+ }
+
+ snprintf(filename, sizeof(filename), "%s/%s", g_func_config.input_dir, tensor_pair.second.c_str());
+
+ DEBUG_MED(GT, "Loading input tensor %s from filename: %s", tensor->getName().c_str(), filename);
+
+ if (tensor->allocate())
+ {
+ WARNING("Fail to allocate tensor %s", tensor->getName().c_str());
+ return 1;
+ }
+
+ if (tensor->readFromNpyFile(filename))
+ {
+ WARNING("Unable to read input tensor %s from filename: %s", tensor->getName().c_str(), filename);
+ tensor->dumpTensorParams(g_func_debug.func_debug_file);
+ return 1;
+ }
+
+ // Push ready consumers to the next node list
+ for (auto gn : tensor->getConsumers())
+ {
+ if (gn->hasAllInputsReady() && !gn->getOnNextNodeList())
+ {
+ gt.addToNextNodeList(gn);
+ }
+ }
+ }
+
+ if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
+ {
+ gt.dumpNextNodeList(g_func_debug.func_debug_file);
+ }
+
+ return 0;
+}
+
+int writeFinalTensors(SubgraphTraverser& gt)
+{
+ int tensorCount = gt.getNumOutputTensors();
+ const Tensor* tensor;
+ char filename[1024];
+
+ for (int i = 0; i < tensorCount; i++)
+ {
+ tensor = gt.getOutputTensor(i);
+ if (!tensor)
+ {
+ WARNING("Unable to find output tensor[%d]", i);
+ return 1;
+ }
+
+ snprintf(filename, sizeof(filename), "%s/%s%s.npy", g_func_config.output_dir,
+ g_func_config.output_tensor_prefix, tensor->getName().c_str());
+
+ DEBUG_MED(GT, "Writing output tensor[%d] %s to filename: %s", i, tensor->getName().c_str(), filename);
+
+ if (tensor->writeToNpyFile(filename))
+ {
+ WARNING("Unable to write output tensor[%d] %s to filename: %s", i, tensor->getName().c_str(), filename);
+ return 1;
+ }
+ }
+
+ return 0;
+}
diff --git a/reference_model/src/model_common.h b/reference_model/src/model_common.h
new file mode 100644
index 0000000..d6dab6d
--- /dev/null
+++ b/reference_model/src/model_common.h
@@ -0,0 +1,28 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef MODEL_COMMON_H
+#define MODEL_COMMON_H
+
+#include <iostream>
+#include <stdio.h>
+
+#include "func_config.h"
+#include "func_debug.h"
+
+extern func_config_t g_func_config;
+extern func_debug_t g_func_debug;
+
+#endif
diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc
new file mode 100644
index 0000000..bca9507
--- /dev/null
+++ b/reference_model/src/ops/activation_funcs.cc
@@ -0,0 +1,118 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "activation_funcs.h"
+#include "quant_util.h"
+#include "template_types.h"
+#include <cmath>
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <int Rank, DType Dtype>
+int OpClamp<Rank, Dtype>::register_fcn()
+{
+
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ {
+ InEigenType min = (InEigenType)attribute->min_fp();
+ InEigenType max = (InEigenType)attribute->max_fp();
+ this->fcn = [min, max](InEigenType a) -> OutEigenType { return a <= min ? min : a >= max ? max : a; };
+ }
+ break;
+ case DType_AINT8:
+ case DType_INT16:
+ {
+ InEigenType min = (InEigenType)attribute->min_int();
+ InEigenType max = (InEigenType)attribute->max_int();
+ this->fcn = [min, max](InEigenType a) -> OutEigenType { return a <= min ? min : a >= max ? max : a; };
+ }
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpReluN<Rank, Dtype>::register_fcn()
+{
+
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ {
+ InEigenType N = (InEigenType)attribute->max_fp();
+ this->fcn = [N](InEigenType a) -> OutEigenType { return a >= 0 ? (a <= N ? a : N) : 0; };
+ }
+ break;
+ case DType_INT32:
+ {
+ InEigenType N = (InEigenType)attribute->max_int();
+ this->fcn = [N](InEigenType a) -> OutEigenType { return a >= 0 ? (a <= N ? a : N) : 0; };
+ }
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpSigmoid<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType { return (1.0 / (1.0 + (expf(-1.0 * a)))); };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpTanh<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType { return tanhf(a); };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReluN, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReluN, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FLOAT);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FLOAT);
diff --git a/reference_model/src/ops/activation_funcs.h b/reference_model/src/ops/activation_funcs.h
new file mode 100644
index 0000000..b051b9d
--- /dev/null
+++ b/reference_model/src/ops/activation_funcs.h
@@ -0,0 +1,101 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_ACTIVATION_FUNCS_H
+#define OPS_ACTIVATION_FUNCS_H
+
+#include "ewise_unary.h"
+#include "graph_node.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+template <int Rank, DType Dtype>
+class OpClamp : public UnaryNode<Rank, Dtype>
+{
+public:
+ OpClamp(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : UnaryNode<Rank, Dtype>(Op_CLAMP, id_)
+ {
+ INIT_ATTRIBUTE(Clamp);
+ register_fcn();
+ }
+ static constexpr int32_t QMin = GetQMin<Dtype>::value;
+ static constexpr int32_t QMax = GetQMax<Dtype>::value;
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ virtual int register_fcn();
+
+protected:
+ TosaClampAttribute* attribute;
+};
+
+template <int Rank, DType Dtype>
+class OpReluN : public UnaryNode<Rank, Dtype>
+{
+public:
+ OpReluN(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : UnaryNode<Rank, Dtype>(Op_RELUN, id_)
+ {
+ INIT_ATTRIBUTE(ReluN);
+ register_fcn();
+ }
+ static constexpr int32_t QMin = GetQMin<Dtype>::value;
+ static constexpr int32_t QMax = GetQMax<Dtype>::value;
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ virtual int register_fcn();
+
+protected:
+ TosaReluNAttribute* attribute;
+};
+
+template <int Rank, DType Dtype>
+class OpSigmoid : public UnaryNode<Rank, Dtype>
+{
+public:
+ OpSigmoid(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : UnaryNode<Rank, Dtype>(Op_SIGMOID, id_)
+ {
+ register_fcn();
+ }
+ static constexpr int32_t QMin = GetQMin<Dtype>::value;
+ static constexpr int32_t QMax = GetQMax<Dtype>::value;
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ virtual int register_fcn();
+};
+
+template <int Rank, DType Dtype>
+class OpTanh : public UnaryNode<Rank, Dtype>
+{
+public:
+ OpTanh(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : UnaryNode<Rank, Dtype>(Op_TANH, id_)
+ {
+ register_fcn();
+ }
+ static constexpr int32_t QMin = GetQMin<Dtype>::value;
+ static constexpr int32_t QMax = GetQMax<Dtype>::value;
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ virtual int register_fcn();
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/comparison.cc b/reference_model/src/ops/comparison.cc
new file mode 100644
index 0000000..402e152
--- /dev/null
+++ b/reference_model/src/ops/comparison.cc
@@ -0,0 +1,81 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "comparison.h"
+#include "arith_util.h"
+#include "quant_util.h"
+#include "template_types.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <int Rank, DType Dtype>
+int OpEqual<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a == b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpGreater<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpGreaterEqual<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a >= b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32);
diff --git a/reference_model/src/ops/comparison.h b/reference_model/src/ops/comparison.h
new file mode 100644
index 0000000..e75b1a6
--- /dev/null
+++ b/reference_model/src/ops/comparison.h
@@ -0,0 +1,71 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_COMPARISON_H
+#define OPS_COMPARISON_H
+
+#include "ewise_binary.h"
+#include "graph_node.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+template <int Rank, DType Dtype>
+class OpEqual : public BinaryNode<Rank, Dtype, DType_BOOL>
+{
+public:
+ OpEqual(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : BinaryNode<Rank, Dtype, DType_BOOL>(Op_EQUAL, qinfo_, id_)
+ {
+ register_fcn();
+ }
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<DType_BOOL>::type;
+ virtual int register_fcn();
+};
+
+template <int Rank, DType Dtype>
+class OpGreater : public BinaryNode<Rank, Dtype, DType_BOOL>
+{
+public:
+ OpGreater(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : BinaryNode<Rank, Dtype, DType_BOOL>(Op_GREATER, qinfo_, id_)
+ {
+ register_fcn();
+ }
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<DType_BOOL>::type;
+ virtual int register_fcn();
+};
+
+template <int Rank, DType Dtype>
+class OpGreaterEqual : public BinaryNode<Rank, Dtype, DType_BOOL>
+{
+public:
+ OpGreaterEqual(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : BinaryNode<Rank, Dtype, DType_BOOL>(Op_EQUAL, qinfo_, id_)
+ {
+ register_fcn();
+ }
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<DType_BOOL>::type;
+ virtual int register_fcn();
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc
new file mode 100644
index 0000000..9d5db40
--- /dev/null
+++ b/reference_model/src/ops/control_flow.cc
@@ -0,0 +1,353 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "control_flow.h"
+#include "subgraph_traverser.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+OpControlFlow::OpControlFlow(TosaSerializationHandler* tsh_, Op op_, uint64_t id_)
+ : GraphNode(op_, id_)
+{
+ tsh = tsh_;
+}
+
+OpControlFlow::~OpControlFlow()
+{}
+
+int OpControlFlow::evalBlock(TosaSerializationBasicBlock* block,
+ std::vector<TosaReference::Tensor*>& block_inputs,
+ std::vector<TosaReference::Tensor*>& block_outputs)
+{
+ std::string block_name = block->GetName();
+
+ DEBUG_MED(OP, "Evaluating block %s", block_name.c_str());
+
+ SubgraphTraverser gt(block, tsh);
+
+ if (gt.initializeGraph())
+ {
+ FATAL_ERROR("Unable to initialize graph traverser for block %s", block_name.c_str());
+ }
+
+ if (gt.linkTensorsAndNodes())
+ {
+ FATAL_ERROR("Failed to link tensors and nodes for block %s", block_name.c_str());
+ }
+
+ if (gt.validateGraph())
+ {
+ FATAL_ERROR("Failed to validate subgraph for block %s", block_name.c_str());
+ }
+
+ int num_input_tensors = gt.getNumInputTensors();
+ int num_output_tensors = gt.getNumOutputTensors();
+
+ for (size_t i = 0; i < block_inputs.size(); i++)
+ {
+ DEBUG_HIGH(OP, "Input[%ld]: %s", i, block_inputs[i]->getName().c_str());
+ }
+ for (size_t i = 0; i < block_outputs.size(); i++)
+ {
+ DEBUG_HIGH(OP, "Output[%ld]: %s", i, block_outputs[i]->getName().c_str());
+ }
+
+ ASSERT_MSG((size_t)num_input_tensors == block_inputs.size(),
+ "op block %s inputs[%lu] does not match with graph traverser's inputs[%d]", block_name.c_str(),
+ block_inputs.size(), num_input_tensors);
+ ASSERT_MSG((size_t)num_output_tensors == block_outputs.size(),
+ "op block %s outputs[%lu] does not match with graph traverser's outputs[%d]", block_name.c_str(),
+ block_outputs.size(), num_output_tensors);
+
+ // set graph traverser's input = basic block's input
+ for (int i = 0; i < num_input_tensors; i++)
+ {
+ TosaReference::Tensor* tensor = gt.getInputTensor(i);
+ ASSERT_MSG(!tensor->is_allocated(), "block %s input tensors are unexpectedly initialized before",
+ block_name.c_str());
+
+ if (tensor->allocate())
+ {
+ WARNING("Fail to allocate tensor %s", tensor->getName().c_str());
+ return 1;
+ }
+
+ if (tensor->copyValueFrom(block_inputs[i]))
+ {
+ WARNING("Fail to copy tensor value %s -> %s", block_inputs[i]->getName().c_str(),
+ tensor->getName().c_str());
+ return 1;
+ }
+
+ // Push ready consumers to the next node list
+ for (auto gn : tensor->getConsumers())
+ {
+ if (gn->hasAllInputsReady() && !gn->getOnNextNodeList())
+ {
+ gt.addToNextNodeList(gn);
+ }
+ }
+ }
+
+ if (gt.evaluateAll())
+ {
+ FATAL_ERROR("Error evaluating network. Giving up.");
+ }
+
+ // make sure output tensor is evaluated and show its value
+ bool all_output_valid = true;
+ for (int i = 0; i < num_output_tensors; i++)
+ {
+ const TosaReference::Tensor* ct = gt.getOutputTensor(i);
+ ASSERT_MEM(ct);
+ if (!ct->getIsValid())
+ {
+ ct->dumpTensorParams(g_func_debug.func_debug_file);
+ if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
+ {
+ ct->dumpTensor(g_func_debug.func_debug_file);
+ }
+ all_output_valid = false;
+ }
+ }
+ if (!all_output_valid)
+ {
+ gt.dumpGraph(g_func_debug.func_debug_file);
+ FATAL_ERROR("SubgraphTraverser \"%s\" error: Output tensors are not all valid at the end of evaluation.",
+ block_name.c_str());
+ }
+
+ // set basic block's output = subgraph_traverser's output
+ for (int i = 0; i < num_output_tensors; i++)
+ {
+ TosaReference::Tensor* tensor = gt.getOutputTensor(i);
+ ASSERT_MSG(tensor->is_allocated(), "tensor %s is not allocated", tensor->getName().c_str());
+
+ if (block_outputs[i]->copyValueFrom(tensor))
+ {
+ WARNING("Fail to copy tensor value %s -> %s", tensor->getName().c_str(), outputs[i]->getName().c_str());
+ return 1;
+ }
+ }
+ return 0;
+}
+
+OpCondIf::OpCondIf(TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_)
+ : OpControlFlow(tsh_, Op_COND_IF, id_)
+{
+ INIT_ATTRIBUTE(CondIf);
+}
+
+OpCondIf::~OpCondIf()
+{
+ if (attribute)
+ delete attribute;
+}
+
+int OpCondIf::checkTensorAttributes()
+{
+ if (getInputs().size() < 1)
+ {
+ WARNING("OpCondIf: must have at least 1 operand");
+ return 1;
+ }
+
+ if (inputs[0]->getDtype() != DType_BOOL || inputs[0]->getRank() != 0)
+ {
+ WARNING("OpCondIf: invalid tensor dtype=%s, rank=%d", EnumNamesDType()[inputs[0]->getDtype()],
+ inputs[0]->getRank());
+ return 1;
+ }
+
+ cond = dynamic_cast<TosaReference::Tensor0<bool>*>(inputs[0]);
+ ASSERT_MEM(cond);
+
+ then_block = tsh->GetBlockByName(attribute->then_branch());
+ else_block = tsh->GetBlockByName(attribute->else_branch());
+
+ if (!then_block)
+ {
+ WARNING("OpCondIf: fail to resolve then_branch %s", attribute->then_branch().c_str());
+ return 1;
+ }
+
+ if (!else_block)
+ {
+ WARNING("OpCondIf: fail to resolve else_branch %s", attribute->else_branch().c_str());
+ return 1;
+ }
+
+ return 0;
+}
+
+int OpCondIf::eval()
+{
+ bool cond_val = cond->getTensor()(0);
+ std::vector<TosaReference::Tensor*> block_inputs(getInputs().begin() + 1, getInputs().end());
+
+ if (cond_val)
+ {
+ if (evalBlock(then_block, block_inputs, getOutputs()))
+ {
+ WARNING("OpCondIf: Fail to evaluate then branch block %s", attribute->then_branch().c_str());
+ return 1;
+ }
+ }
+ else
+ {
+ if (evalBlock(else_block, block_inputs, getOutputs()))
+ {
+ WARNING("OpCondIf: Fail to evaluate else branch block %s", attribute->else_branch().c_str());
+ return 1;
+ }
+ }
+
+ return GraphNode::eval();
+}
+
+OpWhileLoop::OpWhileLoop(TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_)
+ : OpControlFlow(tsh_, Op_WHILE_LOOP, id_)
+{
+ INIT_ATTRIBUTE(WhileLoop);
+}
+
+OpWhileLoop::~OpWhileLoop()
+{
+ if (attribute)
+ delete attribute;
+}
+
+int OpWhileLoop::checkTensorAttributes()
+{
+ if (getInputs().size() <= 0)
+ {
+ WARNING("OpWhileLoop: must have at least 1 operands");
+ return 1;
+ }
+
+ if (getInputs().size() != getOutputs().size())
+ {
+ WARNING("OpWhileLoop: inputs and outputs size must match");
+ return 1;
+ }
+
+ cond_block = tsh->GetBlockByName(attribute->cond_branch());
+ body_block = tsh->GetBlockByName(attribute->body_branch());
+
+ if (!cond_block)
+ {
+ WARNING("OpWhileLoop: fail to resolve cond_branch %s", attribute->cond_branch().c_str());
+ return 1;
+ }
+
+ if (!body_block)
+ {
+ WARNING("OpWhileLoop: fail to resolve body_branch %s", attribute->body_branch().c_str());
+ return 1;
+ }
+
+ if (cond_block->GetOutputs().size() != 1)
+ {
+ WARNING("OpWhileLoop: invalid cond_block output size %lu", cond_block->GetOutputs().size());
+ return 1;
+ }
+
+ TosaSerializationTensor* cond_output_tensor = cond_block->GetTensorByName(cond_block->GetOutputs()[0]);
+
+ if (!cond_output_tensor)
+ {
+ WARNING("OpWhileLoop: fail to resolve cond_block's output tensor %s", cond_block->GetOutputs()[0].c_str());
+ return 1;
+ }
+
+ if (cond_output_tensor->GetDtype() != DType_BOOL)
+ {
+ WARNING("OpWhileLoop: invalid cond_block's output tensor data type %s",
+ EnumNamesDType()[cond_output_tensor->GetDtype()]);
+ return 1;
+ }
+ if (cond_output_tensor->GetShape().size() != 0)
+ {
+ WARNING("OpWhileLoop: invalid cond_block's output rank %lu", cond_output_tensor->GetShape().size());
+ return 1;
+ }
+
+ return 0;
+}
+
+int OpWhileLoop::eval()
+{
+
+ TosaReference::Tensor0<bool> cond_output_ctensor(
+ std::string("cond_output"), DType_BOOL, std::vector<Usage>({ Usage_ACTIVATION }),
+ std::vector<Format>({ Format_UNKNOWN }), std::vector<int32_t>({}), false);
+
+ cond_output_ctensor.allocate();
+ std::vector<TosaReference::Tensor*> cond_block_outputs;
+ cond_block_outputs.push_back(&cond_output_ctensor);
+
+ size_t num_input_output = getInputs().size();
+ size_t eval_count = 0;
+
+ while (eval_count++ < MAX_WHILE_LOOP_ITERATION)
+ {
+ if (evalBlock(cond_block, getInputs(), cond_block_outputs))
+ {
+ WARNING("OpWhileLoop: Fail to evaluate cond block %s", attribute->cond_branch().c_str());
+ return 1;
+ }
+ bool cond_val = cond_output_ctensor.getTensor()(0);
+ DEBUG_HIGH(OP, "Conditional block value: %d", cond_val);
+
+ if (cond_val)
+ {
+ if (evalBlock(body_block, getInputs(), getOutputs()))
+ {
+ WARNING("OpWhileLoop: Fail to evaluate body block %s", attribute->body_branch().c_str());
+ return 1;
+ }
+
+ // assigning output tensors value back to input tensors value for next iteration
+ for (size_t i = 0; i < num_input_output; i++)
+ {
+ if (getInputs()[i]->copyValueFrom(getOutputs()[i]))
+ {
+ WARNING("Fail to copy tensor value %s -> %s", getOutputs()[i]->getName().c_str(),
+ getInputs()[i]->getName().c_str());
+ return 1;
+ }
+ }
+ }
+ else
+ {
+ // in last iteration or the case it never evaluates body block
+ // assign input tensors value to output tensors
+ for (size_t i = 0; i < num_input_output; i++)
+ {
+ if (getOutputs()[i]->copyValueFrom(getInputs()[i]))
+ {
+ WARNING("Fail to copy tensor value %s -> %s", getInputs()[i]->getName().c_str(),
+ getOutputs()[i]->getName().c_str());
+ return 1;
+ }
+ }
+ break;
+ }
+ }
+
+ return GraphNode::eval();
+}
diff --git a/reference_model/src/ops/control_flow.h b/reference_model/src/ops/control_flow.h
new file mode 100644
index 0000000..14c11bc
--- /dev/null
+++ b/reference_model/src/ops/control_flow.h
@@ -0,0 +1,72 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_CONTROL_FLOW_H
+#define OPS_CONTROL_FLOW_H
+
+#include "graph_node.h"
+
+#define MAX_WHILE_LOOP_ITERATION 10000
+
+namespace TosaReference
+{
+class OpControlFlow : public GraphNode
+{
+public:
+ OpControlFlow(TosaSerializationHandler* tsh_, Op op_, uint64_t id_);
+ ~OpControlFlow();
+
+ virtual int evalBlock(TosaSerializationBasicBlock* block,
+ std::vector<TosaReference::Tensor*>& block_inputs,
+ std::vector<TosaReference::Tensor*>& block_outputs);
+
+protected:
+ TosaSerializationHandler* tsh;
+};
+
+class OpCondIf : public OpControlFlow
+{
+public:
+ OpCondIf(TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_);
+ virtual ~OpCondIf();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+protected:
+ TosaCondIfAttribute* attribute;
+ TosaReference::Tensor0<bool>* cond;
+ TosaSerializationBasicBlock* then_block;
+ TosaSerializationBasicBlock* else_block;
+};
+
+class OpWhileLoop : public OpControlFlow
+{
+public:
+ OpWhileLoop(TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_);
+ virtual ~OpWhileLoop();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+protected:
+ TosaWhileLoopAttribute* attribute;
+ TosaSerializationBasicBlock* cond_block;
+ TosaSerializationBasicBlock* body_block;
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/custom.cc b/reference_model/src/ops/custom.cc
new file mode 100644
index 0000000..5c4f29b
--- /dev/null
+++ b/reference_model/src/ops/custom.cc
@@ -0,0 +1,40 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "custom.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+OpCustom::OpCustom(uint64_t id_)
+ : GraphNode(Op_CUSTOM, id_)
+{}
+
+OpCustom::~OpCustom()
+{}
+
+int OpCustom::checkTensorAttributes()
+{
+ return 0;
+}
+
+int OpCustom::eval()
+{
+ FATAL_ERROR_NODE("not supported yet");
+
+ // Evaluation is trivial for constants
+ return GraphNode::eval();
+}
diff --git a/reference_model/src/ops/custom.h b/reference_model/src/ops/custom.h
new file mode 100644
index 0000000..b1085a5
--- /dev/null
+++ b/reference_model/src/ops/custom.h
@@ -0,0 +1,38 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_CUSTOM_H
+#define OPS_CUSTOM_H
+
+#include "graph_node.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+class OpCustom : public GraphNode
+{
+public:
+ OpCustom(uint64_t id_);
+ virtual ~OpCustom();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc
new file mode 100644
index 0000000..32029b9
--- /dev/null
+++ b/reference_model/src/ops/data_layout.cc
@@ -0,0 +1,644 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "data_layout.h"
+#include "quant_util.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <int Rank, DType Dtype>
+OpConcat<Rank, Dtype>::OpConcat(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_CONCAT, id_)
+{
+ setRequiredOperands(2, 1);
+ setRequiredRank(1, 6);
+
+ INIT_ATTRIBUTE(Axis);
+}
+
+template <int Rank, DType Dtype>
+OpConcat<Rank, Dtype>::~OpConcat()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <int Rank, DType Dtype>
+int OpConcat<Rank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same types and rank
+ // inputs[0] and inputs[1] should also match type and rank
+ if (inputs[0]->matchRankType(*outputs[0]) || inputs[1]->matchRankType(*outputs[0]))
+ {
+ printNodeValidationError("Concat operator input ranks and types must match");
+ return 1;
+ }
+
+ lhs = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ rhs = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ if (attribute->axis() < 0 || (size_t)attribute->axis() >= rhs->getShape().size())
+ {
+ printNodeValidationError("Axis is beyond input tensor rank");
+ return 1;
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpConcat<Rank, Dtype>::eval()
+{
+
+ int32_t reversed_axis = Rank - 1 - attribute->axis();
+
+ for (int32_t d = 0; d < Rank; d++)
+ {
+ reverser[d] = Rank - 1 - d;
+ }
+
+ TIn lhs_reversed = lhs->getTensor().shuffle(reverser);
+ TIn rhs_reversed = rhs->getTensor().shuffle(reverser);
+
+ TIn reversed_result = lhs_reversed.concatenate(rhs_reversed, reversed_axis);
+ out->getTensor() = reversed_result.shuffle(reverser);
+ // out->getTensor() = lhs->getTensor().concatenate(rhs->getTensor(), axis);
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+OpPad<Rank, Dtype>::OpPad(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_PAD, id_)
+{
+ setRequiredOperands(2, 1);
+ setRequiredRank(0, 6);
+
+ INIT_QINFO(Pad);
+}
+
+template <int Rank, DType Dtype>
+OpPad<Rank, Dtype>::~OpPad()
+{
+ if (qinfo)
+ delete qinfo;
+}
+
+template <int Rank, DType Dtype>
+int OpPad<Rank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same types
+ if (inputs[0]->matchRankType(*outputs[0]))
+ {
+ printNodeValidationError("Failure to match input and output type and rank");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+ TosaReference::TensorTemplate<ETensor2<int32_t>>* paddings =
+ dynamic_cast<TosaReference::TensorTemplate<ETensor2<int32_t>>*>(inputs[1]);
+
+ for (int i = 0; i < Rank; i++)
+ {
+ paddings_array[i] = std::make_pair(paddings->getTensor()(i, 0), paddings->getTensor()(i, 1));
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpPad<Rank, Dtype>::eval()
+{
+ InEigenType pad_value = 0;
+ if (this->qinfo)
+ {
+ pad_value = (InEigenType)this->qinfo->input_zp();
+ }
+
+ this->out->getTensor() = this->in->getTensor().pad(this->paddings_array, pad_value);
+
+ return GraphNode::eval();
+}
+
+template <int InRank, int OutRank, DType Dtype>
+OpReshape<InRank, OutRank, Dtype>::OpReshape(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_RESHAPE, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(0, 6);
+
+ INIT_ATTRIBUTE(Reshape);
+}
+
+template <int InRank, int OutRank, DType Dtype>
+OpReshape<InRank, OutRank, Dtype>::~OpReshape()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <int InRank, int OutRank, DType Dtype>
+int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes()
+{
+ uint32_t minusOneCount = 0;
+
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same types
+ if (inputs[0]->matchType(*outputs[0]))
+ {
+ printNodeValidationError("OpReshape: Input and output types must match");
+ return 1;
+ }
+
+ for (uint32_t d = 0; d < OutRank; d++)
+ {
+ if (attribute->shape()[d] == -1)
+ {
+ minusOneCount++;
+ }
+ }
+
+ if (minusOneCount > 1)
+ {
+ printNodeValidationError("OpReshape: new shape has more than one -1 dimension");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ return 0;
+}
+
+template <int InRank, int OutRank, DType Dtype>
+int OpReshape<InRank, OutRank, Dtype>::eval()
+{
+ uint32_t remainingSize = in->getElementCount();
+
+ // If there is a -1 dimension, find the remainder in one pass over the output shape
+ for (int32_t d = 0; d < OutRank; d++)
+ {
+ if (attribute->shape()[d] != -1)
+ {
+ remainingSize = remainingSize / attribute->shape()[d];
+ }
+ }
+
+ for (int32_t d = 0; d < OutRank; d++)
+ {
+ array_shape[d] = attribute->shape()[OutRank - 1 - d];
+ out_reverser[d] = OutRank - 1 - d;
+
+ // Jam in the remainder here
+ if (array_shape[d] == -1)
+ {
+ array_shape[d] = remainingSize;
+ }
+ }
+
+ for (int32_t d = 0; d < InRank; d++)
+ {
+ in_reverser[d] = InRank - 1 - d;
+ }
+
+ // Eigen Tensor is col-major, and we're referencing row-major result
+ // need to reverse it to row-major before reshape, and perform another reverse afterward
+
+ // input tensor rank 0 can't do .shuffle(), need to be handled otherwise
+ TIn in_reversed;
+ if (InRank > 1)
+ {
+ in_reversed = in->getTensor().shuffle(in_reverser);
+ }
+ else
+ {
+ in_reversed = in->getTensor();
+ }
+
+ TOut in_reshaped = in_reversed.reshape(array_shape);
+
+ // output tensor can be rank 0, .reshape() and .shuffle() don't work, need to be handled otherwise
+ if (OutRank > 1)
+ {
+ out->getTensor() = in_reshaped.shuffle(out_reverser);
+ }
+ else
+ {
+ out->getTensor() = in_reshaped;
+ }
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+OpReverse<Rank, Dtype>::OpReverse(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_REVERSE, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(1, 6);
+
+ INIT_ATTRIBUTE(Axis);
+}
+
+template <int Rank, DType Dtype>
+OpReverse<Rank, Dtype>::~OpReverse()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <int Rank, DType Dtype>
+int OpReverse<Rank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same types
+ if (inputs[0]->matchRankTypeShape(*outputs[0]))
+ {
+ printNodeValidationError("Failure to match input and output rank/type/shape");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ ASSERT_MEM(in && out);
+
+ if (attribute->axis() < 0 || attribute->axis() >= inputs[0]->getRank())
+ {
+ printNodeValidationError("Reverse axis must between [0, input_rank - 1]");
+ return 1;
+ }
+
+ // transform list of axis into true or false list
+ // e.g. rank=4, axis=[1,2], reverse array would be [false, true, true, false]
+ for (int i = 0; i < Rank; i++)
+ {
+ reverse_array[i] = false;
+ }
+ reverse_array[attribute->axis()] = true;
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpReverse<Rank, Dtype>::eval()
+{
+ out->getTensor() = in->getTensor().reverse(reverse_array);
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+OpSlice<Rank, Dtype>::OpSlice(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_SLICE, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(0, 6);
+
+ INIT_ATTRIBUTE(Slice);
+}
+
+template <int Rank, DType Dtype>
+OpSlice<Rank, Dtype>::~OpSlice()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <int Rank, DType Dtype>
+int OpSlice<Rank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same types
+ if (inputs[0]->matchType(*outputs[0]))
+ {
+ printNodeValidationError("Failure to match input and output type");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ for (size_t i = 0; i < attribute->begin().size(); i++)
+ {
+ begin_array[i] = attribute->begin()[i];
+ }
+
+ for (size_t i = 0; i < attribute->size().size(); i++)
+ {
+ if (attribute->size()[i] != 0)
+ {
+ size_array[i] = attribute->size()[i];
+ }
+ else
+ {
+ // Tensorflow assigns a zero size to dimensions that are kept
+ // Eigen expects size to be the full size of the dimension
+ size_array[i] = in->getTensor().dimension(0);
+ }
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpSlice<Rank, Dtype>::eval()
+{
+ out->getTensor() = in->getTensor().slice(begin_array, size_array);
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+OpTileBase<Rank, Dtype>::OpTileBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_TILE, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(0, 6);
+
+ INIT_ATTRIBUTE(Tile);
+}
+
+template <int Rank, DType Dtype>
+OpTileBase<Rank, Dtype>::~OpTileBase()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <int Rank, DType Dtype>
+int OpTileBase<Rank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same ranks and types
+ if (inputs[0]->matchRankType(*outputs[0]))
+ {
+ printNodeValidationError("Failure to match input and output rank or type");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ if (attribute->multiples().size() != Rank)
+ {
+ printNodeValidationError("1D list 'multiples' must have size equal to input rank");
+ return 1;
+ }
+
+ for (int32_t d = 0; d < Rank; d++)
+ {
+ if (in->getShape()[d] * attribute->multiples()[d] != out->getShape()[d])
+ {
+ printNodeValidationError("unexpected output shape");
+ return 1;
+ }
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpTile<Rank, Dtype>::eval()
+{
+ // primary template shouldn't be called
+ FATAL_ERROR_NODE("OpTile rank=%i, dtype=%s: not implemented yet", Rank, EnumNamesDType()[Dtype]);
+}
+
+template <DType Dtype>
+int OpTile<1, Dtype>::eval()
+{
+ for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
+ {
+ int32_t id0 = od0 % this->in->getShape()[0];
+ this->out->getTensor()(od0) = this->in->getTensor()(id0);
+ }
+
+ return GraphNode::eval();
+}
+
+template <DType Dtype>
+int OpTile<2, Dtype>::eval()
+{
+ for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
+ {
+ int32_t id0 = od0 % this->in->getShape()[0];
+ for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
+ {
+ int32_t id1 = od1 % this->in->getShape()[1];
+ this->out->getTensor()(od0, od1) = this->in->getTensor()(id0, id1);
+ }
+ }
+
+ return GraphNode::eval();
+}
+
+template <DType Dtype>
+int OpTile<3, Dtype>::eval()
+{
+ for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
+ {
+ int32_t id0 = od0 % this->in->getShape()[0];
+ for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
+ {
+ int32_t id1 = od1 % this->in->getShape()[1];
+ for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
+ {
+ int32_t id2 = od2 % this->in->getShape()[2];
+ this->out->getTensor()(od0, od1, od2) = this->in->getTensor()(id0, id1, id2);
+ }
+ }
+ }
+
+ return GraphNode::eval();
+}
+
+template <DType Dtype>
+int OpTile<4, Dtype>::eval()
+{
+ for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++)
+ {
+ int32_t id0 = od0 % this->in->getShape()[0];
+ for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++)
+ {
+ int32_t id1 = od1 % this->in->getShape()[1];
+ for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++)
+ {
+ int32_t id2 = od2 % this->in->getShape()[2];
+ for (int32_t od3 = 0; od3 < this->out->getShape()[3]; od3++)
+ {
+ int32_t id3 = od3 % this->in->getShape()[3];
+ this->out->getTensor()(od0, od1, od2, od3) = this->in->getTensor()(id0, id1, id2, id3);
+ }
+ }
+ }
+ }
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+OpTranspose<Rank, Dtype>::OpTranspose(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_TRANSPOSE, id_)
+{
+ setRequiredOperands(2, 1);
+ setRequiredRank(0, 6);
+}
+
+template <int Rank, DType Dtype>
+OpTranspose<Rank, Dtype>::~OpTranspose()
+{}
+
+template <int Rank, DType Dtype>
+int OpTranspose<Rank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same types
+ if (inputs[0]->matchRankType(*outputs[0]))
+ {
+ printNodeValidationError("Failure to match input and output rank and type");
+ return 1;
+ }
+
+ if (inputs[0]->getElementCount() != outputs[0]->getElementCount())
+ {
+ printNodeValidationError("Failure to match input and output total element count");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+ perm_tensor = dynamic_cast<TosaReference::TensorTemplate<ETensor1<int32_t>>*>(inputs[1]);
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpTranspose<Rank, Dtype>::eval()
+{
+ for (int32_t d = 0; d < Rank; d++)
+ {
+ perm_array[d] = this->perm_tensor->getTensor().data()[d];
+ }
+
+ out->getTensor() = in->getTensor().shuffle(perm_array);
+
+ return GraphNode::eval();
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FLOAT)
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, AINT8)
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8)
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16)
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32)
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL)
+
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FLOAT);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, AINT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
+
+DEF_INSTANTIATE_RESHAPE(OpReshape, FLOAT);
+DEF_INSTANTIATE_RESHAPE(OpReshape, AINT8);
+DEF_INSTANTIATE_RESHAPE(OpReshape, INT8);
+DEF_INSTANTIATE_RESHAPE(OpReshape, INT16);
+DEF_INSTANTIATE_RESHAPE(OpReshape, INT32);
+DEF_INSTANTIATE_RESHAPE(OpReshape, BOOL);
+
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FLOAT);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, AINT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
diff --git a/reference_model/src/ops/data_layout.h b/reference_model/src/ops/data_layout.h
new file mode 100644
index 0000000..100bd6b
--- /dev/null
+++ b/reference_model/src/ops/data_layout.h
@@ -0,0 +1,216 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_DATA_LAYOUT_H
+#define OPS_DATA_LAYOUT_H
+
+#include "graph_node.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+template <int Rank, DType Dtype>
+class OpConcat : public GraphNode
+{
+public:
+ OpConcat(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpConcat();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ Eigen::array<int, Rank> reverser;
+ TosaReference::TensorTemplate<TIn>* lhs;
+ TosaReference::TensorTemplate<TIn>* rhs;
+ TosaAxisAttribute* attribute;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+
+template <int Rank, DType Dtype>
+class OpPad : public GraphNode
+{
+public:
+ OpPad(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpPad();
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ Eigen::array<std::pair<ptrdiff_t, ptrdiff_t>, Rank> paddings_array;
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TOut>* out;
+ TosaPadQuantInfo* qinfo;
+};
+
+template <int InRank, int OutRank, DType Dtype>
+class OpReshape : public GraphNode
+{
+public:
+ OpReshape(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpReshape();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, InRank>;
+ using TOut = Eigen::Tensor<OutEigenType, OutRank>;
+
+protected:
+ Eigen::array<Eigen::Index, OutRank> array_shape;
+ Eigen::array<Eigen::Index, InRank> in_reverser;
+ Eigen::array<Eigen::Index, OutRank> out_reverser;
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReshapeAttribute* attribute;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+
+template <int Rank, DType Dtype>
+class OpReverse : public GraphNode
+{
+public:
+ OpReverse(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpReverse();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ TosaAxisAttribute* attribute;
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TOut>* out;
+ Eigen::array<bool, Rank> reverse_array;
+};
+
+template <int Rank, DType Dtype>
+class OpSlice : public GraphNode
+{
+public:
+ OpSlice(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpSlice();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ TosaSliceAttribute* attribute;
+ Eigen::array<Eigen::Index, Rank> begin_array;
+ Eigen::array<Eigen::Index, Rank> size_array;
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+
+template <int Rank, DType Dtype>
+class OpTileBase : public GraphNode
+{
+public:
+ OpTileBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpTileBase();
+
+ virtual int checkTensorAttributes();
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ TosaTileAttribute* attribute;
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+
+// primary template for op tile
+template <int Rank, DType Dtype>
+class OpTile : public OpTileBase<Rank, Dtype>
+{
+public:
+ OpTile(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : OpTileBase<Rank, Dtype>(attribute_, qinfo_, id_)
+ {}
+
+protected:
+ virtual int eval();
+};
+
+// partial specialization for specific rank
+#define DEF_OP_TILE_RANK(N) \
+ template <DType Dtype> \
+ class OpTile<N, Dtype> : public OpTileBase<N, Dtype> \
+ { \
+ public: \
+ OpTile(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \
+ : OpTileBase<N, Dtype>(attribute_, qinfo_, id_) \
+ {} \
+ \
+ protected: \
+ virtual int eval(); \
+ };
+
+DEF_OP_TILE_RANK(1)
+DEF_OP_TILE_RANK(2)
+DEF_OP_TILE_RANK(3)
+DEF_OP_TILE_RANK(4)
+
+#undef DEF_OP_TILE_RANK
+
+template <int Rank, DType Dtype>
+class OpTranspose : public GraphNode
+{
+public:
+ OpTranspose(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpTranspose();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ Eigen::array<int, Rank> perm_array;
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<ETensor1<int32_t>>* perm_tensor;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/data_nodes.cc b/reference_model/src/ops/data_nodes.cc
new file mode 100644
index 0000000..2ee4935
--- /dev/null
+++ b/reference_model/src/ops/data_nodes.cc
@@ -0,0 +1,172 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "data_nodes.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+OpConst::OpConst(uint64_t id_)
+ : GraphNode(Op_CONST, id_)
+{
+ setRequiredOperands(0, 1);
+}
+
+OpConst::~OpConst()
+{}
+
+int OpConst::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ return 0;
+}
+
+int OpConst::eval()
+{
+ // Evaluation is trivial for constants
+ return GraphNode::eval();
+}
+
+OpPlaceholder::OpPlaceholder(uint64_t id_)
+ : GraphNode(Op_PLACEHOLDER, id_)
+{
+ setRequiredOperands(0, 1);
+}
+
+OpPlaceholder::~OpPlaceholder()
+{}
+
+int OpPlaceholder::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ return 0;
+}
+
+int OpPlaceholder::eval()
+{
+ // Evaluation is trivial for placeholders
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+OpIdentity<Rank, Dtype>::OpIdentity(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_IDENTITY, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(0, 6);
+}
+
+template <int Rank, DType Dtype>
+OpIdentity<Rank, Dtype>::~OpIdentity()
+{}
+
+template <int Rank, DType Dtype>
+int OpIdentity<Rank, Dtype>::checkTensorAttributes()
+{
+
+ if (inputs.size() != outputs.size())
+ {
+ printNodeValidationError("Input and output tensor list lengths are not equal");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ if (in->matchRankTypeShape(*out))
+ {
+ printNodeValidationError("Input and output tensor rank, type, or shape do not match");
+ return 1;
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpIdentity<Rank, Dtype>::eval()
+{
+ out->getTensor() = in->getTensor();
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+OpIdentityN<Rank, Dtype>::OpIdentityN(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_IDENTITYN, id_)
+{
+ setRequiredRank(0, 6);
+}
+
+template <int Rank, DType Dtype>
+OpIdentityN<Rank, Dtype>::~OpIdentityN()
+{}
+
+template <int Rank, DType Dtype>
+int OpIdentityN<Rank, Dtype>::checkTensorAttributes()
+{
+
+ if (inputs.size() != outputs.size())
+ {
+ printNodeValidationError("Input and output tensor list lengths are not equal");
+ return 1;
+ }
+
+ for (size_t i = 0; i < inputs.size(); i++)
+ {
+ ins.push_back(dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[i]));
+ outs.push_back(dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[i]));
+
+ if (ins[i]->matchRankTypeShape(*outs[i]))
+ {
+ printNodeValidationError("Input and output tensor rank, type, or shape do not match");
+ return 1;
+ }
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpIdentityN<Rank, Dtype>::eval()
+{
+ for (size_t i = 0; i < ins.size(); i++)
+ {
+ outs[i]->getTensor() = ins[i]->getTensor();
+ }
+
+ return GraphNode::eval();
+}
+
+// template explicit instantiation
+// note OpConst and OpPlaceholder are not templated
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, BOOL);
diff --git a/reference_model/src/ops/data_nodes.h b/reference_model/src/ops/data_nodes.h
new file mode 100644
index 0000000..bec4669
--- /dev/null
+++ b/reference_model/src/ops/data_nodes.h
@@ -0,0 +1,86 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_DATA_NODES_H
+#define OPS_DATA_NODES_H
+
+#include "graph_node.h"
+
+namespace TosaReference
+{
+
+class OpConst : public GraphNode
+{
+public:
+ OpConst(uint64_t id_);
+ virtual ~OpConst();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+};
+
+class OpPlaceholder : public GraphNode
+{
+public:
+ OpPlaceholder(uint64_t id_);
+ virtual ~OpPlaceholder();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+};
+
+template <int Rank, DType Dtype>
+class OpIdentity : public GraphNode
+{
+public:
+ OpIdentity(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpIdentity();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+
+template <int Rank, DType Dtype>
+class OpIdentityN : public GraphNode
+{
+public:
+ OpIdentityN(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpIdentityN();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ std::vector<TosaReference::TensorTemplate<TIn>*> ins;
+ std::vector<TosaReference::TensorTemplate<TOut>*> outs;
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc
new file mode 100644
index 0000000..4d4f8b9
--- /dev/null
+++ b/reference_model/src/ops/ewise_binary.cc
@@ -0,0 +1,586 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "ewise_binary.h"
+#include "arith_util.h"
+#include "quant_util.h"
+#include "template_types.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <int Rank, DType InDtype, DType OutDtype>
+BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(const Op& op_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(op_, id_)
+{
+ setRequiredOperands(2, 1);
+ setRequiredRank(0, 6);
+
+ a_rank = b_rank = max_input_rank = -1;
+ a = b = nullptr;
+ a_rank0 = b_rank0 = nullptr;
+ result = nullptr;
+
+ fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return OutEigenType(); };
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+BinaryNodeBase<Rank, InDtype, OutDtype>::~BinaryNodeBase()
+{}
+
+template <int Rank, DType InDtype, DType OutDtype>
+int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ a_rank = inputs[0]->getRank();
+ b_rank = inputs[1]->getRank();
+ if (a_rank != 0 && b_rank != 0 && a_rank != b_rank)
+ {
+ printNodeValidationError("Binary operator input ranks must match");
+ return 1;
+ }
+
+ max_input_rank = a_rank >= b_rank ? a_rank : b_rank;
+
+ // A & B must be the same types
+ if (inputs[0]->matchType(*inputs[1]))
+ {
+ printNodeValidationError("Binary operator input types must match");
+ return 1;
+ }
+
+ // Result's geometry must match, but the type may be wider
+ if (outputs[0]->getRank() != max_input_rank)
+ {
+ printNodeValidationError("Binary operator input and output genometry must match");
+ return 1;
+ }
+
+ if (a_rank == max_input_rank)
+ {
+ a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ }
+ else
+ {
+ a_rank0 = dynamic_cast<TosaReference::TensorTemplate<ETensor0<InEigenType>>*>(inputs[0]);
+ }
+
+ if (b_rank == max_input_rank)
+ {
+ b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
+ }
+ else
+ {
+ b_rank0 = dynamic_cast<TosaReference::TensorTemplate<ETensor0<InEigenType>>*>(inputs[1]);
+ }
+
+ result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ // either a or b can be rank0
+ // a_rank0 and b_rank0 can't be valid at the same time.
+ // if a and be are both rank0, they should be evaulated as 'a' and 'b', instead of 'a_rank0' and 'b_rank0'
+ ASSERT_MEM((a || a_rank0) && (b || b_rank0) && !(a_rank0 && b_rank0) && result);
+
+ return 0;
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast()
+{
+ auto output_shape = result->getTensor().dimensions();
+
+ std::vector<int> a_shape, b_shape;
+
+ if (a_rank == max_input_rank)
+ {
+ a_shape = a->getShape();
+ }
+ else
+ {
+ a_shape.assign(max_input_rank, 1);
+ }
+
+ if (b_rank == max_input_rank)
+ {
+ b_shape = b->getShape();
+ }
+ else
+ {
+ b_shape.assign(max_input_rank, 1);
+ }
+
+ for (int i = 0; i < max_input_rank; i++)
+ {
+ if (a_shape[i] != output_shape[i] && a_shape[i] == 1)
+ {
+ bcast_a[i] = output_shape[i];
+ }
+ else
+ {
+ bcast_a[i] = 1;
+ }
+ if (b_shape[i] != output_shape[i] && b_shape[i] == 1)
+ {
+ bcast_b[i] = output_shape[i];
+ }
+ else
+ {
+ bcast_b[i] = 1;
+ }
+ }
+
+ return 0;
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+int BinaryNode<Rank, InDtype, OutDtype>::eval()
+{
+ this->broadcast();
+
+ Eigen::array<int, Rank> reshaper;
+ reshaper.fill(1);
+ TIn ia, ib;
+
+ if (this->a_rank == this->max_input_rank)
+ {
+ ia = this->a->getTensor().broadcast(this->bcast_a);
+ }
+ else
+ {
+ ia = this->a_rank0->getTensor().reshape(reshaper).broadcast(this->bcast_a);
+ }
+
+ if (this->b_rank == this->max_input_rank)
+ {
+ ib = this->b->getTensor().broadcast(this->bcast_b);
+ }
+ else
+ {
+ ib = this->b_rank0->getTensor().reshape(reshaper).broadcast(this->bcast_b);
+ }
+
+ this->result->getTensor() = ia.binaryExpr(ib, this->fcn);
+
+ return GraphNode::eval();
+}
+
+// still need to partial specialize this, or Eigen will throw static assertion
+template <DType InDtype, DType OutDtype>
+int BinaryNode<0, InDtype, OutDtype>::eval()
+{
+ this->result->getTensor() = this->a->getTensor().binaryExpr(this->b->getTensor(), this->fcn);
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+int OpAdd<Rank, Dtype>::register_fcn()
+{
+ switch (InDtype)
+ {
+ case DType_FLOAT:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpArithmeticRightShift<Rank, Dtype>::register_fcn()
+{
+ int32_t num_bits = 0;
+ switch (Dtype)
+ {
+ case DType_INT8:
+ num_bits = 8;
+ break;
+ case DType_INT16:
+ num_bits = 16;
+ break;
+ case DType_INT32:
+ num_bits = 32;
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ this->fcn = [num_bits](InEigenType a, InEigenType b) -> OutEigenType {
+ uint32_t sign = a & (1 << (num_bits - 1));
+ uint32_t ones_mask = ONES_MASK(b) << (num_bits - b);
+ if (sign)
+ return ones_mask | (a >> b);
+ else
+ return (~ones_mask) & (a >> b);
+ };
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpBitwiseAnd<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_AINT8:
+ case DType_INT16:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a & b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpBitwiseOr<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_AINT8:
+ case DType_INT16:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a | b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpBitwiseXor<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_AINT8:
+ case DType_INT16:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpLogicalAnd<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_BOOL:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a && b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpLogicalLeftShift<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_INT8:
+ case DType_INT16:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a << b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpLogicalRightShift<Rank, Dtype>::register_fcn()
+{
+ int32_t num_bits = 0;
+ switch (Dtype)
+ {
+ case DType_INT8:
+ num_bits = 8;
+ break;
+ case DType_INT16:
+ num_bits = 16;
+ break;
+ case DType_INT32:
+ num_bits = 32;
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ this->fcn = [num_bits](InEigenType a, InEigenType b) -> OutEigenType {
+ uint32_t mask = ONES_MASK(num_bits) >> b;
+ return (a >> b) & mask;
+ };
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpLogicalOr<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_BOOL:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a || b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpLogicalXor<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_BOOL:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpMaximum<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpMinimum<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+int OpMul<Rank, InDtype, OutDtype>::register_fcn()
+{
+ switch (InDtype)
+ {
+ case DType_FLOAT:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a * b; };
+ break;
+ case DType_INT8:
+ case DType_INT16:
+ this->fcn = [this](InEigenType lhs, InEigenType rhs) -> OutEigenType {
+ OutEigenType raw_output = (OutEigenType)lhs * (OutEigenType)rhs;
+
+ OutEigenType clamped_output = std::min<OutEigenType>(QMax, std::max<OutEigenType>(raw_output, QMin));
+
+ return clamped_output;
+ };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpPow<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return powf(a, b); };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpSub<Rank, Dtype>::register_fcn()
+{
+ switch (InDtype)
+ {
+ case DType_FLOAT:
+ case DType_INT32:
+ this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank>
+OpTable<Rank>::OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_TABLE, id_)
+{
+ setRequiredOperands(2, 1);
+ setRequiredRank(0, 6);
+}
+
+template <int Rank>
+OpTable<Rank>::~OpTable()
+{}
+
+template <int Rank>
+int OpTable<Rank>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ if (inputs[1]->getRank() != 1 || inputs[1]->getElementCount() != 513 || inputs[1]->getDtype() != DType_INT16)
+ {
+ FATAL_ERROR_NODE("OpTable: must have INT16 table with 513 entries");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ table = dynamic_cast<TosaReference::TensorTemplate<TTable>*>(inputs[1]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ ASSERT_MEM(in && table && out);
+
+ return 0;
+}
+
+template <int Rank>
+int OpTable<Rank>::eval()
+{
+ this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType {
+ // 1. make sure input is int16 range
+ int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax);
+
+ // 2. calculate index and interpolation fraction
+ int32_t index = (input_truncated >> 7) + (1 << (IntegerBits - 1));
+ index = std::min<int32_t>(std::max<int32_t>(index, 0), NumTableEntries - 1); // 9-bit index
+ int32_t frac = (input_truncated)&0x7F; // 7-bit fraction
+
+ // 3. interpolate, generate 16.7 (23-bit) output
+ int32_t base = this->table->getTensor()(index);
+ int32_t next = this->table->getTensor()(index + 1);
+ int32_t value = (base << 7) + (next - base) * frac;
+
+ return value;
+ });
+
+ return GraphNode::eval();
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalAnd, BOOL);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
+
+DEF_INSTANTIATE_ONE_RANK_0_6(OpTable);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FLOAT, BOOL);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL);
diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h
new file mode 100644
index 0000000..00fb3d9
--- /dev/null
+++ b/reference_model/src/ops/ewise_binary.h
@@ -0,0 +1,195 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_EWISE_BINARY_H
+#define OPS_EWISE_BINARY_H
+
+#include "graph_node.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+// class BinaryNodeBase: hold common functions of all the binary nodes
+// when an binary op is created, the virtual OpXXX::register_fcn() will be called
+// and 'fcn' will be register with lambda function which has two inputs
+// class BinaryNode: the level of indirection to partially specialize template for rank 0
+// eval() from toplevel called should call the .binaryExpr(dims, fcn) here
+// this needs to be partially specialize or
+// compiler will statically fail when trying to broadcast rank0 tensor
+// class OpXXX: implement per-element lambda function based on different data type
+// unlike BinaryNode, this doesn't need to be partially specialized
+
+// Eigen::Tensor does support some binary element-wise natively (e.g. CWiseMax, or '+', etc.)
+// which might be faster since it could be implemented with SIMD instructions
+// the way of registering lambda + .binaryExpr() might sacrifice performance here
+// but it can avoid partially specialization for combination of {rankN, rank0} x {FLOAT/INT32, QU8, ...}
+// needs to revisit if performance becomes a bottleneck here
+template <int Rank, DType InDtype, DType OutDtype>
+class BinaryNodeBase : public GraphNode
+{
+public:
+ BinaryNodeBase(const Op& nodeType, TosaQuantInfoBase* qinfo_, const uint64_t id_);
+ virtual ~BinaryNodeBase();
+
+ virtual int checkTensorAttributes() final;
+ virtual int eval() = 0;
+ virtual int register_fcn() = 0;
+
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ int broadcast();
+
+protected:
+ std::function<OutEigenType(InEigenType, InEigenType)> fcn;
+ Eigen::array<int, Rank> bcast_a;
+ Eigen::array<int, Rank> bcast_b;
+ TosaReference::TensorTemplate<TIn>* a;
+ TosaReference::TensorTemplate<TIn>* b;
+ TosaReference::TensorTemplate<ETensor0<InEigenType>>* a_rank0;
+ TosaReference::TensorTemplate<ETensor0<InEigenType>>* b_rank0;
+ TosaReference::TensorTemplate<TOut>* result;
+ int a_rank;
+ int b_rank;
+ int max_input_rank;
+};
+
+// primary class
+template <int Rank, DType InDtype, DType OutDtype>
+class BinaryNode : public BinaryNodeBase<Rank, InDtype, OutDtype>
+{
+public:
+ BinaryNode(const Op& op_, TosaQuantInfoBase* qinfo_, const uint64_t id_)
+ : BinaryNodeBase<Rank, InDtype, OutDtype>(op_, qinfo_, id_)
+ {}
+ virtual ~BinaryNode()
+ {}
+
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+};
+
+// partial specialization for rank 0
+template <DType InDtype, DType OutDtype>
+class BinaryNode<0, InDtype, OutDtype> : public BinaryNodeBase<0, InDtype, OutDtype>
+{
+public:
+ BinaryNode(const Op& op_, TosaQuantInfoBase* qinfo_, const uint64_t id_)
+ : BinaryNodeBase<0, InDtype, OutDtype>(op_, qinfo_, id_)
+ {}
+ virtual ~BinaryNode()
+ {}
+
+ virtual int eval();
+};
+
+#define DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Opname, OPNAME) \
+ template <int Rank, DType Dtype> \
+ class Op##Opname : public BinaryNode<Rank, Dtype, Dtype> \
+ { \
+ public: \
+ Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \
+ : BinaryNode<Rank, Dtype, Dtype>(Op_##OPNAME, qinfo_, id_) \
+ { \
+ register_fcn(); \
+ } \
+ static constexpr DType InDtype = Dtype; \
+ static constexpr DType OutDtype = Dtype; \
+ using InEigenType = typename GetEigenType<InDtype>::type; \
+ using OutEigenType = typename GetEigenType<OutDtype>::type; \
+ virtual int register_fcn(); \
+ };
+
+#define DEF_TEMPLATE_BINARY_OP_TWO_TYPE(Opname, OPNAME) \
+ template <int Rank, DType InDtype, DType OutDtype> \
+ class Op##Opname : public BinaryNode<Rank, InDtype, OutDtype> \
+ { \
+ public: \
+ Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \
+ : BinaryNode<Rank, InDtype, OutDtype>(Op_##OPNAME, qinfo_, id_) \
+ { \
+ register_fcn(); \
+ } \
+ static constexpr int32_t QMin = GetQMin<OutDtype>::value; \
+ static constexpr int32_t QMax = GetQMax<OutDtype>::value; \
+ using InEigenType = typename GetEigenType<InDtype>::type; \
+ using OutEigenType = typename GetEigenType<OutDtype>::type; \
+ virtual int register_fcn(); \
+ };
+
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Add, ADD)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(ArithmeticRightShift, ARITHMETIC_RIGHT_SHIFT)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseAnd, BITWISE_AND)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseOr, BITWISE_OR)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseXor, BITWISE_XOR)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalAnd, LOGICAL_AND)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalLeftShift, LOGICAL_LEFT_SHIFT)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalRightShift, LOGICAL_RIGHT_SHIFT)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalOr, LOGICAL_OR)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalXor, LOGICAL_XOR)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Maximum, MAXIMUM)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Minimum, MINIMUM)
+DEF_TEMPLATE_BINARY_OP_TWO_TYPE(Mul, MUL)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Pow, POW)
+DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Sub, SUB)
+
+#undef DEF_TEMPLATE_BINARY_OP_ONE_TYPE
+#undef DEF_TEMPLATE_BINARY_OP_TWO_TYPE
+
+template <int Rank>
+class OpTable : public GraphNode
+{
+public:
+ OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpTable();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ static constexpr DType InDtype = DType_INT16;
+ static constexpr DType TableDtype = DType_INT16;
+ static constexpr DType OutDtype = DType_INT32;
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using TableEigenType = typename GetEigenType<TableDtype>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TTable = Eigen::Tensor<TableEigenType, 1>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+ static constexpr int32_t IntegerBits = 9;
+ static constexpr int32_t FractionBits = 7;
+ static constexpr int32_t NumTableEntries = (1 << IntegerBits);
+ static constexpr int32_t QInMin = GetQMin<InDtype>::value;
+ static constexpr int32_t QInMax = GetQMax<InDtype>::value;
+ static constexpr int32_t QOutMin = GetQMin<OutDtype>::value;
+ static constexpr int32_t QOutMax = GetQMax<OutDtype>::value;
+
+protected:
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TTable>* table;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc
new file mode 100644
index 0000000..eded0d7
--- /dev/null
+++ b/reference_model/src/ops/ewise_ternary.cc
@@ -0,0 +1,115 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "ewise_ternary.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <int Rank, DType Dtype>
+OpSelectBase<Rank, Dtype>::OpSelectBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_SELECT, id_)
+{
+ setRequiredOperands(3, 1);
+ setRequiredRank(0, 6);
+}
+
+template <int Rank, DType Dtype>
+OpSelectBase<Rank, Dtype>::~OpSelectBase()
+{}
+
+template <int Rank, DType Dtype>
+int OpSelectBase<Rank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(inputs[2]) ||
+ validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same types
+ if (inputs[0]->matchRank(*outputs[0]) || inputs[1]->matchRankType(*outputs[0]) ||
+ inputs[2]->matchRankType(*outputs[0]))
+ {
+ printNodeValidationError("Failure to match input and output rank and type");
+ return 1;
+ }
+
+ cond = dynamic_cast<TosaReference::TensorTemplate<TCond>*>(inputs[0]);
+ then_val = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
+ else_val = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[2]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(outputs[0]);
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpSelectBase<Rank, Dtype>::eval()
+{
+ FATAL_ERROR_NODE("shouldn't be called");
+}
+
+template <int Rank, DType Dtype>
+int OpSelect<Rank, Dtype>::broadcast()
+{
+ std::vector<int> cond_shape = this->cond->getShape();
+ std::vector<int> then_shape = this->then_val->getShape();
+ std::vector<int> else_shape = this->else_val->getShape();
+ std::vector<int> out_shape = this->out->getShape();
+
+ for (int i = 0; i < Rank; i++)
+ {
+ this->bcast_cond[i] = (cond_shape[i] == 1) ? std::max(then_shape[i], else_shape[i]) : 1;
+ this->bcast_then[i] = (then_shape[i] == 1) ? std::max(cond_shape[i], else_shape[i]) : 1;
+ this->bcast_else[i] = (else_shape[i] == 1) ? std::max(then_shape[i], cond_shape[i]) : 1;
+ ASSERT_MSG_NODE((this->bcast_cond[i] * cond_shape[i]) == out_shape[i], "SELECT broadcast invariant failed");
+ ASSERT_MSG_NODE((this->bcast_then[i] * then_shape[i]) == out_shape[i], "SELECT broadcast invariant failed");
+ ASSERT_MSG_NODE((this->bcast_else[i] * else_shape[i]) == out_shape[i], "SELECT broadcast invariant failed");
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpSelect<Rank, Dtype>::eval()
+{
+ this->broadcast();
+ this->out->getTensor() = this->cond->getTensor()
+ .broadcast(this->bcast_cond)
+ .select(this->then_val->getTensor().broadcast(this->bcast_then),
+ this->else_val->getTensor().broadcast(this->bcast_else));
+
+ return GraphNode::eval();
+}
+
+template <DType Dtype>
+int OpSelect<0, Dtype>::eval()
+{
+ this->out->getTensor() = this->cond->getTensor().select(this->then_val->getTensor(), this->else_val->getTensor());
+
+ return GraphNode::eval();
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BOOL);
diff --git a/reference_model/src/ops/ewise_ternary.h b/reference_model/src/ops/ewise_ternary.h
new file mode 100644
index 0000000..b354247
--- /dev/null
+++ b/reference_model/src/ops/ewise_ternary.h
@@ -0,0 +1,83 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_TERNARY_H
+#define OPS_TERNARY_H
+
+#include "graph_node.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+// The Ternary Select op has the following operands:
+// 1. Cond: rank N, type=bool
+// 2. Then_val: Rank N, type=<V>
+// 3. Else_val: Rank N, type=<V>
+// 4. Result: Rank N, type=<V>
+// Cond, Then_val, Else_val need to be mutually-broadcastable
+template <int Rank, DType Dtype>
+class OpSelectBase : public GraphNode
+{
+public:
+ OpSelectBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpSelectBase();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using CondEigenType = typename GetEigenType<DType_BOOL>::type;
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using TCond = Eigen::Tensor<CondEigenType, Rank>;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+
+protected:
+ TosaReference::TensorTemplate<TCond>* cond;
+ Eigen::array<int, Rank> bcast_cond;
+ Eigen::array<int, Rank> bcast_then;
+ Eigen::array<int, Rank> bcast_else;
+ TosaReference::TensorTemplate<TIn>* then_val;
+ TosaReference::TensorTemplate<TIn>* else_val;
+ TosaReference::TensorTemplate<TIn>* out;
+};
+
+// primary class
+template <int Rank, DType Dtype>
+class OpSelect : public OpSelectBase<Rank, Dtype>
+{
+public:
+ OpSelect(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : OpSelectBase<Rank, Dtype>(attribute_, qinfo_, id_)
+ {}
+ virtual int eval();
+ int broadcast();
+
+ using InEigenType = typename OpSelectBase<Rank, Dtype>::InEigenType;
+};
+
+// partial specialization for rank 0
+template <DType Dtype>
+class OpSelect<0, Dtype> : public OpSelectBase<0, Dtype>
+{
+public:
+ OpSelect(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : OpSelectBase<0, Dtype>(attribute_, qinfo_, id_)
+ {}
+ virtual int eval();
+};
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc
new file mode 100644
index 0000000..d7bddc0
--- /dev/null
+++ b/reference_model/src/ops/ewise_unary.cc
@@ -0,0 +1,302 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "ewise_unary.h"
+#include "quant_util.h"
+#include "template_types.h"
+#include <cmath>
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <int Rank, DType Dtype>
+UnaryNode<Rank, Dtype>::UnaryNode(const Op& op_, uint64_t id_)
+ : GraphNode(op_, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(0, 6);
+
+ fcn = [](InEigenType a) -> OutEigenType { return OutEigenType(); };
+}
+
+template <int Rank, DType Dtype>
+UnaryNode<Rank, Dtype>::~UnaryNode()
+{}
+
+template <int Rank, DType Dtype>
+int UnaryNode<Rank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same types
+ if (inputs[0]->matchRankSize(*outputs[0]))
+ {
+ printNodeValidationError("UnaryNode: input and output rank must match");
+ return 1;
+ }
+
+ a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ ASSERT_MEM(a && result);
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int UnaryNode<Rank, Dtype>::eval()
+{
+ this->result->getTensor() = this->a->getTensor().unaryExpr(this->fcn);
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+int OpAbs<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ case DType_INT32:
+ this->fcn = [](InEigenType a) -> OutEigenType { return a > (InEigenType)0 ? a : (-a); };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpBitwiseNot<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_AINT8:
+ case DType_INT16:
+ case DType_INT32:
+ this->fcn = [](InEigenType a) -> OutEigenType { return ~a; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpCeil<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType { return ceilf(a); };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpClz<Rank, Dtype>::register_fcn()
+{
+ int32_t num_bits;
+ switch (Dtype)
+ {
+ case DType_INT32:
+ num_bits = 32;
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ this->fcn = [num_bits](int32_t a) -> int32_t {
+ int32_t leading_zeros = 0;
+ for (int bit = num_bits - 1; bit >= 0; bit--)
+ {
+ if (((a >> bit) & 0x1) == 0)
+ {
+ leading_zeros++;
+ }
+ else
+ {
+ break;
+ }
+ }
+ return leading_zeros;
+ };
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpExp<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType { return expf(a); };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpFloor<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType { return floorf(a); };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpLog<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType { return logf(a); };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpLogicalNot<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_BOOL:
+ this->fcn = [](InEigenType a) -> OutEigenType { return !a; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpNegate<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType {
+ InEigenType result = -(a);
+ return result;
+ };
+ break;
+ case DType_INT16:
+ case DType_INT32:
+ this->fcn = [](InEigenType a) -> OutEigenType {
+ InEigenType result = -(a);
+ return result;
+ };
+ break;
+ case DType_AINT8:
+ ASSERT(this->qinfo);
+ this->fcn = [this](InEigenType a) -> OutEigenType {
+ InEigenType result = -(a - this->qinfo->input_zp()) + this->qinfo->output_zp();
+ return result;
+ };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpReciprocal<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / a; };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpRsqrt<Rank, Dtype>::register_fcn()
+{
+ switch (Dtype)
+ {
+ case DType_FLOAT:
+ this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / sqrtf(a); };
+ break;
+ default:
+ FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]);
+ }
+
+ return 0;
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FLOAT);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FLOAT);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FLOAT);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FLOAT);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FLOAT);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FLOAT);
diff --git a/reference_model/src/ops/ewise_unary.h b/reference_model/src/ops/ewise_unary.h
new file mode 100644
index 0000000..0db3cfb
--- /dev/null
+++ b/reference_model/src/ops/ewise_unary.h
@@ -0,0 +1,102 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_EWISE_UNARY_H
+#define OPS_EWISE_UNARY_H
+
+#include "graph_node.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+template <int Rank, DType Dtype>
+class UnaryNode : public GraphNode
+{
+public:
+ UnaryNode(const Op& nodeType, const uint64_t id_);
+ virtual ~UnaryNode();
+
+ virtual int checkTensorAttributes() final;
+ virtual int eval() final;
+ virtual int register_fcn() = 0;
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ std::function<OutEigenType(InEigenType)> fcn;
+ TosaReference::TensorTemplate<TIn>* a;
+ TosaReference::TensorTemplate<TOut>* result;
+};
+
+#define DEF_TEMPLATE_UNARY_OP(Opname, OPNAME) \
+ template <int Rank, DType Dtype> \
+ class Op##Opname : public UnaryNode<Rank, Dtype> \
+ { \
+ public: \
+ Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \
+ : UnaryNode<Rank, Dtype>(Op_##OPNAME, id_) \
+ { \
+ register_fcn(); \
+ } \
+ static constexpr int32_t QMin = GetQMin<Dtype>::value; \
+ static constexpr int32_t QMax = GetQMax<Dtype>::value; \
+ using InEigenType = typename GetEigenType<Dtype>::type; \
+ using OutEigenType = typename GetEigenType<Dtype>::type; \
+ virtual int register_fcn(); \
+ };
+
+#define DEF_TEMPLATE_UNARY_OP_WITH_QUANT_INFO(Opname, OPNAME) \
+ template <int Rank, DType Dtype> \
+ class Op##Opname : public UnaryNode<Rank, Dtype> \
+ { \
+ public: \
+ Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \
+ : UnaryNode<Rank, Dtype>(Op_##OPNAME, id_) \
+ { \
+ INIT_QINFO(Unary); \
+ register_fcn(); \
+ } \
+ static constexpr int32_t QMin = GetQMin<Dtype>::value; \
+ static constexpr int32_t QMax = GetQMax<Dtype>::value; \
+ using InEigenType = typename GetEigenType<Dtype>::type; \
+ using OutEigenType = typename GetEigenType<Dtype>::type; \
+ virtual int register_fcn(); \
+ \
+ protected: \
+ TosaUnaryQuantInfo* qinfo; \
+ };
+
+DEF_TEMPLATE_UNARY_OP(Abs, ABS)
+DEF_TEMPLATE_UNARY_OP(BitwiseNot, BITWISE_NOT)
+DEF_TEMPLATE_UNARY_OP(Ceil, CEIL)
+DEF_TEMPLATE_UNARY_OP(Clz, CLZ)
+DEF_TEMPLATE_UNARY_OP(Exp, EXP)
+DEF_TEMPLATE_UNARY_OP(Floor, FLOOR)
+DEF_TEMPLATE_UNARY_OP(Log, LOG)
+DEF_TEMPLATE_UNARY_OP(LogicalNot, LOGICAL_NOT)
+DEF_TEMPLATE_UNARY_OP_WITH_QUANT_INFO(Negate, NEGATE)
+DEF_TEMPLATE_UNARY_OP(Reciprocal, RECIPROCAL)
+DEF_TEMPLATE_UNARY_OP(Rsqrt, RSQRT)
+
+#undef DEF_TEMPLATE_UNARY_OP
+#undef DEF_TEMPLATE_UNARY_OP_WITH_QUANT_INFO
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc
new file mode 100644
index 0000000..d3352ce
--- /dev/null
+++ b/reference_model/src/ops/image.cc
@@ -0,0 +1,169 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "image.h"
+#include "arith_util.h"
+#include "quant_util.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <DType InDtype, DType OutDtype>
+OpResize<InDtype, OutDtype>::OpResize(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_RESIZE, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(4, 4);
+
+ INIT_ATTRIBUTE(Resize);
+}
+
+template <DType InDtype, DType OutDtype>
+OpResize<InDtype, OutDtype>::~OpResize()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <DType InDtype, DType OutDtype>
+int OpResize<InDtype, OutDtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ return 1;
+
+ output_size = this->attribute->output_size();
+ stride = this->attribute->stride();
+ offset = this->attribute->offset();
+ shift = this->attribute->shift();
+ mode = this->attribute->mode();
+
+ int output_height = outputs[0]->getShape()[1];
+ int output_width = outputs[0]->getShape()[2];
+
+ if (this->mode == ResizeMode_BILINEAR)
+ {
+ if (OutDtype != DType_INT32 && OutDtype != DType_INT48)
+ {
+ printNodeValidationError("OpResize: invalid data type for BILINEAR");
+ return 1;
+ }
+ }
+ else
+ {
+ if (OutDtype != DType_INT8 && OutDtype != DType_INT16)
+ {
+ printNodeValidationError("OpResize: invalid data type for NEAREST");
+ return 1;
+ }
+ }
+
+ if (output_size[0] != output_height || output_size[1] != output_width)
+ {
+ printNodeValidationError("OpResize: attribute output_size doesn't match output [height, width]");
+ return 1;
+ }
+
+ if (shift < 1 || shift > 11)
+ {
+ printNodeValidationError("OpResize: attribute shift should be within [1, 11]");
+ return 1;
+ }
+
+ if (stride[0] <= 0 || stride[1] <= 0)
+ {
+ printNodeValidationError("OpResize: invalid attribute stride");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ ASSERT_MEM(in && out);
+
+ return 0;
+}
+
+template <DType InDtype, DType OutDtype>
+int OpResize<InDtype, OutDtype>::eval()
+{
+ int in_batch = in->getShape()[0];
+ int in_height = in->getShape()[1];
+ int in_width = in->getShape()[2];
+ int in_channels = in->getShape()[3];
+
+ int out_batch = out->getShape()[0];
+ int out_height = out->getShape()[1];
+ int out_width = out->getShape()[2];
+ int out_channels = out->getShape()[3];
+
+ ASSERT_MSG_NODE(in_batch == out_batch, "OpResize: output tensor batch mismatch");
+ ASSERT_MSG_NODE(in_channels == out_channels, "OpResize: output tensor channel mismatch");
+
+ for (int b = 0; b < out_batch; b++)
+ for (int c = 0; c < out_channels; c++)
+ for (int oy = 0; oy < out_height; oy++)
+ for (int ox = 0; ox < out_width; ox++)
+ {
+ int y = oy * stride[0] + offset[0];
+ int x = ox * stride[1] + offset[1];
+
+ int iy = y >> shift;
+ int dy = y - (iy << shift);
+ int ix = x >> shift;
+ int dx = x - (ix << shift);
+
+ int iy0 = MAX(iy, 0);
+ int iy1 = MIN(iy + 1, in_height - 1);
+ int ix0 = MAX(ix, 0);
+ int ix1 = MIN(ix + 1, in_width - 1);
+
+ ASSERT_MSG(iy0 <= iy1 && ix0 <= ix1, "OpResize: invalid index (iy0, iy1, ix0, ix1)=(%d,%d,%d,%d)",
+ iy0, iy1, ix0, ix1);
+
+ InEigenType v00 = in->getTensor()(b, iy0, ix0, c);
+ InEigenType v01 = in->getTensor()(b, iy0, ix1, c);
+ InEigenType v10 = in->getTensor()(b, iy1, ix0, c);
+ InEigenType v11 = in->getTensor()(b, iy1, ix1, c);
+
+ OutEigenType acc;
+ if (mode == ResizeMode_BILINEAR)
+ {
+ acc = (OutEigenType)v00 * ((1 << shift) - dy) * ((1 << shift) - dx);
+ acc = acc + (OutEigenType)v01 * ((1 << shift) - dy) * dx;
+ acc = acc + (OutEigenType)v10 * dy * ((1 << shift) - dx);
+ acc = acc + (OutEigenType)v11 * dy * dx;
+ }
+ else
+ {
+ iy = (dy >> (shift - 1)) != 0 ? iy1 : iy0;
+ ix = (dx >> (shift - 1)) != 0 ? ix1 : ix0;
+ acc = in->getTensor()(b, iy, ix, c);
+ }
+
+ out->getTensor()(b, oy, ox, c) = acc;
+ }
+
+ return GraphNode::eval();
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_TWO_TYPE(OpResize, INT8, INT32);
+DEF_INSTANTIATE_TWO_TYPE(OpResize, INT8, INT8);
+DEF_INSTANTIATE_TWO_TYPE(OpResize, INT16, INT48);
+DEF_INSTANTIATE_TWO_TYPE(OpResize, INT16, INT16);
diff --git a/reference_model/src/ops/image.h b/reference_model/src/ops/image.h
new file mode 100644
index 0000000..9d15d49
--- /dev/null
+++ b/reference_model/src/ops/image.h
@@ -0,0 +1,53 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_IMAGE_H
+#define OPS_IMAGE_H
+
+#include "graph_node.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+template <DType InDtype, DType OutDtype>
+class OpResize : public GraphNode
+{
+public:
+ OpResize(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpResize();
+ virtual int checkTensorAttributes() final;
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, 4>;
+ using TOut = Eigen::Tensor<OutEigenType, 4>;
+
+protected:
+ TosaResizeAttribute* attribute;
+ std::vector<int32_t> output_size;
+ std::vector<int32_t> stride;
+ std::vector<int32_t> offset;
+ int32_t shift;
+ ResizeMode mode;
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc
new file mode 100644
index 0000000..bad0c40
--- /dev/null
+++ b/reference_model/src/ops/op_factory.cc
@@ -0,0 +1,432 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "op_factory.h"
+#include "activation_funcs.h"
+#include "comparison.h"
+#include "control_flow.h"
+#include "custom.h"
+#include "data_layout.h"
+#include "data_nodes.h"
+#include "ewise_binary.h"
+#include "ewise_ternary.h"
+#include "ewise_unary.h"
+#include "image.h"
+#include "reduction.h"
+#include "scatter_gather.h"
+#include "tensor_ops.h"
+#include "type_conversion.h"
+
+using namespace TosaReference;
+using namespace tosa;
+
+GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh,
+ Op opType,
+ TosaAttributeBase* attribute,
+ TosaQuantInfoBase* qinfo,
+ uint64_t id,
+ DType inputDType,
+ int inputRank,
+ DType outputDType,
+ int outputRank,
+ DType weightDType,
+ int weightRank)
+{
+ switch (opType)
+ {
+ // tensor_ops
+ case Op_ARGMAX:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FLOAT);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, AINT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
+ break;
+ case Op_AVG_POOL2D:
+ DEF_FACTORY_ONE_TYPE(OpAvgPool2d, FLOAT);
+ DEF_FACTORY_ONE_TYPE(OpAvgPool2d, AINT8);
+ DEF_FACTORY_ONE_TYPE(OpAvgPool2d, INT16);
+ break;
+ case Op_CONV2D:
+ DEF_FACTORY_TWO_TYPE(OpConv2d, FLOAT, FLOAT);
+ DEF_FACTORY_TWO_TYPE(OpConv2d, AINT8, INT4);
+ DEF_FACTORY_TWO_TYPE(OpConv2d, AINT8, INT8);
+ DEF_FACTORY_TWO_TYPE(OpConv2d, AINT8, AINT8);
+ DEF_FACTORY_TWO_TYPE(OpConv2d, INT16, INT8);
+ break;
+ case Op_DEPTHWISE_CONV2D:
+ DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, FLOAT, FLOAT);
+ DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, AINT8, INT4);
+ DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, AINT8, INT8);
+ DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, AINT8, AINT8);
+ DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, INT16, INT8);
+ break;
+ case Op_FULLY_CONNECTED:
+ DEF_FACTORY_TWO_TYPE(OpFullyConnected, FLOAT, FLOAT);
+ DEF_FACTORY_TWO_TYPE(OpFullyConnected, AINT8, INT4);
+ DEF_FACTORY_TWO_TYPE(OpFullyConnected, AINT8, INT8);
+ DEF_FACTORY_TWO_TYPE(OpFullyConnected, AINT8, AINT8);
+ DEF_FACTORY_TWO_TYPE(OpFullyConnected, INT16, INT8);
+ break;
+ case Op_MATMUL:
+ DEF_FACTORY_ONE_TYPE(OpMatMul, FLOAT);
+ DEF_FACTORY_ONE_TYPE(OpMatMul, AINT8);
+ DEF_FACTORY_ONE_TYPE(OpMatMul, INT16);
+ break;
+ case Op_MAX_POOL2D:
+ DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FLOAT);
+ DEF_FACTORY_ONE_TYPE(OpMaxPool2d, AINT8);
+ DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT16);
+ break;
+ case Op_TRANSPOSE_CONV2D:
+ DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, FLOAT, FLOAT);
+ DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, AINT8, INT4);
+ DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, AINT8, INT8);
+ DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, AINT8, AINT8);
+ DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, INT16, INT8);
+ break;
+
+ // activation_funcs
+ case Op_CLAMP:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16);
+ break;
+ case Op_RELUN:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReluN, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReluN, INT32);
+ break;
+ case Op_SIGMOID:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FLOAT);
+ break;
+ case Op_TANH:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FLOAT);
+ break;
+
+ // ewise_binary
+ case Op_ADD:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32);
+ break;
+ case Op_ARITHMETIC_RIGHT_SHIFT:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT32);
+ break;
+ case Op_BITWISE_AND:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT32);
+ break;
+ case Op_BITWISE_OR:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT32);
+ break;
+ case Op_BITWISE_XOR:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT32);
+ break;
+ case Op_LOGICAL_AND:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalAnd, BOOL);
+ break;
+ case Op_LOGICAL_LEFT_SHIFT:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT32);
+ break;
+ case Op_LOGICAL_RIGHT_SHIFT:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT32);
+ break;
+ case Op_LOGICAL_OR:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL);
+ break;
+ case Op_LOGICAL_XOR:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL);
+ break;
+ case Op_MAXIMUM:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32);
+ break;
+ case Op_MINIMUM:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32);
+ break;
+ case Op_MUL:
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32);
+ break;
+ case Op_POW:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT);
+ break;
+ case Op_SUB:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32);
+ break;
+ case Op_TABLE:
+ DEF_FACTORY_ONE_RANK_0_6(OpTable);
+ break;
+
+ // ewise_unary
+ case Op_ABS:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32);
+ break;
+ case Op_BITWISE_NOT:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32);
+ break;
+ case Op_CEIL:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FLOAT);
+ break;
+ case Op_CLZ:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32);
+ break;
+ case Op_EXP:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FLOAT);
+ break;
+ case Op_FLOOR:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FLOAT);
+ break;
+ case Op_LOG:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FLOAT);
+ break;
+ case Op_LOGICAL_NOT:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL);
+ break;
+ case Op_NEGATE:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32);
+ break;
+ case Op_RECIPROCAL:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FLOAT);
+ break;
+ case Op_RSQRT:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FLOAT);
+ break;
+
+ // ewise_ternary
+ case Op_SELECT:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BOOL);
+ break;
+
+ // comparison
+ case Op_EQUAL:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32);
+ break;
+ case Op_GREATER:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32);
+ break;
+ case Op_GREATER_EQUAL:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32);
+ break;
+
+ // reduction
+ case Op_REDUCE_ALL:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL);
+ break;
+ case Op_REDUCE_ANY:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAny, BOOL);
+ break;
+ case Op_REDUCE_MAX:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FLOAT);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, AINT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32);
+ break;
+ case Op_REDUCE_MIN:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FLOAT);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, AINT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32);
+ break;
+ case Op_REDUCE_PRODUCT:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FLOAT);
+ break;
+ case Op_REDUCE_SUM:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FLOAT);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, INT32);
+ break;
+
+ // data layout
+ case Op_CONCAT:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FLOAT);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, AINT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL);
+ break;
+ case Op_PAD:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FLOAT);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, AINT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL);
+ break;
+ case Op_RESHAPE:
+ DEF_FACTORY_RESHAPE(OpReshape, FLOAT);
+ DEF_FACTORY_RESHAPE(OpReshape, AINT8);
+ DEF_FACTORY_RESHAPE(OpReshape, INT8);
+ DEF_FACTORY_RESHAPE(OpReshape, INT16);
+ DEF_FACTORY_RESHAPE(OpReshape, INT32);
+ DEF_FACTORY_RESHAPE(OpReshape, BOOL);
+ break;
+ case Op_REVERSE:
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FLOAT);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, AINT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32);
+ DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL);
+ break;
+ case Op_SLICE:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL);
+ break;
+ case Op_TILE:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BOOL);
+ break;
+ case Op_TRANSPOSE:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32);
+ break;
+
+ // scatter_gather
+ case Op_GATHER:
+ {
+ // output.rank = input.rank - 1 + index.rank
+ int32_t index_rank = outputRank - inputRank + 1;
+ DEF_FACTORY_GATHER(OpGather, AINT8);
+ DEF_FACTORY_GATHER(OpGather, INT16);
+ DEF_FACTORY_GATHER(OpGather, INT32);
+ }
+ break;
+
+ // image
+ case Op_RESIZE:
+ DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT8, INT32);
+ DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT8, INT8);
+ DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT16, INT48);
+ DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT16, INT16);
+ break;
+
+ // data_nodes
+ case Op_CONST:
+ return new OpConst(id);
+ case Op_PLACEHOLDER:
+ return new OpPlaceholder(id);
+ case Op_IDENTITY:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL);
+ break;
+ case Op_IDENTITYN:
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, BOOL);
+ break;
+
+ // type_conversion
+ case Op_CAST:
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FLOAT);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT32);
+ break;
+ case Op_RESCALE:
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT16);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT32);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, AINT8);
+ DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, UINT8);
+ break;
+
+ // custom
+ case Op_CUSTOM:
+ return new OpCustom(id);
+
+ // control_flow
+ case Op_COND_IF:
+ return new OpCondIf(tsh, attribute, id);
+ case Op_WHILE_LOOP:
+ return new OpWhileLoop(tsh, attribute, id);
+
+ // Ops not recognized
+ default:
+ goto done;
+
+ } // End of switch(opType)
+
+done:
+ return nullptr;
+}
diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h
new file mode 100644
index 0000000..cde6841
--- /dev/null
+++ b/reference_model/src/ops/op_factory.h
@@ -0,0 +1,294 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_OP_FACTORY_H
+#define OPS_OP_FACTORY_H
+
+#include "attribute.h"
+#include "graph_node.h"
+#include "quant_info.h"
+#include "template_types.h"
+#include "tosa_serialization_handler.h"
+
+#define DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) \
+ case RANK: \
+ return new OP<RANK, DType_##DTYPE>(attribute, qinfo, id);
+
+#define DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, RANK, DTYPE1, DTYPE2) \
+ case RANK: \
+ return new OP<RANK, DType_##DTYPE1, DType_##DTYPE2>(attribute, qinfo, id);
+
+#define DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, RANK1, RANK2, DTYPE) \
+ case RANK2: \
+ return new OP<RANK1, RANK2, DType_##DTYPE>(attribute, qinfo, id);
+
+#define DEF_FACTORY_TWO_RANK_TWO_TYPE(OP, RANK1, RANK2, DTYPE1, DTYPE2) \
+ case RANK2: \
+ return new OP<RANK1, RANK2, DType_##DTYPE1, DType_##DTYPE2>(attribute, qinfo, id);
+
+#define DEF_FACTORY_ONE_RANK_0_6(OP) \
+ switch (inputRank) \
+ { \
+ case 0: \
+ return new OP<0>(attribute, qinfo, id); \
+ case 1: \
+ return new OP<1>(attribute, qinfo, id); \
+ case 2: \
+ return new OP<2>(attribute, qinfo, id); \
+ case 3: \
+ return new OP<3>(attribute, qinfo, id); \
+ case 4: \
+ return new OP<4>(attribute, qinfo, id); \
+ case 5: \
+ return new OP<5>(attribute, qinfo, id); \
+ case 6: \
+ return new OP<6>(attribute, qinfo, id); \
+ }
+
+#define DEF_FACTORY_ONE_TYPE(OP, DTYPE) \
+ if (inputDType == DType_##DTYPE) \
+ { \
+ return new OP<DType_##DTYPE>(attribute, qinfo, id); \
+ }
+
+#define DEF_FACTORY_TWO_TYPE(OP, DTYPE1, DTYPE2) \
+ if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2) \
+ { \
+ return new OP<DType_##DTYPE1, DType_##DTYPE2>(attribute, qinfo, id); \
+ }
+
+#define DEF_FACTORY_TWO_TYPE_RESIZE(OP, DTYPE1, DTYPE2) \
+ if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+ { \
+ return new OP<DType_##DTYPE1, DType_##DTYPE2>(attribute, qinfo, id); \
+ }
+
+#define DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \
+ if (inputDType == DType_##DTYPE) \
+ { \
+ switch (inputRank) \
+ { \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 0, DTYPE) \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 1, DTYPE) \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 2, DTYPE) \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 3, DTYPE) \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 4, DTYPE) \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 5, DTYPE) \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 6, DTYPE) \
+ } \
+ }
+
+#define DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \
+ if (inputDType == DType_##DTYPE) \
+ { \
+ switch (inputRank) \
+ { \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 1, DTYPE) \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 2, DTYPE) \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 3, DTYPE) \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 4, DTYPE) \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 5, DTYPE) \
+ DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 6, DTYPE) \
+ } \
+ }
+
+#define DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OP, DTYPE1, DTYPE2) \
+ if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \
+ { \
+ switch (inputRank) \
+ { \
+ DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 0, DTYPE1, DTYPE2) \
+ DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 1, DTYPE1, DTYPE2) \
+ DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 2, DTYPE1, DTYPE2) \
+ DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 3, DTYPE1, DTYPE2) \
+ DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 4, DTYPE1, DTYPE2) \
+ DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 5, DTYPE1, DTYPE2) \
+ DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 6, DTYPE1, DTYPE2) \
+ } \
+ }
+
+#define DEF_FACTORY_RESHAPE(OP, DTYPE) \
+ if (inputDType == DType_##DTYPE && outputDType == DType_##DTYPE) \
+ { \
+ switch (inputRank) \
+ { \
+ case 0: \
+ { \
+ switch (outputRank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 0, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 1, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 2, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 3, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 4, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 5, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 6, DTYPE) \
+ } \
+ } \
+ case 1: \
+ { \
+ switch (outputRank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 0, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 1, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 2, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 3, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 4, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 5, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 6, DTYPE) \
+ } \
+ } \
+ case 2: \
+ { \
+ switch (outputRank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 0, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 1, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 2, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 3, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 4, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 5, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 6, DTYPE) \
+ } \
+ } \
+ case 3: \
+ { \
+ switch (outputRank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 0, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 1, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 2, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 3, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 4, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 5, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 6, DTYPE) \
+ } \
+ } \
+ case 4: \
+ { \
+ switch (outputRank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 0, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 1, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 2, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 3, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 4, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 5, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 6, DTYPE) \
+ } \
+ } \
+ case 5: \
+ { \
+ switch (outputRank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 0, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 1, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 2, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 3, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 4, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 5, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 6, DTYPE) \
+ } \
+ } \
+ case 6: \
+ { \
+ switch (outputRank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 0, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 1, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 2, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 3, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 4, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 5, DTYPE) \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 6, DTYPE) \
+ } \
+ } \
+ } \
+ }
+
+#define DEF_FACTORY_GATHER(OP, DTYPE) \
+ if (inputDType == DType_##DTYPE && outputDType == DType_##DTYPE) \
+ { \
+ switch (inputRank) \
+ { \
+ case 1: \
+ switch (index_rank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 1, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 2, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 3, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 4, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 5, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 6, DTYPE); \
+ } \
+ case 2: \
+ switch (index_rank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 1, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 2, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 3, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 4, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 5, DTYPE); \
+ } \
+ case 3: \
+ switch (index_rank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 1, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 2, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 3, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 4, DTYPE); \
+ } \
+ case 4: \
+ switch (index_rank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 1, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 2, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 3, DTYPE); \
+ } \
+ case 5: \
+ switch (index_rank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 1, DTYPE); \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 2, DTYPE); \
+ } \
+ case 6: \
+ switch (index_rank) \
+ { \
+ DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 1, DTYPE); \
+ } \
+ } \
+ }
+
+namespace TosaReference
+{
+
+class OpFactory
+{
+public:
+ static GraphNode* newOp(tosa::TosaSerializationHandler* tsh,
+ tosa::Op opType,
+ tosa::TosaAttributeBase* attribute,
+ tosa::TosaQuantInfoBase* qinfo,
+ uint64_t id,
+ tosa::DType inputDType,
+ int inputRank,
+ tosa::DType outputDType,
+ int outputRank,
+ tosa::DType weightDType,
+ int weightRank);
+};
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc
new file mode 100644
index 0000000..a2adfdb
--- /dev/null
+++ b/reference_model/src/ops/reduction.cc
@@ -0,0 +1,139 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "reduction.h"
+#include "quant_util.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <int Rank, DType Dtype>
+ReduceNode<Rank, Dtype>::ReduceNode(const Op& op_, TosaAttributeBase* attribute_, uint64_t id_)
+ : GraphNode(op_, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(0, 4);
+
+ INIT_ATTRIBUTE(Axis);
+}
+
+template <int Rank, DType Dtype>
+ReduceNode<Rank, Dtype>::~ReduceNode()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <int Rank, DType Dtype>
+int ReduceNode<Rank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ if (attribute->axis() < 0 || attribute->axis() >= inputs[0]->getRank())
+ {
+ printNodeValidationError("Reduce axis must between [0, input_rank - 1]");
+ return 1;
+ }
+
+ if (inputs[0]->matchRank(*outputs[0]))
+ {
+ printNodeValidationError("Input and output tensor ranks must match");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ ASSERT_MEM(in && out);
+
+ dims[0] = this->attribute->axis();
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpReduceAll<Rank, Dtype>::eval()
+{
+ this->out->getTensor() = this->in->getTensor().all(this->dims).reshape(this->out->getTensor().dimensions());
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+int OpReduceAny<Rank, Dtype>::eval()
+{
+ this->out->getTensor() = this->in->getTensor().any(this->dims).reshape(this->out->getTensor().dimensions());
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+int OpReduceMax<Rank, Dtype>::eval()
+{
+ this->out->getTensor() = this->in->getTensor().maximum(this->dims).reshape(this->out->getTensor().dimensions());
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+int OpReduceMin<Rank, Dtype>::eval()
+{
+ this->out->getTensor() = this->in->getTensor().minimum(this->dims).reshape(this->out->getTensor().dimensions());
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+int OpReduceProduct<Rank, Dtype>::eval()
+{
+ this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions());
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType Dtype>
+int OpReduceSum<Rank, Dtype>::eval()
+{
+ this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions());
+
+ return GraphNode::eval();
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL);
+
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAny, BOOL);
+
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FLOAT);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, AINT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32);
+
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FLOAT);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, AINT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32);
+
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FLOAT);
+
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FLOAT);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, INT32);
diff --git a/reference_model/src/ops/reduction.h b/reference_model/src/ops/reduction.h
new file mode 100644
index 0000000..cf75812
--- /dev/null
+++ b/reference_model/src/ops/reduction.h
@@ -0,0 +1,109 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_REDUCTION_H
+#define OPS_REDUCTION_H
+
+#include "graph_node.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+template <int Rank, DType Dtype>
+class ReduceNode : public GraphNode
+{
+public:
+ ReduceNode(const Op& nodeType, TosaAttributeBase* attribute_, const uint64_t id_);
+ virtual ~ReduceNode();
+ virtual int checkTensorAttributes();
+ virtual int eval() = 0;
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ Eigen::array<int, 1> dims;
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TOut>* out;
+ TosaAxisAttribute* attribute;
+};
+
+template <int Rank, DType Dtype>
+class OpReduceAll : public ReduceNode<Rank, Dtype>
+{
+public:
+ OpReduceAll(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : ReduceNode<Rank, Dtype>(Op_REDUCE_ALL, attribute_, id_)
+ {}
+ virtual int eval();
+};
+
+template <int Rank, DType Dtype>
+class OpReduceAny : public ReduceNode<Rank, Dtype>
+{
+public:
+ OpReduceAny(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : ReduceNode<Rank, Dtype>(Op_REDUCE_ALL, attribute_, id_)
+ {}
+ virtual int eval();
+};
+
+template <int Rank, DType Dtype>
+class OpReduceMax : public ReduceNode<Rank, Dtype>
+{
+public:
+ OpReduceMax(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : ReduceNode<Rank, Dtype>(Op_REDUCE_MAX, attribute_, id_)
+ {}
+ virtual int eval();
+};
+
+template <int Rank, DType Dtype>
+class OpReduceMin : public ReduceNode<Rank, Dtype>
+{
+public:
+ OpReduceMin(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : ReduceNode<Rank, Dtype>(Op_REDUCE_MIN, attribute_, id_)
+ {}
+ virtual int eval();
+};
+
+template <int Rank, DType Dtype>
+class OpReduceProduct : public ReduceNode<Rank, Dtype>
+{
+public:
+ OpReduceProduct(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : ReduceNode<Rank, Dtype>(Op_REDUCE_PRODUCT, attribute_, id_)
+ {}
+ virtual int eval();
+};
+
+template <int Rank, DType Dtype>
+class OpReduceSum : public ReduceNode<Rank, Dtype>
+{
+public:
+ OpReduceSum(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : ReduceNode<Rank, Dtype>(Op_REDUCE_SUM, attribute_, id_)
+ {}
+ virtual int eval();
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc
new file mode 100644
index 0000000..c54204a
--- /dev/null
+++ b/reference_model/src/ops/scatter_gather.cc
@@ -0,0 +1,120 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "scatter_gather.h"
+#include "quant_util.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <int InRank, int IndexRank, DType Dtype>
+OpGather<InRank, IndexRank, Dtype>::OpGather(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_GATHER, id_)
+{
+ setRequiredOperands(2, 1);
+ setRequiredRank(1, 6);
+
+ INIT_ATTRIBUTE(Axis);
+}
+
+template <int InRank, int IndexRank, DType Dtype>
+OpGather<InRank, IndexRank, Dtype>::~OpGather()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <int InRank, int IndexRank, DType Dtype>
+int OpGather<InRank, IndexRank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same types
+ if (inputs[0]->matchType(*outputs[0]))
+ {
+ printNodeValidationError("Failure to match input and output type");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ index = dynamic_cast<TosaReference::TensorTemplate<TIndex>*>(inputs[1]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ ASSERT_MEM(in && index && out);
+
+ return 0;
+}
+
+template <int InRank, int IndexRank, DType Dtype>
+int OpGather<InRank, IndexRank, Dtype>::eval()
+{
+ int axis = attribute->axis();
+
+ // calculate size left and right to axis
+ int left_size = 1;
+ for (int i = 0; i < axis; ++i)
+ {
+ left_size *= in->getShape()[i];
+ }
+
+ int right_size = 1;
+ for (int i = axis + 1; i < in->getRank(); ++i)
+ {
+ right_size *= in->getShape()[i];
+ }
+
+ InEigenType* input_data = in->getTensor().data();
+ int32_t* index_data = index->getTensor().data();
+ OutEigenType* output_data = out->getTensor().data();
+
+ int32_t axis_size = in->getShape()[axis];
+ int32_t index_count = index->getElementCount();
+
+ // sanity check if index is valid
+ // need to check until this point since index is not known until runtime
+ for (size_t i = 0; i < index->getElementCount(); i++)
+ {
+ if (index_data[i] >= axis_size)
+ {
+ FATAL_ERROR_NODE("OpGather: index[%lu]=%i can't exceed axis_size=%i", i, index_data[i], axis_size);
+ }
+ }
+
+ // Eigen stores tensor in column-major
+ // so we iterate through dimension right to axis and the index array
+ // do memory copy with size of left size each time
+ for (int right = 0; right < right_size; ++right)
+ {
+ for (int i = 0; i < index_count; ++i)
+ {
+ std::memcpy(output_data + (right * index_count + i) * left_size,
+ input_data + (right * axis_size + index_data[i]) * left_size, sizeof(InEigenType) * left_size);
+ }
+ }
+
+ return GraphNode::eval();
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_GATHER(OpGather, AINT8);
+DEF_INSTANTIATE_GATHER(OpGather, INT16);
+DEF_INSTANTIATE_GATHER(OpGather, INT32);
diff --git a/reference_model/src/ops/scatter_gather.h b/reference_model/src/ops/scatter_gather.h
new file mode 100644
index 0000000..d9b1263
--- /dev/null
+++ b/reference_model/src/ops/scatter_gather.h
@@ -0,0 +1,54 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_SCATTER_GATHER_H
+#define OPS_SCATTER_GATHER_H
+
+#include "graph_node.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+// input and index can have different rank
+// and infer OutRank statically
+template <int InRank, int IndexRank, DType Dtype>
+class OpGather : public GraphNode
+{
+public:
+ OpGather(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpGather();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ static constexpr int OutRank = InRank - 1 + IndexRank;
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, InRank>;
+ using TIndex = Eigen::Tensor<int32_t, IndexRank>;
+ using TOut = Eigen::Tensor<OutEigenType, OutRank>;
+
+protected:
+ TosaAxisAttribute* attribute;
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TIndex>* index;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h
new file mode 100644
index 0000000..1859e03
--- /dev/null
+++ b/reference_model/src/ops/template_types.h
@@ -0,0 +1,277 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OP_TEMPLATE_TYPES_H
+#define OP_TEMPLATE_TYPES_H
+
+#include "tosa_generated.h"
+#include <Eigen/CXX11/Tensor>
+
+using namespace tosa;
+
+namespace TosaReference
+{
+// Shorter aliase templates for common Eigen::Tensor types
+template <typename T>
+using ETensor0 = Eigen::Tensor<T, 0>;
+template <typename T>
+using ETensor1 = Eigen::Tensor<T, 1>;
+template <typename T>
+using ETensor2 = Eigen::Tensor<T, 2>;
+template <typename T>
+using ETensor3 = Eigen::Tensor<T, 3>;
+template <typename T>
+using ETensor4 = Eigen::Tensor<T, 4>;
+template <typename T>
+using ETensor5 = Eigen::Tensor<T, 5>;
+template <typename T>
+using ETensor6 = Eigen::Tensor<T, 6>;
+
+// Forward declaration
+template <class T>
+class TensorTemplate;
+
+// Shortcut to hide the TensorTemplate class.
+// For example, declare Tensor1<float> to get a TensorTemplate
+// with an Eigen::Tensor<float, 1>
+template <typename T>
+using Tensor0 = TensorTemplate<ETensor0<T>>;
+template <typename T>
+using Tensor1 = TensorTemplate<ETensor1<T>>;
+template <typename T>
+using Tensor2 = TensorTemplate<ETensor2<T>>;
+template <typename T>
+using Tensor3 = TensorTemplate<ETensor3<T>>;
+template <typename T>
+using Tensor4 = TensorTemplate<ETensor4<T>>;
+template <typename T>
+using Tensor5 = TensorTemplate<ETensor5<T>>;
+template <typename T>
+using Tensor6 = TensorTemplate<ETensor6<T>>;
+
+template <DType type>
+struct GetEigenType;
+template <>
+struct GetEigenType<DType_FLOAT>
+{
+ using type = float;
+};
+template <>
+struct GetEigenType<DType_INT32>
+{
+ using type = int32_t;
+};
+template <>
+struct GetEigenType<DType_INT48>
+{
+ using type = int64_t;
+};
+template <>
+struct GetEigenType<DType_BOOL>
+{
+ using type = bool;
+};
+template <>
+struct GetEigenType<DType_AINT8>
+{
+ using type = int32_t;
+};
+template <>
+struct GetEigenType<DType_UINT8>
+{
+ using type = int32_t;
+};
+template <>
+struct GetEigenType<DType_INT4>
+{
+ using type = int32_t;
+};
+template <>
+struct GetEigenType<DType_INT8>
+{
+ using type = int32_t;
+};
+template <>
+struct GetEigenType<DType_INT16>
+{
+ using type = int32_t;
+};
+
+// Meta function to get number of bits
+template <DType T>
+struct GetNumBits
+{
+ static constexpr int32_t value = 0;
+};
+template <>
+struct GetNumBits<DType_BOOL>
+{
+ static constexpr int32_t value = 1;
+};
+template <>
+struct GetNumBits<DType_AINT8>
+{
+ static constexpr int32_t value = 8;
+};
+template <>
+struct GetNumBits<DType_UINT8>
+{
+ static constexpr int32_t value = 8;
+};
+template <>
+struct GetNumBits<DType_INT4>
+{
+ static constexpr int32_t value = 4;
+};
+template <>
+struct GetNumBits<DType_INT8>
+{
+ static constexpr int32_t value = 8;
+};
+template <>
+struct GetNumBits<DType_INT16>
+{
+ static constexpr int32_t value = 16;
+};
+template <>
+struct GetNumBits<DType_INT32>
+{
+ static constexpr int32_t value = 32;
+};
+template <>
+struct GetNumBits<DType_INT48>
+{
+ static constexpr int32_t value = 48;
+};
+
+// Meta function to get quantized min/max in compile time
+template <DType T>
+struct GetQMin
+{
+ static constexpr int64_t value = 0L;
+};
+template <>
+struct GetQMin<DType_AINT8>
+{
+ static constexpr int64_t value = -128L;
+};
+template <>
+struct GetQMin<DType_UINT8>
+{
+ static constexpr int64_t value = 0L;
+};
+template <>
+struct GetQMin<DType_INT4>
+{
+ static constexpr int64_t value = -8L;
+};
+template <>
+struct GetQMin<DType_INT8>
+{
+ static constexpr int64_t value = -128L;
+};
+template <>
+struct GetQMin<DType_INT16>
+{
+ static constexpr int64_t value = -32768L;
+};
+template <>
+struct GetQMin<DType_INT32>
+{
+ static constexpr int64_t value = -(1L << 31);
+};
+template <>
+struct GetQMin<DType_INT48>
+{
+ static constexpr int64_t value = -(1L << 47);
+};
+
+template <DType T>
+struct GetQMax
+{
+ static constexpr int64_t value = 0L;
+};
+template <>
+struct GetQMax<DType_AINT8>
+{
+ static constexpr int64_t value = 127L;
+};
+template <>
+struct GetQMax<DType_UINT8>
+{
+ static constexpr int64_t value = 255L;
+};
+template <>
+struct GetQMax<DType_INT4>
+{
+ static constexpr int64_t value = 7L;
+};
+template <>
+struct GetQMax<DType_INT8>
+{
+ static constexpr int64_t value = 127L;
+};
+template <>
+struct GetQMax<DType_INT16>
+{
+ static constexpr int64_t value = 32767L;
+};
+template <>
+struct GetQMax<DType_INT32>
+{
+ static constexpr int64_t value = (1L << 31) - 1;
+};
+template <>
+struct GetQMax<DType_INT48>
+{
+ static constexpr int64_t value = (1L << 47) - 1;
+};
+
+template <DType TIn1, DType TIn2>
+struct GetAccDType;
+template <>
+struct GetAccDType<DType_AINT8, DType_AINT8>
+{
+ static constexpr DType value = DType_INT32;
+};
+template <>
+struct GetAccDType<DType_AINT8, DType_INT4>
+{
+ static constexpr DType value = DType_INT32;
+};
+template <>
+struct GetAccDType<DType_AINT8, DType_INT8>
+{
+ static constexpr DType value = DType_INT32;
+};
+template <>
+struct GetAccDType<DType_INT16, DType_INT8>
+{
+ static constexpr DType value = DType_INT48;
+};
+template <>
+struct GetAccDType<DType_INT16, DType_INT16>
+{
+ static constexpr DType value = DType_INT48;
+};
+template <>
+struct GetAccDType<DType_FLOAT, DType_FLOAT>
+{
+ static constexpr DType value = DType_FLOAT;
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc
new file mode 100644
index 0000000..a735334
--- /dev/null
+++ b/reference_model/src/ops/tensor_ops.cc
@@ -0,0 +1,1229 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "tensor_ops.h"
+#include "quant_util.h"
+#include "template_types.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <int Rank, DType Dtype>
+OpArgMax<Rank, Dtype>::OpArgMax(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_ARGMAX, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(0, 6);
+
+ INIT_ATTRIBUTE(Axis);
+}
+
+template <int Rank, DType Dtype>
+OpArgMax<Rank, Dtype>::~OpArgMax()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <int Rank, DType Dtype>
+int OpArgMax<Rank, Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ return 0;
+}
+
+template <int Rank, DType Dtype>
+int OpArgMax<Rank, Dtype>::eval()
+{
+ Eigen::Tensor<DenseIndex, Rank - 1> index = this->input->getTensor().argmax(attribute->axis());
+
+ this->output->getTensor() = index.unaryExpr([](DenseIndex in) -> OutEigenType { return (OutEigenType)in; });
+
+ return GraphNode::eval();
+}
+
+template <DType Dtype>
+OpAvgPool2d<Dtype>::OpAvgPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_AVG_POOL2D, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(4);
+
+ INIT_ATTRIBUTE(Pool2d);
+ INIT_QINFO(Unary);
+}
+
+template <DType Dtype>
+OpAvgPool2d<Dtype>::~OpAvgPool2d()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <DType Dtype>
+int OpAvgPool2d<Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ if (inputs[0]->matchType(*outputs[0]))
+ {
+ printNodeValidationError("OpAvgPool2d: input and output tensor type mismatch");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ if (!in->hasFormat(Format_NHWC))
+ {
+ printNodeValidationError("OpAvgPool2d: unsupported tensor format");
+ return 1;
+ }
+
+ if (attribute->padding().size() != 4)
+ {
+ printNodeValidationError("OpAvgPool2d: illegal size for attribute padding");
+ return 1;
+ }
+
+ if (attribute->kernel().size() != 2)
+ {
+ printNodeValidationError("OpAvgPool2d: illegal size for attribute kernel");
+ return 1;
+ }
+
+ if (attribute->stride().size() != 2)
+ {
+ printNodeValidationError("OpAvgPool2d: illegal size for attribute stride");
+ return 1;
+ }
+
+ return 0;
+}
+
+template <DType Dtype>
+ETensor1<int32_t> OpAvgPool2d<Dtype>::calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride)
+{
+ ETensor1<int32_t> result(out_size);
+
+ int32_t total_pad = (out_size - 1) * stride + kernel_size - in_size;
+ total_pad = total_pad < 0 ? 0 : total_pad;
+
+ int32_t pad_left = total_pad >> 1;
+ int32_t pad_right = total_pad - pad_left;
+
+ result.setConstant(kernel_size);
+
+ // the index left to 'left_index' and index right to 'right_index' indicates
+ // the input window of this output covers a pad bit
+ int32_t left_index = pad_left / stride;
+ int32_t right_index = pad_right / stride;
+
+ // not handle ultra small activation yet
+ ASSERT_MSG_NODE((out_size - 1 - right_index) >= left_index, "AvgPool2d: Small activations not supported yet");
+
+ // minus the number of pad bit this index cover
+ while (left_index >= 0)
+ {
+ result(left_index) -= (pad_left - left_index * stride);
+ left_index--;
+ }
+
+ while (right_index >= 0)
+ {
+ result(out_size - 1 - right_index) -= (pad_right - right_index * stride);
+ right_index--;
+ }
+
+ return result;
+}
+
+// assuming input and output tensor have same scales like tflite reference
+// so no need to scale input and output
+template <DType Dtype>
+int OpAvgPool2d<Dtype>::eval()
+{
+ int in_batch = this->in->getShape()[0];
+ int in_height = this->in->getShape()[1];
+ int in_width = this->in->getShape()[2];
+ int in_channels = this->in->getShape()[3];
+
+ int out_batch = this->out->getShape()[0];
+ int out_height = this->out->getShape()[1];
+ int out_width = this->out->getShape()[2];
+ int out_channels = this->out->getShape()[3];
+
+ ASSERT_MSG_NODE(in_batch == out_batch, "OpAvgPool2d: tensor batch mismatch %d != %d", in_batch, out_batch);
+
+ int padding_top = this->attribute->padding()[0];
+ int padding_bottom = this->attribute->padding()[1];
+ int padding_left = this->attribute->padding()[2];
+ int padding_right = this->attribute->padding()[3];
+ int kernel_h = this->attribute->kernel()[0];
+ int kernel_w = this->attribute->kernel()[1];
+ int stride_h = this->attribute->stride()[0];
+ int stride_w = this->attribute->stride()[1];
+
+ DEBUG_INFO(OP,
+ "perform AvgPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], "
+ "stride=[%d,%d], padding=[%d,%d,%d,%d]",
+ in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_h,
+ kernel_w, stride_h, stride_w, padding_top, padding_bottom, padding_left, padding_right);
+
+ Eigen::array<Eigen::Index, 2> im2col_input_dims;
+ im2col_input_dims[0] = kernel_h * kernel_w;
+ im2col_input_dims[1] = out_batch * out_height * out_width * out_channels;
+
+ Eigen::array<Eigen::Index, 4> col2im_output_dims;
+ col2im_output_dims[0] = out_batch;
+ col2im_output_dims[1] = out_height;
+ col2im_output_dims[2] = out_width;
+ col2im_output_dims[3] = out_channels;
+
+ Eigen::array<std::pair<int32_t, int32_t>, 4> padding;
+ padding[0] = std::make_pair(0, 0);
+ padding[1] = std::make_pair(padding_top, padding_bottom);
+ padding[2] = std::make_pair(padding_left, padding_right);
+ padding[3] = std::make_pair(0, 0);
+
+ ETensor4<InEigenType> input_val = this->in->getTensor();
+ if (this->qinfo)
+ {
+ input_val = input_val - (InEigenType)this->qinfo->input_zp();
+ }
+
+ ETensor4<InEigenType> input_padded = input_val.pad(padding);
+
+ // assuming input and output have same scales
+ // so input and output scaling is not required
+ // TODO: check if this assumption TOSA made
+
+ // extract_image_patches() output [N, KH, KW, H * W, C]
+ // transpose to [KH, KW, N, H * W, C]
+ // reshape to [KH * KW, N * H * W * C]
+ ETensor2<InEigenType> input_extract_patches =
+ input_padded.extract_image_patches(kernel_h, kernel_w, stride_h, stride_w, 1, 1, Eigen::PADDING_VALID)
+ .shuffle(Eigen::array<Eigen::Index, 5>{ 1, 2, 0, 3, 4 })
+ .reshape(im2col_input_dims);
+
+ // 1D result with [N * H * W * C]
+ ETensor1<AccEigenType> out_1d(this->out->getElementCount());
+ out_1d.setZero();
+
+ // sum pool
+ for (size_t i = 0; i < this->out->getElementCount(); i++)
+ {
+ for (int32_t j = 0; j < kernel_h * kernel_w; j++)
+ {
+ out_1d(i) += (AccEigenType)input_extract_patches(j, i);
+ }
+ }
+
+ // reshape result to [N, H, W, C] and divide with div_map
+ ETensor4<AccEigenType> sum = out_1d.reshape(col2im_output_dims);
+
+ // calculate 1d height/width div_map (number of elements this pooling window covers)
+ // and outer product to get 2d div_map, then reshape/broadcast to [N, H, W, C]
+ ETensor1<int32_t> div_map_h = calculate_div_map_1d(in_height, out_height, kernel_h, stride_h);
+ ETensor1<int32_t> div_map_w = calculate_div_map_1d(in_width, out_width, kernel_w, stride_w);
+ Eigen::array<Eigen::IndexPair<Eigen::Index>, 1> contract_dims = { Eigen::IndexPair<Eigen::Index>(1, 0) };
+ Eigen::array<Eigen::Index, 4> bcast{ out_batch, 1, 1, out_channels };
+
+ ETensor4<int32_t> div_map =
+ div_map_h.reshape(Eigen::array<Eigen::Index, 2>{ out_height, 1 })
+ .contract(div_map_w.reshape(Eigen::array<Eigen::Index, 2>{ 1, out_width }), contract_dims)
+ .reshape(Eigen::array<Eigen::Index, 4>{ 1, out_height, out_width, 1 })
+ .broadcast(bcast);
+
+ if (Dtype != DType_FLOAT)
+ {
+ this->out->getTensor() = sum.binaryExpr(div_map, [](AccEigenType value, int32_t div) -> OutEigenType {
+ int32_t multiplier, shift;
+ TosaReference::QuantUtil<AccDtype>::reciprocal_scale(div, multiplier, shift);
+
+ return (OutEigenType)TosaReference::QuantUtil<AccDtype>::apply_scale(value, multiplier, shift, false);
+ });
+ this->out->getTensor() = this->out->getTensor() + (OutEigenType)(this->qinfo->output_zp());
+ this->out->getTensor() = this->out->getTensor().cwiseMax((OutEigenType)QMin);
+ this->out->getTensor() = this->out->getTensor().cwiseMin((OutEigenType)QMax);
+ }
+ else
+ {
+ this->out->getTensor() = (sum / div_map.template cast<AccEigenType>()).template cast<OutEigenType>();
+ }
+
+ return GraphNode::eval();
+}
+
+template <DType InDtype, DType WeightDtype>
+OpConv2d<InDtype, WeightDtype>::OpConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_CONV2D, id_)
+{
+ setRequiredOperands(3, 1);
+ setRequiredRank(4);
+
+ INIT_ATTRIBUTE(Conv2d);
+ INIT_QINFO(Conv);
+}
+
+template <DType InDtype, DType WeightDtype>
+OpConv2d<InDtype, WeightDtype>::~OpConv2d()
+{
+ if (attribute)
+ delete attribute;
+ if (qinfo)
+ delete qinfo;
+}
+
+template <DType InDtype, DType WeightDtype>
+int OpConv2d<InDtype, WeightDtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
+ if (inputs[2]->getRank() != 1)
+ {
+ printNodeValidationError("OpConv2d: bias tensor must be rank 1");
+ }
+
+ if (inputs[1]->getIsConst() == 0)
+ {
+ printNodeValidationError("OpConv2d: weight tensor is not const typed");
+ }
+
+ input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
+ bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
+ output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+
+ if (!input->hasFormat(Format_NHWC))
+ {
+ printNodeValidationError("OpConv2d: unsupported input tensor format");
+ return 1;
+ }
+
+ if (!weight->hasFormat(Format_OHWI))
+ {
+ printNodeValidationError("OpConv2d: unsupported weight tensor format");
+ return 1;
+ }
+
+ if (attribute->padding().size() != 4)
+ {
+ printNodeValidationError("OpConv2d: illegal size for attribute padding");
+ return 1;
+ }
+
+ if (attribute->stride().size() != 2)
+ {
+ printNodeValidationError("OpConv2d: illegal size for attribute stride");
+ return 1;
+ }
+
+ if (attribute->dilation().size() != 2)
+ {
+ printNodeValidationError("OpConv2d: illegal size for attribute dilation");
+ return 1;
+ }
+
+ return 0;
+}
+
+template <DType InDtype, DType WeightDtype>
+int OpConv2d<InDtype, WeightDtype>::eval()
+{
+ int in_batch = this->input->getShape()[0];
+ int in_height = this->input->getShape()[1];
+ int in_width = this->input->getShape()[2];
+ int in_channels = this->input->getShape()[3];
+
+ int f_out_channels = this->weight->getShape()[0];
+ int f_height = this->weight->getShape()[1];
+ int f_width = this->weight->getShape()[2];
+ int f_in_channels = this->weight->getShape()[3];
+
+ int b_out_channels = this->bias->getShape()[0];
+
+ int out_batch = this->output->getShape()[0];
+ int out_height = this->output->getShape()[1];
+ int out_width = this->output->getShape()[2];
+ int out_channels = this->output->getShape()[3];
+
+ ASSERT_MSG_NODE(in_batch == out_batch, "OpConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
+ ASSERT_MSG_NODE(f_in_channels == in_channels, "OpConv2d: tensor input channel mismatch %d != %d", f_in_channels,
+ in_channels);
+ ASSERT_MSG_NODE(f_out_channels == out_channels, "OpConv2d: tensor output channel mismatch %d != %d", f_out_channels,
+ out_channels);
+ ASSERT_MSG_NODE(b_out_channels == out_channels, "OpConv2d: tensor output channel mismatch %d != %d", b_out_channels,
+ out_channels);
+
+ int padding_top = this->attribute->padding()[0];
+ int padding_bottom = this->attribute->padding()[1];
+ int padding_left = this->attribute->padding()[2];
+ int padding_right = this->attribute->padding()[3];
+ int stride_h = this->attribute->stride()[0];
+ int stride_w = this->attribute->stride()[1];
+ int dilation_h = this->attribute->dilation()[0];
+ int dilation_w = this->attribute->dilation()[1];
+
+ DEBUG_INFO(OP,
+ "perform OpConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], "
+ "stride=[%d,%d], dilation=[%d,%d], padding=[%d,%d,%d,%d]",
+ in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_out_channels, out_batch,
+ out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, padding_top,
+ padding_bottom, padding_left, padding_right);
+
+ // GEMM-conv2d, left matrix is input, right matrix is weight
+ Eigen::array<Eigen::Index, 2> im2col_input_dims;
+ im2col_input_dims[0] = out_batch * out_height * out_width;
+ im2col_input_dims[1] = f_height * f_width * f_in_channels;
+
+ Eigen::array<Eigen::Index, 2> im2col_weight_dims;
+ im2col_weight_dims[0] = f_height * f_width * f_in_channels;
+ im2col_weight_dims[1] = f_out_channels;
+
+ Eigen::array<Eigen::Index, 2> bias_reshaped_dims;
+ bias_reshaped_dims[0] = 1;
+ bias_reshaped_dims[1] = b_out_channels;
+
+ Eigen::array<Eigen::Index, 4> weight_zp_bcast_dims;
+ weight_zp_bcast_dims[0] = f_height;
+ weight_zp_bcast_dims[1] = f_width;
+ weight_zp_bcast_dims[2] = f_in_channels;
+
+ Eigen::array<Eigen::Index, 2> bias_bcast_dims;
+ bias_bcast_dims[0] = out_batch * out_height * out_width;
+ bias_bcast_dims[1] = 1;
+
+ Eigen::array<Eigen::Index, 4> col2im_output_dims;
+ col2im_output_dims[0] = out_batch;
+ col2im_output_dims[1] = out_height;
+ col2im_output_dims[2] = out_width;
+ col2im_output_dims[3] = out_channels;
+
+ Eigen::array<Eigen::IndexPair<Eigen::Index>, 1> contract_dims = { Eigen::IndexPair<Eigen::Index>(1, 0) };
+
+ Eigen::array<std::pair<int32_t, int32_t>, 4> padding;
+ padding[0] = std::make_pair(0, 0);
+ padding[1] = std::make_pair(padding_top, padding_bottom);
+ padding[2] = std::make_pair(padding_left, padding_right);
+ padding[3] = std::make_pair(0, 0);
+
+ TIn input_val = this->input->getTensor();
+ TWeight weight_val = this->weight->getTensor();
+ if (this->qinfo)
+ {
+ input_val = input_val - (InEigenType)this->qinfo->input_zp();
+ weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp();
+ }
+
+ ETensor4<InEigenType> input_padded = input_val.pad(padding);
+
+ // extract_image_patches() output [N, KH, KW, H * W, C]
+ // need to transpose to [N, H * W, KH, KW, C]
+ ETensor5<InEigenType> input_extract_patches =
+ input_padded
+ .extract_image_patches(f_height, f_width, stride_h, stride_w, dilation_h, dilation_w, Eigen::PADDING_VALID)
+ .shuffle(Eigen::array<Eigen::Index, 5>{ 0, 3, 1, 2, 4 });
+
+ // reshape input to [N * H * W, KH * KW * C]
+ ETensor2<InEigenType> im2col_input = input_extract_patches.reshape(im2col_input_dims);
+
+ // transpose and reshape weight from [OC, H, W, IC] to [H * W * IC, OC]
+ ETensor2<WeightEigenType> im2col_weight =
+ weight_val.shuffle(Eigen::array<Eigen::Index, 4>({ 1, 2, 3, 0 })).reshape(im2col_weight_dims);
+
+ // don't need to apply bias_multiplier ( * bias_scale and >> bias_shift) since tflite already scale it
+ // and reshaped from [C] to [1, C], and broadcast to [N * H * W, C]
+ ETensor2<AccEigenType> bias_2d = this->bias->getTensor().reshape(bias_reshaped_dims).broadcast(bias_bcast_dims);
+
+ // output matrix is [N * H * W, C]
+ ETensor2<AccEigenType> contracted_result =
+ im2col_input.template cast<AccEigenType>().contract(im2col_weight.template cast<AccEigenType>(), contract_dims);
+
+ // adding bias
+ ETensor2<AccEigenType> biased_output = contracted_result + bias_2d.template cast<AccEigenType>();
+
+ // reshape back to [N, H, W, C]
+ this->output->getTensor() = biased_output.reshape(col2im_output_dims);
+
+ if (AccDtype == DType_INT48)
+ {
+ this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
+ this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
+ }
+
+ return GraphNode::eval();
+}
+
+template <DType InDtype, DType WeightDtype>
+OpDepthwiseConv2d<InDtype, WeightDtype>::OpDepthwiseConv2d(TosaAttributeBase* attribute_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
+ : GraphNode(Op_DEPTHWISE_CONV2D, id_)
+{
+ setRequiredOperands(3, 1);
+ setRequiredRank(4);
+
+ INIT_ATTRIBUTE(Conv2d);
+ INIT_QINFO(Conv);
+}
+
+template <DType InDtype, DType WeightDtype>
+OpDepthwiseConv2d<InDtype, WeightDtype>::~OpDepthwiseConv2d()
+{
+ if (attribute)
+ delete attribute;
+ if (qinfo)
+ delete qinfo;
+}
+
+template <DType InDtype, DType WeightDtype>
+int OpDepthwiseConv2d<InDtype, WeightDtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4
+ if (inputs[2]->getRank() != 1)
+ {
+ printNodeValidationError("OpDepthwiseConv2d: bias tensor must be rank 1");
+ }
+
+ if (inputs[1]->getIsConst() == 0)
+ {
+ printNodeValidationError("OpDepthwiseConv2d: weight tensor is not const typed");
+ }
+
+ input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
+ bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
+ output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+
+ if (!input->hasFormat(Format_NHWC))
+ {
+ printNodeValidationError("OpDepthwiseConv2d: unsupported input tensor format");
+ return 1;
+ }
+
+ if (!weight->hasFormat(Format_HWIM))
+ {
+ printNodeValidationError("OpDepthwiseConv2d: unsupported weight tensor format");
+ return 1;
+ }
+
+ if (attribute->padding().size() != 4)
+ {
+ printNodeValidationError("OpDepthwiseConv2d: illegal size for attribute padding");
+ return 1;
+ }
+
+ if (attribute->stride().size() != 2)
+ {
+ printNodeValidationError("OpDepthwiseConv2d: illegal size for attribute stride");
+ return 1;
+ }
+
+ if (attribute->dilation().size() != 2)
+ {
+ printNodeValidationError("OpDepthwiseConv2d: illegal size for attribute dilation");
+ return 1;
+ }
+
+ return 0;
+}
+
+template <DType InDtype, DType WeightDtype>
+int OpDepthwiseConv2d<InDtype, WeightDtype>::eval()
+{
+ int in_batch = this->input->getShape()[0];
+ int in_height = this->input->getShape()[1];
+ int in_width = this->input->getShape()[2];
+ int in_channels = this->input->getShape()[3];
+
+ int f_height = this->weight->getShape()[0];
+ int f_width = this->weight->getShape()[1];
+ int f_in_channels = this->weight->getShape()[2];
+ int f_multiplier = this->weight->getShape()[3];
+
+ int b_out_channels = this->bias->getShape()[0];
+
+ int out_batch = this->output->getShape()[0];
+ int out_height = this->output->getShape()[1];
+ int out_width = this->output->getShape()[2];
+ int out_channels = this->output->getShape()[3];
+
+ ASSERT_MSG_NODE(in_batch == out_batch, "OpDepthwiseConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
+ ASSERT_MSG_NODE(f_in_channels == in_channels, "OpDepthwiseConv2d: tensor input channel mismatch %d != %d",
+ f_in_channels, in_channels);
+ ASSERT_MSG_NODE(in_channels * f_multiplier == out_channels,
+ "OpDepthwiseConv2d: tensor output channel mismatch %d != %d", in_channels * f_multiplier,
+ out_channels);
+ ASSERT_MSG_NODE(b_out_channels == out_channels, "OpDepthwiseConv2d: tensor b_out_channels mismatch %d != %d",
+ b_out_channels, out_channels);
+
+ int padding_top = this->attribute->padding()[0];
+ int padding_bottom = this->attribute->padding()[1];
+ int padding_left = this->attribute->padding()[2];
+ int padding_right = this->attribute->padding()[3];
+ int stride_h = this->attribute->stride()[0];
+ int stride_w = this->attribute->stride()[1];
+ int dilation_h = this->attribute->dilation()[0];
+ int dilation_w = this->attribute->dilation()[1];
+
+ DEBUG_INFO(OP,
+ "perform OpDepthwiseConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
+ "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], padding=[%d,%d,%d,%d]",
+ in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_multiplier, out_batch,
+ out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, padding_top,
+ padding_bottom, padding_left, padding_right);
+
+ Eigen::array<std::pair<int32_t, int32_t>, 4> padding;
+ padding[0] = std::make_pair(0, 0);
+ padding[1] = std::make_pair(padding_top, padding_bottom);
+ padding[2] = std::make_pair(padding_left, padding_right);
+ padding[3] = std::make_pair(0, 0);
+
+ TIn input_val = this->input->getTensor();
+ TWeight weight_val = this->weight->getTensor();
+ if (this->qinfo)
+ {
+ input_val = input_val - (InEigenType)this->qinfo->input_zp();
+ weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp();
+ }
+
+ ETensor4<InEigenType> input_padded = input_val.pad(padding);
+
+ // GEMM doesn't fit well with DepthwiseConv2d
+ // 1. use extract_image_patches() to handle stride/dilation/padding
+ // 2. perform direct convolution
+
+ // 1. extract_image_patches() output [N, KH, KW, OH * OW, IC]
+ ETensor5<InEigenType> input_extract_patches = input_padded.extract_image_patches(
+ f_height, f_width, stride_h, stride_w, dilation_h, dilation_w, Eigen::PADDING_VALID);
+
+ Eigen::array<Eigen::Index, 4> reshape_dim;
+ reshape_dim.fill(1);
+ reshape_dim[3] = b_out_channels;
+
+ Eigen::array<Eigen::Index, 4> bcast;
+ bcast[0] = out_batch;
+ bcast[1] = out_height;
+ bcast[2] = out_width;
+ bcast[3] = 1;
+
+ // initialize with bias
+ this->output->getTensor() = this->bias->getTensor().reshape(reshape_dim).broadcast(bcast);
+
+ // 2. direct depthwise convolution
+ for (int ob = 0; ob < out_batch; ob++)
+ {
+ for (int oh = 0; oh < out_height; oh++)
+ {
+ for (int ow = 0; ow < out_width; ow++)
+ {
+ for (int ic = 0; ic < in_channels; ic++)
+ {
+ for (int cm = 0; cm < f_multiplier; cm++)
+ {
+ for (int fh = 0; fh < f_height; fh++)
+ {
+ for (int fw = 0; fw < f_width; fw++)
+ {
+ this->output->getTensor()(ob, oh, ow, ic * f_multiplier + cm) +=
+ ((AccEigenType)input_extract_patches(ob, fh, fw, ow * out_height + oh, ic) *
+ (AccEigenType)weight_val(fh, fw, ic, cm));
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ if (AccDtype == DType_INT48)
+ {
+ this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
+ this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
+ }
+
+ return GraphNode::eval();
+}
+
+template <DType InDtype, DType WeightDtype>
+OpFullyConnected<InDtype, WeightDtype>::OpFullyConnected(TosaAttributeBase* attribute_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
+ : GraphNode(Op_FULLY_CONNECTED, id_)
+{
+ setRequiredOperands(3, 1);
+ setRequiredRank(2);
+
+ INIT_QINFO(Conv);
+}
+
+template <DType InDtype, DType WeightDtype>
+OpFullyConnected<InDtype, WeightDtype>::~OpFullyConnected()
+{
+ if (qinfo)
+ delete qinfo;
+}
+
+template <DType InDtype, DType WeightDtype>
+int OpFullyConnected<InDtype, WeightDtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
+ bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
+
+ if (input->getShape()[1] != weight->getShape()[1])
+ {
+ printNodeValidationError("OpFullyConnected operator input.shape[1] should match weight.shape[1]");
+ return 1;
+ }
+
+ if (weight->getShape()[0] != bias->getShape()[0])
+ {
+ printNodeValidationError("OpFullyConnected operator bias.shape[0] should match weight.shape[0]");
+ return 1;
+ }
+
+ output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+
+ return 0;
+}
+
+template <DType InDtype, DType WeightDtype>
+int OpFullyConnected<InDtype, WeightDtype>::eval()
+{
+ typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
+ Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
+
+ Eigen::array<Eigen::Index, 2> weight_shuffle{ 1, 0 };
+
+ Eigen::array<Eigen::Index, 2> bias_reshape;
+ bias_reshape[0] = 1;
+ bias_reshape[1] = this->bias->getShape()[0];
+
+ Eigen::array<Eigen::Index, 2> bias_bcast;
+ bias_bcast[0] = this->input->getShape()[0];
+ bias_bcast[1] = 1;
+
+ TIn input_val = this->input->getTensor();
+ TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle);
+ if (this->qinfo)
+ {
+ input_val = input_val - (InEigenType)this->qinfo->input_zp();
+ weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp();
+ }
+
+ this->output->getTensor() =
+ input_val.template cast<AccEigenType>().contract(weight_val.template cast<AccEigenType>(), dims) +
+ this->bias->getTensor().reshape(bias_reshape).broadcast(bias_bcast);
+
+ if (AccDtype == DType_INT48)
+ {
+ this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
+ this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
+ }
+ return GraphNode::eval();
+}
+
+template <DType Dtype>
+OpMatMul<Dtype>::OpMatMul(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_MATMUL, id_)
+{
+ setRequiredOperands(2, 1);
+ setRequiredRank(2);
+
+ INIT_QINFO(MatMul);
+}
+
+template <DType Dtype>
+OpMatMul<Dtype>::~OpMatMul()
+{
+ if (qinfo)
+ delete qinfo;
+}
+
+template <DType Dtype>
+int OpMatMul<Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]);
+
+ if (a->getShape()[1] != b->getShape()[0])
+ {
+ printNodeValidationError("OpMatMul operator a.shape[1] should match b.shape[0]");
+ return 1;
+ }
+
+ c = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+
+ return 0;
+}
+
+template <DType Dtype>
+int OpMatMul<Dtype>::eval()
+{
+ typedef Eigen::Tensor<int, 1>::DimensionPair DimPair;
+ Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } };
+
+ TIn a_val = this->a->getTensor();
+ TIn b_val = this->b->getTensor();
+ if (this->qinfo)
+ {
+ a_val = a_val - (InEigenType)this->qinfo->a_zp();
+ b_val = b_val - (InEigenType)this->qinfo->b_zp();
+ }
+
+ this->c->getTensor() = a_val.template cast<AccEigenType>().contract(b_val.template cast<AccEigenType>(), dims);
+
+ if (AccDtype == DType_INT48)
+ {
+ this->c->getTensor() = this->c->getTensor().cwiseMax((AccEigenType)AccQMin);
+ this->c->getTensor() = this->c->getTensor().cwiseMin((AccEigenType)AccQMax);
+ }
+
+ return GraphNode::eval();
+}
+
+template <DType Dtype>
+OpMaxPool2d<Dtype>::OpMaxPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_MAX_POOL2D, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(4);
+
+ INIT_ATTRIBUTE(Pool2d);
+}
+
+template <DType Dtype>
+OpMaxPool2d<Dtype>::~OpMaxPool2d()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <DType Dtype>
+int OpMaxPool2d<Dtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ if (inputs[0]->matchType(*outputs[0]))
+ {
+ printNodeValidationError("OpMaxPool2d: input and output tensor type mismatch");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ if (!in->hasFormat(Format_NHWC))
+ {
+ printNodeValidationError("OpMaxPool2d: unsupported tensor format");
+ return 1;
+ }
+
+ if (attribute->padding().size() != 4)
+ {
+ printNodeValidationError("OpMaxPool2d: illegal size for attribute padding");
+ return 1;
+ }
+
+ if (attribute->kernel().size() != 2)
+ {
+ printNodeValidationError("OpMaxPool2d: illegal size for attribute kernel");
+ return 1;
+ }
+
+ if (attribute->stride().size() != 2)
+ {
+ printNodeValidationError("OpMaxPool2d: illegal size for attribute stride");
+ return 1;
+ }
+
+ return 0;
+}
+
+template <DType Dtype>
+int OpMaxPool2d<Dtype>::eval()
+{
+ int in_batch = this->in->getShape()[0];
+ int in_height = this->in->getShape()[1];
+ int in_width = this->in->getShape()[2];
+ int in_channels = this->in->getShape()[3];
+
+ int out_batch = this->out->getShape()[0];
+ int out_height = this->out->getShape()[1];
+ int out_width = this->out->getShape()[2];
+ int out_channels = this->out->getShape()[3];
+
+ ASSERT_MSG_NODE(in_batch == out_batch, "OpMaxPool2d: tensor batch mismatch %d != %d", in_batch, out_batch);
+
+ int padding_top = this->attribute->padding()[0];
+ int padding_bottom = this->attribute->padding()[1];
+ int padding_left = this->attribute->padding()[2];
+ int padding_right = this->attribute->padding()[3];
+ int kernel_h = this->attribute->kernel()[0];
+ int kernel_w = this->attribute->kernel()[1];
+ int stride_h = this->attribute->stride()[0];
+ int stride_w = this->attribute->stride()[1];
+
+ DEBUG_INFO(OP,
+ "perform MaxPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], "
+ "stride=[%d,%d], padding=[%d,%d,%d,%d]",
+ in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_h,
+ kernel_w, stride_h, stride_w, padding_top, padding_bottom, padding_left, padding_right);
+
+ Eigen::array<Eigen::Index, 2> im2col_input_dims;
+ im2col_input_dims[0] = kernel_h * kernel_w;
+ im2col_input_dims[1] = out_batch * out_height * out_width * out_channels;
+
+ Eigen::array<Eigen::Index, 4> col2im_output_dims;
+ col2im_output_dims[0] = out_batch;
+ col2im_output_dims[1] = out_height;
+ col2im_output_dims[2] = out_width;
+ col2im_output_dims[3] = out_channels;
+
+ Eigen::array<std::pair<int32_t, int32_t>, 4> padding;
+ padding[0] = std::make_pair(0, 0);
+ padding[1] = std::make_pair(padding_top, padding_bottom);
+ padding[2] = std::make_pair(padding_left, padding_right);
+ padding[3] = std::make_pair(0, 0);
+
+ ETensor4<InEigenType> input_padded = this->in->getTensor().pad(padding, std::numeric_limits<InEigenType>::lowest());
+
+ // extract_image_patches() output [N, KH, KW, H * W, C]
+ // transpose to [KH, KW, N, H * W, C]
+ // reshape to [KH * KW, N * H * W * C]
+ //
+ // Set the padding value to be the most negative value that can be
+ // represented by the datatype to ensure that any padding values will be equal
+ // to or smaller than the actual maximum in the KH x KW patch.
+ ETensor2<InEigenType> input_extract_patches =
+ input_padded
+ .extract_image_patches(kernel_h, kernel_w, stride_h, stride_w, 1, 1, Eigen::PADDING_VALID,
+ std::numeric_limits<InEigenType>::lowest())
+ .shuffle(Eigen::array<Eigen::Index, 5>{ 1, 2, 0, 3, 4 })
+ .reshape(im2col_input_dims);
+
+ // Get the maximum of the KHxHW patches along axis 0
+ Eigen::Tensor<DenseIndex, 1> tensor_argmax = input_extract_patches.argmax(0);
+
+ // 1D result with [N * H * W * C]
+ ETensor1<OutEigenType> out_1d(this->out->getElementCount());
+
+ // index input_patches with argmax array should give the result
+ for (size_t i = 0; i < this->out->getElementCount(); i++)
+ {
+ out_1d(i) = (OutEigenType)input_extract_patches(tensor_argmax(i), i);
+ }
+
+ // reshape result to [N, H, W, C]
+ this->out->getTensor() = out_1d.reshape(col2im_output_dims);
+
+ return GraphNode::eval();
+}
+
+template <DType InDtype, DType OutDtype>
+OpTransposeConv2d<InDtype, OutDtype>::OpTransposeConv2d(TosaAttributeBase* attribute_,
+ TosaQuantInfoBase* qinfo_,
+ uint64_t id_)
+ : GraphNode(Op_TRANSPOSE_CONV2D, id_)
+{
+ setRequiredOperands(3, 1);
+ setRequiredRank(4);
+
+ INIT_ATTRIBUTE(TransposeConv2d);
+ INIT_QINFO(Conv);
+}
+
+template <DType InDtype, DType OutDtype>
+OpTransposeConv2d<InDtype, OutDtype>::~OpTransposeConv2d()
+{
+ if (attribute)
+ delete attribute;
+ if (qinfo)
+ delete qinfo;
+}
+
+template <DType InDtype, DType OutDtype>
+int OpTransposeConv2d<InDtype, OutDtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ if (inputs[1]->getIsConst() == 0)
+ {
+ printNodeValidationError("OpTransposeConv2d: weight tensor is not const typed");
+ }
+
+ input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]);
+ bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]);
+ output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]);
+
+ if (!input->hasFormat(Format_NHWC))
+ {
+ printNodeValidationError("OpTransposeConv2d: unsupported input tensor format");
+ return 1;
+ }
+
+ if (!weight->hasFormat(Format_OHWI))
+ {
+ printNodeValidationError("OpTransposeConv2d: unsupported weight tensor format");
+ return 1;
+ }
+
+ if (attribute->outpad().size() != 2)
+ {
+ printNodeValidationError("OpTransposeConv2d: illegal size for attribute outpad");
+ return 1;
+ }
+
+ if (attribute->stride().size() != 2)
+ {
+ printNodeValidationError("OpTransposeConv2d: illegal size for attribute stride");
+ return 1;
+ }
+
+ if (attribute->dilation().size() != 2)
+ {
+ printNodeValidationError("OpTransposeConv2d: illegal size for attribute dilation");
+ return 1;
+ }
+
+ if (attribute->output_shape().size() != 4)
+ {
+ printNodeValidationError("OpTransposeConv2d: illegal size for attribute output_shape");
+ return 1;
+ }
+
+ for (int d = 0; d < 4; d++)
+ {
+ if (attribute->output_shape()[d] != this->output->getShape()[d])
+ {
+ printNodeValidationError("OpTransposeConv2d: illegal size for attribute output_shape");
+ return 1;
+ }
+ }
+
+ return 0;
+}
+
+template <DType InDtype, DType OutDtype>
+int OpTransposeConv2d<InDtype, OutDtype>::eval()
+{
+ int in_batch = this->input->getShape()[0];
+ int in_height = this->input->getShape()[1];
+ int in_width = this->input->getShape()[2];
+ int in_channels = this->input->getShape()[3];
+
+ int f_out_channels = this->weight->getShape()[0];
+ int f_height = this->weight->getShape()[1];
+ int f_width = this->weight->getShape()[2];
+ int f_in_channels = this->weight->getShape()[3];
+
+ int b_out_channels = this->bias->getShape()[0];
+
+ int out_batch = this->output->getShape()[0];
+ int out_height = this->output->getShape()[1];
+ int out_width = this->output->getShape()[2];
+ int out_channels = this->output->getShape()[3];
+
+ int padding_top = this->attribute->outpad()[0];
+ int padding_left = this->attribute->outpad()[1];
+ int stride_h = this->attribute->stride()[0];
+ int stride_w = this->attribute->stride()[1];
+ int dilation_h = this->attribute->dilation()[0];
+ int dilation_w = this->attribute->dilation()[1];
+
+ ASSERT_MSG_NODE(in_batch == out_batch, "OpTransposeConv2d: tensor batch mismatch %d != %d", in_batch, out_batch);
+ ASSERT_MSG_NODE(f_in_channels == in_channels, "OpTransposeConv2d: tensor input channel mismatch %d != %d",
+ f_in_channels, in_channels);
+ ASSERT_MSG_NODE(f_out_channels == out_channels, "OpTransposeConv2d: tensor output channel mismatch %d != %d",
+ f_out_channels, out_channels);
+ ASSERT_MSG_NODE(b_out_channels == out_channels, "OpDepthwiseConv2d: tensor b_out_channels mismatch %d != %d",
+ b_out_channels, out_channels);
+
+ DEBUG_INFO(OP,
+ "perform OpTransposeConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], "
+ "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], padding=[%d,%d]",
+ in_batch, in_height, in_width, in_channels, f_height, f_width, f_out_channels, f_in_channels, out_batch,
+ out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, padding_top,
+ padding_left);
+
+ TIn input_val = this->input->getTensor();
+ TWeight weight_val = this->weight->getTensor();
+ if (this->qinfo)
+ {
+ input_val = input_val - (InEigenType)this->qinfo->input_zp();
+ weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp();
+ }
+
+ Eigen::array<Eigen::Index, 4> reshape_dim;
+ reshape_dim.fill(1);
+ reshape_dim[3] = b_out_channels;
+
+ Eigen::array<Eigen::Index, 4> bcast;
+ bcast[0] = out_batch;
+ bcast[1] = out_height;
+ bcast[2] = out_width;
+ bcast[3] = 1;
+
+ // initialize with bias
+ this->output->getTensor() = this->bias->getTensor().reshape(reshape_dim).broadcast(bcast);
+
+ int out_x_origin, out_y_origin;
+ int out_x, out_y;
+
+ // reference implementation from: tensorflow/tensorflow/lite/kernels/internal/reference/reference_ops.h
+ for (int ob = 0; ob < out_batch; ob++)
+ {
+ for (int ih = 0; ih < in_height; ih++)
+ {
+ for (int iw = 0; iw < in_width; iw++)
+ {
+ out_x_origin = iw * stride_w - padding_left;
+ out_y_origin = ih * stride_h - padding_top;
+ for (int ic = 0; ic < in_channels; ic++)
+ {
+ for (int fh = 0; fh < f_height; fh++)
+ {
+ for (int fw = 0; fw < f_width; fw++)
+ {
+ out_x = out_x_origin + fw * dilation_w;
+ out_y = out_y_origin + fh * dilation_h;
+ for (int oc = 0; oc < out_channels; oc++)
+ {
+ if ((out_x >= 0 && out_x < out_width) && (out_y >= 0 && out_y < out_height))
+ {
+ this->output->getTensor()(ob, out_y, out_x, oc) +=
+ ((AccEigenType)input_val(ob, ih, iw, ic) *
+ (AccEigenType)weight_val(oc, fh, fw, ic));
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+ }
+
+ if (AccDtype == DType_INT48)
+ {
+ this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin);
+ this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax);
+ }
+
+ return GraphNode::eval();
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FLOAT);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, AINT8);
+DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16);
+
+DEF_INSTANTIATE_ONE_TYPE(OpAvgPool2d, FLOAT)
+DEF_INSTANTIATE_ONE_TYPE(OpAvgPool2d, AINT8)
+DEF_INSTANTIATE_ONE_TYPE(OpAvgPool2d, INT16)
+
+DEF_INSTANTIATE_TWO_TYPE(OpConv2d, FLOAT, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE(OpConv2d, AINT8, INT4);
+DEF_INSTANTIATE_TWO_TYPE(OpConv2d, AINT8, INT8);
+DEF_INSTANTIATE_TWO_TYPE(OpConv2d, AINT8, AINT8);
+DEF_INSTANTIATE_TWO_TYPE(OpConv2d, INT16, INT8);
+
+DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, FLOAT, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, AINT8, INT4);
+DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, AINT8, INT8);
+DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, AINT8, AINT8);
+DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, INT16, INT8);
+
+DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, FLOAT, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, AINT8, INT4);
+DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, AINT8, INT8);
+DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, AINT8, AINT8);
+DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, INT16, INT8);
+
+DEF_INSTANTIATE_ONE_TYPE(OpMatMul, AINT8);
+DEF_INSTANTIATE_ONE_TYPE(OpMatMul, INT16);
+DEF_INSTANTIATE_ONE_TYPE(OpMatMul, FLOAT);
+
+DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FLOAT);
+DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, AINT8);
+DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16);
+
+DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, FLOAT, FLOAT);
+DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, AINT8, INT4);
+DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, AINT8, INT8);
+DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, AINT8, AINT8);
+DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, INT16, INT8);
diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h
new file mode 100644
index 0000000..26ce84b
--- /dev/null
+++ b/reference_model/src/ops/tensor_ops.h
@@ -0,0 +1,253 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_TENSOR_OPS_H
+#define OPS_TENSOR_OPS_H
+
+#include "graph_node.h"
+#include "quant_util.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+template <int Rank, DType Dtype>
+class OpArgMax : public GraphNode
+{
+public:
+ OpArgMax(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpArgMax();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<DType_INT32>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank - 1>;
+
+protected:
+ TosaAxisAttribute* attribute;
+ TosaReference::TensorTemplate<TIn>* input;
+ TosaReference::TensorTemplate<TOut>* output;
+};
+
+template <DType Dtype>
+class OpAvgPool2d : public GraphNode
+{
+public:
+ OpAvgPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpAvgPool2d();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ static constexpr DType AccDtype = GetAccDType<Dtype, Dtype>::value;
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using AccEigenType = typename GetEigenType<AccDtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, 4>;
+ using TOut = Eigen::Tensor<OutEigenType, 4>;
+
+ static constexpr int64_t QMin = GetQMin<Dtype>::value;
+ static constexpr int64_t QMax = GetQMax<Dtype>::value;
+
+protected:
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TOut>* out;
+ tosa::TosaPool2dAttribute* attribute;
+ tosa::TosaUnaryQuantInfo* qinfo;
+
+protected:
+ // return a 1D [N] tensor that describes a how many valid elements covered in the input space
+ ETensor1<int32_t> calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride);
+};
+
+template <DType InDtype, DType WeightDtype>
+class OpConv2d : public GraphNode
+{
+public:
+ OpConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpConv2d();
+
+ virtual int checkTensorAttributes() final;
+ virtual int eval() final;
+
+ static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
+
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using WeightEigenType = typename GetEigenType<WeightDtype>::type;
+ using AccEigenType = typename GetEigenType<AccDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, 4>;
+ using TWeight = Eigen::Tensor<WeightEigenType, 4>;
+ using TBias = Eigen::Tensor<AccEigenType, 1>;
+ using TAcc = Eigen::Tensor<AccEigenType, 4>;
+
+ static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
+ static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
+
+protected:
+ TosaReference::TensorTemplate<TIn>* input;
+ TosaReference::TensorTemplate<TWeight>* weight;
+ TosaReference::TensorTemplate<TBias>* bias;
+ TosaReference::TensorTemplate<TAcc>* output;
+ tosa::TosaConv2dAttribute* attribute;
+ tosa::TosaConvQuantInfo* qinfo;
+};
+
+template <DType InDtype, DType WeightDtype>
+class OpDepthwiseConv2d : public GraphNode
+{
+public:
+ OpDepthwiseConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpDepthwiseConv2d();
+
+ virtual int checkTensorAttributes() final;
+ virtual int eval() final;
+
+ static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
+
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using WeightEigenType = typename GetEigenType<WeightDtype>::type;
+ using AccEigenType = typename GetEigenType<AccDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, 4>;
+ using TWeight = Eigen::Tensor<WeightEigenType, 4>;
+ using TBias = Eigen::Tensor<AccEigenType, 1>;
+ using TAcc = Eigen::Tensor<AccEigenType, 4>;
+
+ static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
+ static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
+
+protected:
+ TosaReference::TensorTemplate<TIn>* input;
+ TosaReference::TensorTemplate<TWeight>* weight;
+ TosaReference::TensorTemplate<TBias>* bias;
+ TosaReference::TensorTemplate<TAcc>* output;
+ tosa::TosaConv2dAttribute* attribute;
+ tosa::TosaConvQuantInfo* qinfo;
+};
+
+template <DType InDtype, DType WeightDtype>
+class OpFullyConnected : public GraphNode
+{
+public:
+ OpFullyConnected(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpFullyConnected();
+
+ virtual int checkTensorAttributes() final;
+ virtual int eval() final;
+
+ static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using WeightEigenType = typename GetEigenType<WeightDtype>::type;
+ using AccEigenType = typename GetEigenType<AccDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, 2>;
+ using TWeight = Eigen::Tensor<WeightEigenType, 2>;
+ using TBias = Eigen::Tensor<AccEigenType, 1>;
+ using TAcc = Eigen::Tensor<AccEigenType, 2>;
+
+ static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
+ static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
+
+protected:
+ TosaReference::TensorTemplate<TIn>* input;
+ TosaReference::TensorTemplate<TWeight>* weight;
+ TosaReference::TensorTemplate<TBias>* bias;
+ TosaReference::TensorTemplate<TAcc>* output;
+ tosa::TosaConvQuantInfo* qinfo;
+};
+
+template <DType Dtype>
+class OpMatMul : public GraphNode
+{
+public:
+ OpMatMul(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpMatMul();
+
+ virtual int checkTensorAttributes() final;
+ virtual int eval() final;
+
+ static constexpr DType AccDtype = GetAccDType<Dtype, Dtype>::value;
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using AccEigenType = typename GetEigenType<AccDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, 2>;
+ using TAcc = Eigen::Tensor<AccEigenType, 2>;
+ static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
+ static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
+
+protected:
+ TosaReference::TensorTemplate<TIn>* a;
+ TosaReference::TensorTemplate<TIn>* b;
+ TosaReference::TensorTemplate<TAcc>* c;
+ tosa::TosaMatMulQuantInfo* qinfo;
+};
+
+template <DType Dtype>
+class OpMaxPool2d : public GraphNode
+{
+public:
+ OpMaxPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpMaxPool2d();
+
+ virtual int checkTensorAttributes();
+ virtual int eval();
+
+ using InEigenType = typename GetEigenType<Dtype>::type;
+ using OutEigenType = typename GetEigenType<Dtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, 4>;
+ using TOut = Eigen::Tensor<OutEigenType, 4>;
+
+protected:
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TOut>* out;
+ tosa::TosaPool2dAttribute* attribute;
+};
+
+template <DType InDtype, DType WeightDtype>
+class OpTransposeConv2d : public GraphNode
+{
+public:
+ OpTransposeConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpTransposeConv2d();
+
+ virtual int checkTensorAttributes() final;
+ virtual int eval() final;
+
+ static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value;
+
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using WeightEigenType = typename GetEigenType<WeightDtype>::type;
+ using AccEigenType = typename GetEigenType<AccDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, 4>;
+ using TWeight = Eigen::Tensor<WeightEigenType, 4>;
+ using TBias = Eigen::Tensor<AccEigenType, 1>;
+ using TAcc = Eigen::Tensor<AccEigenType, 4>;
+
+ static constexpr int64_t AccQMin = GetQMin<AccDtype>::value;
+ static constexpr int64_t AccQMax = GetQMax<AccDtype>::value;
+
+protected:
+ TosaReference::TensorTemplate<TIn>* input;
+ TosaReference::TensorTemplate<TWeight>* weight;
+ TosaReference::TensorTemplate<TBias>* bias;
+ TosaReference::TensorTemplate<TAcc>* output;
+ TosaTransposeConv2dAttribute* attribute;
+ TosaConvQuantInfo* qinfo;
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc
new file mode 100644
index 0000000..61a19f4
--- /dev/null
+++ b/reference_model/src/ops/type_conversion.cc
@@ -0,0 +1,299 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "type_conversion.h"
+#include "quant_util.h"
+#include "template_types.h"
+#include <cmath>
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+template <int Rank, DType InDtype, DType OutDtype>
+OpRescale<Rank, InDtype, OutDtype>::OpRescale(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_RESCALE, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(0, 6);
+ INIT_ATTRIBUTE(Rescale);
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+OpRescale<Rank, InDtype, OutDtype>::~OpRescale()
+{
+ if (attribute)
+ delete attribute;
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same rank and size
+ if (inputs[0]->matchRankSize(*outputs[0]))
+ {
+ printNodeValidationError("OpRescale: input and output rank/size must match");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ ASSERT_MEM(in && out);
+
+ return 0;
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+int OpRescale<Rank, InDtype, OutDtype>::eval()
+{
+ int32_t input_zp = attribute->input_zp();
+ int32_t output_zp = attribute->output_zp();
+ std::vector<int32_t> multiplier = attribute->multiplier();
+ std::vector<int32_t> shift = attribute->shift();
+ //bool scale32 = attribute->scale32();
+ bool double_round = attribute->double_round();
+ bool per_channel = attribute->per_channel();
+
+ if (TosaReference::TypeChecker::is_symmetric(InDtype))
+ {
+ if (input_zp != 0)
+ {
+ FATAL_ERROR_NODE("input tensor is symmetric type %s but zeropoint is %d instead of 0",
+ EnumNamesDType()[InDtype], input_zp);
+ }
+ }
+
+ if (TosaReference::TypeChecker::is_symmetric(OutDtype))
+ {
+ if (output_zp != 0)
+ {
+ FATAL_ERROR_NODE("output tensor is symmetric type %s but zeropoint is %d instead of 0",
+ EnumNamesDType()[OutDtype], output_zp);
+ }
+ }
+
+ // reshape [d0, d1, ..., dn] into [d0 * d1 ..., dn]
+ Eigen::array<Eigen::Index, 2> shape_2d;
+ shape_2d[0] = 1;
+ if (Rank > 0)
+ {
+ for (int i = 0; i < Rank - 1; i++)
+ {
+ shape_2d[0] *= this->in->getShape()[i];
+ }
+ shape_2d[1] = this->in->getShape()[Rank - 1];
+ }
+ else
+ {
+ shape_2d[1] = 1;
+ }
+ ETensor2<InEigenType> input_reshaped = this->in->getTensor().reshape(shape_2d);
+
+ ETensor2<OutEigenType> output_2d(shape_2d);
+
+ // TODO: pass scale32 in when 16-bit mode implemented
+ if (per_channel)
+ {
+ ETensor2<InEigenType> curr_channel_slice_prescaled;
+ ETensor2<OutEigenType> curr_channel_slice_postscaled;
+ int32_t channel_multiplier, channel_shift;
+ Eigen::array<Eigen::Index, 2> begin, size;
+ size = Eigen::array<Eigen::Index, 2>({ shape_2d[0], 1 });
+ for (int32_t i = 0; i < shape_2d[1]; i++)
+ {
+ begin = Eigen::array<Eigen::Index, 2>({ 0, i });
+ curr_channel_slice_prescaled = input_reshaped.slice(begin, size);
+ channel_multiplier = multiplier[i];
+ channel_shift = shift[i];
+ curr_channel_slice_postscaled =
+ curr_channel_slice_prescaled.unaryExpr([input_zp, output_zp, channel_multiplier, channel_shift,
+ double_round](InEigenType in_val) -> OutEigenType {
+ InEigenType input_zp_shifted = in_val - (InEigenType)input_zp;
+ int32_t scaled = TosaReference::QuantUtil<InDtype>::apply_scale(
+ input_zp_shifted, channel_multiplier, channel_shift, double_round);
+ OutEigenType out_val = (OutEigenType)(scaled + output_zp);
+ out_val = std::max<OutEigenType>(out_val, QMin);
+ out_val = std::min<OutEigenType>(out_val, QMax);
+ return out_val;
+ });
+
+ for (int32_t j = 0; j < shape_2d[0]; j++)
+ {
+ output_2d(j, i) = curr_channel_slice_postscaled(j, 0);
+ }
+ }
+ }
+ else
+ {
+ int32_t tensor_multiplier = multiplier[0];
+ int32_t tensor_shift = shift[0];
+ output_2d = input_reshaped.unaryExpr(
+ [input_zp, output_zp, tensor_multiplier, tensor_shift, double_round](InEigenType in_val) -> OutEigenType {
+ InEigenType input_zp_shifted = in_val - (InEigenType)input_zp;
+ int32_t scaled = TosaReference::QuantUtil<InDtype>::apply_scale(input_zp_shifted, tensor_multiplier,
+ tensor_shift, double_round);
+ OutEigenType out_val = (OutEigenType)(scaled + output_zp);
+ out_val = std::max<OutEigenType>(out_val, QMin);
+ out_val = std::min<OutEigenType>(out_val, QMax);
+ return out_val;
+ });
+ }
+
+ // reshape [d0 * d1 ..., dn] back to [d0, d1, ..., dn]
+ Eigen::array<Eigen::Index, Rank> output_shape;
+ for (int i = 0; i < Rank; i++)
+ {
+ output_shape[i] = this->out->getShape()[i];
+ }
+ this->out->getTensor() = output_2d.reshape(output_shape);
+
+ return GraphNode::eval();
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+OpCast<Rank, InDtype, OutDtype>::OpCast(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_)
+ : GraphNode(Op_CAST, id_)
+{
+ setRequiredOperands(1, 1);
+ setRequiredRank(0, 6);
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+OpCast<Rank, InDtype, OutDtype>::~OpCast()
+{}
+
+template <int Rank, DType InDtype, DType OutDtype>
+int OpCast<Rank, InDtype, OutDtype>::checkTensorAttributes()
+{
+ if (validateRequiredOperands())
+ return 1;
+
+ if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0]))
+ {
+ return 1;
+ }
+
+ // output and input must be the same rank and size
+ if (inputs[0]->matchRankSize(*outputs[0]))
+ {
+ printNodeValidationError("OpCast: input and output rank/size must match");
+ return 1;
+ }
+
+ in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]);
+ out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]);
+
+ ASSERT_MEM(in && out);
+
+ return 0;
+}
+
+template <int Rank, DType InDtype, DType OutDtype>
+int OpCast<Rank, InDtype, OutDtype>::eval()
+{
+ this->out->getTensor() = this->in->getTensor().unaryExpr(cast_helper.get_fcn());
+
+ return GraphNode::eval();
+}
+
+template <DType InDtype, DType OutDtype>
+CastHelper<InDtype, OutDtype>::CastHelper()
+{
+ fcn = [](InEigenType in) -> OutEigenType {
+ OutEigenType out = (OutEigenType)in; // implicit sign_extend() if sizeof(out_t) >= sizeof(in_t)
+ int64_t mask = (1L << OutBits) - 1;
+ out = out & mask;
+ return out;
+ };
+}
+
+template <DType InDtype>
+CastHelper<InDtype, DType_BOOL>::CastHelper()
+{
+ fcn = [](InEigenType in) -> bool { return (in != 0) ? true : false; };
+}
+
+template <DType OutDtype>
+CastHelper<DType_BOOL, OutDtype>::CastHelper()
+{
+ fcn = [](bool in) -> OutEigenType {
+ OutEigenType out = in ? (OutEigenType)1 : (OutEigenType)0;
+ return out;
+ };
+}
+
+template <DType InDtype>
+CastHelper<InDtype, DType_FLOAT>::CastHelper()
+{
+ fcn = [](InEigenType in) -> float {
+ float out = (OutEigenType)in; // default cast to float is round_to_nearest_float()
+ return out;
+ };
+}
+
+template <DType OutDtype>
+CastHelper<DType_FLOAT, OutDtype>::CastHelper()
+{
+ fcn = [](float in) -> OutEigenType {
+ OutEigenType out = std::round(in);
+ out = std::max<OutEigenType>(out, OutMin);
+ out = std::min<OutEigenType>(out, OutMax);
+ return out;
+ };
+}
+
+// template explicit instantiation
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FLOAT);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT32);
+
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT16);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT32);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, AINT8);
+DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, UINT8);
diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h
new file mode 100644
index 0000000..6ec4d6d
--- /dev/null
+++ b/reference_model/src/ops/type_conversion.h
@@ -0,0 +1,162 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef OPS_TYPE_CONVERSION_H
+#define OPS_TYPE_CONVERSION_H
+
+#include "graph_node.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+template <int Rank, DType InDtype, DType OutDtype>
+class OpRescale : public GraphNode
+{
+public:
+ OpRescale(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpRescale();
+
+ virtual int checkTensorAttributes() final;
+ virtual int eval() final;
+
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+ static constexpr int32_t QMin = GetQMin<OutDtype>::value;
+ static constexpr int32_t QMax = GetQMax<OutDtype>::value;
+
+protected:
+ TosaRescaleAttribute* attribute;
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+
+template <DType InDtype, DType OutDtype>
+class CastHelper
+{
+public:
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ static constexpr int32_t OutBits = GetNumBits<OutDtype>::value;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <DType InDtype>
+class CastHelper<InDtype, DType_BOOL>
+{
+public:
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<DType_BOOL>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <DType OutDtype>
+class CastHelper<DType_BOOL, OutDtype>
+{
+public:
+ using InEigenType = typename GetEigenType<DType_BOOL>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ static constexpr int32_t OutMin = GetQMin<OutDtype>::value;
+ static constexpr int32_t OutMax = GetQMax<OutDtype>::value;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <DType InDtype>
+class CastHelper<InDtype, DType_FLOAT>
+{
+public:
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<DType_FLOAT>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <DType OutDtype>
+class CastHelper<DType_FLOAT, OutDtype>
+{
+public:
+ using InEigenType = typename GetEigenType<DType_FLOAT>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using FcnType = std::function<OutEigenType(InEigenType)>;
+ static constexpr int32_t OutMin = GetQMin<OutDtype>::value;
+ static constexpr int32_t OutMax = GetQMax<OutDtype>::value;
+ CastHelper();
+ const FcnType& get_fcn() const
+ {
+ return fcn;
+ }
+
+private:
+ FcnType fcn;
+};
+
+template <int Rank, DType InDtype, DType OutDtype>
+class OpCast : public GraphNode
+{
+public:
+ OpCast(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_);
+ virtual ~OpCast();
+
+ virtual int checkTensorAttributes() final;
+ virtual int eval() final;
+
+ using InEigenType = typename GetEigenType<InDtype>::type;
+ using OutEigenType = typename GetEigenType<OutDtype>::type;
+ using TIn = Eigen::Tensor<InEigenType, Rank>;
+ using TOut = Eigen::Tensor<OutEigenType, Rank>;
+
+protected:
+ CastHelper<InDtype, OutDtype> cast_helper;
+ TosaReference::TensorTemplate<TIn>* in;
+ TosaReference::TensorTemplate<TOut>* out;
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/quant_util.h b/reference_model/src/quant_util.h
new file mode 100644
index 0000000..3638b3b
--- /dev/null
+++ b/reference_model/src/quant_util.h
@@ -0,0 +1,103 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef TOSA_REFERENCE_QUANT_UTIL_H
+#define TOSA_REFERENCE_QUANT_UTIL_H
+
+#include "arith_util.h"
+#include "func_debug.h"
+#include "ops/template_types.h"
+#include "tosa_generated.h"
+
+using namespace tosa;
+
+namespace TosaReference
+{
+
+template <DType AccDType>
+class QuantUtil
+{
+public:
+ using T = typename GetEigenType<AccDType>::type;
+
+ static void reciprocal_scale(int32_t value,
+ // Output
+ int32_t& multiplier,
+ int32_t& shift)
+ {
+ ASSERT_MSG(value > 0, "AvgPool2d reciprocal_scale() error: # of elements should be > 1 but is %d", value);
+ uint32_t value_u32 = (uint32_t)value;
+ int32_t k = 32 - LEADING_ZEROS_32(value_u32 - 1); // (1<<k)/2 < value <= (1<<k)
+ int64_t numerator = ((1L << 30) + 1) << k;
+ multiplier = numerator / value; // (1<<30) <= multiplier < (1<<31)
+ shift = 30 + k;
+ }
+
+ static int32_t apply_scale(T value, int32_t multiplier, int32_t shift, bool enabled_adjusted_rounding = true)
+ {
+ if (AccDType == DType_FLOAT)
+ {
+ return value;
+ }
+ ASSERT_MSG(multiplier >= 0, "apply_scale() error: multiplier should >= 0 but is %d", multiplier);
+ int64_t round = (shift > 0) ? (1L << (shift - 1)) : 0;
+ if (enabled_adjusted_rounding)
+ {
+ if (AccDType != DType_INT48)
+ {
+ if (shift > 31 && value >= 0)
+ round += (1L << 30);
+ if (shift > 31 && value < 0)
+ round -= (1L << 30);
+ }
+ else
+ { // input data could be int16, which leads to 48 bits accumulator
+ ASSERT_MSG(multiplier < (1 << 15), "apply_scale() error: multiplier should <= %d in 48 bit mode",
+ (1 << 15));
+ }
+ }
+ int64_t result = (int64_t)value * multiplier + round;
+ result = result >> shift;
+ ASSERT_MSG(result >= -(1L << 31) && result < (1L << 31),
+ "apply_scale() error: scaled result exceed int32 numeric range");
+ return static_cast<int32_t>(result);
+ }
+};
+
+class TypeChecker
+{
+public:
+ static bool is_integer(DType dtype)
+ {
+ if (dtype == DType_INT4 || dtype == DType_INT8 || dtype == DType_AINT8 || dtype == DType_UINT8 ||
+ dtype == DType_INT16 || dtype == DType_INT32 || dtype == DType_INT48)
+ {
+ return true;
+ }
+ return false;
+ }
+ static bool is_symmetric(DType dtype)
+ {
+ if (dtype == DType_INT4 || dtype == DType_INT8 || dtype == DType_INT16 || dtype == DType_INT32 ||
+ dtype == DType_INT48)
+ {
+ return true;
+ }
+ return false;
+ }
+};
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc
new file mode 100644
index 0000000..789bcae
--- /dev/null
+++ b/reference_model/src/subgraph_traverser.cc
@@ -0,0 +1,649 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "subgraph_traverser.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+SubgraphTraverser::SubgraphTraverser(TosaSerializationBasicBlock* _block, TosaSerializationHandler* _tsh)
+{
+ block = _block;
+ tsh = _tsh;
+
+ tensors.clear();
+ nodes.clear();
+ nextNodeList.clear();
+}
+
+SubgraphTraverser::~SubgraphTraverser()
+{
+ nextNodeList.clear();
+
+ for (GraphNode* n : nodes)
+ {
+ delete n;
+ }
+ nodes.clear();
+
+ for (TosaReference::Tensor* t : tensors)
+ {
+ if (t->is_allocated())
+ {
+ t->deallocate();
+ }
+ delete t;
+ }
+ tensors.clear();
+}
+
+int SubgraphTraverser::getNumInputTensors() const
+{
+ return inputTensors.size();
+}
+
+TosaReference::Tensor* SubgraphTraverser::getInputTensor(const unsigned int idx) const
+{
+ return inputTensors[idx];
+}
+
+TosaReference::Tensor* SubgraphTraverser::getInputTensorByName(const std::string name) const
+{
+ for (auto t : inputTensors)
+ {
+ if (t->getName() == name)
+ {
+ return t;
+ }
+ }
+
+ return nullptr;
+}
+
+int SubgraphTraverser::getNumOutputTensors() const
+{
+ return outputTensors.size();
+}
+
+TosaReference::Tensor* SubgraphTraverser::getOutputTensor(const unsigned int idx) const
+{
+ return outputTensors[idx];
+}
+
+TosaReference::Tensor* SubgraphTraverser::getOutputTensorByName(const std::string name) const
+{
+ for (auto t : outputTensors)
+ {
+ if (t->getName() == name)
+ {
+ return t;
+ }
+ }
+
+ return nullptr;
+}
+
+int SubgraphTraverser::initializeGraph()
+{
+ char tensor_fullname[1000];
+ int idx = 0;
+ for (auto op : block->GetOperators())
+ {
+ // translated TosaSerializationOperator to GraphNode
+ DType in_dtype = DType_UNKNOWN, out_dtype = DType_UNKNOWN, weight_dtype = DType_UNKNOWN;
+ uint32_t in_rank = 0, out_rank = 0, weight_rank = 0;
+ for (auto name : op->GetInputTensorNames())
+ {
+
+ TosaSerializationTensor* ts = block->GetTensorByName(name);
+ ASSERT_MSG(ts, "SubgraphTraverser: fail to get tensor %s from TosaSerializationHandler", name.c_str());
+
+ if (ts->HasUsage(Usage_WEIGHT))
+ {
+ weight_dtype = ts->GetDtype();
+ weight_rank = ts->GetShape().size();
+ }
+ else if (ts->HasUsage(Usage_INDEX))
+ {
+ // do nothing, but this will prevent tensor's dtype/rank being wrongly used as template argument when initializing this op
+ }
+ else if (ts->HasUsage(Usage_ACTIVATION))
+ {
+ if (ts->GetShape().size() >= in_rank)
+ {
+ in_dtype = ts->GetDtype();
+ in_rank = ts->GetShape().size();
+ }
+ }
+ }
+
+ for (auto name : op->GetOutputTensorNames())
+ {
+
+ TosaSerializationTensor* ts = block->GetTensorByName(name);
+ ASSERT_MSG(ts, "SubgraphTraverser: fail to get tensor %s from TosaSerializationHandler", name.c_str());
+
+ out_dtype = ts->GetDtype();
+ out_rank = ts->GetShape().size();
+ }
+
+ DEBUG_INFO(GT, "Creating operator id_%03u, %8s, %lu input tensors, %lu output tensors", idx,
+ EnumNamesOp()[op->GetOp()], op->GetInputTensorNames().size(), op->GetOutputTensorNames().size());
+
+ GraphNode* cn = OpFactory::newOp(tsh, op->GetOp(), op->GetAttribute(), op->GetQInfo(), idx, in_dtype, in_rank,
+ out_dtype, out_rank, weight_dtype, weight_rank);
+ if (!cn)
+ {
+ if (weight_dtype == DType_UNKNOWN && weight_rank == 0)
+ {
+ fprintf(g_func_debug.func_debug_file,
+ "OpFactory could not allocate op %8s input=(%s rank %d) -> (%s rank %d)",
+ EnumNamesOp()[op->GetOp()], EnumNamesDType()[in_dtype], in_rank, EnumNamesDType()[out_dtype],
+ out_rank);
+ }
+ else
+ {
+ fprintf(g_func_debug.func_debug_file,
+ "OpFactory could not allocate op %8s input=(%s rank %d), weight=(%s rank %d) -> (%s rank %d)",
+ EnumNamesOp()[op->GetOp()], EnumNamesDType()[in_dtype], in_rank, EnumNamesDType()[weight_dtype],
+ weight_rank, EnumNamesDType()[out_dtype], out_rank);
+ }
+
+ for (auto ts : op->GetInputTensors())
+ {
+ fprintf(g_func_debug.func_debug_file, "Input: %s\n", ts->GetName().c_str());
+ }
+
+ for (auto ts : op->GetOutputTensors())
+ {
+ fprintf(g_func_debug.func_debug_file, "Output: %s\n", ts->GetName().c_str());
+ }
+ FATAL_ERROR("Unsupported operation type or rank.");
+ }
+
+ for (auto name : op->GetInputTensorNames())
+ {
+ cn->addInputName(name);
+ }
+
+ for (auto name : op->GetOutputTensorNames())
+ {
+ cn->addOutputName(name);
+ }
+
+ addNode(cn);
+
+ // if node doesn't have any inputs (i.e. CONST)
+ // it should be ready for evaluation
+ if (op->GetInputTensorNames().empty() && !cn->getOnNextNodeList())
+ {
+ addToNextNodeList(cn);
+ }
+
+ idx++;
+ }
+
+ for (auto ts : block->GetTensors())
+ {
+
+ bool is_const = false;
+ if (ts->HasUsage(Usage_WEIGHT))
+ {
+ is_const = true;
+ }
+
+ DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str());
+ TosaReference::Tensor* ct =
+ TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetUsage(), ts->GetFormat(), ts->GetShape(),
+ is_const, ts->GetShape().size());
+
+ if (ts->GetNpyFilePtr())
+ {
+ if (ct->allocate())
+ {
+ FATAL_ERROR("Fail to allocate Eigen tensor %s", ct->getName().c_str());
+ }
+
+ bzero(tensor_fullname, sizeof(tensor_fullname));
+ snprintf(tensor_fullname, sizeof(tensor_fullname), "%s/%s", g_func_config.subgraph_dir,
+ ts->GetNpyFilePtr()->c_str());
+ if (ct->readFromNpyFile(tensor_fullname))
+ {
+ FATAL_ERROR("Cannot read input data into graph tensor %s from block %s", ct->getName().c_str(),
+ block->GetName().c_str());
+ }
+ }
+
+ // update this->tensors
+ addTensor(ct);
+ }
+
+ DEBUG_INFO(GT, "Enumerating block %s graph inputs", block->GetName().c_str());
+ for (auto& input_name : block->GetInputs())
+ {
+ TosaReference::Tensor* ct = findTensorByName(input_name);
+ DEBUG_INFO(GT, "input tensor name=%s", input_name.c_str());
+ if (ct)
+ {
+ ct->setIsSubgraphInput();
+ inputTensors.push_back(ct);
+ }
+ else
+ {
+ FATAL_ERROR("loadGraphJson: Fail to find input tensor by name %s", input_name.c_str());
+ }
+ }
+
+ DEBUG_INFO(GT, "Enumerating block %s graph outputs", block->GetName().c_str());
+ for (auto& output_name : block->GetOutputs())
+ {
+ TosaReference::Tensor* ct = findTensorByName(output_name);
+ DEBUG_INFO(GT, "output tensor name=%s\n", output_name.c_str());
+ if (ct)
+ {
+ ct->setIsSubgraphOutput();
+ outputTensors.push_back(ct);
+ }
+ else
+ {
+ FATAL_ERROR("loadGraphJson: Fail to find output tensor by name %s", output_name.c_str());
+ }
+ }
+
+ if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
+ {
+ dumpNextNodeList(g_func_debug.func_debug_file);
+ }
+
+ return 0;
+}
+
+int SubgraphTraverser::isFullyEvaluated() const
+{
+ return nextNodeList.empty();
+}
+
+GraphNode* SubgraphTraverser::getNextNode()
+{
+ GraphNode* nextNode = nextNodeList.front();
+ ASSERT_MSG(nextNode, "SubgraphTraverser::getNextNode(): called with empty next node list");
+ ASSERT_MSG(nextNode->getOnNextNodeList(),
+ "SubgraphTraverser::getNextNode(): internal state error: node is not listed as being on next node list");
+
+ nextNodeList.pop_front();
+
+ nextNode->clearOnNextNodeList();
+ return nextNode;
+}
+
+int SubgraphTraverser::addToNextNodeList(GraphNode* nextNode)
+{
+ ASSERT_MSG(nextNode, "SubgraphTraverser::addToNextNodeList(): called with no node");
+ ASSERT_MSG(!nextNode->getOnNextNodeList(),
+ "SubgraphTraverser::addToNextNodeList(): internal state error: node is already on next node list");
+
+ nextNode->setOnNextNodeList();
+ nextNodeList.push_back(nextNode);
+
+ return 0;
+}
+
+int SubgraphTraverser::evaluateNextNode()
+{
+ if (isFullyEvaluated())
+ return 0;
+
+ GraphNode* currNode = getNextNode();
+
+ DEBUG_INFO(GT, "Evaluating node_%03lu, %8s, output tensor=%s", currNode->getID(), EnumNamesOp()[currNode->getOp()],
+ currNode->getOutputNames()[0].c_str());
+
+ // Sanity check for never-ending loops
+ if (currNode->getEvalCount() >= MAX_EVAL_COUNT && (currNode->getEvalCount() % MAX_EVAL_COUNT) == 0)
+ {
+ WARNING("Node %lu has been evaluated %d times. Loop suspected.", currNode->getID(), currNode->getEvalCount());
+ }
+
+ for (auto ct : currNode->getOutputs())
+ {
+ if (!ct->is_allocated())
+ if (ct->allocate())
+ {
+ FATAL_ERROR("Fail to allocate Eigen tensor %s", ct->getName().c_str());
+ }
+ }
+
+ if (currNode->eval())
+ {
+ FATAL_ERROR("Error evaluating node: %lu\n", currNode->getID());
+ }
+
+ // free input tensor if all of its consumers have all of their outputs ready and it's not block's output
+ for (auto ct : currNode->getInputs())
+ {
+ bool in_use = false;
+ for (auto cn : ct->getConsumers())
+ {
+ if (!cn->hasAllOutputsReady())
+ {
+ in_use = true;
+ }
+ }
+ for (auto name : block->GetOutputs())
+ {
+ if (name == ct->getName())
+ {
+ in_use = true;
+ }
+ }
+ if (!in_use)
+ {
+ ct->deallocate();
+ }
+ }
+
+ // Search the output tensors of this node to see if
+ // there are now new ready nodes available from completing this node
+ for (TosaReference::Tensor* tensor : currNode->getOutputs())
+ {
+ for (GraphNode* node : tensor->getConsumers())
+ {
+ if (!node->getOnNextNodeList() && node->hasAllInputsReady())
+ {
+ addToNextNodeList(node);
+ }
+ }
+ }
+
+ if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
+ {
+ dumpNextNodeList(g_func_debug.func_debug_file);
+ }
+
+ if (g_func_config.dump_intermediates)
+ {
+ currNode->dumpNode(g_func_debug.func_debug_file);
+ for (auto outs : currNode->getOutputs())
+ {
+ outs->dumpTensorParams(g_func_debug.func_debug_file);
+ outs->dumpTensor(g_func_debug.func_debug_file);
+ fprintf(g_func_debug.func_debug_file, "\n");
+ }
+ }
+
+ return 0;
+}
+
+int SubgraphTraverser::dumpNextNodeList(FILE* out) const
+{
+
+ // Dump next node list
+ fprintf(out, "Next node list\n");
+
+ if (nextNodeList.empty())
+ {
+ fprintf(out, "<empty>\n");
+ }
+
+ for (auto gn : nextNodeList)
+ {
+ gn->dumpNode(out);
+ }
+
+ fprintf(out, "Done.\n");
+ return 0;
+}
+
+int SubgraphTraverser::clearAllNodeMarkings()
+{
+ for (GraphNode* currNode : nodes)
+ {
+ currNode->clearNodeMarked();
+ }
+
+ return false;
+}
+
+int SubgraphTraverser::addTensor(TosaReference::Tensor* ct)
+{
+ // Enforce no duplicate tensors/tensor names
+ // O(N), but the number of tensors is small
+ for (TosaReference::Tensor* currTensor : tensors)
+ {
+ if (ct == currTensor || currTensor->getName() == ct->getName())
+ {
+ FATAL_ERROR("Error: Duplicate tensor or tensor name being added to graph: %s\n", ct->getName().c_str());
+ return 1;
+ }
+ }
+
+ tensors.push_back(ct);
+
+ if (ct->getIsSubgraphInput())
+ {
+ inputTensors.push_back(ct);
+ }
+
+ if (ct->getIsSubgraphOutput())
+ {
+ outputTensors.push_back(ct);
+ }
+
+ return 0;
+}
+int SubgraphTraverser::addNode(GraphNode* newNode)
+{
+ // Enforce no duplicate nodes
+ for (GraphNode* currNode : nodes)
+ {
+ if (currNode == newNode)
+ {
+ FATAL_ERROR("Error: duplicate node being added to graph");
+ return 1;
+ }
+ }
+
+ nodes.push_back(newNode);
+
+ return 0;
+}
+
+TosaReference::Tensor* SubgraphTraverser::findTensorByName(const std::string& name) const
+{
+ for (TosaReference::Tensor* currTensor : tensors)
+ {
+ if (currTensor->getName() == name)
+ {
+ return currTensor;
+ }
+ }
+
+ WARNING("Unable to find tensor with name: %s\n", name.c_str());
+
+ return nullptr;
+}
+
+int SubgraphTraverser::linkTensorsAndNodes()
+{
+ // Nodes have a list of input/output tensor names
+ // For each node, read this list, link up the tensors with their inputs/outputs
+ for (GraphNode* currNode : nodes)
+ {
+
+ // Link inputs/consuming nodes
+ for (std::string& name : currNode->getInputNames())
+ {
+ TosaReference::Tensor* t = findTensorByName(name);
+ if (!t)
+ {
+ FATAL_ERROR("linkTensorsAndNodes: Cannot find tensor %s in node %lu\n", name.c_str(),
+ currNode->getID());
+ return 1;
+ }
+
+ if (currNode->addInputTensor(t))
+ {
+ FATAL_ERROR("linkTensorsAndNodes: cannot link tensor %s to node %lu\n", name.c_str(),
+ currNode->getID());
+ return 1;
+ }
+
+ if (t->addConsumer(currNode))
+ {
+ FATAL_ERROR("linkTensorsAndNodes: cannot link consumer node %lu to tensor %s\n", currNode->getID(),
+ name.c_str());
+ return 1;
+ }
+ }
+
+ // Link outputs/producing nodes
+ for (std::string& name : currNode->getOutputNames())
+ {
+ TosaReference::Tensor* t = findTensorByName(name);
+ if (!t)
+ {
+ FATAL_ERROR("linkTensorsAndNodes: Cannot find tensor %s in node %lu\n", name.c_str(),
+ currNode->getID());
+ return 1;
+ }
+
+ if (currNode->addOutputTensor(t))
+ {
+ FATAL_ERROR("linkTensorsAndNodes: cannot link tensor %s to node %lu\n", name.c_str(),
+ currNode->getID());
+ return 1;
+ }
+
+ if (t->setProducer(currNode))
+ {
+ FATAL_ERROR("linkTensorsAndNodes: cannot link producer node %lu to tensor tensor %s\n",
+ currNode->getID(), name.c_str());
+ return 1;
+ }
+ }
+ }
+
+ return 0;
+}
+
+int SubgraphTraverser::validateGraph()
+{
+ // Need to make sure that:
+ // - each tensor is actually used
+ // - input and output tesnsors truly are just input and just output
+ // Graph building already determined that each node has found its input/output tensors
+
+ for (TosaReference::Tensor* currTensor : tensors)
+ {
+
+ if (!currTensor->getProducer() && currTensor->getConsumers().empty())
+ {
+ WARNING("Graph inconsistency: TosaReference::Tensor %s has no producers or consumers\n",
+ currTensor->getName().c_str());
+ return 1;
+ }
+
+ if (currTensor->getIsSubgraphInput())
+ {
+ if (currTensor->getProducer() && currTensor->getProducer()->getOp() != Op_PLACEHOLDER)
+ {
+ WARNING("Graph inconsistency: TosaReference::Tensor %s is a subgraph input and has a producer\n",
+ currTensor->getName().c_str());
+ return 1;
+ }
+ }
+
+ // comment this check out as this is possible when graph have multiple output
+ // for example:
+ // %0 = add(%arg0, %arg1)
+ // %1 = mul(%arg0, %0)
+ // yields(%0, %1)
+ //if (currTensor->getIsSubgraphOutput()) {
+ // if (!currTensor->getConsumers().empty()) {
+ // WARNING ("Graph inconsistency: TosaReference::Tensor %s is a subgraph output and has a consumer\n",
+ // currTensor->getName().c_str());
+ // return 1;
+ // }
+ //}
+
+ if (g_func_config.tosa_profile == 0)
+ {
+ DType dtype = currTensor->getDtype();
+
+ // Float-point disallowed
+ if (dtype == DType_FLOAT)
+ {
+ WARNING("TOSA Base Inference profile selected: All floating point disabled, but %s tensor %s found\n",
+ EnumNamesDType()[dtype], currTensor->getName().c_str());
+ return 1;
+ }
+ }
+ else if (g_func_config.tosa_profile == 1 || g_func_config.tosa_profile == 2)
+ {
+ // Do nothing. All FP types allowed
+ // Currently no implementation difference between Main Inference and Main Training modes
+ }
+ else
+ {
+ FATAL_ERROR("TOSA profile not recognized: %d", g_func_config.tosa_profile);
+ }
+ }
+
+ for (GraphNode* currNode : nodes)
+ {
+ if (currNode->checkTensorAttributes())
+ {
+ WARNING("TosaReference::Tensor attribute check failed");
+ return 1;
+ }
+ }
+
+ if (outputTensors.size() <= 0)
+ {
+ DEBUG_MED(GT, "Graph output tensor empty");
+ return 0;
+ }
+
+ return 0;
+}
+
+int SubgraphTraverser::dumpGraph(FILE* out) const
+{
+ int i = 0;
+
+ fprintf(out, "Full graph dump:\n");
+ for (GraphNode* currNode : nodes)
+ {
+ fprintf(out, "Node [%d]: ", i++);
+ currNode->dumpNode(out);
+ }
+
+ return 0;
+}
+
+int SubgraphTraverser::evaluateAll()
+{
+ // evaluation loop
+ while (!isFullyEvaluated())
+ {
+ if (evaluateNextNode())
+ {
+ return 1;
+ }
+ }
+
+ return 0;
+}
diff --git a/reference_model/src/subgraph_traverser.h b/reference_model/src/subgraph_traverser.h
new file mode 100644
index 0000000..3f4eecf
--- /dev/null
+++ b/reference_model/src/subgraph_traverser.h
@@ -0,0 +1,90 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef SUBGRAPH_TRAVERSER_H
+#define SUBGRAPH_TRAVERSER_H
+
+#include "model_common.h"
+
+#include "graph_node.h"
+#include "ops/op_factory.h"
+#include "tosa_serialization_handler.h"
+
+namespace TosaReference
+{
+
+class SubgraphTraverser
+{
+public:
+ SubgraphTraverser(TosaSerializationBasicBlock* block, TosaSerializationHandler* tsh);
+ ~SubgraphTraverser();
+
+ int initializeGraph();
+ int isFullyEvaluated() const;
+ int evaluateNextNode();
+ int evaluateAll();
+
+ int linkTensorsAndNodes();
+ int validateGraph();
+
+ int dumpGraph(FILE* out) const;
+ int dumpNextNodeList(FILE* out) const;
+ int clearAllNodeMarkings();
+
+ int getNumInputTensors() const;
+ Tensor* getInputTensor(const unsigned int idx) const;
+ Tensor* getInputTensorByName(const std::string name) const;
+ int getNumOutputTensors() const;
+ Tensor* getOutputTensor(const unsigned int idx) const;
+ Tensor* getOutputTensorByName(const std::string name) const;
+ int addToNextNodeList(GraphNode*);
+
+private:
+ int addTensor(Tensor* ct);
+ int addNode(GraphNode* cn);
+
+ Tensor* findTensorByName(const std::string& name) const;
+
+ GraphNode* getNextNode();
+
+ // pointer to serialization library and corresponding basic block
+ TosaSerializationBasicBlock* block;
+ TosaSerializationHandler* tsh;
+
+ // The definitive list of all tensors
+ std::vector<Tensor*> tensors;
+
+ // The subset of tensors that are also input tensors
+ std::vector<Tensor*> inputTensors;
+
+ // The subset of tensors that are also output tensors
+ std::vector<Tensor*> outputTensors;
+
+ // The definitive list of all nodes in the graph
+ std::vector<GraphNode*> nodes;
+
+ // The subset of node that have all of their input tensors ready, but
+ // have not yet been evaluated to produce their output tensors.
+ // With control flow, a node may appear on this list more than once during its
+ // lifetime, although the list itself should only contain unique nodes.
+ std::list<GraphNode*> nextNodeList;
+
+ // Maximum number of times to evalute a node before
+ // warning.
+ const int MAX_EVAL_COUNT = 10000;
+};
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc
new file mode 100644
index 0000000..179484e
--- /dev/null
+++ b/reference_model/src/tensor.cc
@@ -0,0 +1,3008 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "tensor.h"
+#include "arith_util.h"
+
+using namespace TosaReference;
+using namespace Eigen;
+using namespace tosa;
+
+TosaReference::Tensor::Tensor(std::string tensorName_,
+ DType tensorDtype_,
+ const std::vector<Usage>& tensorUsage_,
+ const std::vector<Format>& tensorFormat_,
+ std::vector<int> shape_,
+ int isConst_)
+{
+ tensorName = std::string(tensorName_);
+ tensorDtype = tensorDtype_;
+ tensorUsage = std::vector<Usage>(tensorUsage_);
+ tensorFormat = std::vector<Format>(tensorFormat_);
+ shape = std::vector<int>(shape_);
+ isConst = isConst_;
+ producer = nullptr;
+ isValid = false;
+ consumers.clear();
+ isSubgraphInput = false;
+ isSubgraphOutput = false;
+}
+
+TosaReference::Tensor::~Tensor()
+{}
+
+int TosaReference::Tensor::setIsSubgraphInput()
+{
+ isSubgraphInput = true;
+ return 0;
+}
+
+int TosaReference::Tensor::setIsSubgraphOutput()
+{
+ isSubgraphOutput = true;
+ return 0;
+}
+
+int TosaReference::Tensor::setProducer(GraphNode* node)
+{
+ ASSERT_MSG(node, "Tensor::setProducer: no node passed in");
+ ASSERT_MSG(!producer, "Tensor::setProducer: producer node already set, tensor %s", tensorName.c_str());
+ producer = node;
+
+ return 0;
+}
+
+int TosaReference::Tensor::addConsumer(GraphNode* node)
+{
+ ASSERT_MSG(node, "Tensor::addConsumer: no node passed in");
+ consumers.push_back(node);
+
+ return 0;
+}
+
+int TosaReference::Tensor::dumpTensorParams(FILE* out) const
+{
+ fprintf(out, "Name: %s DType=%s Usage=%s isValid=%d Rank=%d Shape=%s\n", tensorName.c_str(),
+ EnumNamesDType()[getDtype()], getUsageAsString().c_str(), getIsValid(), getRank(),
+ getShapeAsString().c_str());
+
+ return 0;
+}
+
+int TosaReference::Tensor::dumpTensorParams(std::ostream& out) const
+{
+ out << "Name: " << getName() << " DType=" << EnumNamesDType()[getDtype()] << " Usage=" << getUsageAsString()
+ << " isValid=" << getIsValid() << " Rank=" << getRank() << " Shape=" << getShapeAsString() << "\n";
+
+ return 0;
+}
+
+int TosaReference::Tensor::readFromNpyFile(const char* filename)
+{
+ uint32_t elements = getElementCount();
+ float* fdatabuf = nullptr;
+ int32_t* i32databuf = nullptr;
+ int64_t* i64databuf = nullptr;
+ bool* bdatabuf = nullptr;
+ NumpyUtilities::NPError nperror;
+
+ switch (getDtype())
+ {
+ case DType_FLOAT:
+ fdatabuf = (float*)calloc(sizeof(float), elements);
+ ASSERT_MEM(fdatabuf);
+
+ nperror = NumpyUtilities::readFromNpyFile(filename, elements, fdatabuf);
+ break;
+ case DType_INT32:
+ case DType_AINT8:
+ case DType_UINT8:
+ case DType_INT4:
+ case DType_INT8:
+ case DType_INT16:
+ i32databuf = (int32_t*)calloc(sizeof(int32_t), elements);
+ ASSERT_MEM(i32databuf);
+
+ nperror = NumpyUtilities::readFromNpyFile(filename, elements, i32databuf);
+ break;
+ case DType_INT48:
+ i64databuf = (int64_t*)calloc(sizeof(int64_t), elements);
+ ASSERT_MEM(i64databuf);
+
+ nperror = NumpyUtilities::readFromNpyFile(filename, elements, i64databuf);
+ break;
+ case DType_BOOL:
+ bdatabuf = (bool*)calloc(sizeof(bool), elements);
+ ASSERT_MEM(bdatabuf);
+
+ nperror = NumpyUtilities::readFromNpyFile(filename, elements, bdatabuf);
+ break;
+ default:
+ FATAL_ERROR("unsupported tensor type=%s", EnumNamesDType()[getDtype()]);
+ }
+
+ switch (nperror)
+ {
+ case NumpyUtilities::NO_ERROR:
+ break;
+ case NumpyUtilities::FILE_NOT_FOUND:
+ FATAL_ERROR("readFromNpyFile: Cannot open file %s", filename);
+ case NumpyUtilities::FILE_IO_ERROR:
+ FATAL_ERROR("readFromNpyFile: IO error reading file: %s", filename);
+ case NumpyUtilities::FILE_TYPE_MISMATCH:
+ FATAL_ERROR("readFromNpyFile: Tensor type %s and Numpy file type mismatch for tensor %s filename %s",
+ EnumNamesDType()[getDtype()], getName().c_str(), filename);
+ case NumpyUtilities::HEADER_PARSE_ERROR:
+ FATAL_ERROR("Numpy header parsing error for file: %s", filename);
+ case NumpyUtilities::BUFFER_SIZE_MISMATCH:
+ FATAL_ERROR("Buffer size does not match numpy file size for tensor %s filename %s", getName().c_str(),
+ filename);
+ default:
+ FATAL_ERROR("Unknown error parsing Numpy file: %s", filename);
+ }
+
+ switch (getDtype())
+ {
+ case DType_FLOAT:
+ if (setTensorValueFloat(elements, fdatabuf))
+ {
+ free(fdatabuf);
+ return 1;
+ }
+ break;
+ case DType_INT32:
+ case DType_AINT8:
+ case DType_UINT8:
+ case DType_INT4:
+ case DType_INT8:
+ case DType_INT16:
+ if (setTensorValueInt32(elements, i32databuf))
+ {
+ free(i32databuf);
+ return 1;
+ }
+ break;
+ case DType_INT48:
+ if (setTensorValueInt64(elements, i64databuf))
+ {
+ free(i64databuf);
+ return 1;
+ }
+ break;
+ case DType_BOOL:
+ if (setTensorValueBool(elements, bdatabuf))
+ {
+ free(i32databuf);
+ return 1;
+ }
+ break;
+ default:
+ FATAL_ERROR("unsupported tensor type=%s", EnumNamesDType()[getDtype()]);
+ }
+
+ setIsValid();
+
+ if (fdatabuf)
+ free(fdatabuf);
+ if (i32databuf)
+ free(i32databuf);
+ if (i64databuf)
+ free(i64databuf);
+ if (bdatabuf)
+ free(bdatabuf);
+
+ return 0;
+}
+
+int TosaReference::Tensor::writeToNpyFile(const char* filename) const
+{
+ float* fdatabuf = nullptr;
+ int32_t* i32databuf = nullptr;
+ int64_t* i64databuf = nullptr;
+ bool* bdatabuf = nullptr;
+ NumpyUtilities::NPError nperror;
+ int elements = getElementCount();
+
+ switch (getDtype())
+ {
+ case DType_FLOAT:
+ fdatabuf = (float*)calloc(sizeof(float), elements);
+ ASSERT_MEM(fdatabuf);
+
+ if (getTensorValueFloat(elements, fdatabuf))
+ {
+ free(fdatabuf);
+ return 1;
+ }
+
+ nperror = NumpyUtilities::writeToNpyFile(filename, shape, fdatabuf);
+
+ free(fdatabuf);
+ break;
+ case DType_INT32:
+ case DType_AINT8:
+ case DType_UINT8:
+ case DType_INT4:
+ case DType_INT8:
+ case DType_INT16:
+ i32databuf = (int32_t*)calloc(sizeof(int32_t), elements);
+ ASSERT_MEM(i32databuf);
+
+ if (getTensorValueInt32(elements, i32databuf))
+ {
+ free(i32databuf);
+ return 1;
+ }
+
+ nperror = NumpyUtilities::writeToNpyFile(filename, shape, i32databuf);
+
+ free(i32databuf);
+ break;
+ case DType_INT48:
+ i64databuf = (int64_t*)calloc(sizeof(int64_t), elements);
+ ASSERT_MEM(i64databuf);
+
+ if (getTensorValueInt64(elements, i64databuf))
+ {
+ free(i64databuf);
+ return 1;
+ }
+
+ nperror = NumpyUtilities::writeToNpyFile(filename, shape, i64databuf);
+
+ free(i64databuf);
+ break;
+ case DType_BOOL:
+ bdatabuf = (bool*)calloc(sizeof(bool), elements);
+ ASSERT_MEM(bdatabuf);
+
+ if (getTensorValueBool(elements, bdatabuf))
+ {
+ free(bdatabuf);
+ return 1;
+ }
+
+ nperror = NumpyUtilities::writeToNpyFile(filename, shape, bdatabuf);
+
+ free(bdatabuf);
+ break;
+ default:
+ FATAL_ERROR("unsupported tensor type=%s", EnumNamesDType()[getDtype()]);
+ }
+
+ switch (nperror)
+ {
+ case NumpyUtilities::NO_ERROR:
+ break;
+ case NumpyUtilities::FILE_NOT_FOUND:
+ FATAL_ERROR("writeToNpyFile: Cannot open output file %s", filename);
+ case NumpyUtilities::FILE_IO_ERROR:
+ FATAL_ERROR("writeToNpyFile: IO error writing file: %s", filename);
+ case NumpyUtilities::FILE_TYPE_MISMATCH:
+ FATAL_ERROR("writeToNpyFile: Tensor type and Numpy file type mismatch for tensor %s filename %s",
+ getName().c_str(), filename);
+ case NumpyUtilities::HEADER_PARSE_ERROR:
+ FATAL_ERROR("Numpy header parsing error for file: %s", filename);
+ case NumpyUtilities::BUFFER_SIZE_MISMATCH:
+ FATAL_ERROR("Buffer size does not match numpy file size for tensor %s filename %s", getName().c_str(),
+ filename);
+ default:
+ FATAL_ERROR("Unknown error writing Numpy file: %s", filename);
+ }
+
+ return 0;
+}
+
+template <class T>
+int TosaReference::TensorTemplate<T>::copyValueFrom(TosaReference::Tensor* src)
+{
+ FATAL_ERROR("TensorTemplate<T>::copyValueFrom should not be called. "
+ "Implement template specialization version.");
+ return 0;
+}
+
+#define DEF_CTENSOR_COPY_VALUE_FROM(RANK, TYPE) \
+ template <> \
+ int TosaReference::Tensor##RANK<TYPE>::copyValueFrom(TosaReference::Tensor* src) \
+ { \
+ TosaReference::Tensor##RANK<TYPE>* t = dynamic_cast<Tensor##RANK<TYPE>*>(src); \
+ if (!t) \
+ { \
+ WARNING("tensor %s templated class does not match %s", src->getName().c_str(), this->getName().c_str()); \
+ return 1; \
+ } \
+ \
+ uint32_t src_rank = src->getRank(); \
+ uint32_t dst_rank = this->getRank(); \
+ DType src_dtype = src->getDtype(); \
+ DType dst_dtype = this->getDtype(); \
+ bool tensor_match = true; \
+ \
+ if ((src_rank != dst_rank) || (src_dtype != dst_dtype)) \
+ { \
+ tensor_match = false; \
+ } \
+ else \
+ { \
+ for (uint32_t i = 0; i < src_rank; i++) \
+ { \
+ int src_dim = src->getShape()[i]; \
+ int dst_dim = this->getShape()[i]; \
+ if (src_dim != dst_dim) \
+ { \
+ tensor_match = false; \
+ } \
+ } \
+ } \
+ \
+ if (!tensor_match) \
+ { \
+ WARNING("source tensor %s (rank=%u, dtype=%s, shape=%s) doesn't match destination tensor %s (rank=%u, " \
+ "dtype=%s, shape=%s)", \
+ src->getName().c_str(), src_rank, EnumNamesDType()[src_dtype], src->getShapeAsString().c_str(), \
+ this->getName().c_str(), dst_rank, EnumNamesDType()[dst_dtype], this->getShapeAsString().c_str()); \
+ return 1; \
+ } \
+ \
+ this->getTensor() = t->getTensor(); \
+ return 0; \
+ }
+
+DEF_CTENSOR_COPY_VALUE_FROM(0, float)
+DEF_CTENSOR_COPY_VALUE_FROM(1, float)
+DEF_CTENSOR_COPY_VALUE_FROM(2, float)
+DEF_CTENSOR_COPY_VALUE_FROM(3, float)
+DEF_CTENSOR_COPY_VALUE_FROM(4, float)
+DEF_CTENSOR_COPY_VALUE_FROM(5, float)
+DEF_CTENSOR_COPY_VALUE_FROM(6, float)
+DEF_CTENSOR_COPY_VALUE_FROM(0, int32_t)
+DEF_CTENSOR_COPY_VALUE_FROM(1, int32_t)
+DEF_CTENSOR_COPY_VALUE_FROM(2, int32_t)
+DEF_CTENSOR_COPY_VALUE_FROM(3, int32_t)
+DEF_CTENSOR_COPY_VALUE_FROM(4, int32_t)
+DEF_CTENSOR_COPY_VALUE_FROM(5, int32_t)
+DEF_CTENSOR_COPY_VALUE_FROM(6, int32_t)
+DEF_CTENSOR_COPY_VALUE_FROM(0, int64_t)
+DEF_CTENSOR_COPY_VALUE_FROM(1, int64_t)
+DEF_CTENSOR_COPY_VALUE_FROM(2, int64_t)
+DEF_CTENSOR_COPY_VALUE_FROM(3, int64_t)
+DEF_CTENSOR_COPY_VALUE_FROM(4, int64_t)
+DEF_CTENSOR_COPY_VALUE_FROM(5, int64_t)
+DEF_CTENSOR_COPY_VALUE_FROM(6, int64_t)
+DEF_CTENSOR_COPY_VALUE_FROM(0, bool)
+DEF_CTENSOR_COPY_VALUE_FROM(1, bool)
+DEF_CTENSOR_COPY_VALUE_FROM(2, bool)
+DEF_CTENSOR_COPY_VALUE_FROM(3, bool)
+DEF_CTENSOR_COPY_VALUE_FROM(4, bool)
+DEF_CTENSOR_COPY_VALUE_FROM(5, bool)
+DEF_CTENSOR_COPY_VALUE_FROM(6, bool)
+
+#undef DEF_CTENSOR_COPY_VALUE_FROM
+
+template <class T>
+int TosaReference::TensorTemplate<T>::setTensorValueFloat(const size_t buflen, const float* vals)
+{
+ FATAL_ERROR("TensorTemplate<T>::setTensorValueFloat should not be called. "
+ "Implement template specialization version.");
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor0<float>::setTensorValueFloat(const size_t bufLen, const float* vals)
+{
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ (*tensor)(0) = vals[0];
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor1<float>::setTensorValueFloat(const size_t bufLen, const float* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ (*tensor)(i0) = vals[idx++];
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor2<float>::setTensorValueFloat(const size_t bufLen, const float* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ (*tensor)(i0, i1) = vals[idx++];
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor3<float>::setTensorValueFloat(const size_t bufLen, const float* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ (*tensor)(i0, i1, i2) = vals[idx++];
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor4<float>::setTensorValueFloat(const size_t bufLen, const float* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ (*tensor)(i0, i1, i2, i3) = vals[idx++];
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor5<float>::setTensorValueFloat(const size_t bufLen, const float* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ (*tensor)(i0, i1, i2, i3, i4) = vals[idx++];
+ }
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor6<float>::setTensorValueFloat(const size_t bufLen, const float* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ for (int i5 = 0; i5 < shape[5]; i5++)
+ {
+ (*tensor)(i0, i1, i2, i3, i4, i5) = vals[idx++];
+ }
+ }
+ }
+ }
+ }
+ }
+ return 0;
+}
+
+template <class T>
+int TosaReference::TensorTemplate<T>::setTensorValueInt32(const size_t bufLen, const int32_t* vals)
+{
+ FATAL_ERROR("TensorTemplate<T>::setTensorValueInt32 should not be called. "
+ "Implement template specialization version.");
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor0<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals)
+{
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ (*tensor)(0) = vals[0];
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor1<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ (*tensor)(i0) = vals[idx++];
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor2<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ (*tensor)(i0, i1) = vals[idx++];
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor3<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ (*tensor)(i0, i1, i2) = vals[idx++];
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor4<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ (*tensor)(i0, i1, i2, i3) = vals[idx++];
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor5<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ (*tensor)(i0, i1, i2, i3, i4) = vals[idx++];
+ }
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor6<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ for (int i5 = 0; i5 < shape[5]; i5++)
+ {
+ (*tensor)(i0, i1, i2, i3, i4, i5) = vals[idx++];
+ }
+ }
+ }
+ }
+ }
+ }
+ return 0;
+}
+
+template <class T>
+int TosaReference::TensorTemplate<T>::setTensorValueInt64(const size_t bufLen, const int64_t* vals)
+{
+ FATAL_ERROR("TensorTemplate<T>::setTensorValueInt64 should not be called. "
+ "Implement template specialization version.");
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor0<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals)
+{
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ (*tensor)(0) = vals[0];
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor1<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ (*tensor)(i0) = vals[idx++];
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor2<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ (*tensor)(i0, i1) = vals[idx++];
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor3<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ (*tensor)(i0, i1, i2) = vals[idx++];
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor4<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ (*tensor)(i0, i1, i2, i3) = vals[idx++];
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor5<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ (*tensor)(i0, i1, i2, i3, i4) = vals[idx++];
+ }
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor6<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ for (int i5 = 0; i5 < shape[5]; i5++)
+ {
+ (*tensor)(i0, i1, i2, i3, i4, i5) = vals[idx++];
+ }
+ }
+ }
+ }
+ }
+ }
+ return 0;
+}
+
+template <class T>
+int TosaReference::TensorTemplate<T>::setTensorValueBool(const size_t buflen, const bool* vals)
+{
+ FATAL_ERROR("TensorTemplate<T>::setTensorValueBool should not be called. "
+ "Implement template specialization version.");
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor0<bool>::setTensorValueBool(const size_t bufLen, const bool* vals)
+{
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ (*tensor)(0) = vals[0];
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor1<bool>::setTensorValueBool(const size_t bufLen, const bool* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ (*tensor)(i0) = vals[idx++];
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor2<bool>::setTensorValueBool(const size_t bufLen, const bool* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ (*tensor)(i0, i1) = vals[idx++];
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor3<bool>::setTensorValueBool(const size_t bufLen, const bool* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ (*tensor)(i0, i1, i2) = vals[idx++];
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor4<bool>::setTensorValueBool(const size_t bufLen, const bool* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ (*tensor)(i0, i1, i2, i3) = vals[idx++];
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor5<bool>::setTensorValueBool(const size_t bufLen, const bool* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ (*tensor)(i0, i1, i2, i3, i4) = vals[idx++];
+ }
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor6<bool>::setTensorValueBool(const size_t bufLen, const bool* vals)
+{
+ uint32_t idx = 0;
+
+ ASSERT_MSG(bufLen == getElementCount(), "Total elements must match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ for (int i5 = 0; i5 < shape[5]; i5++)
+ {
+ (*tensor)(i0, i1, i2, i3, i4, i5) = vals[idx++];
+ }
+ }
+ }
+ }
+ }
+ }
+ return 0;
+}
+
+template <class T>
+int TosaReference::TensorTemplate<T>::getTensorValueFloat(const size_t bufLen, float* vals) const
+{
+ FATAL_ERROR("TensorTemplate<T>::getTensorValueFloat should not be called. "
+ "Implement template specialization version.");
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor0<float>::getTensorValueFloat(const size_t bufLen, float* vals) const
+{
+ int totalVals = 1;
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ vals[0] = (*tensor)(0);
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor1<float>::getTensorValueFloat(const size_t bufLen, float* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ vals[idx++] = (*tensor)(i0);
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor2<float>::getTensorValueFloat(const size_t bufLen, float* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ vals[idx++] = (*tensor)(i0, i1);
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor3<float>::getTensorValueFloat(const size_t bufLen, float* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2);
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor4<float>::getTensorValueFloat(const size_t bufLen, float* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2, i3);
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor5<float>::getTensorValueFloat(const size_t bufLen, float* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2, i3, i4);
+ }
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor6<float>::getTensorValueFloat(const size_t bufLen, float* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ for (int i5 = 0; i5 < shape[5]; i5++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2, i3, i4, i5);
+ }
+ }
+ }
+ }
+ }
+ }
+ return 0;
+}
+
+template <class T>
+int TosaReference::TensorTemplate<T>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const
+{
+ FATAL_ERROR("TensorTemplate<T>::getTensorValueInt32 should not be called. "
+ "Implement template specialization version.");
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor0<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const
+{
+ int totalVals = 1;
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ vals[0] = (*tensor)(0);
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor1<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ vals[idx++] = (*tensor)(i0);
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor2<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ vals[idx++] = (*tensor)(i0, i1);
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor3<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2);
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor4<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2, i3);
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor5<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2, i3, i4);
+ }
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor6<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ for (int i5 = 0; i5 < shape[5]; i5++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2, i3, i4, i5);
+ }
+ }
+ }
+ }
+ }
+ }
+ return 0;
+}
+
+template <class T>
+int TosaReference::TensorTemplate<T>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const
+{
+ FATAL_ERROR("TensorTemplate<T>::getTensorValueInt64 should not be called. "
+ "Implement template specialization version.");
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor0<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const
+{
+ int totalVals = 1;
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ vals[0] = (*tensor)(0);
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor1<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ vals[idx++] = (*tensor)(i0);
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor2<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ vals[idx++] = (*tensor)(i0, i1);
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor3<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2);
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor4<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2, i3);
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor5<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2, i3, i4);
+ }
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor6<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ for (int i5 = 0; i5 < shape[5]; i5++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2, i3, i4, i5);
+ }
+ }
+ }
+ }
+ }
+ }
+ return 0;
+}
+
+template <class T>
+int TosaReference::TensorTemplate<T>::getTensorValueBool(const size_t bufLen, bool* vals) const
+{
+ FATAL_ERROR("TensorTemplate<T>::getTensorValueBool should not be called. "
+ "Implement template specialization version.");
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor0<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const
+{
+ int totalVals = 1;
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ vals[0] = (*tensor)(0);
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor1<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ vals[idx++] = (*tensor)(i0);
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor2<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ vals[idx++] = (*tensor)(i0, i1);
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor3<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2);
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor4<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2, i3);
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor5<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2, i3, i4);
+ }
+ }
+ }
+ }
+ }
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor6<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const
+{
+ uint32_t idx = 0;
+ int totalVals = 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ totalVals *= shape[i];
+ }
+
+ ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match");
+
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ for (int i5 = 0; i5 < shape[5]; i5++)
+ {
+ vals[idx++] = (*tensor)(i0, i1, i2, i3, i4, i5);
+ }
+ }
+ }
+ }
+ }
+ }
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor0<float>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor0<float>();
+
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+template <>
+int TosaReference::Tensor1<float>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor1<float>(shape[0]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+template <>
+int TosaReference::Tensor2<float>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor2<float>(shape[0], shape[1]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor3<float>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor3<float>(shape[0], shape[1], shape[2]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor4<float>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor4<float>(shape[0], shape[1], shape[2], shape[3]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor5<float>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor5<float>(shape[0], shape[1], shape[2], shape[3], shape[4]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor6<float>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor6<float>(shape[0], shape[1], shape[2], shape[3], shape[4], shape[5]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor0<int32_t>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor0<int32_t>();
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+template <>
+int TosaReference::Tensor1<int32_t>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor1<int32_t>(shape[0]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+template <>
+int TosaReference::Tensor2<int32_t>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor2<int32_t>(shape[0], shape[1]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor3<int32_t>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor3<int32_t>(shape[0], shape[1], shape[2]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor4<int32_t>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor4<int32_t>(shape[0], shape[1], shape[2], shape[3]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor5<int32_t>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor5<int32_t>(shape[0], shape[1], shape[2], shape[3], shape[4]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor6<int32_t>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor6<int32_t>(shape[0], shape[1], shape[2], shape[3], shape[4], shape[5]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor0<int64_t>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor0<int64_t>();
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+template <>
+int TosaReference::Tensor1<int64_t>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor1<int64_t>(shape[0]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+template <>
+int TosaReference::Tensor2<int64_t>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor2<int64_t>(shape[0], shape[1]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor3<int64_t>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor3<int64_t>(shape[0], shape[1], shape[2]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor4<int64_t>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor4<int64_t>(shape[0], shape[1], shape[2], shape[3]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor5<int64_t>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor5<int64_t>(shape[0], shape[1], shape[2], shape[3], shape[4]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor6<int64_t>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor6<int64_t>(shape[0], shape[1], shape[2], shape[3], shape[4], shape[5]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor0<bool>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor0<bool>();
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+template <>
+int TosaReference::Tensor1<bool>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor1<bool>(shape[0]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+template <>
+int TosaReference::Tensor2<bool>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor2<bool>(shape[0], shape[1]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor3<bool>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor3<bool>(shape[0], shape[1], shape[2]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor4<bool>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor4<bool>(shape[0], shape[1], shape[2], shape[3]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor5<bool>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor5<bool>(shape[0], shape[1], shape[2], shape[3], shape[4]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor6<bool>::allocate()
+{
+ ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor");
+ tensor = new ETensor6<bool>(shape[0], shape[1], shape[2], shape[3], shape[4], shape[5]);
+ if (tensor)
+ return 0;
+ else
+ return 1;
+}
+
+template <>
+int TosaReference::Tensor0<float>::dumpTensor(FILE* out) const
+{
+ char fp_fmt[FOF_STR_LEN];
+ snprintf(fp_fmt, FOF_STR_LEN, "[ %%%sf ]\n", g_func_config.fp_format);
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, fp_fmt, (*tensor)(0));
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor1<float>::dumpTensor(FILE* out) const
+{
+ char fp_fmt[FOF_STR_LEN];
+ snprintf(fp_fmt, FOF_STR_LEN, " %%%sf ", g_func_config.fp_format);
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, fp_fmt, (*tensor)(i0));
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor2<float>::dumpTensor(FILE* out) const
+{
+ char fp_fmt[FOF_STR_LEN];
+ snprintf(fp_fmt, FOF_STR_LEN, " %%%sf ", g_func_config.fp_format);
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, fp_fmt, (*tensor)(i0, i1));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor3<float>::dumpTensor(FILE* out) const
+{
+ char fp_fmt[FOF_STR_LEN];
+ snprintf(fp_fmt, FOF_STR_LEN, " %%%sf ", g_func_config.fp_format);
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, fp_fmt, (*tensor)(i0, i1, i2));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor4<float>::dumpTensor(FILE* out) const
+{
+ char fp_fmt[FOF_STR_LEN];
+ snprintf(fp_fmt, FOF_STR_LEN, " %%%sf ", g_func_config.fp_format);
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, "[");
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ fprintf(out, fp_fmt, (*tensor)(i0, i1, i2, i3));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor5<float>::dumpTensor(FILE* out) const
+{
+ char fp_fmt[FOF_STR_LEN];
+ snprintf(fp_fmt, FOF_STR_LEN, " %%%sf ", g_func_config.fp_format);
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, "[");
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ fprintf(out, "[");
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ fprintf(out, fp_fmt, (*tensor)(i0, i1, i2, i3, i4));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor6<float>::dumpTensor(FILE* out) const
+{
+ char fp_fmt[FOF_STR_LEN];
+ snprintf(fp_fmt, FOF_STR_LEN, " %%%sf ", g_func_config.fp_format);
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, "[");
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ fprintf(out, "[");
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ fprintf(out, "[");
+ for (int i5 = 0; i5 < shape[5]; i5++)
+ {
+ fprintf(out, fp_fmt, (*tensor)(i0, i1, i2, i3, i4, i5));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor0<int64_t>::dumpTensor(FILE* out) const
+{
+ char i64_fmt[FOF_STR_LEN];
+ snprintf(i64_fmt, FOF_STR_LEN, "[ %%ld ]\n");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, i64_fmt, (*tensor)(0));
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor1<int64_t>::dumpTensor(FILE* out) const
+{
+ char i64_fmt[FOF_STR_LEN];
+ snprintf(i64_fmt, FOF_STR_LEN, " %%ld ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, i64_fmt, (*tensor)(i0));
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor2<int64_t>::dumpTensor(FILE* out) const
+{
+ char i64_fmt[FOF_STR_LEN];
+ snprintf(i64_fmt, FOF_STR_LEN, " %%ld ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, i64_fmt, (*tensor)(i0, i1));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor3<int64_t>::dumpTensor(FILE* out) const
+{
+ char i64_fmt[FOF_STR_LEN];
+ snprintf(i64_fmt, FOF_STR_LEN, " %%ld ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, i64_fmt, (*tensor)(i0, i1, i2));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor4<int64_t>::dumpTensor(FILE* out) const
+{
+ char i64_fmt[FOF_STR_LEN];
+ snprintf(i64_fmt, FOF_STR_LEN, " %%ld ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, "[");
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ fprintf(out, i64_fmt, (*tensor)(i0, i1, i2, i3));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor5<int64_t>::dumpTensor(FILE* out) const
+{
+ char i64_fmt[FOF_STR_LEN];
+ snprintf(i64_fmt, FOF_STR_LEN, " %%ld ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, "[");
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ fprintf(out, "[");
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ fprintf(out, i64_fmt, (*tensor)(i0, i1, i2, i3, i4));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor6<int64_t>::dumpTensor(FILE* out) const
+{
+ char i64_fmt[FOF_STR_LEN];
+ snprintf(i64_fmt, FOF_STR_LEN, " %%ld ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, "[");
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ fprintf(out, "[");
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ fprintf(out, "[");
+ for (int i5 = 0; i5 < shape[5]; i5++)
+ {
+ fprintf(out, i64_fmt, (*tensor)(i0, i1, i2, i3, i4, i5));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor0<int32_t>::dumpTensor(FILE* out) const
+{
+ char i32_fmt[FOF_STR_LEN];
+ snprintf(i32_fmt, FOF_STR_LEN, "[ %%d ]\n");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, i32_fmt, (*tensor)(0));
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor1<int32_t>::dumpTensor(FILE* out) const
+{
+ char i32_fmt[FOF_STR_LEN];
+ snprintf(i32_fmt, FOF_STR_LEN, " %%d ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, i32_fmt, (*tensor)(i0));
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor2<int32_t>::dumpTensor(FILE* out) const
+{
+ char i32_fmt[FOF_STR_LEN];
+ snprintf(i32_fmt, FOF_STR_LEN, " %%d ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, i32_fmt, (*tensor)(i0, i1));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor3<int32_t>::dumpTensor(FILE* out) const
+{
+ char i32_fmt[FOF_STR_LEN];
+ snprintf(i32_fmt, FOF_STR_LEN, " %%d ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, i32_fmt, (*tensor)(i0, i1, i2));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor4<int32_t>::dumpTensor(FILE* out) const
+{
+ char i32_fmt[FOF_STR_LEN];
+ snprintf(i32_fmt, FOF_STR_LEN, " %%d ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, "[");
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ fprintf(out, i32_fmt, (*tensor)(i0, i1, i2, i3));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor5<int32_t>::dumpTensor(FILE* out) const
+{
+ char i32_fmt[FOF_STR_LEN];
+ snprintf(i32_fmt, FOF_STR_LEN, " %%d ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, "[");
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ fprintf(out, "[");
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ fprintf(out, i32_fmt, (*tensor)(i0, i1, i2, i3, i4));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor6<int32_t>::dumpTensor(FILE* out) const
+{
+ char i32_fmt[FOF_STR_LEN];
+ snprintf(i32_fmt, FOF_STR_LEN, " %%d ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, "[");
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ fprintf(out, "[");
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ fprintf(out, "[");
+ for (int i5 = 0; i5 < shape[5]; i5++)
+ {
+ fprintf(out, i32_fmt, (*tensor)(i0, i1, i2, i3, i4, i5));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor0<bool>::dumpTensor(FILE* out) const
+{
+ char bool_fmt[FOF_STR_LEN];
+ snprintf(bool_fmt, FOF_STR_LEN, "[ %%s ]\n");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, bool_fmt, bool_to_str((*tensor)(0)));
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor1<bool>::dumpTensor(FILE* out) const
+{
+ char bool_fmt[FOF_STR_LEN];
+ snprintf(bool_fmt, FOF_STR_LEN, " %%s ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, bool_fmt, bool_to_str((*tensor)(i0)));
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor2<bool>::dumpTensor(FILE* out) const
+{
+ char bool_fmt[FOF_STR_LEN];
+ snprintf(bool_fmt, FOF_STR_LEN, " %%s ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, bool_fmt, bool_to_str((*tensor)(i0, i1)));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor3<bool>::dumpTensor(FILE* out) const
+{
+ char bool_fmt[FOF_STR_LEN];
+ snprintf(bool_fmt, FOF_STR_LEN, " %%s ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, bool_fmt, bool_to_str((*tensor)(i0, i1, i2)));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor4<bool>::dumpTensor(FILE* out) const
+{
+ char bool_fmt[FOF_STR_LEN];
+ snprintf(bool_fmt, FOF_STR_LEN, " %%s ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, "[");
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ fprintf(out, bool_fmt, bool_to_str((*tensor)(i0, i1, i2, i3)));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor5<bool>::dumpTensor(FILE* out) const
+{
+ char bool_fmt[FOF_STR_LEN];
+ snprintf(bool_fmt, FOF_STR_LEN, " %%s ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, "[");
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ fprintf(out, "[");
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ fprintf(out, bool_fmt, bool_to_str((*tensor)(i0, i1, i2, i3, i4)));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <>
+int TosaReference::Tensor6<bool>::dumpTensor(FILE* out) const
+{
+ char bool_fmt[FOF_STR_LEN];
+ snprintf(bool_fmt, FOF_STR_LEN, " %%s ");
+
+ if (tensor == nullptr)
+ {
+ fprintf(out, "<Not allocated>\n");
+ return 0;
+ }
+
+ fprintf(out, "[");
+ for (int i0 = 0; i0 < shape[0]; i0++)
+ {
+ fprintf(out, "[");
+ for (int i1 = 0; i1 < shape[1]; i1++)
+ {
+ fprintf(out, "[");
+ for (int i2 = 0; i2 < shape[2]; i2++)
+ {
+ fprintf(out, "[");
+ for (int i3 = 0; i3 < shape[3]; i3++)
+ {
+ fprintf(out, "[");
+ for (int i4 = 0; i4 < shape[4]; i4++)
+ {
+ fprintf(out, "[");
+ for (int i5 = 0; i5 < shape[5]; i5++)
+ {
+ fprintf(out, bool_fmt, bool_to_str((*tensor)(i0, i1, i2, i3, i4, i5)));
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+ }
+ fprintf(out, "]\n");
+
+ return 0;
+}
+
+template <class T>
+int TosaReference::TensorTemplate<T>::dumpTensor(FILE* out) const
+{
+ return 0;
+}
+
+// template explicit specialization
+template class TosaReference::TensorTemplate<Eigen::Tensor<float, 0>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<float, 1>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<float, 2>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<float, 3>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<float, 4>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<float, 5>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<float, 6>>;
+
+template class TosaReference::TensorTemplate<Eigen::Tensor<int32_t, 0>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<int32_t, 1>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<int32_t, 2>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<int32_t, 3>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<int32_t, 4>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<int32_t, 5>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<int32_t, 6>>;
+
+template class TosaReference::TensorTemplate<Eigen::Tensor<int64_t, 0>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<int64_t, 1>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<int64_t, 2>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<int64_t, 3>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<int64_t, 4>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<int64_t, 5>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<int64_t, 6>>;
+
+template class TosaReference::TensorTemplate<Eigen::Tensor<bool, 0>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<bool, 1>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<bool, 2>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<bool, 3>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<bool, 4>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<bool, 5>>;
+template class TosaReference::TensorTemplate<Eigen::Tensor<bool, 6>>;
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
new file mode 100644
index 0000000..2fd37cd
--- /dev/null
+++ b/reference_model/src/tensor.h
@@ -0,0 +1,815 @@
+
+// Copyright (c) 2020, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef TOSA_REFERENCE_TENSOR_H
+#define TOSA_REFERENCE_TENSOR_H
+
+#include "model_common.h"
+#include "ops/template_types.h"
+#include "tosa_generated.h"
+#include "tosa_serialization_handler.h"
+#include <Eigen/CXX11/Tensor>
+#include <list>
+#include <vector>
+
+using namespace tosa;
+
+namespace TosaReference
+{
+class GraphNode;
+
+class Tensor
+{
+public:
+ Tensor(std::string tensorName_,
+ DType tensorDtype__,
+ const std::vector<Usage>& tensorUsage_,
+ const std::vector<Format>& tensorFormat_,
+ std::vector<int> shape_,
+ int isConst_);
+
+ virtual ~Tensor();
+
+ int setIsSubgraphInput();
+ int setIsSubgraphOutput();
+
+ int getIsSubgraphInput() const
+ {
+ return isSubgraphInput;
+ }
+
+ int getIsSubgraphOutput() const
+ {
+ return isSubgraphOutput;
+ }
+
+ int setProducer(GraphNode* node);
+ int addConsumer(GraphNode* node);
+
+ int setIsValid()
+ {
+ isValid = 1;
+ return 0;
+ }
+
+ int clearIsValid()
+ {
+ isValid = 0;
+ return 0;
+ }
+
+ int getIsValid() const
+ {
+ return isValid;
+ }
+
+ int getIsConst() const
+ {
+ return isConst;
+ }
+
+ GraphNode* getProducer()
+ {
+ return producer;
+ }
+
+ std::vector<GraphNode*>& getConsumers()
+ {
+ return consumers;
+ }
+
+ const std::string& getName() const
+ {
+ return tensorName;
+ }
+
+ const std::vector<int>& getShape() const
+ {
+ return shape;
+ }
+
+ std::string getShapeAsString() const
+ {
+ std::string shape_str("[");
+ for (auto& dim : shape)
+ {
+ shape_str += (std::to_string(dim) + ", ");
+ }
+ shape_str.append("]");
+ return shape_str;
+ }
+
+ const std::vector<Usage>& getUsage() const
+ {
+ return tensorUsage;
+ }
+
+ bool hasUsage(Usage usage) const
+ {
+ for (auto& usg : tensorUsage)
+ {
+ if (usg == usage)
+ {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ std::string getUsageAsString() const
+ {
+ std::string usage_str("[");
+ for (auto& usg : tensorUsage)
+ {
+ usage_str += (std::string(EnumNamesUsage()[usg]) + ", ");
+ }
+ usage_str.append("]");
+ return usage_str;
+ }
+
+ const std::vector<Format>& getFormat() const
+ {
+ return tensorFormat;
+ }
+
+ bool hasFormat(Format format) const
+ {
+ for (auto& fmt : tensorFormat)
+ {
+ if (fmt == format)
+ {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ std::string getFormatAsString() const
+ {
+ std::string format_str("[");
+ for (auto& fmt : tensorFormat)
+ {
+ format_str += (std::string(EnumNamesFormat()[fmt]) + ", ");
+ }
+ format_str.append("]");
+ return format_str;
+ }
+
+ const uint32_t getElementCount() const
+ {
+ uint32_t elements = 1;
+ for (size_t i = 0; i < shape.size(); i++)
+ elements *= shape[i];
+
+ return elements;
+ }
+
+ // Comparison of rank and type with other tensors
+ const int matchRank(const Tensor& ref) const
+ {
+ return (ref.shape.size() == shape.size()) ? 0 : 1;
+ }
+
+ const int matchType(const Tensor& ref) const
+ {
+ return (ref.tensorDtype == tensorDtype) ? 0 : 1;
+ }
+
+ const int matchRankType(const Tensor& ref) const
+ {
+ return (matchType(ref) || matchRank(ref));
+ }
+
+ const int matchRankTypeShape(const Tensor& ref, const bool broadcastOk = false) const
+ {
+ if (matchRankType(ref))
+ return 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ if (shape[i] != ref.shape[i])
+ {
+ if (!broadcastOk ||
+ // For broadcasts, at least one operand must have size 1
+ // if they don't both match
+ (broadcastOk && (shape[i] != 1 && ref.shape[i] != 1)))
+ {
+ return 1;
+ }
+ }
+ }
+
+ return 0;
+ }
+
+ // Sometimes we might want to match several semi-compatible types,
+ // so just check rank and size here
+ const int matchRankSize(const Tensor& ref) const
+ {
+ if (matchRank(ref))
+ return 1;
+
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ if (shape[i] != ref.shape[i])
+ return 1;
+ }
+
+ return 0;
+ }
+
+ // Unary check to make sure rank matches
+ const int checkRequiredRank(const int exactRank) const
+ {
+ return (shape.size() == (size_t)exactRank) ? 0 : 1;
+ }
+
+ const int checkRequiredRank(const int minRank, const int maxRank) const
+ {
+ return (shape.size() >= (size_t)minRank && shape.size() <= (size_t)maxRank) ? 0 : 1;
+ }
+
+ const int getRank() const
+ {
+ return shape.size();
+ }
+
+ const DType getDtype() const
+ {
+ return tensorDtype;
+ }
+
+ virtual int dumpTensor(FILE* out) const = 0;
+ virtual int dumpTensorParams(FILE* out) const;
+ virtual int dumpTensorParams(std::ostream& out) const;
+
+ virtual int setTensorValueFloat(const size_t bufLen, const float* vals) = 0;
+ virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals) = 0;
+ virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals) = 0;
+ virtual int setTensorValueBool(const size_t bufLen, const bool* vals) = 0;
+ virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const = 0;
+ virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const = 0;
+ virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const = 0;
+ virtual int getTensorValueBool(const size_t bufLen, bool* ibuf) const = 0;
+
+ virtual int readFromNpyFile(const char* filename);
+ virtual int writeToNpyFile(const char* filename) const;
+ virtual int copyValueFrom(Tensor* tensor) = 0;
+
+ const char* bool_to_str(bool in) const
+ {
+ static const char* true_str = "true";
+ static const char* false_str = "false";
+ return in ? true_str : false_str;
+ }
+
+ virtual int allocate() = 0;
+ virtual int deallocate() = 0;
+ virtual bool is_allocated() = 0;
+
+protected:
+ std::string tensorName;
+ DType tensorDtype;
+ std::vector<Usage> tensorUsage;
+ std::vector<Format> tensorFormat;
+ int isConst;
+ int isValid;
+ std::vector<int> shape;
+ int isSubgraphInput;
+ int isSubgraphOutput;
+ bool isAllocated;
+
+ GraphNode* producer;
+ std::vector<GraphNode*> consumers;
+
+ // Note: the Eigen::Tensor is not declared in Tensor
+ // Instead, the TensorTemplate class keeps the templated tensor
+ // declaration so that the graph manipulation tools are isolated
+ // from the templated tensor type.
+ //
+ // Operators need to be aware of the TensorTemplate<EigenTensor<type, rank>> type
+ // so that they can operate on the right types.
+};
+
+template <class T>
+class TensorTemplate : public Tensor
+{
+public:
+ TensorTemplate(std::string tensorName_,
+ DType tensorDtype_,
+ const std::vector<Usage>& tensorUsage_,
+ const std::vector<Format>& tensorFormat_,
+ std::vector<int> shape_,
+ int isConst_)
+ : Tensor(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, isConst_)
+ {
+ tensor = nullptr;
+ }
+
+ virtual ~TensorTemplate()
+ {
+ deallocate();
+ }
+
+ virtual int allocate()
+ {
+ tensor = new T();
+ if (tensor)
+ return 0;
+ else
+ return 1;
+ }
+
+ virtual int deallocate()
+ {
+ if (tensor)
+ {
+ delete tensor;
+ }
+ tensor = nullptr;
+ return 0;
+ }
+
+ virtual bool is_allocated()
+ {
+ if (tensor)
+ {
+ return true;
+ }
+ return false;
+ }
+
+ T& getTensor()
+ {
+ return *tensor;
+ }
+
+ virtual int dumpTensor(FILE* out) const;
+
+ virtual int setTensorValueFloat(const size_t bufLen, const float* vals);
+ virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals);
+ virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals);
+ virtual int setTensorValueBool(const size_t bufLen, const bool* vals);
+ virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const;
+ virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const;
+ virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const;
+ virtual int getTensorValueBool(const size_t bufLen, bool* bbuf) const;
+
+ virtual int copyValueFrom(Tensor* tensor);
+
+protected:
+ T* tensor;
+};
+
+// allocate() template specializations to allocate the different tensor sizes
+// Let the compiler know here before the factory uses them, but define them in the .cc file.
+template <>
+int Tensor0<float>::allocate();
+template <>
+int Tensor1<float>::allocate();
+template <>
+int Tensor2<float>::allocate();
+template <>
+int Tensor3<float>::allocate();
+template <>
+int Tensor4<float>::allocate();
+template <>
+int Tensor5<float>::allocate();
+template <>
+int Tensor6<float>::allocate();
+
+template <>
+int Tensor0<int32_t>::allocate();
+template <>
+int Tensor1<int32_t>::allocate();
+template <>
+int Tensor2<int32_t>::allocate();
+template <>
+int Tensor3<int32_t>::allocate();
+template <>
+int Tensor4<int32_t>::allocate();
+template <>
+int Tensor5<int32_t>::allocate();
+template <>
+int Tensor6<int32_t>::allocate();
+
+template <>
+int Tensor0<int64_t>::allocate();
+template <>
+int Tensor1<int64_t>::allocate();
+template <>
+int Tensor2<int64_t>::allocate();
+template <>
+int Tensor3<int64_t>::allocate();
+template <>
+int Tensor4<int64_t>::allocate();
+template <>
+int Tensor5<int64_t>::allocate();
+template <>
+int Tensor6<int64_t>::allocate();
+
+template <>
+int Tensor0<bool>::allocate();
+template <>
+int Tensor1<bool>::allocate();
+template <>
+int Tensor2<bool>::allocate();
+template <>
+int Tensor3<bool>::allocate();
+template <>
+int Tensor4<bool>::allocate();
+template <>
+int Tensor5<bool>::allocate();
+template <>
+int Tensor6<bool>::allocate();
+
+template <>
+int Tensor0<float>::copyValueFrom(Tensor* src);
+template <>
+int Tensor1<float>::copyValueFrom(Tensor* src);
+template <>
+int Tensor2<float>::copyValueFrom(Tensor* src);
+template <>
+int Tensor3<float>::copyValueFrom(Tensor* src);
+template <>
+int Tensor4<float>::copyValueFrom(Tensor* src);
+template <>
+int Tensor5<float>::copyValueFrom(Tensor* src);
+template <>
+int Tensor6<float>::copyValueFrom(Tensor* src);
+
+template <>
+int Tensor0<int32_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor1<int32_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor2<int32_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor3<int32_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor4<int32_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor5<int32_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor6<int32_t>::copyValueFrom(Tensor* src);
+
+template <>
+int Tensor0<int64_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor1<int64_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor2<int64_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor3<int64_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor4<int64_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor5<int64_t>::copyValueFrom(Tensor* src);
+template <>
+int Tensor6<int64_t>::copyValueFrom(Tensor* src);
+
+template <>
+int Tensor0<bool>::copyValueFrom(Tensor* src);
+template <>
+int Tensor1<bool>::copyValueFrom(Tensor* src);
+template <>
+int Tensor2<bool>::copyValueFrom(Tensor* src);
+template <>
+int Tensor3<bool>::copyValueFrom(Tensor* src);
+template <>
+int Tensor4<bool>::copyValueFrom(Tensor* src);
+template <>
+int Tensor5<bool>::copyValueFrom(Tensor* src);
+template <>
+int Tensor6<bool>::copyValueFrom(Tensor* src);
+
+template <>
+int Tensor0<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
+template <>
+int Tensor1<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
+template <>
+int Tensor2<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
+template <>
+int Tensor3<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
+template <>
+int Tensor4<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
+template <>
+int Tensor5<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
+template <>
+int Tensor6<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals);
+
+template <>
+int Tensor0<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
+template <>
+int Tensor1<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
+template <>
+int Tensor2<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
+template <>
+int Tensor3<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
+template <>
+int Tensor4<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
+template <>
+int Tensor5<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
+template <>
+int Tensor6<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const;
+
+template <>
+int Tensor0<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
+template <>
+int Tensor1<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
+template <>
+int Tensor2<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
+template <>
+int Tensor3<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
+template <>
+int Tensor4<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
+template <>
+int Tensor5<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
+template <>
+int Tensor6<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals);
+
+template <>
+int Tensor0<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
+template <>
+int Tensor1<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
+template <>
+int Tensor2<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
+template <>
+int Tensor3<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
+template <>
+int Tensor4<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
+template <>
+int Tensor5<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
+template <>
+int Tensor6<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const;
+
+template <>
+int Tensor0<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
+template <>
+int Tensor1<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
+template <>
+int Tensor2<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
+template <>
+int Tensor3<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
+template <>
+int Tensor4<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
+template <>
+int Tensor5<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
+template <>
+int Tensor6<float>::setTensorValueFloat(const size_t bufLen, const float* vals);
+
+template <>
+int Tensor0<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
+template <>
+int Tensor1<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
+template <>
+int Tensor2<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
+template <>
+int Tensor3<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
+template <>
+int Tensor4<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
+template <>
+int Tensor5<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
+template <>
+int Tensor6<float>::getTensorValueFloat(const size_t bufLen, float* vals) const;
+
+template <>
+int Tensor0<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
+template <>
+int Tensor1<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
+template <>
+int Tensor2<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
+template <>
+int Tensor3<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
+template <>
+int Tensor4<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
+template <>
+int Tensor5<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
+template <>
+int Tensor6<bool>::setTensorValueBool(const size_t bufLen, const bool* vals);
+
+template <>
+int Tensor0<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
+template <>
+int Tensor1<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
+template <>
+int Tensor2<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
+template <>
+int Tensor3<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
+template <>
+int Tensor4<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
+template <>
+int Tensor5<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
+template <>
+int Tensor6<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const;
+
+// assume we only dump float type tensor now
+template <>
+int Tensor0<float>::dumpTensor(FILE* out) const;
+template <>
+int Tensor1<float>::dumpTensor(FILE* out) const;
+template <>
+int Tensor2<float>::dumpTensor(FILE* out) const;
+template <>
+int Tensor3<float>::dumpTensor(FILE* out) const;
+template <>
+int Tensor4<float>::dumpTensor(FILE* out) const;
+template <>
+int Tensor5<float>::dumpTensor(FILE* out) const;
+template <>
+int Tensor6<float>::dumpTensor(FILE* out) const;
+template <>
+int Tensor0<int32_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor1<int32_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor2<int32_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor3<int32_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor4<int32_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor5<int32_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor6<int32_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor0<int64_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor1<int64_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor2<int64_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor3<int64_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor4<int64_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor5<int64_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor6<int64_t>::dumpTensor(FILE* out) const;
+template <>
+int Tensor0<bool>::dumpTensor(FILE* out) const;
+template <>
+int Tensor1<bool>::dumpTensor(FILE* out) const;
+template <>
+int Tensor2<bool>::dumpTensor(FILE* out) const;
+template <>
+int Tensor3<bool>::dumpTensor(FILE* out) const;
+template <>
+int Tensor4<bool>::dumpTensor(FILE* out) const;
+template <>
+int Tensor5<bool>::dumpTensor(FILE* out) const;
+template <>
+int Tensor6<bool>::dumpTensor(FILE* out) const;
+
+class TensorFactory
+{
+public:
+ static Tensor* newTensor(std::string tensorName_,
+ DType tensorDtype_,
+ const std::vector<Usage>& tensorUsage_,
+ const std::vector<Format>& tensorFormat_,
+ std::vector<int> shape_,
+ int isConst_,
+ const uint32_t rank)
+ {
+ switch (tensorDtype_)
+ {
+ case DType_FLOAT:
+ switch (rank)
+ {
+ case 0:
+ return new Tensor0<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 1:
+ return new Tensor1<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 2:
+ return new Tensor2<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 3:
+ return new Tensor3<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 4:
+ return new Tensor4<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 5:
+ return new Tensor5<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 6:
+ return new Tensor6<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ default:
+ goto done;
+ }
+ case DType_INT32:
+ case DType_AINT8:
+ case DType_UINT8:
+ case DType_INT4:
+ case DType_INT8:
+ case DType_INT16:
+ switch (rank)
+ {
+ case 0:
+ return new Tensor0<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 1:
+ return new Tensor1<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 2:
+ return new Tensor2<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 3:
+ return new Tensor3<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 4:
+ return new Tensor4<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 5:
+ return new Tensor5<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 6:
+ return new Tensor6<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ default:
+ goto done;
+ }
+ case DType_INT48:
+ switch (rank)
+ {
+ case 0:
+ return new Tensor0<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 1:
+ return new Tensor1<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 2:
+ return new Tensor2<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 3:
+ return new Tensor3<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 4:
+ return new Tensor4<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 5:
+ return new Tensor5<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 6:
+ return new Tensor6<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ default:
+ goto done;
+ }
+ case DType_BOOL:
+ switch (rank)
+ {
+ case 0:
+ return new Tensor0<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 1:
+ return new Tensor1<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 2:
+ return new Tensor2<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 3:
+ return new Tensor3<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 4:
+ return new Tensor4<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 5:
+ return new Tensor5<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ case 6:
+ return new Tensor6<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_,
+ isConst_);
+ default:
+ goto done;
+ }
+ default:
+ goto done;
+ }
+
+ done:
+ FATAL_ERROR("Unsupported tensor name=%s, type=%s, rank=%d", tensorName_.c_str(), EnumNamesDType()[tensorDtype_],
+ rank);
+ }
+
+ static Tensor* newTensor(DType type, const std::vector<int> shape);
+};
+}; // namespace TosaReference
+
+#endif