diff options
Diffstat (limited to 'reference_model/src')
48 files changed, 14272 insertions, 0 deletions
diff --git a/reference_model/src/arith_util.h b/reference_model/src/arith_util.h new file mode 100644 index 0000000..554a7a2 --- /dev/null +++ b/reference_model/src/arith_util.h @@ -0,0 +1,194 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * Filename: src/arith_util.h + * Description: + * arithmetic utility macro, include: + * fp16 (float16_t ) type alias + * bitwise operation + * fix point arithmetic + * fp16 type conversion(in binary translation) + * fp16 arithmetic (disguised with fp32 now) + */ + +#ifndef ARITH_UTIL_H +#define ARITH_UTIL_H + +#include <fenv.h> +#include <math.h> +#define __STDC_LIMIT_MACROS //enable min/max of plain data type +#include "func_debug.h" +#include "inttypes.h" +#include <cassert> +#include <iostream> +#include <limits> +#include <stdint.h> +#include <typeinfo> + +using namespace std; + +inline size_t _count_one(uint64_t val) +{ + size_t count = 0; + for (; val; count++) + { + val &= val - 1; + } + return count; +} + +template <typename T> +inline size_t _integer_log2(T val) +{ + size_t result = 0; + while (val >>= 1) + { + ++result; + } + return result; +} + +template <typename T> +inline size_t _count_leading_zeros(T val) +{ + size_t size = sizeof(T) * 8; + size_t count = 0; + T msb = static_cast<T>(1) << (size - 1); + for (size_t i = 0; i < size; i++) + { + if (!((val << i) & msb)) + count++; + else + break; + } + return count; +} + +template <typename T> +inline size_t _count_leading_ones(T val) +{ + size_t size = sizeof(T) * 8; + size_t count = 0; + T msb = static_cast<T>(1) << (size - 1); + for (size_t i = 0; i < size; i++) + { + if ((val << i) & msb) + count++; + else + break; + } + return count; +} + +#define MAX(a, b) ((a) > (b) ? (a) : (b)) +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +// Compute ceiling of (a/b) +#define DIV_CEIL(a, b) ((a) % (b) ? ((a) / (b) + 1) : ((a) / (b))) + +// Returns a mask of 1's of this size +#define ONES_MASK(SIZE) ((uint64_t)((SIZE) >= 64 ? 0xffffffffffffffffULL : ((uint64_t)(1) << (SIZE)) - 1)) + +// Returns a field of bits from HIGH_BIT to LOW_BIT, right-shifted +// include both side, equivalent VAL[LOW_BIT:HIGH_BIT] in verilog + +#define BIT_FIELD(HIGH_BIT, LOW_BIT, VAL) (((uint64_t)(VAL) >> (LOW_BIT)) & ONES_MASK((HIGH_BIT) + 1 - (LOW_BIT))) + +// Returns a bit at a particular position +#define BIT_EXTRACT(POS, VAL) (((uint64_t)(VAL) >> (POS)) & (1)) + +// Use Brian Kernigahan's way: https://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetKernighan +// Does this need to support floating point type? +// Not sure if static_cast is the right thing to do, try to be type safe first +#define ONES_COUNT(VAL) (_count_one((uint64_t)(VAL))) + +#define SHIFT(SHF, VAL) (((SHF) > 0) ? ((VAL) << (SHF)) : ((SHF < 0) ? ((VAL) >> (-(SHF))) : (VAL))) +#define ROUNDTO(A, B) ((A) % (B) == 0 ? (A) : ((A) / (B) + 1) * (B)) +#define ROUNDTOLOWER(A, B) (((A) / (B)) * (B)) +#define BIDIRECTIONAL_SHIFT(VAL, SHIFT) (((SHIFT) >= 0) ? ((VAL) << (SHIFT)) : ((VAL) >> (-(SHIFT)))) +#define ILOG2(VAL) (_integer_log2(VAL)) + +// Get negative value (2's complement) +#define NEGATIVE_8(VAL) ((uint8_t)(~(VAL) + 1)) +#define NEGATIVE_16(VAL) ((uint16_t)(~(VAL) + 1)) +#define NEGATIVE_32(VAL) ((uint32_t)(~(VAL) + 1)) +#define NEGATIVE_64(VAL) ((uint64_t)(~(VAL) + 1)) +// Convert a bit quanity to the minimum bytes required to hold those bits +#define BITS_TO_BYTES(BITS) (ROUNDTO((BITS), 8) / 8) + +// Count leading zeros/ones for 8/16/32/64-bit operands +// (I don't see an obvious way to collapse this into a size-independent set) +// treated as unsigned +#define LEADING_ZEROS_64(VAL) (_count_leading_zeros((uint64_t)(VAL))) +#define LEADING_ZEROS_32(VAL) (_count_leading_zeros((uint32_t)(VAL))) +#define LEADING_ZEROS_16(VAL) (_count_leading_zeros((uint16_t)(VAL))) +#define LEADING_ZEROS_8(VAL) (_count_leading_zeros((uint8_t)(VAL))) +#define LEADING_ZEROS(VAL) (_count_leading_zeros(VAL)) + +#define LEADING_ONES_64(VAL) _count_leading_ones((uint64_t)(VAL)) +#define LEADING_ONES_32(VAL) _count_leading_ones((uint32_t)(VAL)) +#define LEADING_ONES_16(VAL) _count_leading_ones((uint16_t)(VAL)) +#define LEADING_ONES_8(VAL) _count_leading_ones((uint8_t)(VAL)) +#define LEADING_ONES(VAL) _count_leading_ones(VAL) +// math operation +// sign-extended for signed version +// extend different return type (8, 16, 32) + (S, U) +// Saturate a value at a certain bitwidth, signed and unsigned versions +// Format is as followed: SATURATE_VAL_{saturation_sign}_{return_type} +// for example +// SATURATE_VAL_U_8U(8,300) will return uint8_t with value of 255(0xff) +// SATURATE_VAL_S_32S(5,-48) will return int32_t with value of -16(0x10) +// note that negative value can cast to unsigned return type using native uint(int) cast +// so SATURATE_VAL_S_8U(5,-40) will have value 0'b1110000 which is in turn 224 in uint8_t + +template <typename T> +constexpr T bitmask(const uint32_t width) +{ + ASSERT(width <= sizeof(T) * 8); + return width == sizeof(T) * 8 ? static_cast<T>(std::numeric_limits<uintmax_t>::max()) + : (static_cast<T>(1) << width) - 1; +} + +template <typename T> +constexpr T minval(const uint32_t width) +{ + ASSERT(width <= sizeof(T) * 8); + return std::is_signed<T>::value ? -(static_cast<T>(1) << (width - 1)) : 0; +} + +template <typename T> +constexpr T maxval(const uint32_t width) +{ + ASSERT(width <= sizeof(T) * 8); + return bitmask<T>(width - std::is_signed<T>::value); +} + +template <typename T> +constexpr T saturate(const uint32_t width, const intmax_t value) +{ + // clang-format off + return static_cast<T>( + std::min( + std::max( + value, + static_cast<intmax_t>(minval<T>(width)) + ), + static_cast<intmax_t>(maxval<T>(width)) + ) + ); + // clang-format on +} + +#endif /* _ARITH_UTIL_H */ diff --git a/reference_model/src/debug_modes.def b/reference_model/src/debug_modes.def new file mode 100644 index 0000000..51b151d --- /dev/null +++ b/reference_model/src/debug_modes.def @@ -0,0 +1,20 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Defines the debugging printing modes + +DEBUG_MODE(CONFIG,0) // Configuration parsing/initialization +DEBUG_MODE(GT,1) // Graph traverser +DEBUG_MODE(OP,2) // Operation diff --git a/reference_model/src/debug_types.h b/reference_model/src/debug_types.h new file mode 100644 index 0000000..bd93f19 --- /dev/null +++ b/reference_model/src/debug_types.h @@ -0,0 +1,57 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * Filename: src/debug_types.h + * Description: + * Defines fundamental debugger datatypes for the functional model + */ + +#ifndef DEBUG_TYPES_H_ +#define DEBUG_TYPES_H_ + +#ifdef __cplusplus +extern "C" +{ +#endif + + // Debug verbosity mask + typedef enum func_debug_verbosity_e + { + DEBUG_VERB_NONE = 0x00, + DEBUG_VERB_INFO = 0x01, // Informational debugging messages + DEBUG_VERB_IFACE = 0x02, // Interface debugging support + DEBUG_VERB_LOW = 0x04, // Low, medium, and high levels of debug printout + DEBUG_VERB_MED = 0x08, + DEBUG_VERB_HIGH = 0x10 + } func_debug_verbosity_e; + + // Generated debug modes enumeration + typedef enum func_debug_mode_e + { + DEBUG_NONE = 0x0, +#define DEBUG_MODE(NAME, BIT) DEBUG_##NAME = (1UL << BIT), +#include "debug_modes.def" +#undef DEBUG_MODE + DEBUG_ALL = 0xffffffffffffffffUL + } func_debug_mode_e; + +#define DEBUG_INST_ALL 0xffffffffffffffffUL + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/reference_model/src/func_config.cc b/reference_model/src/func_config.cc new file mode 100644 index 0000000..bd1ce32 --- /dev/null +++ b/reference_model/src/func_config.cc @@ -0,0 +1,632 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <ctype.h> +#include <signal.h> +#include <stdarg.h> +#include <stdint.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include <sys/types.h> + +#include "func_config.h" +#include "func_debug.h" + +#define MAX_NAME_LEN 128 +#define MAX_DESC_LEN 128 + +#ifndef ARG_ERROR +#define ARG_ERROR(...) \ + fprintf(stderr, "ERROR: "); \ + fprintf(stderr, __VA_ARGS__); \ + fprintf(stderr, "\n"); \ + return 1; +#endif + +// Parameter base name string table +const char* config_base_name_table[] = { +#define DEF_UNIT_START(UNIT) +#define DEF_UNIT_END(UNIT) +#define DEF_OPTION(NAME, DESC, TYPE, FMT, DEFAULT) #NAME, +#define DEF_OPTION_STR(NAME, DESC, LEN, DEFAULT) #NAME, +#define DEF_UNIT_OPTION(UNIT, NAME, DESC, TYPE, FMT, DEFAULT) #NAME, +#define DEF_UNIT_OPTION_STR(UNIT, NAME, DESC, LEN, DEFAULT) #NAME, +#include "func_config.def" +#undef DEF_UNIT_START +#undef DEF_UNIT_END +#undef DEF_OPTION +#undef DEF_OPTION_STR +#undef DEF_UNIT_OPTION_STR +#undef DEF_UNIT_OPTION +}; + +// Parameter description table +const char* config_param_desc_table[] = { +#define DEF_UNIT_START(UNIT) +#define DEF_UNIT_END(UNIT) +#define DEF_OPTION(NAME, DESC, TYPE, FMT, DEFAULT) #DESC, +#define DEF_OPTION_STR(NAME, DESC, LEN, DEFAULT) #DESC, +#define DEF_UNIT_OPTION(UNIT, NAME, DESC, TYPE, FMT, DEFAULT) #DESC, +#define DEF_UNIT_OPTION_STR(UNIT, NAME, DESC, LEN, DEFAULT) #DESC, +#include "func_config.def" +#undef DEF_UNIT_START +#undef DEF_UNIT_END +#undef DEF_OPTION +#undef DEF_UNIT_OPTION +#undef DEF_OPTION_STR +#undef DEF_UNIT_OPTION_STR +}; + +// String table and enum for the option hierarchy level/sub-levels +// (no leaf options). Attribute at the top level have "BASE" as their +// enum value and an empty string for the value. +const char* config_hier_str_table[] = { + "", +#define DEF_UNIT_START(UNIT) #UNIT, +#define DEF_UNIT_END(UNIT) /**/ +#define DEF_OPTION(NAME, DESC, TYPE, FMT, DEFAULT) /**/ +#define DEF_OPTION_STR(NAME, DESC, LEN, DEFAULT) /**/ +#define DEF_UNIT_OPTION(UNIT, NAME, DESC, TYPE, FMT, DEFAULT) /**/ +#define DEF_UNIT_OPTION_STR(UNIT, NAME, DESC, LEN, DEFAULT) /**/ +#include "func_config.def" +#undef DEF_UNIT_START +#undef DEF_UNIT_END +#undef DEF_OPTION +#undef DEF_UNIT_OPTION +#undef DEF_OPTION_STR +#undef DEF_UNIT_OPTION_STR +}; + +typedef enum config_hier_enum_t +{ + BASE, +#define DEF_UNIT_START(UNIT) CURRENT_UNIT, +#define DEF_UNIT_END(UNIT) /**/ +#define DEF_OPTION(NAME, DESC, TYPE, FMT, DEFAULT) /**/ +#define DEF_OPTION_STR(NAME, DESC, LEN, DEFAULT) /**/ +#define DEF_UNIT_OPTION(UNIT, NAME, DESC, TYPE, FMT, DEFAULT) /**/ +#define DEF_UNIT_OPTION_STR(UNIT, NAME, DESC, LEN, DEFAULT) /**/ +#include "func_config.def" +#undef DEF_UNIT_START +#undef DEF_UNIT_END +#undef DEF_OPTION +#undef DEF_UNIT_OPTION +#undef DEF_OPTION_STR +#undef DEF_UNIT_OPTION_STR + + MAX_CONFIG_HIER +} config_hier_enum_t; + +// Mapping from a leaf parameter index to the +// position in the hierarchy. +config_hier_enum_t config_hierarchy_map[] = { +#define DEF_UNIT_START(UNIT) +#define DEF_UNIT_END(UNIT) +#define DEF_OPTION(NAME, DESC, TYPE, FMT, DEFAULT) BASE, +#define DEF_OPTION_STR(NAME, DESC, LEN, DEFAULT) BASE, +#define DEF_UNIT_OPTION(UNIT, NAME, DESC, TYPE, FMT, DEFAULT) CURRENT_UNIT, +#define DEF_UNIT_OPTION_STR(UNIT, NAME, DESC, LEN, DEFAULT) CURRENT_UNIT, +#include "func_config.def" +#undef DEF_UNIT_START +#undef DEF_UNIT_END +#undef DEF_OPTION +#undef DEF_UNIT_OPTION +#undef DEF_OPTION_STR +#undef DEF_UNIT_OPTION_STR +}; + +#define CONFIG_PARAMETER_COUNT (sizeof(config_hierarchy_map) / sizeof(config_hier_enum_t)) + +// Dynamically generated at initialization +char** config_param_str_table = nullptr; + +// Initialize the configuration data structures +int func_model_init_config() +{ + // Initialize string table (builds the hierarchical names) + config_param_str_table = (char**)calloc(CONFIG_PARAMETER_COUNT, sizeof(char*)); + ASSERT_MEM(config_param_str_table); + + for (uint32_t i = 0; i < CONFIG_PARAMETER_COUNT; i++) + { + size_t len = strlen(config_base_name_table[i]) + 1; + if (config_hierarchy_map[i] != BASE) + { + ASSERT_MSG(config_hierarchy_map[i] <= MAX_CONFIG_HIER, + "Configuration parameter\'s hierarchy is out of bounds"); + len += strlen(config_hier_str_table[config_hierarchy_map[i]]) + 1; + } + config_param_str_table[i] = (char*)calloc(len, 1); + ASSERT_MEM(config_param_str_table[i]); + ASSERT_MSG(len < MAX_NAME_LEN, "option expanded name is too long: %s", config_base_name_table[i]); + + if (config_hierarchy_map[i] != BASE) + { + snprintf(config_param_str_table[i], len, "%s.%s", config_hier_str_table[config_hierarchy_map[i]], + config_base_name_table[i]); + } + else + { + snprintf(config_param_str_table[i], len, "%s", config_base_name_table[i]); + } + } + + return 0; +} + +int func_model_set_default_config(func_config_t* func_config) +{ + // Set default values in the global configuration data structure + bzero(func_config, sizeof(*func_config)); + +#define DEF_UNIT_START(UNIT) +#define DEF_UNIT_END(UNIT) +#define DEF_OPTION(NAME, DESC, TYPE, FMT, DEFAULT) func_config->NAME = (DEFAULT); +#define DEF_OPTION_STR(NAME, DESC, LEN, DEFAULT) strncpy(func_config->NAME, (DEFAULT), (LEN)-1); +#define DEF_UNIT_OPTION(UNIT, NAME, DESC, TYPE, FMT, DEFAULT) func_config->UNIT.NAME = (DEFAULT); +#define DEF_UNIT_OPTION_STR(UNIT, NAME, DESC, LEN, DEFAULT) strncpy(func_config->UNIT.NAME, (DEFAULT), (LEN)-1); +#include "func_config.def" +#undef DEF_UNIT_START +#undef DEF_UNIT_END +#undef DEF_OPTION +#undef DEF_UNIT_OPTION +#undef DEF_OPTION_STR +#undef DEF_UNIT_OPTION_STR + + return 0; +} + +int func_model_config_cleanup() +{ + uint32_t i; + + if (!config_param_str_table) + return 1; + + for (i = 0; i < CONFIG_PARAMETER_COUNT; i++) + { + free(config_param_str_table[i]); + } + + free(config_param_str_table); + config_param_str_table = nullptr; + + return 0; +} + +int func_model_config_set_option(func_config_t* func_config, const char* name, const char* value) +{ + // Increment an index variable on each parameter position + // so that we can index both the position struct through the macro and the + // array of parameter names through a simple array of strings. + int param_idx = 0; + char* endptr; + + // TODO: does not handle strings yet. Can set magic values on FMT to + // choose a string copy vs strtoull +#define DEF_UNIT_START(UNIT) +#define DEF_UNIT_END(UNIT) +#define DEF_OPTION(NAME, DESC, TYPE, FMT, DEFAULT) \ + if (!strcmp(config_param_str_table[param_idx], name)) \ + { \ + func_config->NAME = (uint64_t)strtoll(value, &endptr, 0); \ + if (endptr == value) \ + { \ + ARG_ERROR("Cannot parse option: %s = %s", name, value); \ + } \ + return 0; \ + } \ + param_idx++; + +#define DEF_OPTION_STR(NAME, DESC, LEN, DEFAULT) \ + if (!strcmp(config_param_str_table[param_idx], name)) \ + { \ + if (strlen(value) >= LEN) \ + { \ + ARG_ERROR("Option value is too long: %s = %s", name, value); \ + } \ + strncpy(func_config->NAME, value, (LEN)-1); \ + return 0; \ + } \ + param_idx++; + +#define DEF_UNIT_OPTION(UNIT, NAME, DESC, TYPE, FMT, DEFAULT) \ + if (!strcmp(config_param_str_table[param_idx], name)) \ + { \ + func_config->UNIT.NAME = (uint64_t)strtoll(value, &endptr, 0); \ + if (endptr == value) \ + { \ + ARG_ERROR("Cannot parse option: %s = %s", name, value); \ + } \ + return 0; \ + } \ + param_idx++; + +#define DEF_UNIT_OPTION_STR(UNIT, NAME, DESC, LEN, DEFAULT) \ + if (!strcmp(config_param_str_table[param_idx], name)) \ + { \ + if (strlen(value) >= LEN) \ + { \ + ARG_ERROR("Option value is too long: %s = %s", name, value); \ + } \ + strncpy(func_config->UNIT.NAME, value, (LEN)-1); \ + return 0; \ + } \ + param_idx++; + +#include "func_config.def" +#undef DEF_UNIT_START +#undef DEF_UNIT_END +#undef DEF_OPTION +#undef DEF_UNIT_OPTION +#undef DEF_OPTION_STR +#undef DEF_UNIT_OPTION_STR + + // No match! + ARG_ERROR("Cannot find option: %s", name); + + return 1; +} + +int func_model_config_get_option_by_name(func_config_t* func_config, const char* name, uint64_t* val) +{ + // Increment an index variable on each parameter position + // so that we can index both the position struct through the macro and the + // array of parameter names through a simple array of strings. + int param_idx = 0; + +#define DEF_UNIT_START(UNIT) +#define DEF_UNIT_END(UNIT) + +#define DEF_OPTION_STR(NAME, DESC, LEN, DEFAULT) param_idx++; + +#define DEF_UNIT_OPTION_STR(UNIT, NAME, DESC, FMT, DEFAULT) param_idx++; + +#define DEF_OPTION(NAME, DESC, TYPE, FMT, DEFAULT) \ + if (!strcmp(config_param_str_table[param_idx], name)) \ + { \ + *val = func_config->NAME; \ + return 0; \ + } \ + param_idx++; + +#define DEF_UNIT_OPTION(UNIT, NAME, DESC, TYPE, FMT, DEFAULT) \ + if (!strcmp(config_param_str_table[param_idx], name)) \ + { \ + *val = func_config->UNIT.NAME; \ + return 0; \ + } \ + param_idx++; + +#include "func_config.def" +#undef DEF_UNIT_START +#undef DEF_UNIT_END +#undef DEF_OPTION +#undef DEF_UNIT_OPTION +#undef DEF_OPTION_STR +#undef DEF_UNIT_OPTION_STR + // No match! + return 1; +} +int func_model_config_get_str_option_by_name(func_config_t* func_config, + const char* name, + char* value, + const uint32_t len) +{ + // Increment an index variable on each parameter position + // so that we can index both the position struct through the macro and the + // array of parameter names through a simple array of strings. + int param_idx = 0; + +#define DEF_UNIT_START(UNIT) +#define DEF_UNIT_END(UNIT) +#define DEF_OPTION_STR(NAME, DESC, LEN, DEFAULT) \ + if (!strcmp(config_param_str_table[param_idx], name)) \ + { \ + strncpy(value, func_config->NAME, len - 1); \ + return 0; \ + } \ + param_idx++; + +#define DEF_UNIT_OPTION_STR(UNIT, NAME, DESC, LEN, DEFAULT) \ + if (!strcmp(config_param_str_table[param_idx], name)) \ + { \ + strncpy(value, func_config->UNIT.NAME, len - 1); \ + return 0; \ + } \ + param_idx++; + +#define DEF_OPTION(NAME, DESC, TYPE, FMT, DEFAULT) param_idx++; + +#define DEF_UNIT_OPTION(UNIT, NAME, DESC, TYPE, FMT, DEFAULT) param_idx++; + +#include "func_config.def" +#undef DEF_UNIT_START +#undef DEF_UNIT_END +#undef DEF_OPTION +#undef DEF_UNIT_OPTION +#undef DEF_OPTION_STR +#undef DEF_UNIT_OPTION_STR + // No match! + return 1; +} + +int func_config_print_config_help(FILE* out) +{ + fprintf(out, "%-40s %s\n", "Option", "Description"); + fprintf(out, "%-40s %s\n", "------", "-----------"); + + for (uint32_t i = 0; i < CONFIG_PARAMETER_COUNT; i++) + { + fprintf(out, "-C%-40s %s\n", config_param_str_table[i], config_param_desc_table[i]); + } + + fprintf(out, "\n"); + + return 0; +} + +int func_model_print_config(func_config_t* func_config, FILE* out) +{ +#define DEF_UNIT_START(UNIT) +#define DEF_UNIT_END(UNIT) +#define DEF_OPTION(NAME, DESC, TYPE, FMT, DEFAULT) fprintf(out, "%-40s = " FMT "\n", #NAME, func_config->NAME); +#define DEF_UNIT_OPTION(UNIT, NAME, DESC, TYPE, FMT, DEFAULT) \ + fprintf(out, "%-40s = " FMT "\n", #UNIT "." #NAME, func_config->UNIT.NAME); +#define DEF_OPTION_STR(NAME, DESC, LEN, DEFAULT) fprintf(out, "%-40s = %s\n", #NAME, func_config->NAME); +#define DEF_UNIT_OPTION_STR(UNIT, NAME, DESC, LEN, DEFAULT) \ + fprintf(out, "%-40s = %s\n", #UNIT "." #NAME, func_config->UNIT.NAME); + +#define FOF_HEX "0x%llx" +#define FOF_DEC "%" PRIu32 +#define FOF_DECU64 "%" PRIu64 + +#include "func_config.def" +#undef DEF_UNIT_START +#undef DEF_UNIT_END +#undef DEF_OPTION +#undef DEF_UNIT_OPTION +#undef DEF_OPTION_STR +#undef DEF_UNIT_OPTION_STR + + return 0; +} + +static const char* programname; + +void func_model_print_debug_masks(FILE* out) +{ + fprintf(out, "\t List of components:\n"); +#define DEBUG_MODE(string, value) fprintf(out, "\t\t" #string "\n"); +#include "debug_modes.def" +#undef DEBUG_MODE +} + +int func_model_print_help(FILE* out) +{ + fprintf(out, "TOSA Reference Model help\n\n"); + + fprintf(out, + "Usage: %s [-c] [-C <name=value>] [-d <Debug Mask>] [-h] [-i <uscriptfile>] [-l <verbosity>] [-F " + "<flatconfig>]\n", + programname); + fprintf(out, "\t-c - Print list of config options\n"); + fprintf(out, "\t-C <name=value> - modify config option <name> to <value>\n"); + fprintf(out, "\t-d <Debug Mask - set component debug mask\n"); + func_model_print_debug_masks(out); + fprintf(out, "\t-F <flatconfig> - parse <flatconfig> as file of config options\n"); + fprintf(out, "\t-h - show this help message and exit\n"); + fprintf( + out, + "\t-i <input_tensor_name>,<filename> - set input tensor <input_tensor_name> to the values from <filename>\n"); + fprintf(out, "\t-l <verbosity> - set log verbosity\n"); + fprintf(out, "\t-o <debuglog> - set debug log file\n"); + fprintf(out, "\n"); + + func_config_print_config_help(stdout); + + return 0; +} + +static const char* get_arg_text(int& index, const int argc, const char** argv) +{ + if (strlen(argv[index]) > 2) + { + return argv[index] + 2; + } + + if ((index + 1 == argc) || (argv[index + 1][0] == '-')) + { + fprintf(stderr, "No option value found for option %s\n", argv[index]); + return ""; + } + + index++; + return argv[index]; +} + +// Read the command line arguments +int func_model_parse_cmd_line(func_config_t* func_config, func_debug_t* func_debug, const int argc, const char** argv) +{ + int i; + programname = argv[0]; + for (i = 1; i < argc; i++) + { + // All command line arguments must begin with -X where X is a recognized character + if (strlen(argv[i]) < 2 || argv[i][0] != '-') + { + func_model_print_help(stderr); + ARG_ERROR("Command line argument at position %d not valid: %s", i, argv[i]); + } + + switch (argv[i][1]) + { + // Model parameters may be overridden with the -Cname=value switch + case 'c': + func_config_print_config_help(stderr); + return 1; + + case 'C': + { + const char *name = nullptr, *value = nullptr; + + // Break the string into name and value parts + name = get_arg_text(i, argc, argv); + value = strchr(name, '='); + + if (value == nullptr) + { + func_model_print_help(stderr); + ARG_ERROR("Cannot parse -C argument at position %d: %s", i, argv[i]); + } + + *const_cast<char*>(value) = 0; + + if (func_model_config_set_option(func_config, name, value + 1)) + { + func_model_print_help(stderr); + ARG_ERROR("Cannot parse -C argument at position %d: %s", i, argv[i]); + } + break; + } + + case 'd': + case 'D': + { + func_debug_set_mask(func_debug, get_arg_text(i, argc, argv)); + break; + } + case 'F': + { + // Read a flat configuration file + if (func_model_parse_flat_config_file(func_config, get_arg_text(i, argc, argv))) + return 1; + + break; + } + case 'h': + func_model_print_help(stderr); + return 1; + + case 'i': + { + // shortcut for '-Cinput_tensor=' + if (func_model_config_set_option(func_config, "input_tensor", get_arg_text(i, argc, argv))) + { + func_model_print_help(stderr); + ARG_ERROR("Cannot set input tensor config value"); + } + break; + } + case 'l': + { + // Debug verbosity/logging level + func_debug_set_verbosity(func_debug, get_arg_text(i, argc, argv)); + break; + } + case 'o': + { + func_debug_set_file(func_debug, get_arg_text(i, argc, argv)); + break; + } + default: + func_model_print_help(stderr); + ARG_ERROR("Unrecognized argument at position %d: %s", i, argv[i]); + } + } + + return 0; +} + +int func_model_parse_flat_config_file(func_config_t* func_config, const char* filename) +{ + const int MAX_LINE_LEN = 1024; + + FILE* infile = nullptr; + char line_buf[MAX_LINE_LEN]; + int line = 1; + + infile = fopen(filename, "r"); + + if (infile == nullptr) + { + ARG_ERROR("Cannot open config file: %s\n", filename); + } + + while (fgets(line_buf, MAX_LINE_LEN - 1, infile) != nullptr) + { + char *name = line_buf, *value = nullptr, *comment = nullptr, *ptr = nullptr; + + // Remove comments + comment = strchr(line_buf, '#'); + + if (comment) + *comment = 0; + + // Break the string into name and value parts + name = line_buf; + + // Remove leading whitespace + while (*name && isspace(*name)) + name++; + + // Empty line? + if (*name == 0) + { + line++; + continue; + } + + value = strchr(name, '='); + + // Missing value + if (value == nullptr) + { + ARG_ERROR("Cannot parse parameter in %s at line %d: %s", filename, line, line_buf); + } + + // Remove the = + *value = 0; + value++; + + // Trim off any whitespace at the end of the value + ptr = value; + while (*ptr != 0 && !isspace(*ptr)) + ptr++; + *ptr = 0; + + // Include a nested file + if (!strcmp(name, "include")) + { + if (func_model_parse_flat_config_file(func_config, value)) + return 1; + line++; + continue; + } + + if (func_model_config_set_option(func_config, name, value)) + { + func_model_print_help(stderr); + ARG_ERROR("Cannot set parameter in %s at line %d: %s", filename, line, line_buf) + } + + line++; + } + + fclose(infile); + + return 0; +} diff --git a/reference_model/src/func_config.def b/reference_model/src/func_config.def new file mode 100644 index 0000000..004cf36 --- /dev/null +++ b/reference_model/src/func_config.def @@ -0,0 +1,90 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * Filename: src/func_config.def + * Description: + * Defines the model parameters/options for the functional model. + */ + +// Placeholder values for the Functional model Option Formatting (FOF) fields +// +// FOF_DEC is decimal +// FOF_HEX is hexidecimal +// +// Floating point values are not supported yet, but there is no fundamental reason +// why we can't have them. +#ifndef FOF_DEC +#define FOF_DEC 1 +#endif + +#ifndef FOF_HEX +#define FOF_HEX 1 +#endif + +#ifndef FOF_STR_LEN +#define FOF_STR_LEN 1024 +#endif + +// Options are defined as follows: +// DEF_OPTION() defines a top-level option +// Arguments: +// option_field_name: a C-syntax field name in the struct +// description: a short string that describes the purpose of the option (printed out with help) +// C type: the type of the option (typically a uint64_t, uint32_t, etc) +// Format field: the FOF_* type used to figure out how to format/print the option +// Default value: the default value assigned to the option, if it isn't assigned by an configuration file +// or command line override + +// For defining hierarchical options (example hierarchy is 'cle', use the following formula). +// All options within the hierarchical space must be grouped together: +// + +// #define CURRENT_UNIT cle +// DEF_UNIT_START(CURRENT_UNIT) +// DEF_UNIT_OPTION(CURRENT_UNIT,...) +// ... +// DEF_UNIT_END(CURRENT_UNIT) +// #undef CURRENT_UNIT +// +// The CURRENT_UNIT argument is required as a parameter in these definitions because +// macro processing rules only allow stringification of macro parameters. Unfortunately, +// Other tokens that are NOT passed in as macro parameters cannot be stringified. + +DEF_OPTION_STR(operator_fbs, "Flat buffer syntax file", FOF_STR_LEN, "../serialization/tosa.fbs") +DEF_OPTION_STR(subgraph_dir, "Subgraph directory to load", FOF_STR_LEN, ".") +DEF_OPTION_STR(subgraph_file, "Subgraph file to load", FOF_STR_LEN, "") +DEF_OPTION_STR(input_dir, "Input directory path for dumps/files", FOF_STR_LEN, ".") +DEF_OPTION_STR(input_tensor, "A list of pairs <name0>:<npy0>,<name1>:<npy1>", FOF_STR_LEN, "") +DEF_OPTION_STR(output_dir, "Output directory path for output dumps/files", FOF_STR_LEN, ".") +DEF_OPTION(eval, "Evaluate the network (0/1)", uint32_t, FOF_DEC, 1) +DEF_OPTION(validate_only, "Validate the network, but do not read inputs or evaluate (0/1)", uint32_t, FOF_DEC, 0) +DEF_OPTION(output_tensors, "Output tensors to a file (0/1)", uint32_t, FOF_DEC, 1) +DEF_OPTION(tosa_profile, "Set TOSA profile (0 = Base Inference, 1 = Main Inference, 2 = Main Training)", uint32_t, FOF_DEC, 1) +DEF_OPTION_STR(output_tensor_prefix, "Optional output tensor prefix", FOF_STR_LEN, "output_") +DEF_OPTION(dump_intermediates, "Dump intermediate tensors (0/1)", uint32_t, FOF_DEC, 0) +DEF_OPTION_STR(fp_format, "Floating-point number dump format string (printf-style format, e.g. 0.5)", FOF_STR_LEN, "0.5") +// Example of a hierarchical option +//#define CURRENT_UNIT arch +//DEF_UNIT_START(arch) +//DEF_UNIT_OPTION(arch, ifm_width, "input feature map width(x dim)", uint32_t, FOF_DEC, 10) +//DEF_UNIT_END(CURRENT_UNIT) +///#undef CURRENT_UNIT + +// START Do not delete +// Required for keeping the FOFs clean +#undef FOF_DEC +#undef FOF_HEX +// END Do not delete^^ diff --git a/reference_model/src/func_config.h b/reference_model/src/func_config.h new file mode 100644 index 0000000..f941300 --- /dev/null +++ b/reference_model/src/func_config.h @@ -0,0 +1,55 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef FUNC_CONFIG_H_ +#define FUNC_CONFIG_H_ + +// Parameter value structure +#define DEF_UNIT_START(UNIT) \ + struct UNIT##_t \ + { +#define DEF_UNIT_END(UNIT) \ + } \ + UNIT; +#define DEF_OPTION(NAME, DESC, TYPE, FMT, DEFAULT) TYPE NAME; +#define DEF_OPTION_STR(NAME, DESC, LEN, DEFAULT) char NAME[LEN]; +#define DEF_UNIT_OPTION(UNIT, NAME, DESC, TYPE, FMT, DEFAULT) TYPE NAME; +#define DEF_UNIT_OPTION_STR(UNIT, NAME, DESC, LEN, DEFAULT) char NAME[LEN]; +struct func_config_t +{ +#include "func_config.def" +#undef DEF_UNIT_START +#undef DEF_UNIT_END +#undef DEF_OPTION +#undef DEF_OPTION_STR +#undef DEF_UNIT_OPTION +#undef DEF_UNIT_OPTION_STR +}; + +// Forward declaration +struct func_debug_t; + +int func_model_init_config(); +int func_model_set_default_config(func_config_t*); +int func_model_config_set_option(func_config_t*, const char* name, const char* value); +int func_model_print_config(func_config_t*, FILE* out); +int func_model_parse_cmd_line(func_config_t*, func_debug_t* func_debug, const int argc, const char** argv); +int func_model_parse_flat_config_file(func_config_t*, const char* filename); +int func_model_config_cleanup(); +int func_model_config_get_str_option_by_name(func_config_t*, const char* name, char* value, const uint32_t len); +int func_model_config_get_option_by_name(func_config_t*, const char* name, uint64_t* val); +int func_model_print_help(FILE* out); + +#endif diff --git a/reference_model/src/func_debug.cc b/reference_model/src/func_debug.cc new file mode 100644 index 0000000..f5f045e --- /dev/null +++ b/reference_model/src/func_debug.cc @@ -0,0 +1,436 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <ctype.h> +#include <signal.h> +#include <stdarg.h> +#include <stdio.h> +#include <stdlib.h> +#include <string.h> +#include <sys/types.h> + +#ifndef _MSC_VER +#include <execinfo.h> +#include <sys/prctl.h> +#include <sys/ptrace.h> +#include <sys/wait.h> +#include <unistd.h> +#endif + +#include "func_debug.h" + +#define MAX_FRAMES 100 + +#ifndef _MSC_VER +pid_t func_print_backtrace_helper(int num_tries, int sig); +#endif + +void func_print_backtrace(FILE* out, int sig) +{ +#ifndef _MSC_VER + for (int i = 0; i < 2; i++) + { + const pid_t child_pid = func_print_backtrace_helper(i, sig); + if (child_pid < 0) + { + perror("Backtrace generation failed on fork"); + break; + } + + int status = 0; + waitpid(child_pid, &status, 0); + if (WEXITSTATUS(status) == 0) + { + break; + } + } +#endif +} + +#ifndef _MSC_VER +pid_t func_print_backtrace_helper(int num_tries, int sig) +{ + const pid_t child_pid = fork(); + + if (child_pid) + { + return 0; + } + + const pid_t ppid = getppid(); + + printf("Attaching debugger to pid %d\n", ppid); + // Check if we're in a debugger + if (ptrace(PTRACE_ATTACH, ppid, 0, 0) == 0) + { + // If we reach this point, no debugger is present + // Undo effects of PTRACE_ATTACH + waitpid(ppid, NULL, 0); + ptrace(PTRACE_CONT, 0, 0, 0); + ptrace(PTRACE_DETACH, ppid, 0, 0); + + dup2(STDERR_FILENO, STDOUT_FILENO); + + char parent_pid[20]; + snprintf(parent_pid, sizeof(parent_pid), "attach %d", ppid); + fprintf(stdout, "Caught signal %d (%s)\n", sig, strsignal(sig)); + + execlp("gdb", "gdb", "--batch", "-n", "-ex", + // Don't print startup messages for each thread + "-ex", "set print thread-events off", "-ex", parent_pid, + // Turn off pagination + "-ex", "set height 0", + // Print a backtrace for the current thread + "-ex", "thread $_thread", "-ex", "bt", + // Print a backtrace for the main thread (uncomment the next two lines, if desired) + //"-ex", "thread 1", + //"-ex", "bt", + // Print a backtrace for all thread (TMI) + //"-ex", "thread apply all bt", + NULL); + + // If we reach this point, it is bad. Attempt to print an error before exiting. + perror("Backtrace generation failed to invoke gdb"); + exit(1); + } + + // Debugger present. Exit here. + exit(0); + + return 0; +} +#endif + +void func_backtrace_signal_handler(int sig) +{ + func_print_backtrace(NULL, sig); + exit(1); +} + +// Note: this overwrites other signal handlers. May want to make this +// more friendly sometime +void func_enable_signal_handlers() +{ + static const int sig_list[] = { SIGABRT, SIGSEGV, SIGILL, SIGFPE }; + + if (getenv("FUNC_NO_SIG_HANDLERS")) + { + return; + } + + for (size_t i = 0; i < sizeof(sig_list) / sizeof(int); i++) + { + struct sigaction act; + + bzero(&act, sizeof(act)); + act.sa_handler = func_backtrace_signal_handler; + + if (sigaction(sig_list[i], &act, NULL)) + { + perror("Error calling sigaction"); + } + } +} + +const char* func_debug_mode_str_table[] = { +#define DEBUG_MODE(NAME, BIT) #NAME, +#include "debug_modes.def" +#undef DEBUG_MODE +}; + +#define DEBUG_MASK_COUNT (sizeof(func_debug_mode_str_table) / sizeof(const char*)) + +const char* func_debug_verbosity_str_table[] = { "NONE", "INFO", "IFACE", "LOW", "MED", "HIGH" }; + +const uint32_t func_debug_verbosity_mask_table[] = { DEBUG_VERB_NONE, DEBUG_VERB_INFO, DEBUG_VERB_IFACE, + DEBUG_VERB_LOW, DEBUG_VERB_MED, DEBUG_VERB_HIGH }; + +#define DEBUG_VERBOSITY_COUNT (sizeof(func_debug_verbosity_str_table) / sizeof(const char*)) + +// Initialize the debug mode +int func_init_debug(func_debug_t* func_debug, uint64_t inst_id) +{ + // Set the default debug settings + bzero(func_debug, sizeof(func_debug_t)); + func_debug_set_mask(func_debug, DEBUG_NONE); + func_debug_set_verbosity(func_debug, DEBUG_VERB_NONE); + func_debug_set_inst_mask(func_debug, DEBUG_INST_ALL); + func_debug->func_debug_file = stderr; + func_debug_set_captured_warnings(func_debug, 0); + func_debug_set_output_unbuffered(func_debug, false); + func_debug->inst_id = inst_id; + + return 0; +} + +int func_fini_debug(func_debug_t* func_debug) +{ + if (func_debug->record_warnings) + { + func_debug_set_captured_warnings(func_debug, 0); + } + +#ifndef _FUNC_INCLUDE_WINDOWS_SUPPORT_H + if (func_debug->is_gzip && func_debug->func_debug_file) + { + pclose(func_debug->func_debug_file); + func_debug->func_debug_file = NULL; + } +#endif + + return 0; +} + +int func_debug_set_file(func_debug_t* func_debug, const char* filename) +{ + int filenameLen = strlen(filename); + + // Open the debug output file + ASSERT(filename != NULL); +#ifndef _FUNC_INCLUDE_WINDOWS_SUPPORT_H + if (filenameLen > 3 && strcmp(filename + filenameLen - 3, ".gz") == 0) + { + char cmd[256]; + + snprintf(cmd, sizeof(cmd), "gzip > %s", filename); + func_debug->func_debug_file = popen(cmd, "w"); + func_debug->is_gzip = 1; + } + else + { +#else + { +#endif + func_debug->func_debug_file = fopen(filename, "w"); + } + + if (!func_debug->func_debug_file) + { + perror(NULL); + FATAL_ERROR("Cannot open debug output file: %s\n", filename); + return 1; + } + if (func_debug->is_output_unbuffered) + { + setvbuf(func_debug->func_debug_file, nullptr, _IONBF, 0); + } + + return 0; +} + +void func_debug_set_verbosity(func_debug_t* func_debug, const char* str) +{ + if (!strcasecmp(str, "RESET")) + { + func_debug_set_verbosity(func_debug, DEBUG_VERB_NONE); + return; + } + + for (size_t i = 0; i < DEBUG_VERBOSITY_COUNT; i++) + { + if (!strcasecmp(str, func_debug_verbosity_str_table[i])) + { + func_debug_set_verbosity(func_debug, func_debug_verbosity_mask_table[i]); + return; + } + } + + FATAL_ERROR("Invalid debug verbosity: %s", str); +} + +void func_debug_set_verbosity(func_debug_t* func_debug, const uint32_t verb) +{ + uint32_t new_mask = verb; + + switch (verb) + { + case DEBUG_VERB_NONE: + new_mask = DEBUG_VERB_NONE; + break; + case DEBUG_VERB_INFO: + new_mask = DEBUG_VERB_INFO; + break; + case DEBUG_VERB_IFACE: + new_mask = DEBUG_VERB_IFACE; + break; + case DEBUG_VERB_HIGH: + new_mask |= DEBUG_VERB_HIGH; + // Intentional fallthrough + case DEBUG_VERB_MED: + new_mask |= DEBUG_VERB_MED; + // Intentional fallthrough + case DEBUG_VERB_LOW: + new_mask |= DEBUG_VERB_LOW; + new_mask |= DEBUG_VERB_INFO; + new_mask |= DEBUG_VERB_IFACE; + break; + } + + func_debug->func_debug_verbosity = new_mask; +} + +void func_debug_set_suppress_arch_error_mask(func_debug_t* func_debug, const uint32_t suppress) +{ + func_debug->func_suppress_arch_error_mask = suppress; +} + +void func_debug_set_mask(func_debug_t* func_debug, const uint64_t mask) +{ + if (mask == DEBUG_NONE) + func_debug->func_debug_mask = mask; + else + func_debug->func_debug_mask |= mask; + + // Set a minimum verbosity level + if (func_debug->func_debug_verbosity == DEBUG_VERB_NONE) + func_debug->func_debug_verbosity = DEBUG_VERB_INFO; +} + +void func_debug_set_inst_mask(func_debug_t* func_debug, const char* mask) +{ + uint64_t val; + + val = strtoul(mask, NULL, 0); + + return func_debug_set_inst_mask(func_debug, val); +} + +void func_debug_set_inst_mask(func_debug_t* func_debug, const uint64_t mask) +{ + if (mask == 0) + func_debug->func_debug_inst_mask = DEBUG_INST_ALL; + else + func_debug->func_debug_inst_mask = mask; +} + +void func_debug_set_mask(func_debug_t* func_debug, const char* str) +{ + if (!strcasecmp(str, "all")) + { + func_debug_set_mask(func_debug, UINT64_MAX - 1); + return; + } + + size_t i; + for (i = 0; i < DEBUG_MASK_COUNT; i++) + { + if (!strcasecmp(str, func_debug_mode_str_table[i])) + { + func_debug_set_mask(func_debug, 1ULL << i); + return; + } + } + + func_debug_print_masks(stderr); + + FATAL_ERROR("Invalid debug mask: %s", str); +} + +void func_debug_print_masks(FILE* out) +{ + uint32_t i; + + fprintf(out, "Available debug masks:\n"); + + for (i = 0; i < DEBUG_MASK_COUNT; i++) + { + fprintf(out, "[%d] %s\n", i, func_debug_mode_str_table[i]); + } +} + +void func_debug_set_output_unbuffered(func_debug_t* func_debug, const bool is_unbuffered) +{ + func_debug->is_output_unbuffered = is_unbuffered; +} + +// Print warnings to the debug file or optionally store them in a buffer instead +// Note that the buffer is circular and can be overwritten if enough messages are +// written before removing a warning from the front. +void func_debug_warning( + func_debug_t* func_debug, const char* file, const char* func, const int line, const char* fmt, ...) +{ + va_list args; + va_start(args, fmt); + + if (func_debug->record_warnings) + { + // Record to the circular buffer + uint32_t len; + + len = snprintf(func_debug->warning_buffer[func_debug->warning_buffer_tail], WARNING_BUFFER_ENTRY_LENGTH, + "WARNING AT %s:%d %s(): ", file, line, func); + vsnprintf(func_debug->warning_buffer[func_debug->warning_buffer_tail] + len, WARNING_BUFFER_ENTRY_LENGTH - len, + fmt, args); + func_debug->warning_buffer_tail = (func_debug->warning_buffer_tail + 1) % WARNING_BUFFER_SIZE; + } + else + { + // Print to the debug file (e.g., stderr) + fprintf(func_debug->func_debug_file, "WARNING AT %s:%d %s():\n", file, line, func); + vfprintf(func_debug->func_debug_file, fmt, args); + fprintf(func_debug->func_debug_file, "\n"); + } + va_end(args); +} + +// Initialize the warning buffer capture +int func_debug_set_captured_warnings(func_debug_t* func_debug, uint32_t capture) +{ + uint32_t i; + func_debug->record_warnings = capture; + if (capture) + { + func_debug->warning_buffer_head = 0; + func_debug->warning_buffer_tail = 0; + + for (i = 0; i < WARNING_BUFFER_SIZE; i++) + { + func_debug->warning_buffer[i] = (char*)calloc(1, WARNING_BUFFER_ENTRY_LENGTH); + } + } + else + { + for (i = 0; i < WARNING_BUFFER_SIZE; i++) + { + if (func_debug->warning_buffer[i]) + { + free(func_debug->warning_buffer[i]); + func_debug->warning_buffer[i] = NULL; + } + } + } + + return 0; +} + +int func_debug_has_captured_warning(func_debug_t* func_debug) +{ + if (func_debug->record_warnings && func_debug->warning_buffer_head != func_debug->warning_buffer_tail) + return 1; + else + return 0; +} + +int func_debug_get_captured_warning(func_debug_t* func_debug, char* buf_ptr, const uint32_t buf_len) +{ + if (!func_debug_has_captured_warning(func_debug)) + return 1; + + strncpy(buf_ptr, func_debug->warning_buffer[func_debug->warning_buffer_head], buf_len); + + func_debug->warning_buffer_head = (func_debug->warning_buffer_head + 1) % WARNING_BUFFER_SIZE; + + return 0; +} diff --git a/reference_model/src/func_debug.h b/reference_model/src/func_debug.h new file mode 100644 index 0000000..2d47462 --- /dev/null +++ b/reference_model/src/func_debug.h @@ -0,0 +1,255 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef FUNC_DEBUG_H +#define FUNC_DEBUG_H + +#include "debug_types.h" +#include <assert.h> +#include <cinttypes> +#include <signal.h> +#include <stdio.h> + +void func_print_backtrace(FILE* out, int sig = SIGABRT); + +void func_enable_signal_handlers(); + +// Debug content container +#define WARNING_BUFFER_SIZE 16 +#define WARNING_BUFFER_ENTRY_LENGTH 1024 + +// STRINGIFY2 is needed expand expression passed to STRINGIFY +#define STRINGIFY2(s) #s +#define STRINGIFY(s) STRINGIFY2(s) + +// If TRACED_LOG is defined, add file:line to log messages +#if defined(TRACED_LOG) +#define WHERE "@" __FILE__ ":" STRINGIFY(__LINE__) +#else +#define WHERE +#endif + +#if defined(COLORIZED_LOG) +#define COL(col, fmt) "\x1b[3" col "m" fmt "\x1b[0m" +#define COL_FATAL(fmt) COL("1;41", fmt) +#define COL_WARN(fmt) COL("1;43", fmt) +#define COL_INFO(fmt) COL("2", fmt) +#define COL_IFACE(fmt) fmt +#define COL_LOW(fmt) COL("35", fmt) +#define COL_MED(fmt) COL("2;33", fmt) +#define COL_HIGH(fmt) COL("2;32", fmt) +#else +#define COL_FATAL(fmt) fmt +#define COL_WARN(fmt) fmt +#define COL_INFO(fmt) fmt +#define COL_IFACE(fmt) fmt +#define COL_LOW(fmt) fmt +#define COL_MED(fmt) fmt +#define COL_HIGH(fmt) fmt +#endif + +struct func_debug_t +{ + uint32_t func_debug_verbosity; // What verbosity level is set? (bitmask) + uint64_t func_debug_mask; // Which units have debugging enabled? (bitmask) + uint64_t func_debug_inst_mask; // Which instances have debugging enabled (bitmask) + uint64_t inst_id; // The instance id for multiple model instances + uint32_t func_suppress_arch_error_mask; // Which architecture error should be suppressed? (bitmask) + FILE* func_debug_file; // Output file + uint32_t record_warnings; + char* warning_buffer[WARNING_BUFFER_SIZE]; + uint32_t warning_buffer_head; // next unread message + uint32_t warning_buffer_tail; // next message to write + uint32_t is_gzip; + bool is_output_unbuffered; // should log files be opened with unbuffered I/O. +}; + +#ifndef ASSERT +#define ASSERT(COND) \ + if (!(COND)) \ + { \ + fprintf(stderr, COL_FATAL("ASSERTION AT %s:%d %s(): (%s)\n"), __FILE__, __LINE__, __func__, #COND); \ + func_print_backtrace(stderr); \ + assert(COND); \ + } +#endif + +#ifndef ASSERT_MSG +#define ASSERT_MSG(COND, fmt, ...) \ + if (!(COND)) \ + { \ + fprintf(stderr, COL_FATAL("ASSERTION AT %s:%d %s(): (%s)\n"), __FILE__, __LINE__, __func__, #COND); \ + fprintf(stderr, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \ + func_print_backtrace(stderr); \ + assert(COND); \ + } +#endif + +#ifndef ASSERT_MSG_NODE +#define ASSERT_MSG_NODE(COND, fmt, ...) \ + if (!(COND)) \ + { \ + fprintf(g_func_debug.func_debug_file, COL_FATAL("ASSERTION AT %s:%d %s(): (%s)\n"), __FILE__, __LINE__, \ + __func__, #COND); \ + fprintf(g_func_debug.func_debug_file, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \ + this->dumpNode(g_func_debug.func_debug_file); \ + func_print_backtrace(g_func_debug.func_debug_file); \ + assert(COND); \ + } +#endif + +// Assertion specific to allocating memory +#ifndef ASSERT_MEM +#define ASSERT_MEM(OBJ) \ + if (!(OBJ)) \ + { \ + fprintf(stderr, COL_FATAL("ASSERTION AT %s:%d %s(): (" #OBJ "): out of memory\n"), __FILE__, __LINE__, \ + __func__); \ + func_print_backtrace(stderr); \ + assert(OBJ); \ + } +#endif + +#ifndef FATAL_ERROR +#define FATAL_ERROR(fmt, ...) \ + fprintf(stderr, COL_FATAL("FATAL ERROR AT %s:%d %s():\n"), __FILE__, __LINE__, __func__); \ + fprintf(stderr, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \ + func_print_backtrace(stderr); \ + abort(); +#endif + +#ifndef FATAL_ERROR_NODE +#define FATAL_ERROR_NODE(fmt, ...) \ + fprintf(g_func_debug.func_debug_file, COL_FATAL("FATAL ERROR AT %s:%d %s():\n"), __FILE__, __LINE__, __func__); \ + fprintf(g_func_debug.func_debug_file, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \ + this->dumpNode(g_func_debug.func_debug_file); \ + func_print_backtrace(g_func_debug.func_debug_file); \ + abort(); +#endif +#ifndef SIMPLE_FATAL_ERROR +#define SIMPLE_FATAL_ERROR(fmt, ...) \ + fprintf(stderr, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \ + exit(1); +#endif + +void func_debug_warning( + func_debug_t* func_debug, const char* file, const char* func, const int line, const char* fmt, ...); +#ifndef WARNING +#define WARNING(...) func_debug_warning(&g_func_debug, __FILE__, __func__, __LINE__, __VA_ARGS__) +#endif + +#ifndef WARNING_STDERR +#define WARNING_STDERR(fmt, ...) \ + fprintf(stderr, COL_WARN("WARNING AT %s:%d %s():\n"), __FILE__, __LINE__, __func__); \ + fprintf(stderr, COL_WARN(fmt) "\n", ##__VA_ARGS__); +#endif + +int func_debug_set_captured_warnings(func_debug_t* func_debug, uint32_t capture); + +int func_debug_has_captured_warning(func_debug_t* func_debug); + +int func_debug_get_captured_warning(func_debug_t* func_debug, char* buf_ptr, const uint32_t buf_len); + +// Is this debug verbosity and unit level enabled? +// Provide compiler hints that this is unlikely +// Two versions, depending on whether DEBUG_INSTANCE_EXPR is defined in a file or not +// +// For .cpp files whose units have discrete instance IDs, define DEBUG_INSTANCE_EXPR to evalute +// to the instance ID variable. The use of this define in header files is discouraged. + +#ifdef DEBUG_INSTANCE_EXPR +// Expression for whether the debugging verbosity + debugging unit is enabled for free-form printouts +#ifdef DEBUG_INSTANCE_EXPR_2 +#define DEBUG_ENABLED(VERB, LEVEL) \ + (__builtin_expect((g_func_debug.func_debug_mask == DEBUG_ALL || g_func_debug.func_debug_mask & (DEBUG_##LEVEL)) && \ + (g_func_debug.func_debug_inst_mask & (uint64_t(1) << (DEBUG_INSTANCE_EXPR))) && \ + (g_func_debug.func_debug_verbosity & (VERB)), \ + 0)) +// Debug printing macro +#define DEBUG(VERB, LEVEL, FMT, ...) \ + if (DEBUG_ENABLED(VERB, LEVEL)) \ + { \ + fprintf(g_func_debug.func_debug_file, "[%d:" #LEVEL "_%02d_%02d" WHERE "]: " FMT "\n", \ + (int)g_func_debug.inst_id, (int)(DEBUG_INSTANCE_EXPR), (int)(DEBUG_INSTANCE_EXPR_2), ##__VA_ARGS__); \ + } + +// Prints just the debugging prefix for properly marking free-form printouts +#define DEBUG_PREFIX(LEVEL) \ + fprintf(g_func_debug.func_debug_file, "[%d" #LEVEL "_%02d_%02d" WHERE "]: ", (int)g_func_debug.inst_id, \ + (int)(DEBUG_INSTANCE_EXPR), (int)(DEBUG_INSTANCE_EXPR_2)) + +#else // !DEBUG_INSTANCE_EXPR_2 + +#define DEBUG_ENABLED(VERB, LEVEL) \ + (__builtin_expect((g_func_debug.func_debug_mask == DEBUG_ALL || g_func_debug.func_debug_mask & (DEBUG_##LEVEL)) && \ + (g_func_debug.func_debug_inst_mask & (uint64_t(1) << (DEBUG_INSTANCE_EXPR))) && \ + (g_func_debug.func_debug_verbosity & (VERB)), \ + 0)) +// Debug printing macro +#define DEBUG(VERB, LEVEL, FMT, ...) \ + if (DEBUG_ENABLED(VERB, LEVEL)) \ + { \ + fprintf(g_func_debug.func_debug_file, "[%d:" #LEVEL "_%02d" WHERE "]: " FMT "\n", (int)g_func_debug.inst_id, \ + (int)(DEBUG_INSTANCE_EXPR), ##__VA_ARGS__); \ + } + +// Prints just the debugging prefix for properly marking free-form printouts +#define DEBUG_PREFIX(LEVEL) \ + fprintf(g_func_debug.func_debug_file, "[%d:" #LEVEL "_%02d" WHERE "]: ", (int)g_func_debug.inst_id, \ + (int)(DEBUG_INSTANCE_EXPR)) + +#endif // DEBUG_INSTANCE_EXPR_2 + +#else // !DEBUG_INSTANCE_EXPR + +// Expression for whether the debugging verbosity + debugging unit is enabled for free-form printouts +#define DEBUG_ENABLED(VERB, LEVEL) \ + (__builtin_expect((g_func_debug.func_debug_mask == DEBUG_ALL || g_func_debug.func_debug_mask & (DEBUG_##LEVEL)) && \ + (g_func_debug.func_debug_verbosity & (VERB)), \ + 0)) +// Debug printing macro +#define DEBUG(VERB, LEVEL, FMT, ...) \ + if (DEBUG_ENABLED(VERB, LEVEL)) \ + { \ + fprintf(g_func_debug.func_debug_file, "[%d:" #LEVEL WHERE "]: " FMT "\n", (int)g_func_debug.inst_id, \ + ##__VA_ARGS__); \ + } + +// Prints just the debugging prefix for properly marking free-form printouts +#define DEBUG_PREFIX(LEVEL) fprintf(g_func_debug.func_debug_file, "[" #LEVEL WHERE "]: ") + +#endif + +// Macros for different verbosity levels +#define DEBUG_INFO(LEVEL, FMT, ...) DEBUG(DEBUG_VERB_INFO, LEVEL, COL_INFO(FMT), ##__VA_ARGS__) +#define DEBUG_IFACE(LEVEL, FMT, ...) DEBUG(DEBUG_VERB_IFACE, LEVEL, COL_IFACE(FMT), ##__VA_ARGS__) +#define DEBUG_LOW(LEVEL, FMT, ...) DEBUG(DEBUG_VERB_LOW, LEVEL, COL_LOW(FMT), ##__VA_ARGS__) +#define DEBUG_MED(LEVEL, FMT, ...) DEBUG(DEBUG_VERB_MED, LEVEL, COL_MED(FMT), ##__VA_ARGS__) +#define DEBUG_HIGH(LEVEL, FMT, ...) DEBUG(DEBUG_VERB_HIGH, LEVEL, COL_HIGH(FMT), ##__VA_ARGS__) + +int func_init_debug(func_debug_t*, uint64_t inst_id); +int func_fini_debug(func_debug_t*); +int func_debug_set_file(func_debug_t*, const char* filename); +void func_debug_set_mask(func_debug_t*, const char* str); +void func_debug_set_mask(func_debug_t*, const uint64_t mask); +void func_debug_print_masks(FILE* out); +void func_debug_set_verbosity(func_debug_t*, const char* str); +void func_debug_set_verbosity(func_debug_t*, const uint32_t verb); +void func_debug_set_suppress_arch_error_mask(func_debug_t*, const uint32_t suppress); +void func_debug_set_inst_mask(func_debug_t*, const char* mask); +void func_debug_set_inst_mask(func_debug_t*, const uint64_t mask); +void func_debug_set_output_unbuffered(func_debug_t*, const bool is_unbuffered); + +#endif diff --git a/reference_model/src/graph_node.cc b/reference_model/src/graph_node.cc new file mode 100644 index 0000000..b57b9dd --- /dev/null +++ b/reference_model/src/graph_node.cc @@ -0,0 +1,226 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "graph_node.h" + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +GraphNode::GraphNode(const Op& nodeType_, const uint64_t id_) +{ + nodeType = nodeType_; + nodeId = id_; + inputs.clear(); + outputs.clear(); + inputNames.clear(); + outputNames.clear(); + clearNodeMarked(); + evalCount = 0; + clearOnNextNodeList(); + setRequiredOperands(-1, -1); + setRequiredRank(-1); +} + +GraphNode::~GraphNode() +{} + +int GraphNode::addInputName(std::string& name) +{ + inputNames.push_back(name); + return 0; +} + +int GraphNode::addOutputName(std::string& name) +{ + outputNames.push_back(name); + return 0; +} + +int GraphNode::addInputTensor(Tensor* tens) +{ + ASSERT_MSG(tens, "GraphNode::addInputTensor: no tensor provided"); + inputs.push_back(tens); + return 0; +} + +int GraphNode::addOutputTensor(Tensor* tens) +{ + ASSERT_MSG(tens, "GraphNode::addOutputTensor: no tensor provided"); + outputs.push_back(tens); + return 0; +} + +int GraphNode::checkTensorAttributes() +{ + // Placeholder + return 0; +} + +int GraphNode::eval() +{ + // Placeholder evaluation function + evalCount++; + + // this should be set by derived op + for (auto ct : getOutputs()) + { + ct->setIsValid(); + } + + return 0; +} + +int GraphNode::hasAllInputsReady() const +{ + for (size_t i = 0; i < inputs.size(); i++) + { + if (!inputs[i]->getIsValid()) + return false; + } + + return true; +} + +int GraphNode::hasAllOutputsReady() const +{ + for (size_t i = 0; i < outputs.size(); i++) + { + if (!outputs[i]->getIsValid()) + return false; + } + + return true; +} + +int GraphNode::dumpNode(FILE* out) +{ + int i; + fprintf(out, "Node type: %s ID: %lu Eval Count: %d On next node list: %d Is marked: %d\n", EnumNamesOp()[nodeType], + nodeId, evalCount, onNextNodeList, isMarked); + + i = 0; + for (Tensor* ins : inputs) + { + fprintf(out, " Input[%d] ", i++); + ins->dumpTensorParams(out); + } + + i = 0; + for (Tensor* outs : outputs) + { + fprintf(out, " Output[%d] ", i++); + outs->dumpTensorParams(out); + } + + return 0; +} + +int GraphNode::dumpNode(std::ostream& out) +{ + int i; + + out << "Node type: " << EnumNamesOp()[nodeType] << " ID: " << nodeId << " Eval count: " << evalCount + << " On next node list: " << onNextNodeList << " Is marked: " << isMarked << std::endl; + + out << " Inputs:"; + for (std::string& name : inputNames) + { + out << " " << name; + } + out << std::endl; + + i = 0; + for (Tensor* ins : inputs) + { + out << " Input[" << i++ << "]: "; + ins->dumpTensorParams(out); + } + + out << " Outputs:"; + for (std::string& name : outputNames) + { + out << " " << name; + } + out << std::endl; + + i = 0; + for (Tensor* outs : outputs) + { + out << " Output[" << i++ << "]: "; + outs->dumpTensorParams(out); + } + return 0; +} + +int GraphNode::printNodeValidationError(const std::string& msg) +{ + std::cout << "Operator validation error: " << msg << std::endl; + ; + dumpNode(std::cout); + + return 0; +} + +int GraphNode::validateRequiredOperands() +{ + if (requiredInputCount >= 0 && inputs.size() != (size_t)requiredInputCount) + { + printNodeValidationError(std::string(EnumNamesOp()[nodeType]) + " operator must have " + + std::to_string(requiredInputCount) + " input(s)"); + return 1; + } + + if (requiredOutputCount >= 0 && outputs.size() != (size_t)requiredOutputCount) + { + printNodeValidationError(std::string(EnumNamesOp()[nodeType]) + " operator output must have exactly " + + std::to_string(requiredOutputCount) + " output(s)"); + return 1; + } + + return 0; +} + +int GraphNode::validateRequiredRank(const Tensor* t) +{ + if (requiredRankMin >= 0 && requiredRankMax >= 0) + { + if (t->checkRequiredRank(requiredRankMin, requiredRankMax)) + { + printNodeValidationError(std::string(EnumNamesOp()[nodeType]) + + " operand has illegal rank=" + std::to_string(t->getRank()) + " not in range [" + + std::to_string(requiredRankMin) + "," + std::to_string(requiredRankMax) + + "]. tensorName: " + t->getName()); + return 1; + } + else + { + return 0; + } + } + + if (requiredRankMin >= 0) + { + if (t->checkRequiredRank(requiredRankMin)) + { + printNodeValidationError(std::string(EnumNamesOp()[nodeType]) + + " operand has illegal rank=" + std::to_string(t->getRank()) + " not equal to " + + std::to_string(requiredRankMin) + ". tensorName: " + t->getName()); + return 1; + } + } + + return 0; +} diff --git a/reference_model/src/graph_node.h b/reference_model/src/graph_node.h new file mode 100644 index 0000000..5b4a767 --- /dev/null +++ b/reference_model/src/graph_node.h @@ -0,0 +1,354 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GRAPH_NODE_H +#define GRAPH_NODE_H + +#include "attribute.h" +#include "quant_info.h" +#include "tensor.h" +#include "tosa_generated.h" +#include <iostream> + +#define DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) template class TosaReference::OP<RANK, DType_##DTYPE>; + +#define DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, RANK, DTYPE1, DTYPE2) \ + template class TosaReference::OP<RANK, DType_##DTYPE1, DType_##DTYPE2>; + +#define DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, RANK1, RANK2, DTYPE) \ + template class TosaReference::OP<RANK1, RANK2, DType_##DTYPE>; + +#define DEF_INSTANTIATE_TWO_RANK_TWO_TYPE(OP, RANK1, RANK2, DTYPE1, DTYPE2) \ + template class TosaReference::OP<RANK1, RANK2, DType_##DTYPE1, DType_##DTYPE2>; + +#define DEF_INSTANTIATE_ONE_RANK_0_6(OP) \ + template class TosaReference::OP<0>; \ + template class TosaReference::OP<1>; \ + template class TosaReference::OP<2>; \ + template class TosaReference::OP<3>; \ + template class TosaReference::OP<4>; \ + template class TosaReference::OP<5>; \ + template class TosaReference::OP<6>; + +#define DEF_INSTANTIATE_ONE_TYPE(OP, DTYPE) template class TosaReference::OP<DType_##DTYPE>; + +#define DEF_INSTANTIATE_TWO_TYPE(OP, DTYPE1, DTYPE2) template class TosaReference::OP<DType_##DTYPE1, DType_##DTYPE2>; + +#define DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 0, DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 1, DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 2, DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 3, DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 4, DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 5, DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 6, DTYPE) + +#define DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 1, DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 2, DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 3, DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 4, DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 5, DTYPE) \ + DEF_INSTANTIATE_ONE_RANK_ONE_TYPE(OP, 6, DTYPE) + +#define DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OP, DTYPE1, DTYPE2) \ + DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 0, DTYPE1, DTYPE2) \ + DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 1, DTYPE1, DTYPE2) \ + DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 2, DTYPE1, DTYPE2) \ + DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 3, DTYPE1, DTYPE2) \ + DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 4, DTYPE1, DTYPE2) \ + DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 5, DTYPE1, DTYPE2) \ + DEF_INSTANTIATE_ONE_RANK_TWO_TYPE(OP, 6, DTYPE1, DTYPE2) + +#define DEF_INSTANTIATE_RESHAPE(OP, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 0, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 1, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 2, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 3, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 4, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 5, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 0, 6, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 0, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 1, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 2, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 3, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 4, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 5, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 6, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 0, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 1, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 2, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 3, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 4, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 5, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 6, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 0, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 1, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 2, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 3, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 4, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 5, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 6, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 0, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 1, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 2, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 3, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 4, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 5, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 6, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 0, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 1, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 2, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 3, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 4, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 5, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 6, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 0, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 1, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 2, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 3, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 4, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 5, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 6, DTYPE) + +#define DEF_INSTANTIATE_GATHER(OP, DTYPE) \ + /* gather op takes input and index rank as template argument */ \ + /* note output rank = input rank - 1 + index rank */ \ + /* and max rank allowed in tosa_reference is 6 */ \ + /* so only specific input and index pair is instantiated */ \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 1, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 2, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 3, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 4, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 5, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 1, 6, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 1, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 2, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 3, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 4, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 2, 5, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 1, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 2, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 3, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 3, 4, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 1, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 2, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 4, 3, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 1, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 5, 2, DTYPE) \ + DEF_INSTANTIATE_TWO_RANK_ONE_TYPE(OP, 6, 1, DTYPE) + +#define INIT_ATTRIBUTE(ATTRIBUTE_NAME) \ + if (auto p = dynamic_cast<Tosa##ATTRIBUTE_NAME##Attribute*>(attribute_)) \ + { \ + attribute = new Tosa##ATTRIBUTE_NAME##Attribute(p); \ + ASSERT_MEM(attribute); \ + } \ + else \ + { \ + FATAL_ERROR("Can't initialize Tosa" #ATTRIBUTE_NAME "Attribute"); \ + } + +#define INIT_QINFO(QINFO_NAME) \ + if (auto p = dynamic_cast<Tosa##QINFO_NAME##QuantInfo*>(qinfo_)) \ + { \ + qinfo = new Tosa##QINFO_NAME##QuantInfo(p); \ + ASSERT_MEM(qinfo); \ + } \ + else \ + { \ + qinfo = nullptr; \ + } + +namespace TosaReference +{ + +// Nodes in the graph (e.g., tosa operators) are defined with this base +// class. +class GraphNode +{ +public: + GraphNode(const tosa::Op& nodeType, const uint64_t id_); + virtual ~GraphNode(); + + int addInputName(std::string& name); + int addOutputName(std::string& name); + + int addInputTensor(Tensor* tens); + int addOutputTensor(Tensor* tens); + + // Validate that the input tensors match properly + // in their types, attributes, rank, etc well enough to be + // processed. + // + // This function should be pure virtual (eventually) in order to force + // derivative operators to implement the check, but we'll initially + // provide a default function so that GraphNode can be instantiated + // directly for testing purposes. + virtual int checkTensorAttributes(); + + // Evalute the node/operator + virtual int eval(); + + int hasAllInputsReady() const; + int hasAllOutputsReady() const; + + int dumpNode(FILE* out); + int dumpNode(std::ostream& out); + + int setNodeMarked() + { + isMarked = true; + return 0; + } + + int getNodeMarked() const + { + return isMarked; + } + + int clearNodeMarked() + { + isMarked = false; + return 0; + } + + int getEvalCount() const + { + return evalCount; + } + + uint64_t getID() const + { + return nodeId; + } + + std::vector<std::string>& getInputNames() + { + return inputNames; + } + + std::vector<std::string>& getOutputNames() + { + return outputNames; + } + + std::vector<Tensor*>& getOutputs() + { + return outputs; + } + + std::vector<Tensor*>& getInputs() + { + return inputs; + } + + int getOnNextNodeList() const + { + return onNextNodeList; + } + + int setOnNextNodeList() + { + onNextNodeList = true; + return 0; + } + + int clearOnNextNodeList() + { + onNextNodeList = false; + return 0; + } + + tosa::Op getOp() const + { + return nodeType; + } + +protected: + // Print out a node validation error + int printNodeValidationError(const std::string& msg); + + int setRequiredOperands(const int in, const int out) + { + requiredInputCount = in; + requiredOutputCount = out; + return 0; + } + + int setRequiredRank(const int min, const int max = -1) + { + if (max == -1) + { + requiredRankMin = requiredRankMax = min; + } + else + { + requiredRankMin = min; + requiredRankMax = max; + } + + ASSERT_MSG(requiredRankMin <= requiredRankMax, + "GraphNode::setRequiredRank: requiredRankMin %d must be <= requiredRankMax %d", requiredRankMin, + requiredRankMax); + + return 0; + } + + int validateRequiredOperands(); + int validateRequiredRank(const Tensor* t); + + // Description of the node type (e.g., CONST, CONV2D, etc...) + tosa::Op nodeType; + + // A list of input tensor names + std::vector<std::string> inputNames; + + // A list of the output tensor names + std::vector<std::string> outputNames; + + // A list of the input tensors (after names have been matched up) + std::vector<Tensor*> inputs; + + // A list of the output tensors (after names have been matched up) + std::vector<Tensor*> outputs; + + // Unique node ID for debugging + uint64_t nodeId; + + // Flag used for graph analysis + int isMarked; + + // Number of times eval() has been called for this node + int evalCount; + + // Flag indicating that this node is ready and is on the + // next-node list. + int onNextNodeList; + + // Required input/output tensor counts for node validation + // -1 means any number is allowed + int requiredInputCount; + int requiredOutputCount; + + // Required rank ranges for input/output tensors + // -1 means n/a + int requiredRankMin; + int requiredRankMax; +}; + +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp new file mode 100644 index 0000000..ec2fdc9 --- /dev/null +++ b/reference_model/src/main.cpp @@ -0,0 +1,295 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <stdio.h> + +#include "flatbuffers/idl.h" +#include "flatbuffers/util.h" +#include "model_common.h" +#include "ops/op_factory.h" +#include "subgraph_traverser.h" +#include "tosa_serialization_handler.h" +#include <Eigen/CXX11/Tensor> +#include <iostream> + +using namespace TosaReference; +using namespace tosa; + +// Global instantiation of configuration and debug objects +func_config_t g_func_config; +func_debug_t g_func_debug; + +int readInputTensors(SubgraphTraverser& gt); +int writeFinalTensors(SubgraphTraverser& gt); +int loadGraph(TosaSerializationHandler& tsh); + +int main(int argc, const char** argv) +{ + // Initialize configuration and debug subsystems + func_model_init_config(); + func_model_set_default_config(&g_func_config); + func_init_debug(&g_func_debug, 0); + TosaSerializationHandler tsh; + + if (func_model_parse_cmd_line(&g_func_config, &g_func_debug, argc, argv)) + { + return 1; + } + + if (loadGraph(tsh)) + { + SIMPLE_FATAL_ERROR("Unable to load graph"); + } + + // load json first since it's easier debugging + SubgraphTraverser main_gt(tsh.GetMainBlock(), &tsh); + + if (main_gt.initializeGraph()) + { + SIMPLE_FATAL_ERROR("Unable to initialize graph traverser: \"main\""); + } + + if (main_gt.linkTensorsAndNodes()) + { + SIMPLE_FATAL_ERROR("Failed to link tensors and nodes"); + } + + if (main_gt.validateGraph()) + { + SIMPLE_FATAL_ERROR("Failed to validate graph"); + } + + if (g_func_config.validate_only) + { + goto done; + } + + if (readInputTensors(main_gt)) + { + SIMPLE_FATAL_ERROR("Unable to read input tensors"); + } + + if (g_func_config.eval) + { + + if (main_gt.evaluateAll()) + { + SIMPLE_FATAL_ERROR("Error evaluating network. Giving up."); + } + + // make sure output tensor is evaluated and show its value + int num_output_tensors = main_gt.getNumOutputTensors(); + bool all_output_valid = true; + for (int i = 0; i < num_output_tensors; i++) + { + const Tensor* ct = main_gt.getOutputTensor(i); + ASSERT_MEM(ct); + if (!ct->getIsValid()) + { + ct->dumpTensorParams(g_func_debug.func_debug_file); + if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT)) + { + ct->dumpTensor(g_func_debug.func_debug_file); + } + all_output_valid = false; + } + } + if (!all_output_valid) + { + main_gt.dumpGraph(g_func_debug.func_debug_file); + SIMPLE_FATAL_ERROR( + "SubgraphTraverser \"main\" error: Output tensors are not all valid at the end of evaluation."); + } + + if (g_func_config.output_tensors) + { + if (writeFinalTensors(main_gt)) + { + WARNING("Errors encountered in saving output tensors"); + } + } + } + +done: + func_fini_debug(&g_func_debug); + func_model_config_cleanup(); + + return 0; +} + +int loadGraph(TosaSerializationHandler& tsh) +{ + char graph_fullname[1024]; + + snprintf(graph_fullname, sizeof(graph_fullname), "%s/%s", g_func_config.subgraph_dir, g_func_config.subgraph_file); + + if (strlen(graph_fullname) <= 2) + { + func_model_print_help(stderr); + SIMPLE_FATAL_ERROR("Missing required argument: Check -Csubgraph_file="); + } + + const char JSON_EXT[] = ".json"; + int is_json = 0; + { + // look for JSON file extension + size_t suffix_len = strlen(JSON_EXT); + size_t str_len = strlen(graph_fullname); + + if (str_len > suffix_len && strncasecmp(graph_fullname + (str_len - suffix_len), JSON_EXT, suffix_len) == 0) + { + is_json = 1; + } + } + + if (is_json) + { + if (tsh.LoadFileSchema(g_func_config.operator_fbs)) + { + SIMPLE_FATAL_ERROR( + "\nJSON file detected. Unable to load TOSA flatbuffer schema from: %s\nCheck -Coperator_fbs=", + g_func_config.operator_fbs); + } + + if (tsh.LoadFileJson(graph_fullname)) + { + SIMPLE_FATAL_ERROR("\nError loading JSON graph file: %s\nCheck -Csubgraph_file= and -Csubgraph_dir=", + graph_fullname); + } + } + else + { + if (tsh.LoadFileTosaFlatbuffer(graph_fullname)) + { + SIMPLE_FATAL_ERROR("\nError loading TOSA flatbuffer file: %s\nCheck -Csubgraph_file= and -Csubgraph_dir=", + graph_fullname); + } + } + + return 0; +} + +int readInputTensors(SubgraphTraverser& gt) +{ + int tensorCount = gt.getNumInputTensors(); + Tensor* tensor; + char filename[1024]; + + // assuming filename doesn't have colons(:) + std::map<std::string, std::string> input_tensor_map; + std::string raw_str(g_func_config.input_tensor); + std::string name, npy; + bool last_pair = false; + + std::string::size_type pair_start = 0, pair_end, colons_pos; + do + { + pair_end = raw_str.find(',', pair_start); + if (pair_end == std::string::npos) + last_pair = true; + + colons_pos = raw_str.find(':', pair_start); + + name = raw_str.substr(pair_start, colons_pos - pair_start); + npy = raw_str.substr(colons_pos + 1, pair_end - colons_pos - 1); + + // Empty strings can make it to here + if (name.length() == 0 || npy.length() == 0) + break; + + input_tensor_map[name] = npy; + + pair_start = pair_end + 1; // skip colons + } while (!last_pair); + + if ((size_t)tensorCount != input_tensor_map.size()) + { + WARNING("graph has %lu input placeholders, but %lu initialized", tensorCount, input_tensor_map.size()); + return 1; + } + + for (auto& tensor_pair : input_tensor_map) + { + tensor = gt.getInputTensorByName(tensor_pair.first); + if (!tensor) + { + WARNING("Unable to find input tensor %s", tensor_pair.first.c_str()); + return 1; + } + + snprintf(filename, sizeof(filename), "%s/%s", g_func_config.input_dir, tensor_pair.second.c_str()); + + DEBUG_MED(GT, "Loading input tensor %s from filename: %s", tensor->getName().c_str(), filename); + + if (tensor->allocate()) + { + WARNING("Fail to allocate tensor %s", tensor->getName().c_str()); + return 1; + } + + if (tensor->readFromNpyFile(filename)) + { + WARNING("Unable to read input tensor %s from filename: %s", tensor->getName().c_str(), filename); + tensor->dumpTensorParams(g_func_debug.func_debug_file); + return 1; + } + + // Push ready consumers to the next node list + for (auto gn : tensor->getConsumers()) + { + if (gn->hasAllInputsReady() && !gn->getOnNextNodeList()) + { + gt.addToNextNodeList(gn); + } + } + } + + if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT)) + { + gt.dumpNextNodeList(g_func_debug.func_debug_file); + } + + return 0; +} + +int writeFinalTensors(SubgraphTraverser& gt) +{ + int tensorCount = gt.getNumOutputTensors(); + const Tensor* tensor; + char filename[1024]; + + for (int i = 0; i < tensorCount; i++) + { + tensor = gt.getOutputTensor(i); + if (!tensor) + { + WARNING("Unable to find output tensor[%d]", i); + return 1; + } + + snprintf(filename, sizeof(filename), "%s/%s%s.npy", g_func_config.output_dir, + g_func_config.output_tensor_prefix, tensor->getName().c_str()); + + DEBUG_MED(GT, "Writing output tensor[%d] %s to filename: %s", i, tensor->getName().c_str(), filename); + + if (tensor->writeToNpyFile(filename)) + { + WARNING("Unable to write output tensor[%d] %s to filename: %s", i, tensor->getName().c_str(), filename); + return 1; + } + } + + return 0; +} diff --git a/reference_model/src/model_common.h b/reference_model/src/model_common.h new file mode 100644 index 0000000..d6dab6d --- /dev/null +++ b/reference_model/src/model_common.h @@ -0,0 +1,28 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MODEL_COMMON_H +#define MODEL_COMMON_H + +#include <iostream> +#include <stdio.h> + +#include "func_config.h" +#include "func_debug.h" + +extern func_config_t g_func_config; +extern func_debug_t g_func_debug; + +#endif diff --git a/reference_model/src/ops/activation_funcs.cc b/reference_model/src/ops/activation_funcs.cc new file mode 100644 index 0000000..bca9507 --- /dev/null +++ b/reference_model/src/ops/activation_funcs.cc @@ -0,0 +1,118 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "activation_funcs.h" +#include "quant_util.h" +#include "template_types.h" +#include <cmath> + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +template <int Rank, DType Dtype> +int OpClamp<Rank, Dtype>::register_fcn() +{ + + switch (Dtype) + { + case DType_FLOAT: + { + InEigenType min = (InEigenType)attribute->min_fp(); + InEigenType max = (InEigenType)attribute->max_fp(); + this->fcn = [min, max](InEigenType a) -> OutEigenType { return a <= min ? min : a >= max ? max : a; }; + } + break; + case DType_AINT8: + case DType_INT16: + { + InEigenType min = (InEigenType)attribute->min_int(); + InEigenType max = (InEigenType)attribute->max_int(); + this->fcn = [min, max](InEigenType a) -> OutEigenType { return a <= min ? min : a >= max ? max : a; }; + } + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpReluN<Rank, Dtype>::register_fcn() +{ + + switch (Dtype) + { + case DType_FLOAT: + { + InEigenType N = (InEigenType)attribute->max_fp(); + this->fcn = [N](InEigenType a) -> OutEigenType { return a >= 0 ? (a <= N ? a : N) : 0; }; + } + break; + case DType_INT32: + { + InEigenType N = (InEigenType)attribute->max_int(); + this->fcn = [N](InEigenType a) -> OutEigenType { return a >= 0 ? (a <= N ? a : N) : 0; }; + } + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpSigmoid<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + this->fcn = [](InEigenType a) -> OutEigenType { return (1.0 / (1.0 + (expf(-1.0 * a)))); }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpTanh<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + this->fcn = [](InEigenType a) -> OutEigenType { return tanhf(a); }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +// template explicit instantiation +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReluN, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReluN, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FLOAT); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FLOAT); diff --git a/reference_model/src/ops/activation_funcs.h b/reference_model/src/ops/activation_funcs.h new file mode 100644 index 0000000..b051b9d --- /dev/null +++ b/reference_model/src/ops/activation_funcs.h @@ -0,0 +1,101 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPS_ACTIVATION_FUNCS_H +#define OPS_ACTIVATION_FUNCS_H + +#include "ewise_unary.h" +#include "graph_node.h" + +using namespace tosa; + +namespace TosaReference +{ + +template <int Rank, DType Dtype> +class OpClamp : public UnaryNode<Rank, Dtype> +{ +public: + OpClamp(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : UnaryNode<Rank, Dtype>(Op_CLAMP, id_) + { + INIT_ATTRIBUTE(Clamp); + register_fcn(); + } + static constexpr int32_t QMin = GetQMin<Dtype>::value; + static constexpr int32_t QMax = GetQMax<Dtype>::value; + using InEigenType = typename GetEigenType<Dtype>::type; + using OutEigenType = typename GetEigenType<Dtype>::type; + virtual int register_fcn(); + +protected: + TosaClampAttribute* attribute; +}; + +template <int Rank, DType Dtype> +class OpReluN : public UnaryNode<Rank, Dtype> +{ +public: + OpReluN(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : UnaryNode<Rank, Dtype>(Op_RELUN, id_) + { + INIT_ATTRIBUTE(ReluN); + register_fcn(); + } + static constexpr int32_t QMin = GetQMin<Dtype>::value; + static constexpr int32_t QMax = GetQMax<Dtype>::value; + using InEigenType = typename GetEigenType<Dtype>::type; + using OutEigenType = typename GetEigenType<Dtype>::type; + virtual int register_fcn(); + +protected: + TosaReluNAttribute* attribute; +}; + +template <int Rank, DType Dtype> +class OpSigmoid : public UnaryNode<Rank, Dtype> +{ +public: + OpSigmoid(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : UnaryNode<Rank, Dtype>(Op_SIGMOID, id_) + { + register_fcn(); + } + static constexpr int32_t QMin = GetQMin<Dtype>::value; + static constexpr int32_t QMax = GetQMax<Dtype>::value; + using InEigenType = typename GetEigenType<Dtype>::type; + using OutEigenType = typename GetEigenType<Dtype>::type; + virtual int register_fcn(); +}; + +template <int Rank, DType Dtype> +class OpTanh : public UnaryNode<Rank, Dtype> +{ +public: + OpTanh(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : UnaryNode<Rank, Dtype>(Op_TANH, id_) + { + register_fcn(); + } + static constexpr int32_t QMin = GetQMin<Dtype>::value; + static constexpr int32_t QMax = GetQMax<Dtype>::value; + using InEigenType = typename GetEigenType<Dtype>::type; + using OutEigenType = typename GetEigenType<Dtype>::type; + virtual int register_fcn(); +}; + +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/ops/comparison.cc b/reference_model/src/ops/comparison.cc new file mode 100644 index 0000000..402e152 --- /dev/null +++ b/reference_model/src/ops/comparison.cc @@ -0,0 +1,81 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "comparison.h" +#include "arith_util.h" +#include "quant_util.h" +#include "template_types.h" + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +template <int Rank, DType Dtype> +int OpEqual<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + case DType_INT32: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a == b; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpGreater<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + case DType_INT32: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpGreaterEqual<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + case DType_INT32: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a >= b; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +// template explicit instantiation +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32); diff --git a/reference_model/src/ops/comparison.h b/reference_model/src/ops/comparison.h new file mode 100644 index 0000000..e75b1a6 --- /dev/null +++ b/reference_model/src/ops/comparison.h @@ -0,0 +1,71 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPS_COMPARISON_H +#define OPS_COMPARISON_H + +#include "ewise_binary.h" +#include "graph_node.h" + +using namespace tosa; + +namespace TosaReference +{ + +template <int Rank, DType Dtype> +class OpEqual : public BinaryNode<Rank, Dtype, DType_BOOL> +{ +public: + OpEqual(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : BinaryNode<Rank, Dtype, DType_BOOL>(Op_EQUAL, qinfo_, id_) + { + register_fcn(); + } + using InEigenType = typename GetEigenType<Dtype>::type; + using OutEigenType = typename GetEigenType<DType_BOOL>::type; + virtual int register_fcn(); +}; + +template <int Rank, DType Dtype> +class OpGreater : public BinaryNode<Rank, Dtype, DType_BOOL> +{ +public: + OpGreater(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : BinaryNode<Rank, Dtype, DType_BOOL>(Op_GREATER, qinfo_, id_) + { + register_fcn(); + } + using InEigenType = typename GetEigenType<Dtype>::type; + using OutEigenType = typename GetEigenType<DType_BOOL>::type; + virtual int register_fcn(); +}; + +template <int Rank, DType Dtype> +class OpGreaterEqual : public BinaryNode<Rank, Dtype, DType_BOOL> +{ +public: + OpGreaterEqual(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : BinaryNode<Rank, Dtype, DType_BOOL>(Op_EQUAL, qinfo_, id_) + { + register_fcn(); + } + using InEigenType = typename GetEigenType<Dtype>::type; + using OutEigenType = typename GetEigenType<DType_BOOL>::type; + virtual int register_fcn(); +}; + +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/ops/control_flow.cc b/reference_model/src/ops/control_flow.cc new file mode 100644 index 0000000..9d5db40 --- /dev/null +++ b/reference_model/src/ops/control_flow.cc @@ -0,0 +1,353 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "control_flow.h" +#include "subgraph_traverser.h" + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +OpControlFlow::OpControlFlow(TosaSerializationHandler* tsh_, Op op_, uint64_t id_) + : GraphNode(op_, id_) +{ + tsh = tsh_; +} + +OpControlFlow::~OpControlFlow() +{} + +int OpControlFlow::evalBlock(TosaSerializationBasicBlock* block, + std::vector<TosaReference::Tensor*>& block_inputs, + std::vector<TosaReference::Tensor*>& block_outputs) +{ + std::string block_name = block->GetName(); + + DEBUG_MED(OP, "Evaluating block %s", block_name.c_str()); + + SubgraphTraverser gt(block, tsh); + + if (gt.initializeGraph()) + { + FATAL_ERROR("Unable to initialize graph traverser for block %s", block_name.c_str()); + } + + if (gt.linkTensorsAndNodes()) + { + FATAL_ERROR("Failed to link tensors and nodes for block %s", block_name.c_str()); + } + + if (gt.validateGraph()) + { + FATAL_ERROR("Failed to validate subgraph for block %s", block_name.c_str()); + } + + int num_input_tensors = gt.getNumInputTensors(); + int num_output_tensors = gt.getNumOutputTensors(); + + for (size_t i = 0; i < block_inputs.size(); i++) + { + DEBUG_HIGH(OP, "Input[%ld]: %s", i, block_inputs[i]->getName().c_str()); + } + for (size_t i = 0; i < block_outputs.size(); i++) + { + DEBUG_HIGH(OP, "Output[%ld]: %s", i, block_outputs[i]->getName().c_str()); + } + + ASSERT_MSG((size_t)num_input_tensors == block_inputs.size(), + "op block %s inputs[%lu] does not match with graph traverser's inputs[%d]", block_name.c_str(), + block_inputs.size(), num_input_tensors); + ASSERT_MSG((size_t)num_output_tensors == block_outputs.size(), + "op block %s outputs[%lu] does not match with graph traverser's outputs[%d]", block_name.c_str(), + block_outputs.size(), num_output_tensors); + + // set graph traverser's input = basic block's input + for (int i = 0; i < num_input_tensors; i++) + { + TosaReference::Tensor* tensor = gt.getInputTensor(i); + ASSERT_MSG(!tensor->is_allocated(), "block %s input tensors are unexpectedly initialized before", + block_name.c_str()); + + if (tensor->allocate()) + { + WARNING("Fail to allocate tensor %s", tensor->getName().c_str()); + return 1; + } + + if (tensor->copyValueFrom(block_inputs[i])) + { + WARNING("Fail to copy tensor value %s -> %s", block_inputs[i]->getName().c_str(), + tensor->getName().c_str()); + return 1; + } + + // Push ready consumers to the next node list + for (auto gn : tensor->getConsumers()) + { + if (gn->hasAllInputsReady() && !gn->getOnNextNodeList()) + { + gt.addToNextNodeList(gn); + } + } + } + + if (gt.evaluateAll()) + { + FATAL_ERROR("Error evaluating network. Giving up."); + } + + // make sure output tensor is evaluated and show its value + bool all_output_valid = true; + for (int i = 0; i < num_output_tensors; i++) + { + const TosaReference::Tensor* ct = gt.getOutputTensor(i); + ASSERT_MEM(ct); + if (!ct->getIsValid()) + { + ct->dumpTensorParams(g_func_debug.func_debug_file); + if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT)) + { + ct->dumpTensor(g_func_debug.func_debug_file); + } + all_output_valid = false; + } + } + if (!all_output_valid) + { + gt.dumpGraph(g_func_debug.func_debug_file); + FATAL_ERROR("SubgraphTraverser \"%s\" error: Output tensors are not all valid at the end of evaluation.", + block_name.c_str()); + } + + // set basic block's output = subgraph_traverser's output + for (int i = 0; i < num_output_tensors; i++) + { + TosaReference::Tensor* tensor = gt.getOutputTensor(i); + ASSERT_MSG(tensor->is_allocated(), "tensor %s is not allocated", tensor->getName().c_str()); + + if (block_outputs[i]->copyValueFrom(tensor)) + { + WARNING("Fail to copy tensor value %s -> %s", tensor->getName().c_str(), outputs[i]->getName().c_str()); + return 1; + } + } + return 0; +} + +OpCondIf::OpCondIf(TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_) + : OpControlFlow(tsh_, Op_COND_IF, id_) +{ + INIT_ATTRIBUTE(CondIf); +} + +OpCondIf::~OpCondIf() +{ + if (attribute) + delete attribute; +} + +int OpCondIf::checkTensorAttributes() +{ + if (getInputs().size() < 1) + { + WARNING("OpCondIf: must have at least 1 operand"); + return 1; + } + + if (inputs[0]->getDtype() != DType_BOOL || inputs[0]->getRank() != 0) + { + WARNING("OpCondIf: invalid tensor dtype=%s, rank=%d", EnumNamesDType()[inputs[0]->getDtype()], + inputs[0]->getRank()); + return 1; + } + + cond = dynamic_cast<TosaReference::Tensor0<bool>*>(inputs[0]); + ASSERT_MEM(cond); + + then_block = tsh->GetBlockByName(attribute->then_branch()); + else_block = tsh->GetBlockByName(attribute->else_branch()); + + if (!then_block) + { + WARNING("OpCondIf: fail to resolve then_branch %s", attribute->then_branch().c_str()); + return 1; + } + + if (!else_block) + { + WARNING("OpCondIf: fail to resolve else_branch %s", attribute->else_branch().c_str()); + return 1; + } + + return 0; +} + +int OpCondIf::eval() +{ + bool cond_val = cond->getTensor()(0); + std::vector<TosaReference::Tensor*> block_inputs(getInputs().begin() + 1, getInputs().end()); + + if (cond_val) + { + if (evalBlock(then_block, block_inputs, getOutputs())) + { + WARNING("OpCondIf: Fail to evaluate then branch block %s", attribute->then_branch().c_str()); + return 1; + } + } + else + { + if (evalBlock(else_block, block_inputs, getOutputs())) + { + WARNING("OpCondIf: Fail to evaluate else branch block %s", attribute->else_branch().c_str()); + return 1; + } + } + + return GraphNode::eval(); +} + +OpWhileLoop::OpWhileLoop(TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_) + : OpControlFlow(tsh_, Op_WHILE_LOOP, id_) +{ + INIT_ATTRIBUTE(WhileLoop); +} + +OpWhileLoop::~OpWhileLoop() +{ + if (attribute) + delete attribute; +} + +int OpWhileLoop::checkTensorAttributes() +{ + if (getInputs().size() <= 0) + { + WARNING("OpWhileLoop: must have at least 1 operands"); + return 1; + } + + if (getInputs().size() != getOutputs().size()) + { + WARNING("OpWhileLoop: inputs and outputs size must match"); + return 1; + } + + cond_block = tsh->GetBlockByName(attribute->cond_branch()); + body_block = tsh->GetBlockByName(attribute->body_branch()); + + if (!cond_block) + { + WARNING("OpWhileLoop: fail to resolve cond_branch %s", attribute->cond_branch().c_str()); + return 1; + } + + if (!body_block) + { + WARNING("OpWhileLoop: fail to resolve body_branch %s", attribute->body_branch().c_str()); + return 1; + } + + if (cond_block->GetOutputs().size() != 1) + { + WARNING("OpWhileLoop: invalid cond_block output size %lu", cond_block->GetOutputs().size()); + return 1; + } + + TosaSerializationTensor* cond_output_tensor = cond_block->GetTensorByName(cond_block->GetOutputs()[0]); + + if (!cond_output_tensor) + { + WARNING("OpWhileLoop: fail to resolve cond_block's output tensor %s", cond_block->GetOutputs()[0].c_str()); + return 1; + } + + if (cond_output_tensor->GetDtype() != DType_BOOL) + { + WARNING("OpWhileLoop: invalid cond_block's output tensor data type %s", + EnumNamesDType()[cond_output_tensor->GetDtype()]); + return 1; + } + if (cond_output_tensor->GetShape().size() != 0) + { + WARNING("OpWhileLoop: invalid cond_block's output rank %lu", cond_output_tensor->GetShape().size()); + return 1; + } + + return 0; +} + +int OpWhileLoop::eval() +{ + + TosaReference::Tensor0<bool> cond_output_ctensor( + std::string("cond_output"), DType_BOOL, std::vector<Usage>({ Usage_ACTIVATION }), + std::vector<Format>({ Format_UNKNOWN }), std::vector<int32_t>({}), false); + + cond_output_ctensor.allocate(); + std::vector<TosaReference::Tensor*> cond_block_outputs; + cond_block_outputs.push_back(&cond_output_ctensor); + + size_t num_input_output = getInputs().size(); + size_t eval_count = 0; + + while (eval_count++ < MAX_WHILE_LOOP_ITERATION) + { + if (evalBlock(cond_block, getInputs(), cond_block_outputs)) + { + WARNING("OpWhileLoop: Fail to evaluate cond block %s", attribute->cond_branch().c_str()); + return 1; + } + bool cond_val = cond_output_ctensor.getTensor()(0); + DEBUG_HIGH(OP, "Conditional block value: %d", cond_val); + + if (cond_val) + { + if (evalBlock(body_block, getInputs(), getOutputs())) + { + WARNING("OpWhileLoop: Fail to evaluate body block %s", attribute->body_branch().c_str()); + return 1; + } + + // assigning output tensors value back to input tensors value for next iteration + for (size_t i = 0; i < num_input_output; i++) + { + if (getInputs()[i]->copyValueFrom(getOutputs()[i])) + { + WARNING("Fail to copy tensor value %s -> %s", getOutputs()[i]->getName().c_str(), + getInputs()[i]->getName().c_str()); + return 1; + } + } + } + else + { + // in last iteration or the case it never evaluates body block + // assign input tensors value to output tensors + for (size_t i = 0; i < num_input_output; i++) + { + if (getOutputs()[i]->copyValueFrom(getInputs()[i])) + { + WARNING("Fail to copy tensor value %s -> %s", getInputs()[i]->getName().c_str(), + getOutputs()[i]->getName().c_str()); + return 1; + } + } + break; + } + } + + return GraphNode::eval(); +} diff --git a/reference_model/src/ops/control_flow.h b/reference_model/src/ops/control_flow.h new file mode 100644 index 0000000..14c11bc --- /dev/null +++ b/reference_model/src/ops/control_flow.h @@ -0,0 +1,72 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPS_CONTROL_FLOW_H +#define OPS_CONTROL_FLOW_H + +#include "graph_node.h" + +#define MAX_WHILE_LOOP_ITERATION 10000 + +namespace TosaReference +{ +class OpControlFlow : public GraphNode +{ +public: + OpControlFlow(TosaSerializationHandler* tsh_, Op op_, uint64_t id_); + ~OpControlFlow(); + + virtual int evalBlock(TosaSerializationBasicBlock* block, + std::vector<TosaReference::Tensor*>& block_inputs, + std::vector<TosaReference::Tensor*>& block_outputs); + +protected: + TosaSerializationHandler* tsh; +}; + +class OpCondIf : public OpControlFlow +{ +public: + OpCondIf(TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_); + virtual ~OpCondIf(); + + virtual int checkTensorAttributes(); + virtual int eval(); + +protected: + TosaCondIfAttribute* attribute; + TosaReference::Tensor0<bool>* cond; + TosaSerializationBasicBlock* then_block; + TosaSerializationBasicBlock* else_block; +}; + +class OpWhileLoop : public OpControlFlow +{ +public: + OpWhileLoop(TosaSerializationHandler* tsh_, TosaAttributeBase* attribute_, uint64_t id_); + virtual ~OpWhileLoop(); + + virtual int checkTensorAttributes(); + virtual int eval(); + +protected: + TosaWhileLoopAttribute* attribute; + TosaSerializationBasicBlock* cond_block; + TosaSerializationBasicBlock* body_block; +}; + +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/ops/custom.cc b/reference_model/src/ops/custom.cc new file mode 100644 index 0000000..5c4f29b --- /dev/null +++ b/reference_model/src/ops/custom.cc @@ -0,0 +1,40 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "custom.h" + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +OpCustom::OpCustom(uint64_t id_) + : GraphNode(Op_CUSTOM, id_) +{} + +OpCustom::~OpCustom() +{} + +int OpCustom::checkTensorAttributes() +{ + return 0; +} + +int OpCustom::eval() +{ + FATAL_ERROR_NODE("not supported yet"); + + // Evaluation is trivial for constants + return GraphNode::eval(); +} diff --git a/reference_model/src/ops/custom.h b/reference_model/src/ops/custom.h new file mode 100644 index 0000000..b1085a5 --- /dev/null +++ b/reference_model/src/ops/custom.h @@ -0,0 +1,38 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPS_CUSTOM_H +#define OPS_CUSTOM_H + +#include "graph_node.h" + +using namespace tosa; + +namespace TosaReference +{ + +class OpCustom : public GraphNode +{ +public: + OpCustom(uint64_t id_); + virtual ~OpCustom(); + + virtual int checkTensorAttributes(); + virtual int eval(); +}; + +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/ops/data_layout.cc b/reference_model/src/ops/data_layout.cc new file mode 100644 index 0000000..32029b9 --- /dev/null +++ b/reference_model/src/ops/data_layout.cc @@ -0,0 +1,644 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "data_layout.h" +#include "quant_util.h" + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +template <int Rank, DType Dtype> +OpConcat<Rank, Dtype>::OpConcat(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_CONCAT, id_) +{ + setRequiredOperands(2, 1); + setRequiredRank(1, 6); + + INIT_ATTRIBUTE(Axis); +} + +template <int Rank, DType Dtype> +OpConcat<Rank, Dtype>::~OpConcat() +{ + if (attribute) + delete attribute; +} + +template <int Rank, DType Dtype> +int OpConcat<Rank, Dtype>::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + // output and input must be the same types and rank + // inputs[0] and inputs[1] should also match type and rank + if (inputs[0]->matchRankType(*outputs[0]) || inputs[1]->matchRankType(*outputs[0])) + { + printNodeValidationError("Concat operator input ranks and types must match"); + return 1; + } + + lhs = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + rhs = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]); + out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); + + if (attribute->axis() < 0 || (size_t)attribute->axis() >= rhs->getShape().size()) + { + printNodeValidationError("Axis is beyond input tensor rank"); + return 1; + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpConcat<Rank, Dtype>::eval() +{ + + int32_t reversed_axis = Rank - 1 - attribute->axis(); + + for (int32_t d = 0; d < Rank; d++) + { + reverser[d] = Rank - 1 - d; + } + + TIn lhs_reversed = lhs->getTensor().shuffle(reverser); + TIn rhs_reversed = rhs->getTensor().shuffle(reverser); + + TIn reversed_result = lhs_reversed.concatenate(rhs_reversed, reversed_axis); + out->getTensor() = reversed_result.shuffle(reverser); + // out->getTensor() = lhs->getTensor().concatenate(rhs->getTensor(), axis); + + return GraphNode::eval(); +} + +template <int Rank, DType Dtype> +OpPad<Rank, Dtype>::OpPad(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_PAD, id_) +{ + setRequiredOperands(2, 1); + setRequiredRank(0, 6); + + INIT_QINFO(Pad); +} + +template <int Rank, DType Dtype> +OpPad<Rank, Dtype>::~OpPad() +{ + if (qinfo) + delete qinfo; +} + +template <int Rank, DType Dtype> +int OpPad<Rank, Dtype>::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + // output and input must be the same types + if (inputs[0]->matchRankType(*outputs[0])) + { + printNodeValidationError("Failure to match input and output type and rank"); + return 1; + } + + in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); + TosaReference::TensorTemplate<ETensor2<int32_t>>* paddings = + dynamic_cast<TosaReference::TensorTemplate<ETensor2<int32_t>>*>(inputs[1]); + + for (int i = 0; i < Rank; i++) + { + paddings_array[i] = std::make_pair(paddings->getTensor()(i, 0), paddings->getTensor()(i, 1)); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpPad<Rank, Dtype>::eval() +{ + InEigenType pad_value = 0; + if (this->qinfo) + { + pad_value = (InEigenType)this->qinfo->input_zp(); + } + + this->out->getTensor() = this->in->getTensor().pad(this->paddings_array, pad_value); + + return GraphNode::eval(); +} + +template <int InRank, int OutRank, DType Dtype> +OpReshape<InRank, OutRank, Dtype>::OpReshape(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_RESHAPE, id_) +{ + setRequiredOperands(1, 1); + setRequiredRank(0, 6); + + INIT_ATTRIBUTE(Reshape); +} + +template <int InRank, int OutRank, DType Dtype> +OpReshape<InRank, OutRank, Dtype>::~OpReshape() +{ + if (attribute) + delete attribute; +} + +template <int InRank, int OutRank, DType Dtype> +int OpReshape<InRank, OutRank, Dtype>::checkTensorAttributes() +{ + uint32_t minusOneCount = 0; + + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + // output and input must be the same types + if (inputs[0]->matchType(*outputs[0])) + { + printNodeValidationError("OpReshape: Input and output types must match"); + return 1; + } + + for (uint32_t d = 0; d < OutRank; d++) + { + if (attribute->shape()[d] == -1) + { + minusOneCount++; + } + } + + if (minusOneCount > 1) + { + printNodeValidationError("OpReshape: new shape has more than one -1 dimension"); + return 1; + } + + in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); + + return 0; +} + +template <int InRank, int OutRank, DType Dtype> +int OpReshape<InRank, OutRank, Dtype>::eval() +{ + uint32_t remainingSize = in->getElementCount(); + + // If there is a -1 dimension, find the remainder in one pass over the output shape + for (int32_t d = 0; d < OutRank; d++) + { + if (attribute->shape()[d] != -1) + { + remainingSize = remainingSize / attribute->shape()[d]; + } + } + + for (int32_t d = 0; d < OutRank; d++) + { + array_shape[d] = attribute->shape()[OutRank - 1 - d]; + out_reverser[d] = OutRank - 1 - d; + + // Jam in the remainder here + if (array_shape[d] == -1) + { + array_shape[d] = remainingSize; + } + } + + for (int32_t d = 0; d < InRank; d++) + { + in_reverser[d] = InRank - 1 - d; + } + + // Eigen Tensor is col-major, and we're referencing row-major result + // need to reverse it to row-major before reshape, and perform another reverse afterward + + // input tensor rank 0 can't do .shuffle(), need to be handled otherwise + TIn in_reversed; + if (InRank > 1) + { + in_reversed = in->getTensor().shuffle(in_reverser); + } + else + { + in_reversed = in->getTensor(); + } + + TOut in_reshaped = in_reversed.reshape(array_shape); + + // output tensor can be rank 0, .reshape() and .shuffle() don't work, need to be handled otherwise + if (OutRank > 1) + { + out->getTensor() = in_reshaped.shuffle(out_reverser); + } + else + { + out->getTensor() = in_reshaped; + } + + return GraphNode::eval(); +} + +template <int Rank, DType Dtype> +OpReverse<Rank, Dtype>::OpReverse(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_REVERSE, id_) +{ + setRequiredOperands(1, 1); + setRequiredRank(1, 6); + + INIT_ATTRIBUTE(Axis); +} + +template <int Rank, DType Dtype> +OpReverse<Rank, Dtype>::~OpReverse() +{ + if (attribute) + delete attribute; +} + +template <int Rank, DType Dtype> +int OpReverse<Rank, Dtype>::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + // output and input must be the same types + if (inputs[0]->matchRankTypeShape(*outputs[0])) + { + printNodeValidationError("Failure to match input and output rank/type/shape"); + return 1; + } + + in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); + + ASSERT_MEM(in && out); + + if (attribute->axis() < 0 || attribute->axis() >= inputs[0]->getRank()) + { + printNodeValidationError("Reverse axis must between [0, input_rank - 1]"); + return 1; + } + + // transform list of axis into true or false list + // e.g. rank=4, axis=[1,2], reverse array would be [false, true, true, false] + for (int i = 0; i < Rank; i++) + { + reverse_array[i] = false; + } + reverse_array[attribute->axis()] = true; + + return 0; +} + +template <int Rank, DType Dtype> +int OpReverse<Rank, Dtype>::eval() +{ + out->getTensor() = in->getTensor().reverse(reverse_array); + + return GraphNode::eval(); +} + +template <int Rank, DType Dtype> +OpSlice<Rank, Dtype>::OpSlice(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_SLICE, id_) +{ + setRequiredOperands(1, 1); + setRequiredRank(0, 6); + + INIT_ATTRIBUTE(Slice); +} + +template <int Rank, DType Dtype> +OpSlice<Rank, Dtype>::~OpSlice() +{ + if (attribute) + delete attribute; +} + +template <int Rank, DType Dtype> +int OpSlice<Rank, Dtype>::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + // output and input must be the same types + if (inputs[0]->matchType(*outputs[0])) + { + printNodeValidationError("Failure to match input and output type"); + return 1; + } + + in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); + + for (size_t i = 0; i < attribute->begin().size(); i++) + { + begin_array[i] = attribute->begin()[i]; + } + + for (size_t i = 0; i < attribute->size().size(); i++) + { + if (attribute->size()[i] != 0) + { + size_array[i] = attribute->size()[i]; + } + else + { + // Tensorflow assigns a zero size to dimensions that are kept + // Eigen expects size to be the full size of the dimension + size_array[i] = in->getTensor().dimension(0); + } + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpSlice<Rank, Dtype>::eval() +{ + out->getTensor() = in->getTensor().slice(begin_array, size_array); + + return GraphNode::eval(); +} + +template <int Rank, DType Dtype> +OpTileBase<Rank, Dtype>::OpTileBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_TILE, id_) +{ + setRequiredOperands(1, 1); + setRequiredRank(0, 6); + + INIT_ATTRIBUTE(Tile); +} + +template <int Rank, DType Dtype> +OpTileBase<Rank, Dtype>::~OpTileBase() +{ + if (attribute) + delete attribute; +} + +template <int Rank, DType Dtype> +int OpTileBase<Rank, Dtype>::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + // output and input must be the same ranks and types + if (inputs[0]->matchRankType(*outputs[0])) + { + printNodeValidationError("Failure to match input and output rank or type"); + return 1; + } + + in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); + + if (attribute->multiples().size() != Rank) + { + printNodeValidationError("1D list 'multiples' must have size equal to input rank"); + return 1; + } + + for (int32_t d = 0; d < Rank; d++) + { + if (in->getShape()[d] * attribute->multiples()[d] != out->getShape()[d]) + { + printNodeValidationError("unexpected output shape"); + return 1; + } + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpTile<Rank, Dtype>::eval() +{ + // primary template shouldn't be called + FATAL_ERROR_NODE("OpTile rank=%i, dtype=%s: not implemented yet", Rank, EnumNamesDType()[Dtype]); +} + +template <DType Dtype> +int OpTile<1, Dtype>::eval() +{ + for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++) + { + int32_t id0 = od0 % this->in->getShape()[0]; + this->out->getTensor()(od0) = this->in->getTensor()(id0); + } + + return GraphNode::eval(); +} + +template <DType Dtype> +int OpTile<2, Dtype>::eval() +{ + for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++) + { + int32_t id0 = od0 % this->in->getShape()[0]; + for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++) + { + int32_t id1 = od1 % this->in->getShape()[1]; + this->out->getTensor()(od0, od1) = this->in->getTensor()(id0, id1); + } + } + + return GraphNode::eval(); +} + +template <DType Dtype> +int OpTile<3, Dtype>::eval() +{ + for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++) + { + int32_t id0 = od0 % this->in->getShape()[0]; + for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++) + { + int32_t id1 = od1 % this->in->getShape()[1]; + for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++) + { + int32_t id2 = od2 % this->in->getShape()[2]; + this->out->getTensor()(od0, od1, od2) = this->in->getTensor()(id0, id1, id2); + } + } + } + + return GraphNode::eval(); +} + +template <DType Dtype> +int OpTile<4, Dtype>::eval() +{ + for (int32_t od0 = 0; od0 < this->out->getShape()[0]; od0++) + { + int32_t id0 = od0 % this->in->getShape()[0]; + for (int32_t od1 = 0; od1 < this->out->getShape()[1]; od1++) + { + int32_t id1 = od1 % this->in->getShape()[1]; + for (int32_t od2 = 0; od2 < this->out->getShape()[2]; od2++) + { + int32_t id2 = od2 % this->in->getShape()[2]; + for (int32_t od3 = 0; od3 < this->out->getShape()[3]; od3++) + { + int32_t id3 = od3 % this->in->getShape()[3]; + this->out->getTensor()(od0, od1, od2, od3) = this->in->getTensor()(id0, id1, id2, id3); + } + } + } + } + + return GraphNode::eval(); +} + +template <int Rank, DType Dtype> +OpTranspose<Rank, Dtype>::OpTranspose(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_TRANSPOSE, id_) +{ + setRequiredOperands(2, 1); + setRequiredRank(0, 6); +} + +template <int Rank, DType Dtype> +OpTranspose<Rank, Dtype>::~OpTranspose() +{} + +template <int Rank, DType Dtype> +int OpTranspose<Rank, Dtype>::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + // output and input must be the same types + if (inputs[0]->matchRankType(*outputs[0])) + { + printNodeValidationError("Failure to match input and output rank and type"); + return 1; + } + + if (inputs[0]->getElementCount() != outputs[0]->getElementCount()) + { + printNodeValidationError("Failure to match input and output total element count"); + return 1; + } + + in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); + perm_tensor = dynamic_cast<TosaReference::TensorTemplate<ETensor1<int32_t>>*>(inputs[1]); + + return 0; +} + +template <int Rank, DType Dtype> +int OpTranspose<Rank, Dtype>::eval() +{ + for (int32_t d = 0; d < Rank; d++) + { + perm_array[d] = this->perm_tensor->getTensor().data()[d]; + } + + out->getTensor() = in->getTensor().shuffle(perm_array); + + return GraphNode::eval(); +} + +// template explicit instantiation +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FLOAT) +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, AINT8) +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8) +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16) +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32) +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL) + +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FLOAT); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, AINT8); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL); + +DEF_INSTANTIATE_RESHAPE(OpReshape, FLOAT); +DEF_INSTANTIATE_RESHAPE(OpReshape, AINT8); +DEF_INSTANTIATE_RESHAPE(OpReshape, INT8); +DEF_INSTANTIATE_RESHAPE(OpReshape, INT16); +DEF_INSTANTIATE_RESHAPE(OpReshape, INT32); +DEF_INSTANTIATE_RESHAPE(OpReshape, BOOL); + +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FLOAT); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, AINT8); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BOOL); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL); diff --git a/reference_model/src/ops/data_layout.h b/reference_model/src/ops/data_layout.h new file mode 100644 index 0000000..100bd6b --- /dev/null +++ b/reference_model/src/ops/data_layout.h @@ -0,0 +1,216 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPS_DATA_LAYOUT_H +#define OPS_DATA_LAYOUT_H + +#include "graph_node.h" + +using namespace tosa; + +namespace TosaReference +{ + +template <int Rank, DType Dtype> +class OpConcat : public GraphNode +{ +public: + OpConcat(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpConcat(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + using InEigenType = typename GetEigenType<Dtype>::type; + using OutEigenType = typename GetEigenType<Dtype>::type; + using TIn = Eigen::Tensor<InEigenType, Rank>; + using TOut = Eigen::Tensor<OutEigenType, Rank>; + +protected: + Eigen::array<int, Rank> reverser; + TosaReference::TensorTemplate<TIn>* lhs; + TosaReference::TensorTemplate<TIn>* rhs; + TosaAxisAttribute* attribute; + TosaReference::TensorTemplate<TOut>* out; +}; + +template <int Rank, DType Dtype> +class OpPad : public GraphNode +{ +public: + OpPad(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpPad(); + virtual int checkTensorAttributes(); + virtual int eval(); + + using InEigenType = typename GetEigenType<Dtype>::type; + using OutEigenType = typename GetEigenType<Dtype>::type; + using TIn = Eigen::Tensor<InEigenType, Rank>; + using TOut = Eigen::Tensor<OutEigenType, Rank>; + +protected: + Eigen::array<std::pair<ptrdiff_t, ptrdiff_t>, Rank> paddings_array; + TosaReference::TensorTemplate<TIn>* in; + TosaReference::TensorTemplate<TOut>* out; + TosaPadQuantInfo* qinfo; +}; + +template <int InRank, int OutRank, DType Dtype> +class OpReshape : public GraphNode +{ +public: + OpReshape(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpReshape(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + using InEigenType = typename GetEigenType<Dtype>::type; + using OutEigenType = typename GetEigenType<Dtype>::type; + using TIn = Eigen::Tensor<InEigenType, InRank>; + using TOut = Eigen::Tensor<OutEigenType, OutRank>; + +protected: + Eigen::array<Eigen::Index, OutRank> array_shape; + Eigen::array<Eigen::Index, InRank> in_reverser; + Eigen::array<Eigen::Index, OutRank> out_reverser; + TosaReference::TensorTemplate<TIn>* in; + TosaReshapeAttribute* attribute; + TosaReference::TensorTemplate<TOut>* out; +}; + +template <int Rank, DType Dtype> +class OpReverse : public GraphNode +{ +public: + OpReverse(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpReverse(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + using InEigenType = typename GetEigenType<Dtype>::type; + using OutEigenType = typename GetEigenType<Dtype>::type; + using TIn = Eigen::Tensor<InEigenType, Rank>; + using TOut = Eigen::Tensor<OutEigenType, Rank>; + +protected: + TosaAxisAttribute* attribute; + TosaReference::TensorTemplate<TIn>* in; + TosaReference::TensorTemplate<TOut>* out; + Eigen::array<bool, Rank> reverse_array; +}; + +template <int Rank, DType Dtype> +class OpSlice : public GraphNode +{ +public: + OpSlice(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpSlice(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + using InEigenType = typename GetEigenType<Dtype>::type; + using OutEigenType = typename GetEigenType<Dtype>::type; + using TIn = Eigen::Tensor<InEigenType, Rank>; + using TOut = Eigen::Tensor<OutEigenType, Rank>; + +protected: + TosaSliceAttribute* attribute; + Eigen::array<Eigen::Index, Rank> begin_array; + Eigen::array<Eigen::Index, Rank> size_array; + TosaReference::TensorTemplate<TIn>* in; + TosaReference::TensorTemplate<TOut>* out; +}; + +template <int Rank, DType Dtype> +class OpTileBase : public GraphNode +{ +public: + OpTileBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpTileBase(); + + virtual int checkTensorAttributes(); + + using InEigenType = typename GetEigenType<Dtype>::type; + using OutEigenType = typename GetEigenType<Dtype>::type; + using TIn = Eigen::Tensor<InEigenType, Rank>; + using TOut = Eigen::Tensor<OutEigenType, Rank>; + +protected: + TosaTileAttribute* attribute; + TosaReference::TensorTemplate<TIn>* in; + TosaReference::TensorTemplate<TOut>* out; +}; + +// primary template for op tile +template <int Rank, DType Dtype> +class OpTile : public OpTileBase<Rank, Dtype> +{ +public: + OpTile(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : OpTileBase<Rank, Dtype>(attribute_, qinfo_, id_) + {} + +protected: + virtual int eval(); +}; + +// partial specialization for specific rank +#define DEF_OP_TILE_RANK(N) \ + template <DType Dtype> \ + class OpTile<N, Dtype> : public OpTileBase<N, Dtype> \ + { \ + public: \ + OpTile(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ + : OpTileBase<N, Dtype>(attribute_, qinfo_, id_) \ + {} \ + \ + protected: \ + virtual int eval(); \ + }; + +DEF_OP_TILE_RANK(1) +DEF_OP_TILE_RANK(2) +DEF_OP_TILE_RANK(3) +DEF_OP_TILE_RANK(4) + +#undef DEF_OP_TILE_RANK + +template <int Rank, DType Dtype> +class OpTranspose : public GraphNode +{ +public: + OpTranspose(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpTranspose(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + using InEigenType = typename GetEigenType<Dtype>::type; + using OutEigenType = typename GetEigenType<Dtype>::type; + using TIn = Eigen::Tensor<InEigenType, Rank>; + using TOut = Eigen::Tensor<OutEigenType, Rank>; + +protected: + Eigen::array<int, Rank> perm_array; + TosaReference::TensorTemplate<TIn>* in; + TosaReference::TensorTemplate<ETensor1<int32_t>>* perm_tensor; + TosaReference::TensorTemplate<TOut>* out; +}; +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/ops/data_nodes.cc b/reference_model/src/ops/data_nodes.cc new file mode 100644 index 0000000..2ee4935 --- /dev/null +++ b/reference_model/src/ops/data_nodes.cc @@ -0,0 +1,172 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "data_nodes.h" + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +OpConst::OpConst(uint64_t id_) + : GraphNode(Op_CONST, id_) +{ + setRequiredOperands(0, 1); +} + +OpConst::~OpConst() +{} + +int OpConst::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + return 0; +} + +int OpConst::eval() +{ + // Evaluation is trivial for constants + return GraphNode::eval(); +} + +OpPlaceholder::OpPlaceholder(uint64_t id_) + : GraphNode(Op_PLACEHOLDER, id_) +{ + setRequiredOperands(0, 1); +} + +OpPlaceholder::~OpPlaceholder() +{} + +int OpPlaceholder::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + return 0; +} + +int OpPlaceholder::eval() +{ + // Evaluation is trivial for placeholders + return GraphNode::eval(); +} + +template <int Rank, DType Dtype> +OpIdentity<Rank, Dtype>::OpIdentity(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_IDENTITY, id_) +{ + setRequiredOperands(1, 1); + setRequiredRank(0, 6); +} + +template <int Rank, DType Dtype> +OpIdentity<Rank, Dtype>::~OpIdentity() +{} + +template <int Rank, DType Dtype> +int OpIdentity<Rank, Dtype>::checkTensorAttributes() +{ + + if (inputs.size() != outputs.size()) + { + printNodeValidationError("Input and output tensor list lengths are not equal"); + return 1; + } + + in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); + + if (in->matchRankTypeShape(*out)) + { + printNodeValidationError("Input and output tensor rank, type, or shape do not match"); + return 1; + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpIdentity<Rank, Dtype>::eval() +{ + out->getTensor() = in->getTensor(); + + return GraphNode::eval(); +} + +template <int Rank, DType Dtype> +OpIdentityN<Rank, Dtype>::OpIdentityN(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_IDENTITYN, id_) +{ + setRequiredRank(0, 6); +} + +template <int Rank, DType Dtype> +OpIdentityN<Rank, Dtype>::~OpIdentityN() +{} + +template <int Rank, DType Dtype> +int OpIdentityN<Rank, Dtype>::checkTensorAttributes() +{ + + if (inputs.size() != outputs.size()) + { + printNodeValidationError("Input and output tensor list lengths are not equal"); + return 1; + } + + for (size_t i = 0; i < inputs.size(); i++) + { + ins.push_back(dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[i])); + outs.push_back(dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[i])); + + if (ins[i]->matchRankTypeShape(*outs[i])) + { + printNodeValidationError("Input and output tensor rank, type, or shape do not match"); + return 1; + } + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpIdentityN<Rank, Dtype>::eval() +{ + for (size_t i = 0; i < ins.size(); i++) + { + outs[i]->getTensor() = ins[i]->getTensor(); + } + + return GraphNode::eval(); +} + +// template explicit instantiation +// note OpConst and OpPlaceholder are not templated + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, BOOL); diff --git a/reference_model/src/ops/data_nodes.h b/reference_model/src/ops/data_nodes.h new file mode 100644 index 0000000..bec4669 --- /dev/null +++ b/reference_model/src/ops/data_nodes.h @@ -0,0 +1,86 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPS_DATA_NODES_H +#define OPS_DATA_NODES_H + +#include "graph_node.h" + +namespace TosaReference +{ + +class OpConst : public GraphNode +{ +public: + OpConst(uint64_t id_); + virtual ~OpConst(); + + virtual int checkTensorAttributes(); + virtual int eval(); +}; + +class OpPlaceholder : public GraphNode +{ +public: + OpPlaceholder(uint64_t id_); + virtual ~OpPlaceholder(); + + virtual int checkTensorAttributes(); + virtual int eval(); +}; + +template <int Rank, DType Dtype> +class OpIdentity : public GraphNode +{ +public: + OpIdentity(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpIdentity(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + using InEigenType = typename GetEigenType<Dtype>::type; + using OutEigenType = typename GetEigenType<Dtype>::type; + using TIn = Eigen::Tensor<InEigenType, Rank>; + using TOut = Eigen::Tensor<OutEigenType, Rank>; + +protected: + TosaReference::TensorTemplate<TIn>* in; + TosaReference::TensorTemplate<TOut>* out; +}; + +template <int Rank, DType Dtype> +class OpIdentityN : public GraphNode +{ +public: + OpIdentityN(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpIdentityN(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + using InEigenType = typename GetEigenType<Dtype>::type; + using OutEigenType = typename GetEigenType<Dtype>::type; + using TIn = Eigen::Tensor<InEigenType, Rank>; + using TOut = Eigen::Tensor<OutEigenType, Rank>; + +protected: + std::vector<TosaReference::TensorTemplate<TIn>*> ins; + std::vector<TosaReference::TensorTemplate<TOut>*> outs; +}; + +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/ops/ewise_binary.cc b/reference_model/src/ops/ewise_binary.cc new file mode 100644 index 0000000..4d4f8b9 --- /dev/null +++ b/reference_model/src/ops/ewise_binary.cc @@ -0,0 +1,586 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ewise_binary.h" +#include "arith_util.h" +#include "quant_util.h" +#include "template_types.h" + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +template <int Rank, DType InDtype, DType OutDtype> +BinaryNodeBase<Rank, InDtype, OutDtype>::BinaryNodeBase(const Op& op_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(op_, id_) +{ + setRequiredOperands(2, 1); + setRequiredRank(0, 6); + + a_rank = b_rank = max_input_rank = -1; + a = b = nullptr; + a_rank0 = b_rank0 = nullptr; + result = nullptr; + + fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return OutEigenType(); }; +} + +template <int Rank, DType InDtype, DType OutDtype> +BinaryNodeBase<Rank, InDtype, OutDtype>::~BinaryNodeBase() +{} + +template <int Rank, DType InDtype, DType OutDtype> +int BinaryNodeBase<Rank, InDtype, OutDtype>::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + a_rank = inputs[0]->getRank(); + b_rank = inputs[1]->getRank(); + if (a_rank != 0 && b_rank != 0 && a_rank != b_rank) + { + printNodeValidationError("Binary operator input ranks must match"); + return 1; + } + + max_input_rank = a_rank >= b_rank ? a_rank : b_rank; + + // A & B must be the same types + if (inputs[0]->matchType(*inputs[1])) + { + printNodeValidationError("Binary operator input types must match"); + return 1; + } + + // Result's geometry must match, but the type may be wider + if (outputs[0]->getRank() != max_input_rank) + { + printNodeValidationError("Binary operator input and output genometry must match"); + return 1; + } + + if (a_rank == max_input_rank) + { + a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + } + else + { + a_rank0 = dynamic_cast<TosaReference::TensorTemplate<ETensor0<InEigenType>>*>(inputs[0]); + } + + if (b_rank == max_input_rank) + { + b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]); + } + else + { + b_rank0 = dynamic_cast<TosaReference::TensorTemplate<ETensor0<InEigenType>>*>(inputs[1]); + } + + result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); + + // either a or b can be rank0 + // a_rank0 and b_rank0 can't be valid at the same time. + // if a and be are both rank0, they should be evaulated as 'a' and 'b', instead of 'a_rank0' and 'b_rank0' + ASSERT_MEM((a || a_rank0) && (b || b_rank0) && !(a_rank0 && b_rank0) && result); + + return 0; +} + +template <int Rank, DType InDtype, DType OutDtype> +int BinaryNodeBase<Rank, InDtype, OutDtype>::broadcast() +{ + auto output_shape = result->getTensor().dimensions(); + + std::vector<int> a_shape, b_shape; + + if (a_rank == max_input_rank) + { + a_shape = a->getShape(); + } + else + { + a_shape.assign(max_input_rank, 1); + } + + if (b_rank == max_input_rank) + { + b_shape = b->getShape(); + } + else + { + b_shape.assign(max_input_rank, 1); + } + + for (int i = 0; i < max_input_rank; i++) + { + if (a_shape[i] != output_shape[i] && a_shape[i] == 1) + { + bcast_a[i] = output_shape[i]; + } + else + { + bcast_a[i] = 1; + } + if (b_shape[i] != output_shape[i] && b_shape[i] == 1) + { + bcast_b[i] = output_shape[i]; + } + else + { + bcast_b[i] = 1; + } + } + + return 0; +} + +template <int Rank, DType InDtype, DType OutDtype> +int BinaryNode<Rank, InDtype, OutDtype>::eval() +{ + this->broadcast(); + + Eigen::array<int, Rank> reshaper; + reshaper.fill(1); + TIn ia, ib; + + if (this->a_rank == this->max_input_rank) + { + ia = this->a->getTensor().broadcast(this->bcast_a); + } + else + { + ia = this->a_rank0->getTensor().reshape(reshaper).broadcast(this->bcast_a); + } + + if (this->b_rank == this->max_input_rank) + { + ib = this->b->getTensor().broadcast(this->bcast_b); + } + else + { + ib = this->b_rank0->getTensor().reshape(reshaper).broadcast(this->bcast_b); + } + + this->result->getTensor() = ia.binaryExpr(ib, this->fcn); + + return GraphNode::eval(); +} + +// still need to partial specialize this, or Eigen will throw static assertion +template <DType InDtype, DType OutDtype> +int BinaryNode<0, InDtype, OutDtype>::eval() +{ + this->result->getTensor() = this->a->getTensor().binaryExpr(this->b->getTensor(), this->fcn); + + return GraphNode::eval(); +} + +template <int Rank, DType Dtype> +int OpAdd<Rank, Dtype>::register_fcn() +{ + switch (InDtype) + { + case DType_FLOAT: + case DType_INT32: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a + b; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpArithmeticRightShift<Rank, Dtype>::register_fcn() +{ + int32_t num_bits = 0; + switch (Dtype) + { + case DType_INT8: + num_bits = 8; + break; + case DType_INT16: + num_bits = 16; + break; + case DType_INT32: + num_bits = 32; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + this->fcn = [num_bits](InEigenType a, InEigenType b) -> OutEigenType { + uint32_t sign = a & (1 << (num_bits - 1)); + uint32_t ones_mask = ONES_MASK(b) << (num_bits - b); + if (sign) + return ones_mask | (a >> b); + else + return (~ones_mask) & (a >> b); + }; + + return 0; +} + +template <int Rank, DType Dtype> +int OpBitwiseAnd<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_AINT8: + case DType_INT16: + case DType_INT32: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a & b; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpBitwiseOr<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_AINT8: + case DType_INT16: + case DType_INT32: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a | b; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpBitwiseXor<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_AINT8: + case DType_INT16: + case DType_INT32: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpLogicalAnd<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_BOOL: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a && b; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpLogicalLeftShift<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_INT8: + case DType_INT16: + case DType_INT32: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a << b; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpLogicalRightShift<Rank, Dtype>::register_fcn() +{ + int32_t num_bits = 0; + switch (Dtype) + { + case DType_INT8: + num_bits = 8; + break; + case DType_INT16: + num_bits = 16; + break; + case DType_INT32: + num_bits = 32; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + this->fcn = [num_bits](InEigenType a, InEigenType b) -> OutEigenType { + uint32_t mask = ONES_MASK(num_bits) >> b; + return (a >> b) & mask; + }; + + return 0; +} + +template <int Rank, DType Dtype> +int OpLogicalOr<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_BOOL: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a || b; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpLogicalXor<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_BOOL: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a ^ b; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpMaximum<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + case DType_INT32: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a > b ? a : b; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpMinimum<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + case DType_INT32: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a < b ? a : b; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template <int Rank, DType InDtype, DType OutDtype> +int OpMul<Rank, InDtype, OutDtype>::register_fcn() +{ + switch (InDtype) + { + case DType_FLOAT: + case DType_INT32: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a * b; }; + break; + case DType_INT8: + case DType_INT16: + this->fcn = [this](InEigenType lhs, InEigenType rhs) -> OutEigenType { + OutEigenType raw_output = (OutEigenType)lhs * (OutEigenType)rhs; + + OutEigenType clamped_output = std::min<OutEigenType>(QMax, std::max<OutEigenType>(raw_output, QMin)); + + return clamped_output; + }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpPow<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return powf(a, b); }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpSub<Rank, Dtype>::register_fcn() +{ + switch (InDtype) + { + case DType_FLOAT: + case DType_INT32: + this->fcn = [](InEigenType a, InEigenType b) -> OutEigenType { return a - b; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[InDtype]); + } + + return 0; +} + +template <int Rank> +OpTable<Rank>::OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_TABLE, id_) +{ + setRequiredOperands(2, 1); + setRequiredRank(0, 6); +} + +template <int Rank> +OpTable<Rank>::~OpTable() +{} + +template <int Rank> +int OpTable<Rank>::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + if (inputs[1]->getRank() != 1 || inputs[1]->getElementCount() != 513 || inputs[1]->getDtype() != DType_INT16) + { + FATAL_ERROR_NODE("OpTable: must have INT16 table with 513 entries"); + return 1; + } + + in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + table = dynamic_cast<TosaReference::TensorTemplate<TTable>*>(inputs[1]); + out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); + + ASSERT_MEM(in && table && out); + + return 0; +} + +template <int Rank> +int OpTable<Rank>::eval() +{ + this->out->getTensor() = this->in->getTensor().unaryExpr([this](InEigenType in) -> OutEigenType { + // 1. make sure input is int16 range + int32_t input_truncated = std::min<int32_t>(std::max<int32_t>(in, QInMin), QInMax); + + // 2. calculate index and interpolation fraction + int32_t index = (input_truncated >> 7) + (1 << (IntegerBits - 1)); + index = std::min<int32_t>(std::max<int32_t>(index, 0), NumTableEntries - 1); // 9-bit index + int32_t frac = (input_truncated)&0x7F; // 7-bit fraction + + // 3. interpolate, generate 16.7 (23-bit) output + int32_t base = this->table->getTensor()(index); + int32_t next = this->table->getTensor()(index + 1); + int32_t value = (base << 7) + (next - base) * frac; + + return value; + }); + + return GraphNode::eval(); +} + +// template explicit instantiation +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalAnd, BOOL); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalXor, BOOL); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32); + +DEF_INSTANTIATE_ONE_RANK_0_6(OpTable); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, FLOAT, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(BinaryNode, INT32, BOOL); diff --git a/reference_model/src/ops/ewise_binary.h b/reference_model/src/ops/ewise_binary.h new file mode 100644 index 0000000..00fb3d9 --- /dev/null +++ b/reference_model/src/ops/ewise_binary.h @@ -0,0 +1,195 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPS_EWISE_BINARY_H +#define OPS_EWISE_BINARY_H + +#include "graph_node.h" + +using namespace tosa; + +namespace TosaReference +{ + +// class BinaryNodeBase: hold common functions of all the binary nodes +// when an binary op is created, the virtual OpXXX::register_fcn() will be called +// and 'fcn' will be register with lambda function which has two inputs +// class BinaryNode: the level of indirection to partially specialize template for rank 0 +// eval() from toplevel called should call the .binaryExpr(dims, fcn) here +// this needs to be partially specialize or +// compiler will statically fail when trying to broadcast rank0 tensor +// class OpXXX: implement per-element lambda function based on different data type +// unlike BinaryNode, this doesn't need to be partially specialized + +// Eigen::Tensor does support some binary element-wise natively (e.g. CWiseMax, or '+', etc.) +// which might be faster since it could be implemented with SIMD instructions +// the way of registering lambda + .binaryExpr() might sacrifice performance here +// but it can avoid partially specialization for combination of {rankN, rank0} x {FLOAT/INT32, QU8, ...} +// needs to revisit if performance becomes a bottleneck here +template <int Rank, DType InDtype, DType OutDtype> +class BinaryNodeBase : public GraphNode +{ +public: + BinaryNodeBase(const Op& nodeType, TosaQuantInfoBase* qinfo_, const uint64_t id_); + virtual ~BinaryNodeBase(); + + virtual int checkTensorAttributes() final; + virtual int eval() = 0; + virtual int register_fcn() = 0; + + using InEigenType = typename GetEigenType<InDtype>::type; + using OutEigenType = typename GetEigenType<OutDtype>::type; + using TIn = Eigen::Tensor<InEigenType, Rank>; + using TOut = Eigen::Tensor<OutEigenType, Rank>; + +protected: + int broadcast(); + +protected: + std::function<OutEigenType(InEigenType, InEigenType)> fcn; + Eigen::array<int, Rank> bcast_a; + Eigen::array<int, Rank> bcast_b; + TosaReference::TensorTemplate<TIn>* a; + TosaReference::TensorTemplate<TIn>* b; + TosaReference::TensorTemplate<ETensor0<InEigenType>>* a_rank0; + TosaReference::TensorTemplate<ETensor0<InEigenType>>* b_rank0; + TosaReference::TensorTemplate<TOut>* result; + int a_rank; + int b_rank; + int max_input_rank; +}; + +// primary class +template <int Rank, DType InDtype, DType OutDtype> +class BinaryNode : public BinaryNodeBase<Rank, InDtype, OutDtype> +{ +public: + BinaryNode(const Op& op_, TosaQuantInfoBase* qinfo_, const uint64_t id_) + : BinaryNodeBase<Rank, InDtype, OutDtype>(op_, qinfo_, id_) + {} + virtual ~BinaryNode() + {} + + virtual int eval(); + + using InEigenType = typename GetEigenType<InDtype>::type; + using OutEigenType = typename GetEigenType<OutDtype>::type; + using TIn = Eigen::Tensor<InEigenType, Rank>; + using TOut = Eigen::Tensor<OutEigenType, Rank>; +}; + +// partial specialization for rank 0 +template <DType InDtype, DType OutDtype> +class BinaryNode<0, InDtype, OutDtype> : public BinaryNodeBase<0, InDtype, OutDtype> +{ +public: + BinaryNode(const Op& op_, TosaQuantInfoBase* qinfo_, const uint64_t id_) + : BinaryNodeBase<0, InDtype, OutDtype>(op_, qinfo_, id_) + {} + virtual ~BinaryNode() + {} + + virtual int eval(); +}; + +#define DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Opname, OPNAME) \ + template <int Rank, DType Dtype> \ + class Op##Opname : public BinaryNode<Rank, Dtype, Dtype> \ + { \ + public: \ + Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ + : BinaryNode<Rank, Dtype, Dtype>(Op_##OPNAME, qinfo_, id_) \ + { \ + register_fcn(); \ + } \ + static constexpr DType InDtype = Dtype; \ + static constexpr DType OutDtype = Dtype; \ + using InEigenType = typename GetEigenType<InDtype>::type; \ + using OutEigenType = typename GetEigenType<OutDtype>::type; \ + virtual int register_fcn(); \ + }; + +#define DEF_TEMPLATE_BINARY_OP_TWO_TYPE(Opname, OPNAME) \ + template <int Rank, DType InDtype, DType OutDtype> \ + class Op##Opname : public BinaryNode<Rank, InDtype, OutDtype> \ + { \ + public: \ + Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ + : BinaryNode<Rank, InDtype, OutDtype>(Op_##OPNAME, qinfo_, id_) \ + { \ + register_fcn(); \ + } \ + static constexpr int32_t QMin = GetQMin<OutDtype>::value; \ + static constexpr int32_t QMax = GetQMax<OutDtype>::value; \ + using InEigenType = typename GetEigenType<InDtype>::type; \ + using OutEigenType = typename GetEigenType<OutDtype>::type; \ + virtual int register_fcn(); \ + }; + +DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Add, ADD) +DEF_TEMPLATE_BINARY_OP_ONE_TYPE(ArithmeticRightShift, ARITHMETIC_RIGHT_SHIFT) +DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseAnd, BITWISE_AND) +DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseOr, BITWISE_OR) +DEF_TEMPLATE_BINARY_OP_ONE_TYPE(BitwiseXor, BITWISE_XOR) +DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalAnd, LOGICAL_AND) +DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalLeftShift, LOGICAL_LEFT_SHIFT) +DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalRightShift, LOGICAL_RIGHT_SHIFT) +DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalOr, LOGICAL_OR) +DEF_TEMPLATE_BINARY_OP_ONE_TYPE(LogicalXor, LOGICAL_XOR) +DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Maximum, MAXIMUM) +DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Minimum, MINIMUM) +DEF_TEMPLATE_BINARY_OP_TWO_TYPE(Mul, MUL) +DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Pow, POW) +DEF_TEMPLATE_BINARY_OP_ONE_TYPE(Sub, SUB) + +#undef DEF_TEMPLATE_BINARY_OP_ONE_TYPE +#undef DEF_TEMPLATE_BINARY_OP_TWO_TYPE + +template <int Rank> +class OpTable : public GraphNode +{ +public: + OpTable(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpTable(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + static constexpr DType InDtype = DType_INT16; + static constexpr DType TableDtype = DType_INT16; + static constexpr DType OutDtype = DType_INT32; + using InEigenType = typename GetEigenType<InDtype>::type; + using TableEigenType = typename GetEigenType<TableDtype>::type; + using OutEigenType = typename GetEigenType<OutDtype>::type; + using TIn = Eigen::Tensor<InEigenType, Rank>; + using TTable = Eigen::Tensor<TableEigenType, 1>; + using TOut = Eigen::Tensor<OutEigenType, Rank>; + static constexpr int32_t IntegerBits = 9; + static constexpr int32_t FractionBits = 7; + static constexpr int32_t NumTableEntries = (1 << IntegerBits); + static constexpr int32_t QInMin = GetQMin<InDtype>::value; + static constexpr int32_t QInMax = GetQMax<InDtype>::value; + static constexpr int32_t QOutMin = GetQMin<OutDtype>::value; + static constexpr int32_t QOutMax = GetQMax<OutDtype>::value; + +protected: + TosaReference::TensorTemplate<TIn>* in; + TosaReference::TensorTemplate<TTable>* table; + TosaReference::TensorTemplate<TOut>* out; +}; + +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/ops/ewise_ternary.cc b/reference_model/src/ops/ewise_ternary.cc new file mode 100644 index 0000000..eded0d7 --- /dev/null +++ b/reference_model/src/ops/ewise_ternary.cc @@ -0,0 +1,115 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ewise_ternary.h" + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +template <int Rank, DType Dtype> +OpSelectBase<Rank, Dtype>::OpSelectBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_SELECT, id_) +{ + setRequiredOperands(3, 1); + setRequiredRank(0, 6); +} + +template <int Rank, DType Dtype> +OpSelectBase<Rank, Dtype>::~OpSelectBase() +{} + +template <int Rank, DType Dtype> +int OpSelectBase<Rank, Dtype>::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(inputs[2]) || + validateRequiredRank(outputs[0])) + { + return 1; + } + + // output and input must be the same types + if (inputs[0]->matchRank(*outputs[0]) || inputs[1]->matchRankType(*outputs[0]) || + inputs[2]->matchRankType(*outputs[0])) + { + printNodeValidationError("Failure to match input and output rank and type"); + return 1; + } + + cond = dynamic_cast<TosaReference::TensorTemplate<TCond>*>(inputs[0]); + then_val = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]); + else_val = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[2]); + out = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(outputs[0]); + + return 0; +} + +template <int Rank, DType Dtype> +int OpSelectBase<Rank, Dtype>::eval() +{ + FATAL_ERROR_NODE("shouldn't be called"); +} + +template <int Rank, DType Dtype> +int OpSelect<Rank, Dtype>::broadcast() +{ + std::vector<int> cond_shape = this->cond->getShape(); + std::vector<int> then_shape = this->then_val->getShape(); + std::vector<int> else_shape = this->else_val->getShape(); + std::vector<int> out_shape = this->out->getShape(); + + for (int i = 0; i < Rank; i++) + { + this->bcast_cond[i] = (cond_shape[i] == 1) ? std::max(then_shape[i], else_shape[i]) : 1; + this->bcast_then[i] = (then_shape[i] == 1) ? std::max(cond_shape[i], else_shape[i]) : 1; + this->bcast_else[i] = (else_shape[i] == 1) ? std::max(then_shape[i], cond_shape[i]) : 1; + ASSERT_MSG_NODE((this->bcast_cond[i] * cond_shape[i]) == out_shape[i], "SELECT broadcast invariant failed"); + ASSERT_MSG_NODE((this->bcast_then[i] * then_shape[i]) == out_shape[i], "SELECT broadcast invariant failed"); + ASSERT_MSG_NODE((this->bcast_else[i] * else_shape[i]) == out_shape[i], "SELECT broadcast invariant failed"); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpSelect<Rank, Dtype>::eval() +{ + this->broadcast(); + this->out->getTensor() = this->cond->getTensor() + .broadcast(this->bcast_cond) + .select(this->then_val->getTensor().broadcast(this->bcast_then), + this->else_val->getTensor().broadcast(this->bcast_else)); + + return GraphNode::eval(); +} + +template <DType Dtype> +int OpSelect<0, Dtype>::eval() +{ + this->out->getTensor() = this->cond->getTensor().select(this->then_val->getTensor(), this->else_val->getTensor()); + + return GraphNode::eval(); +} + +// template explicit instantiation +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BOOL); diff --git a/reference_model/src/ops/ewise_ternary.h b/reference_model/src/ops/ewise_ternary.h new file mode 100644 index 0000000..b354247 --- /dev/null +++ b/reference_model/src/ops/ewise_ternary.h @@ -0,0 +1,83 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPS_TERNARY_H +#define OPS_TERNARY_H + +#include "graph_node.h" + +using namespace tosa; + +namespace TosaReference +{ + +// The Ternary Select op has the following operands: +// 1. Cond: rank N, type=bool +// 2. Then_val: Rank N, type=<V> +// 3. Else_val: Rank N, type=<V> +// 4. Result: Rank N, type=<V> +// Cond, Then_val, Else_val need to be mutually-broadcastable +template <int Rank, DType Dtype> +class OpSelectBase : public GraphNode +{ +public: + OpSelectBase(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpSelectBase(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + using CondEigenType = typename GetEigenType<DType_BOOL>::type; + using InEigenType = typename GetEigenType<Dtype>::type; + using TCond = Eigen::Tensor<CondEigenType, Rank>; + using TIn = Eigen::Tensor<InEigenType, Rank>; + +protected: + TosaReference::TensorTemplate<TCond>* cond; + Eigen::array<int, Rank> bcast_cond; + Eigen::array<int, Rank> bcast_then; + Eigen::array<int, Rank> bcast_else; + TosaReference::TensorTemplate<TIn>* then_val; + TosaReference::TensorTemplate<TIn>* else_val; + TosaReference::TensorTemplate<TIn>* out; +}; + +// primary class +template <int Rank, DType Dtype> +class OpSelect : public OpSelectBase<Rank, Dtype> +{ +public: + OpSelect(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : OpSelectBase<Rank, Dtype>(attribute_, qinfo_, id_) + {} + virtual int eval(); + int broadcast(); + + using InEigenType = typename OpSelectBase<Rank, Dtype>::InEigenType; +}; + +// partial specialization for rank 0 +template <DType Dtype> +class OpSelect<0, Dtype> : public OpSelectBase<0, Dtype> +{ +public: + OpSelect(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : OpSelectBase<0, Dtype>(attribute_, qinfo_, id_) + {} + virtual int eval(); +}; +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/ops/ewise_unary.cc b/reference_model/src/ops/ewise_unary.cc new file mode 100644 index 0000000..d7bddc0 --- /dev/null +++ b/reference_model/src/ops/ewise_unary.cc @@ -0,0 +1,302 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "ewise_unary.h" +#include "quant_util.h" +#include "template_types.h" +#include <cmath> + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +template <int Rank, DType Dtype> +UnaryNode<Rank, Dtype>::UnaryNode(const Op& op_, uint64_t id_) + : GraphNode(op_, id_) +{ + setRequiredOperands(1, 1); + setRequiredRank(0, 6); + + fcn = [](InEigenType a) -> OutEigenType { return OutEigenType(); }; +} + +template <int Rank, DType Dtype> +UnaryNode<Rank, Dtype>::~UnaryNode() +{} + +template <int Rank, DType Dtype> +int UnaryNode<Rank, Dtype>::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + // output and input must be the same types + if (inputs[0]->matchRankSize(*outputs[0])) + { + printNodeValidationError("UnaryNode: input and output rank must match"); + return 1; + } + + a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + result = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); + + ASSERT_MEM(a && result); + + return 0; +} + +template <int Rank, DType Dtype> +int UnaryNode<Rank, Dtype>::eval() +{ + this->result->getTensor() = this->a->getTensor().unaryExpr(this->fcn); + + return GraphNode::eval(); +} + +template <int Rank, DType Dtype> +int OpAbs<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + case DType_INT32: + this->fcn = [](InEigenType a) -> OutEigenType { return a > (InEigenType)0 ? a : (-a); }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpBitwiseNot<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_AINT8: + case DType_INT16: + case DType_INT32: + this->fcn = [](InEigenType a) -> OutEigenType { return ~a; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpCeil<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + this->fcn = [](InEigenType a) -> OutEigenType { return ceilf(a); }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpClz<Rank, Dtype>::register_fcn() +{ + int32_t num_bits; + switch (Dtype) + { + case DType_INT32: + num_bits = 32; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + this->fcn = [num_bits](int32_t a) -> int32_t { + int32_t leading_zeros = 0; + for (int bit = num_bits - 1; bit >= 0; bit--) + { + if (((a >> bit) & 0x1) == 0) + { + leading_zeros++; + } + else + { + break; + } + } + return leading_zeros; + }; + + return 0; +} + +template <int Rank, DType Dtype> +int OpExp<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + this->fcn = [](InEigenType a) -> OutEigenType { return expf(a); }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpFloor<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + this->fcn = [](InEigenType a) -> OutEigenType { return floorf(a); }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpLog<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + this->fcn = [](InEigenType a) -> OutEigenType { return logf(a); }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpLogicalNot<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_BOOL: + this->fcn = [](InEigenType a) -> OutEigenType { return !a; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpNegate<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + this->fcn = [](InEigenType a) -> OutEigenType { + InEigenType result = -(a); + return result; + }; + break; + case DType_INT16: + case DType_INT32: + this->fcn = [](InEigenType a) -> OutEigenType { + InEigenType result = -(a); + return result; + }; + break; + case DType_AINT8: + ASSERT(this->qinfo); + this->fcn = [this](InEigenType a) -> OutEigenType { + InEigenType result = -(a - this->qinfo->input_zp()) + this->qinfo->output_zp(); + return result; + }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpReciprocal<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / a; }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +template <int Rank, DType Dtype> +int OpRsqrt<Rank, Dtype>::register_fcn() +{ + switch (Dtype) + { + case DType_FLOAT: + this->fcn = [](InEigenType a) -> OutEigenType { return 1.0 / sqrtf(a); }; + break; + default: + FATAL_ERROR_NODE("unsupported DType %s", EnumNamesDType()[Dtype]); + } + + return 0; +} + +// template explicit instantiation +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FLOAT); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FLOAT); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FLOAT); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FLOAT); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FLOAT); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FLOAT); diff --git a/reference_model/src/ops/ewise_unary.h b/reference_model/src/ops/ewise_unary.h new file mode 100644 index 0000000..0db3cfb --- /dev/null +++ b/reference_model/src/ops/ewise_unary.h @@ -0,0 +1,102 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPS_EWISE_UNARY_H +#define OPS_EWISE_UNARY_H + +#include "graph_node.h" + +using namespace tosa; + +namespace TosaReference +{ +template <int Rank, DType Dtype> +class UnaryNode : public GraphNode +{ +public: + UnaryNode(const Op& nodeType, const uint64_t id_); + virtual ~UnaryNode(); + + virtual int checkTensorAttributes() final; + virtual int eval() final; + virtual int register_fcn() = 0; + + using InEigenType = typename GetEigenType<Dtype>::type; + using OutEigenType = typename GetEigenType<Dtype>::type; + using TIn = Eigen::Tensor<InEigenType, Rank>; + using TOut = Eigen::Tensor<OutEigenType, Rank>; + +protected: + std::function<OutEigenType(InEigenType)> fcn; + TosaReference::TensorTemplate<TIn>* a; + TosaReference::TensorTemplate<TOut>* result; +}; + +#define DEF_TEMPLATE_UNARY_OP(Opname, OPNAME) \ + template <int Rank, DType Dtype> \ + class Op##Opname : public UnaryNode<Rank, Dtype> \ + { \ + public: \ + Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ + : UnaryNode<Rank, Dtype>(Op_##OPNAME, id_) \ + { \ + register_fcn(); \ + } \ + static constexpr int32_t QMin = GetQMin<Dtype>::value; \ + static constexpr int32_t QMax = GetQMax<Dtype>::value; \ + using InEigenType = typename GetEigenType<Dtype>::type; \ + using OutEigenType = typename GetEigenType<Dtype>::type; \ + virtual int register_fcn(); \ + }; + +#define DEF_TEMPLATE_UNARY_OP_WITH_QUANT_INFO(Opname, OPNAME) \ + template <int Rank, DType Dtype> \ + class Op##Opname : public UnaryNode<Rank, Dtype> \ + { \ + public: \ + Op##Opname(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) \ + : UnaryNode<Rank, Dtype>(Op_##OPNAME, id_) \ + { \ + INIT_QINFO(Unary); \ + register_fcn(); \ + } \ + static constexpr int32_t QMin = GetQMin<Dtype>::value; \ + static constexpr int32_t QMax = GetQMax<Dtype>::value; \ + using InEigenType = typename GetEigenType<Dtype>::type; \ + using OutEigenType = typename GetEigenType<Dtype>::type; \ + virtual int register_fcn(); \ + \ + protected: \ + TosaUnaryQuantInfo* qinfo; \ + }; + +DEF_TEMPLATE_UNARY_OP(Abs, ABS) +DEF_TEMPLATE_UNARY_OP(BitwiseNot, BITWISE_NOT) +DEF_TEMPLATE_UNARY_OP(Ceil, CEIL) +DEF_TEMPLATE_UNARY_OP(Clz, CLZ) +DEF_TEMPLATE_UNARY_OP(Exp, EXP) +DEF_TEMPLATE_UNARY_OP(Floor, FLOOR) +DEF_TEMPLATE_UNARY_OP(Log, LOG) +DEF_TEMPLATE_UNARY_OP(LogicalNot, LOGICAL_NOT) +DEF_TEMPLATE_UNARY_OP_WITH_QUANT_INFO(Negate, NEGATE) +DEF_TEMPLATE_UNARY_OP(Reciprocal, RECIPROCAL) +DEF_TEMPLATE_UNARY_OP(Rsqrt, RSQRT) + +#undef DEF_TEMPLATE_UNARY_OP +#undef DEF_TEMPLATE_UNARY_OP_WITH_QUANT_INFO + +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/ops/image.cc b/reference_model/src/ops/image.cc new file mode 100644 index 0000000..d3352ce --- /dev/null +++ b/reference_model/src/ops/image.cc @@ -0,0 +1,169 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "image.h" +#include "arith_util.h" +#include "quant_util.h" + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +template <DType InDtype, DType OutDtype> +OpResize<InDtype, OutDtype>::OpResize(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_RESIZE, id_) +{ + setRequiredOperands(1, 1); + setRequiredRank(4, 4); + + INIT_ATTRIBUTE(Resize); +} + +template <DType InDtype, DType OutDtype> +OpResize<InDtype, OutDtype>::~OpResize() +{ + if (attribute) + delete attribute; +} + +template <DType InDtype, DType OutDtype> +int OpResize<InDtype, OutDtype>::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + return 1; + + output_size = this->attribute->output_size(); + stride = this->attribute->stride(); + offset = this->attribute->offset(); + shift = this->attribute->shift(); + mode = this->attribute->mode(); + + int output_height = outputs[0]->getShape()[1]; + int output_width = outputs[0]->getShape()[2]; + + if (this->mode == ResizeMode_BILINEAR) + { + if (OutDtype != DType_INT32 && OutDtype != DType_INT48) + { + printNodeValidationError("OpResize: invalid data type for BILINEAR"); + return 1; + } + } + else + { + if (OutDtype != DType_INT8 && OutDtype != DType_INT16) + { + printNodeValidationError("OpResize: invalid data type for NEAREST"); + return 1; + } + } + + if (output_size[0] != output_height || output_size[1] != output_width) + { + printNodeValidationError("OpResize: attribute output_size doesn't match output [height, width]"); + return 1; + } + + if (shift < 1 || shift > 11) + { + printNodeValidationError("OpResize: attribute shift should be within [1, 11]"); + return 1; + } + + if (stride[0] <= 0 || stride[1] <= 0) + { + printNodeValidationError("OpResize: invalid attribute stride"); + return 1; + } + + in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); + + ASSERT_MEM(in && out); + + return 0; +} + +template <DType InDtype, DType OutDtype> +int OpResize<InDtype, OutDtype>::eval() +{ + int in_batch = in->getShape()[0]; + int in_height = in->getShape()[1]; + int in_width = in->getShape()[2]; + int in_channels = in->getShape()[3]; + + int out_batch = out->getShape()[0]; + int out_height = out->getShape()[1]; + int out_width = out->getShape()[2]; + int out_channels = out->getShape()[3]; + + ASSERT_MSG_NODE(in_batch == out_batch, "OpResize: output tensor batch mismatch"); + ASSERT_MSG_NODE(in_channels == out_channels, "OpResize: output tensor channel mismatch"); + + for (int b = 0; b < out_batch; b++) + for (int c = 0; c < out_channels; c++) + for (int oy = 0; oy < out_height; oy++) + for (int ox = 0; ox < out_width; ox++) + { + int y = oy * stride[0] + offset[0]; + int x = ox * stride[1] + offset[1]; + + int iy = y >> shift; + int dy = y - (iy << shift); + int ix = x >> shift; + int dx = x - (ix << shift); + + int iy0 = MAX(iy, 0); + int iy1 = MIN(iy + 1, in_height - 1); + int ix0 = MAX(ix, 0); + int ix1 = MIN(ix + 1, in_width - 1); + + ASSERT_MSG(iy0 <= iy1 && ix0 <= ix1, "OpResize: invalid index (iy0, iy1, ix0, ix1)=(%d,%d,%d,%d)", + iy0, iy1, ix0, ix1); + + InEigenType v00 = in->getTensor()(b, iy0, ix0, c); + InEigenType v01 = in->getTensor()(b, iy0, ix1, c); + InEigenType v10 = in->getTensor()(b, iy1, ix0, c); + InEigenType v11 = in->getTensor()(b, iy1, ix1, c); + + OutEigenType acc; + if (mode == ResizeMode_BILINEAR) + { + acc = (OutEigenType)v00 * ((1 << shift) - dy) * ((1 << shift) - dx); + acc = acc + (OutEigenType)v01 * ((1 << shift) - dy) * dx; + acc = acc + (OutEigenType)v10 * dy * ((1 << shift) - dx); + acc = acc + (OutEigenType)v11 * dy * dx; + } + else + { + iy = (dy >> (shift - 1)) != 0 ? iy1 : iy0; + ix = (dx >> (shift - 1)) != 0 ? ix1 : ix0; + acc = in->getTensor()(b, iy, ix, c); + } + + out->getTensor()(b, oy, ox, c) = acc; + } + + return GraphNode::eval(); +} + +// template explicit instantiation +DEF_INSTANTIATE_TWO_TYPE(OpResize, INT8, INT32); +DEF_INSTANTIATE_TWO_TYPE(OpResize, INT8, INT8); +DEF_INSTANTIATE_TWO_TYPE(OpResize, INT16, INT48); +DEF_INSTANTIATE_TWO_TYPE(OpResize, INT16, INT16); diff --git a/reference_model/src/ops/image.h b/reference_model/src/ops/image.h new file mode 100644 index 0000000..9d15d49 --- /dev/null +++ b/reference_model/src/ops/image.h @@ -0,0 +1,53 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPS_IMAGE_H +#define OPS_IMAGE_H + +#include "graph_node.h" + +using namespace tosa; + +namespace TosaReference +{ + +template <DType InDtype, DType OutDtype> +class OpResize : public GraphNode +{ +public: + OpResize(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpResize(); + virtual int checkTensorAttributes() final; + virtual int eval(); + + using InEigenType = typename GetEigenType<InDtype>::type; + using OutEigenType = typename GetEigenType<OutDtype>::type; + using TIn = Eigen::Tensor<InEigenType, 4>; + using TOut = Eigen::Tensor<OutEigenType, 4>; + +protected: + TosaResizeAttribute* attribute; + std::vector<int32_t> output_size; + std::vector<int32_t> stride; + std::vector<int32_t> offset; + int32_t shift; + ResizeMode mode; + TosaReference::TensorTemplate<TIn>* in; + TosaReference::TensorTemplate<TOut>* out; +}; + +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/ops/op_factory.cc b/reference_model/src/ops/op_factory.cc new file mode 100644 index 0000000..bad0c40 --- /dev/null +++ b/reference_model/src/ops/op_factory.cc @@ -0,0 +1,432 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "op_factory.h" +#include "activation_funcs.h" +#include "comparison.h" +#include "control_flow.h" +#include "custom.h" +#include "data_layout.h" +#include "data_nodes.h" +#include "ewise_binary.h" +#include "ewise_ternary.h" +#include "ewise_unary.h" +#include "image.h" +#include "reduction.h" +#include "scatter_gather.h" +#include "tensor_ops.h" +#include "type_conversion.h" + +using namespace TosaReference; +using namespace tosa; + +GraphNode* OpFactory::newOp(TosaSerializationHandler* tsh, + Op opType, + TosaAttributeBase* attribute, + TosaQuantInfoBase* qinfo, + uint64_t id, + DType inputDType, + int inputRank, + DType outputDType, + int outputRank, + DType weightDType, + int weightRank) +{ + switch (opType) + { + // tensor_ops + case Op_ARGMAX: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FLOAT); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, AINT8); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16); + break; + case Op_AVG_POOL2D: + DEF_FACTORY_ONE_TYPE(OpAvgPool2d, FLOAT); + DEF_FACTORY_ONE_TYPE(OpAvgPool2d, AINT8); + DEF_FACTORY_ONE_TYPE(OpAvgPool2d, INT16); + break; + case Op_CONV2D: + DEF_FACTORY_TWO_TYPE(OpConv2d, FLOAT, FLOAT); + DEF_FACTORY_TWO_TYPE(OpConv2d, AINT8, INT4); + DEF_FACTORY_TWO_TYPE(OpConv2d, AINT8, INT8); + DEF_FACTORY_TWO_TYPE(OpConv2d, AINT8, AINT8); + DEF_FACTORY_TWO_TYPE(OpConv2d, INT16, INT8); + break; + case Op_DEPTHWISE_CONV2D: + DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, FLOAT, FLOAT); + DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, AINT8, INT4); + DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, AINT8, INT8); + DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, AINT8, AINT8); + DEF_FACTORY_TWO_TYPE(OpDepthwiseConv2d, INT16, INT8); + break; + case Op_FULLY_CONNECTED: + DEF_FACTORY_TWO_TYPE(OpFullyConnected, FLOAT, FLOAT); + DEF_FACTORY_TWO_TYPE(OpFullyConnected, AINT8, INT4); + DEF_FACTORY_TWO_TYPE(OpFullyConnected, AINT8, INT8); + DEF_FACTORY_TWO_TYPE(OpFullyConnected, AINT8, AINT8); + DEF_FACTORY_TWO_TYPE(OpFullyConnected, INT16, INT8); + break; + case Op_MATMUL: + DEF_FACTORY_ONE_TYPE(OpMatMul, FLOAT); + DEF_FACTORY_ONE_TYPE(OpMatMul, AINT8); + DEF_FACTORY_ONE_TYPE(OpMatMul, INT16); + break; + case Op_MAX_POOL2D: + DEF_FACTORY_ONE_TYPE(OpMaxPool2d, FLOAT); + DEF_FACTORY_ONE_TYPE(OpMaxPool2d, AINT8); + DEF_FACTORY_ONE_TYPE(OpMaxPool2d, INT16); + break; + case Op_TRANSPOSE_CONV2D: + DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, FLOAT, FLOAT); + DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, AINT8, INT4); + DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, AINT8, INT8); + DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, AINT8, AINT8); + DEF_FACTORY_TWO_TYPE(OpTransposeConv2d, INT16, INT8); + break; + + // activation_funcs + case Op_CLAMP: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClamp, INT16); + break; + case Op_RELUN: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReluN, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReluN, INT32); + break; + case Op_SIGMOID: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSigmoid, FLOAT); + break; + case Op_TANH: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTanh, FLOAT); + break; + + // ewise_binary + case Op_ADD: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAdd, INT32); + break; + case Op_ARITHMETIC_RIGHT_SHIFT: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpArithmeticRightShift, INT32); + break; + case Op_BITWISE_AND: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseAnd, INT32); + break; + case Op_BITWISE_OR: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseOr, INT32); + break; + case Op_BITWISE_XOR: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseXor, INT32); + break; + case Op_LOGICAL_AND: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalAnd, BOOL); + break; + case Op_LOGICAL_LEFT_SHIFT: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalLeftShift, INT32); + break; + case Op_LOGICAL_RIGHT_SHIFT: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalRightShift, INT32); + break; + case Op_LOGICAL_OR: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL); + break; + case Op_LOGICAL_XOR: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalOr, BOOL); + break; + case Op_MAXIMUM: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMaximum, INT32); + break; + case Op_MINIMUM: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpMinimum, INT32); + break; + case Op_MUL: + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, FLOAT, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT8, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT16, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpMul, INT32, INT32); + break; + case Op_POW: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpPow, FLOAT); + break; + case Op_SUB: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSub, INT32); + break; + case Op_TABLE: + DEF_FACTORY_ONE_RANK_0_6(OpTable); + break; + + // ewise_unary + case Op_ABS: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpAbs, INT32); + break; + case Op_BITWISE_NOT: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpBitwiseNot, INT32); + break; + case Op_CEIL: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpCeil, FLOAT); + break; + case Op_CLZ: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpClz, INT32); + break; + case Op_EXP: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpExp, FLOAT); + break; + case Op_FLOOR: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpFloor, FLOAT); + break; + case Op_LOG: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLog, FLOAT); + break; + case Op_LOGICAL_NOT: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpLogicalNot, BOOL); + break; + case Op_NEGATE: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpNegate, INT32); + break; + case Op_RECIPROCAL: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpReciprocal, FLOAT); + break; + case Op_RSQRT: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpRsqrt, FLOAT); + break; + + // ewise_ternary + case Op_SELECT: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSelect, BOOL); + break; + + // comparison + case Op_EQUAL: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpEqual, INT32); + break; + case Op_GREATER: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreater, INT32); + break; + case Op_GREATER_EQUAL: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpGreaterEqual, INT32); + break; + + // reduction + case Op_REDUCE_ALL: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL); + break; + case Op_REDUCE_ANY: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAny, BOOL); + break; + case Op_REDUCE_MAX: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FLOAT); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, AINT8); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32); + break; + case Op_REDUCE_MIN: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FLOAT); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, AINT8); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32); + break; + case Op_REDUCE_PRODUCT: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FLOAT); + break; + case Op_REDUCE_SUM: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FLOAT); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, INT32); + break; + + // data layout + case Op_CONCAT: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, FLOAT); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, AINT8); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT8); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, INT32); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpConcat, BOOL); + break; + case Op_PAD: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, FLOAT); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT32); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, AINT8); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT8); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, INT16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpPad, BOOL); + break; + case Op_RESHAPE: + DEF_FACTORY_RESHAPE(OpReshape, FLOAT); + DEF_FACTORY_RESHAPE(OpReshape, AINT8); + DEF_FACTORY_RESHAPE(OpReshape, INT8); + DEF_FACTORY_RESHAPE(OpReshape, INT16); + DEF_FACTORY_RESHAPE(OpReshape, INT32); + DEF_FACTORY_RESHAPE(OpReshape, BOOL); + break; + case Op_REVERSE: + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, FLOAT); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, AINT8); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT8); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT16); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, INT32); + DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OpReverse, BOOL); + break; + case Op_SLICE: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpSlice, BOOL); + break; + case Op_TILE: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTile, BOOL); + break; + case Op_TRANSPOSE: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, BOOL); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpTranspose, INT32); + break; + + // scatter_gather + case Op_GATHER: + { + // output.rank = input.rank - 1 + index.rank + int32_t index_rank = outputRank - inputRank + 1; + DEF_FACTORY_GATHER(OpGather, AINT8); + DEF_FACTORY_GATHER(OpGather, INT16); + DEF_FACTORY_GATHER(OpGather, INT32); + } + break; + + // image + case Op_RESIZE: + DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT8, INT32); + DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT8, INT8); + DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT16, INT48); + DEF_FACTORY_TWO_TYPE_RESIZE(OpResize, INT16, INT16); + break; + + // data_nodes + case Op_CONST: + return new OpConst(id); + case Op_PLACEHOLDER: + return new OpPlaceholder(id); + case Op_IDENTITY: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentity, BOOL); + break; + case Op_IDENTITYN: + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OpIdentityN, BOOL); + break; + + // type_conversion + case Op_CAST: + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FLOAT); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT32); + break; + case Op_RESCALE: + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT16); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT32); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, AINT8); + DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, UINT8); + break; + + // custom + case Op_CUSTOM: + return new OpCustom(id); + + // control_flow + case Op_COND_IF: + return new OpCondIf(tsh, attribute, id); + case Op_WHILE_LOOP: + return new OpWhileLoop(tsh, attribute, id); + + // Ops not recognized + default: + goto done; + + } // End of switch(opType) + +done: + return nullptr; +} diff --git a/reference_model/src/ops/op_factory.h b/reference_model/src/ops/op_factory.h new file mode 100644 index 0000000..cde6841 --- /dev/null +++ b/reference_model/src/ops/op_factory.h @@ -0,0 +1,294 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPS_OP_FACTORY_H +#define OPS_OP_FACTORY_H + +#include "attribute.h" +#include "graph_node.h" +#include "quant_info.h" +#include "template_types.h" +#include "tosa_serialization_handler.h" + +#define DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, RANK, DTYPE) \ + case RANK: \ + return new OP<RANK, DType_##DTYPE>(attribute, qinfo, id); + +#define DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, RANK, DTYPE1, DTYPE2) \ + case RANK: \ + return new OP<RANK, DType_##DTYPE1, DType_##DTYPE2>(attribute, qinfo, id); + +#define DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, RANK1, RANK2, DTYPE) \ + case RANK2: \ + return new OP<RANK1, RANK2, DType_##DTYPE>(attribute, qinfo, id); + +#define DEF_FACTORY_TWO_RANK_TWO_TYPE(OP, RANK1, RANK2, DTYPE1, DTYPE2) \ + case RANK2: \ + return new OP<RANK1, RANK2, DType_##DTYPE1, DType_##DTYPE2>(attribute, qinfo, id); + +#define DEF_FACTORY_ONE_RANK_0_6(OP) \ + switch (inputRank) \ + { \ + case 0: \ + return new OP<0>(attribute, qinfo, id); \ + case 1: \ + return new OP<1>(attribute, qinfo, id); \ + case 2: \ + return new OP<2>(attribute, qinfo, id); \ + case 3: \ + return new OP<3>(attribute, qinfo, id); \ + case 4: \ + return new OP<4>(attribute, qinfo, id); \ + case 5: \ + return new OP<5>(attribute, qinfo, id); \ + case 6: \ + return new OP<6>(attribute, qinfo, id); \ + } + +#define DEF_FACTORY_ONE_TYPE(OP, DTYPE) \ + if (inputDType == DType_##DTYPE) \ + { \ + return new OP<DType_##DTYPE>(attribute, qinfo, id); \ + } + +#define DEF_FACTORY_TWO_TYPE(OP, DTYPE1, DTYPE2) \ + if (inputDType == DType_##DTYPE1 && weightDType == DType_##DTYPE2) \ + { \ + return new OP<DType_##DTYPE1, DType_##DTYPE2>(attribute, qinfo, id); \ + } + +#define DEF_FACTORY_TWO_TYPE_RESIZE(OP, DTYPE1, DTYPE2) \ + if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ + { \ + return new OP<DType_##DTYPE1, DType_##DTYPE2>(attribute, qinfo, id); \ + } + +#define DEF_FACTORY_RANK0_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \ + if (inputDType == DType_##DTYPE) \ + { \ + switch (inputRank) \ + { \ + DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 0, DTYPE) \ + DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 1, DTYPE) \ + DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 2, DTYPE) \ + DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 3, DTYPE) \ + DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 4, DTYPE) \ + DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 5, DTYPE) \ + DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 6, DTYPE) \ + } \ + } + +#define DEF_FACTORY_RANK1_6_ONE_RANK_ONE_TYPE(OP, DTYPE) \ + if (inputDType == DType_##DTYPE) \ + { \ + switch (inputRank) \ + { \ + DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 1, DTYPE) \ + DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 2, DTYPE) \ + DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 3, DTYPE) \ + DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 4, DTYPE) \ + DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 5, DTYPE) \ + DEF_FACTORY_ONE_RANK_ONE_TYPE(OP, 6, DTYPE) \ + } \ + } + +#define DEF_FACTORY_RANK0_6_ONE_RANK_TWO_TYPE(OP, DTYPE1, DTYPE2) \ + if (inputDType == DType_##DTYPE1 && outputDType == DType_##DTYPE2) \ + { \ + switch (inputRank) \ + { \ + DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 0, DTYPE1, DTYPE2) \ + DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 1, DTYPE1, DTYPE2) \ + DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 2, DTYPE1, DTYPE2) \ + DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 3, DTYPE1, DTYPE2) \ + DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 4, DTYPE1, DTYPE2) \ + DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 5, DTYPE1, DTYPE2) \ + DEF_FACTORY_ONE_RANK_TWO_TYPE(OP, 6, DTYPE1, DTYPE2) \ + } \ + } + +#define DEF_FACTORY_RESHAPE(OP, DTYPE) \ + if (inputDType == DType_##DTYPE && outputDType == DType_##DTYPE) \ + { \ + switch (inputRank) \ + { \ + case 0: \ + { \ + switch (outputRank) \ + { \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 0, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 1, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 2, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 3, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 4, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 5, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 0, 6, DTYPE) \ + } \ + } \ + case 1: \ + { \ + switch (outputRank) \ + { \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 0, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 1, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 2, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 3, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 4, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 5, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 6, DTYPE) \ + } \ + } \ + case 2: \ + { \ + switch (outputRank) \ + { \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 0, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 1, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 2, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 3, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 4, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 5, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 6, DTYPE) \ + } \ + } \ + case 3: \ + { \ + switch (outputRank) \ + { \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 0, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 1, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 2, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 3, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 4, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 5, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 6, DTYPE) \ + } \ + } \ + case 4: \ + { \ + switch (outputRank) \ + { \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 0, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 1, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 2, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 3, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 4, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 5, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 6, DTYPE) \ + } \ + } \ + case 5: \ + { \ + switch (outputRank) \ + { \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 0, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 1, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 2, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 3, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 4, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 5, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 6, DTYPE) \ + } \ + } \ + case 6: \ + { \ + switch (outputRank) \ + { \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 0, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 1, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 2, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 3, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 4, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 5, DTYPE) \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 6, DTYPE) \ + } \ + } \ + } \ + } + +#define DEF_FACTORY_GATHER(OP, DTYPE) \ + if (inputDType == DType_##DTYPE && outputDType == DType_##DTYPE) \ + { \ + switch (inputRank) \ + { \ + case 1: \ + switch (index_rank) \ + { \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 1, DTYPE); \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 2, DTYPE); \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 3, DTYPE); \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 4, DTYPE); \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 5, DTYPE); \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 1, 6, DTYPE); \ + } \ + case 2: \ + switch (index_rank) \ + { \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 1, DTYPE); \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 2, DTYPE); \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 3, DTYPE); \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 4, DTYPE); \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 2, 5, DTYPE); \ + } \ + case 3: \ + switch (index_rank) \ + { \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 1, DTYPE); \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 2, DTYPE); \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 3, DTYPE); \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 3, 4, DTYPE); \ + } \ + case 4: \ + switch (index_rank) \ + { \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 1, DTYPE); \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 2, DTYPE); \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 4, 3, DTYPE); \ + } \ + case 5: \ + switch (index_rank) \ + { \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 1, DTYPE); \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 5, 2, DTYPE); \ + } \ + case 6: \ + switch (index_rank) \ + { \ + DEF_FACTORY_TWO_RANK_ONE_TYPE(OP, 6, 1, DTYPE); \ + } \ + } \ + } + +namespace TosaReference +{ + +class OpFactory +{ +public: + static GraphNode* newOp(tosa::TosaSerializationHandler* tsh, + tosa::Op opType, + tosa::TosaAttributeBase* attribute, + tosa::TosaQuantInfoBase* qinfo, + uint64_t id, + tosa::DType inputDType, + int inputRank, + tosa::DType outputDType, + int outputRank, + tosa::DType weightDType, + int weightRank); +}; +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/ops/reduction.cc b/reference_model/src/ops/reduction.cc new file mode 100644 index 0000000..a2adfdb --- /dev/null +++ b/reference_model/src/ops/reduction.cc @@ -0,0 +1,139 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "reduction.h" +#include "quant_util.h" + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +template <int Rank, DType Dtype> +ReduceNode<Rank, Dtype>::ReduceNode(const Op& op_, TosaAttributeBase* attribute_, uint64_t id_) + : GraphNode(op_, id_) +{ + setRequiredOperands(1, 1); + setRequiredRank(0, 4); + + INIT_ATTRIBUTE(Axis); +} + +template <int Rank, DType Dtype> +ReduceNode<Rank, Dtype>::~ReduceNode() +{ + if (attribute) + delete attribute; +} + +template <int Rank, DType Dtype> +int ReduceNode<Rank, Dtype>::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + if (attribute->axis() < 0 || attribute->axis() >= inputs[0]->getRank()) + { + printNodeValidationError("Reduce axis must between [0, input_rank - 1]"); + return 1; + } + + if (inputs[0]->matchRank(*outputs[0])) + { + printNodeValidationError("Input and output tensor ranks must match"); + return 1; + } + + in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); + + ASSERT_MEM(in && out); + + dims[0] = this->attribute->axis(); + + return 0; +} + +template <int Rank, DType Dtype> +int OpReduceAll<Rank, Dtype>::eval() +{ + this->out->getTensor() = this->in->getTensor().all(this->dims).reshape(this->out->getTensor().dimensions()); + + return GraphNode::eval(); +} + +template <int Rank, DType Dtype> +int OpReduceAny<Rank, Dtype>::eval() +{ + this->out->getTensor() = this->in->getTensor().any(this->dims).reshape(this->out->getTensor().dimensions()); + + return GraphNode::eval(); +} + +template <int Rank, DType Dtype> +int OpReduceMax<Rank, Dtype>::eval() +{ + this->out->getTensor() = this->in->getTensor().maximum(this->dims).reshape(this->out->getTensor().dimensions()); + + return GraphNode::eval(); +} + +template <int Rank, DType Dtype> +int OpReduceMin<Rank, Dtype>::eval() +{ + this->out->getTensor() = this->in->getTensor().minimum(this->dims).reshape(this->out->getTensor().dimensions()); + + return GraphNode::eval(); +} + +template <int Rank, DType Dtype> +int OpReduceProduct<Rank, Dtype>::eval() +{ + this->out->getTensor() = this->in->getTensor().prod(this->dims).reshape(this->out->getTensor().dimensions()); + + return GraphNode::eval(); +} + +template <int Rank, DType Dtype> +int OpReduceSum<Rank, Dtype>::eval() +{ + this->out->getTensor() = this->in->getTensor().sum(this->dims).reshape(this->out->getTensor().dimensions()); + + return GraphNode::eval(); +} + +// template explicit instantiation +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAll, BOOL); + +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceAny, BOOL); + +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, FLOAT); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, AINT8); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMax, INT32); + +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, FLOAT); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, AINT8); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT16); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceMin, INT32); + +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceProduct, FLOAT); + +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, FLOAT); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpReduceSum, INT32); diff --git a/reference_model/src/ops/reduction.h b/reference_model/src/ops/reduction.h new file mode 100644 index 0000000..cf75812 --- /dev/null +++ b/reference_model/src/ops/reduction.h @@ -0,0 +1,109 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPS_REDUCTION_H +#define OPS_REDUCTION_H + +#include "graph_node.h" + +using namespace tosa; + +namespace TosaReference +{ + +template <int Rank, DType Dtype> +class ReduceNode : public GraphNode +{ +public: + ReduceNode(const Op& nodeType, TosaAttributeBase* attribute_, const uint64_t id_); + virtual ~ReduceNode(); + virtual int checkTensorAttributes(); + virtual int eval() = 0; + + using InEigenType = typename GetEigenType<Dtype>::type; + using OutEigenType = typename GetEigenType<Dtype>::type; + using TIn = Eigen::Tensor<InEigenType, Rank>; + using TOut = Eigen::Tensor<OutEigenType, Rank>; + +protected: + Eigen::array<int, 1> dims; + TosaReference::TensorTemplate<TIn>* in; + TosaReference::TensorTemplate<TOut>* out; + TosaAxisAttribute* attribute; +}; + +template <int Rank, DType Dtype> +class OpReduceAll : public ReduceNode<Rank, Dtype> +{ +public: + OpReduceAll(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : ReduceNode<Rank, Dtype>(Op_REDUCE_ALL, attribute_, id_) + {} + virtual int eval(); +}; + +template <int Rank, DType Dtype> +class OpReduceAny : public ReduceNode<Rank, Dtype> +{ +public: + OpReduceAny(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : ReduceNode<Rank, Dtype>(Op_REDUCE_ALL, attribute_, id_) + {} + virtual int eval(); +}; + +template <int Rank, DType Dtype> +class OpReduceMax : public ReduceNode<Rank, Dtype> +{ +public: + OpReduceMax(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : ReduceNode<Rank, Dtype>(Op_REDUCE_MAX, attribute_, id_) + {} + virtual int eval(); +}; + +template <int Rank, DType Dtype> +class OpReduceMin : public ReduceNode<Rank, Dtype> +{ +public: + OpReduceMin(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : ReduceNode<Rank, Dtype>(Op_REDUCE_MIN, attribute_, id_) + {} + virtual int eval(); +}; + +template <int Rank, DType Dtype> +class OpReduceProduct : public ReduceNode<Rank, Dtype> +{ +public: + OpReduceProduct(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : ReduceNode<Rank, Dtype>(Op_REDUCE_PRODUCT, attribute_, id_) + {} + virtual int eval(); +}; + +template <int Rank, DType Dtype> +class OpReduceSum : public ReduceNode<Rank, Dtype> +{ +public: + OpReduceSum(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : ReduceNode<Rank, Dtype>(Op_REDUCE_SUM, attribute_, id_) + {} + virtual int eval(); +}; + +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/ops/scatter_gather.cc b/reference_model/src/ops/scatter_gather.cc new file mode 100644 index 0000000..c54204a --- /dev/null +++ b/reference_model/src/ops/scatter_gather.cc @@ -0,0 +1,120 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "scatter_gather.h" +#include "quant_util.h" + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +template <int InRank, int IndexRank, DType Dtype> +OpGather<InRank, IndexRank, Dtype>::OpGather(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_GATHER, id_) +{ + setRequiredOperands(2, 1); + setRequiredRank(1, 6); + + INIT_ATTRIBUTE(Axis); +} + +template <int InRank, int IndexRank, DType Dtype> +OpGather<InRank, IndexRank, Dtype>::~OpGather() +{ + if (attribute) + delete attribute; +} + +template <int InRank, int IndexRank, DType Dtype> +int OpGather<InRank, IndexRank, Dtype>::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + // output and input must be the same types + if (inputs[0]->matchType(*outputs[0])) + { + printNodeValidationError("Failure to match input and output type"); + return 1; + } + + in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + index = dynamic_cast<TosaReference::TensorTemplate<TIndex>*>(inputs[1]); + out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); + + ASSERT_MEM(in && index && out); + + return 0; +} + +template <int InRank, int IndexRank, DType Dtype> +int OpGather<InRank, IndexRank, Dtype>::eval() +{ + int axis = attribute->axis(); + + // calculate size left and right to axis + int left_size = 1; + for (int i = 0; i < axis; ++i) + { + left_size *= in->getShape()[i]; + } + + int right_size = 1; + for (int i = axis + 1; i < in->getRank(); ++i) + { + right_size *= in->getShape()[i]; + } + + InEigenType* input_data = in->getTensor().data(); + int32_t* index_data = index->getTensor().data(); + OutEigenType* output_data = out->getTensor().data(); + + int32_t axis_size = in->getShape()[axis]; + int32_t index_count = index->getElementCount(); + + // sanity check if index is valid + // need to check until this point since index is not known until runtime + for (size_t i = 0; i < index->getElementCount(); i++) + { + if (index_data[i] >= axis_size) + { + FATAL_ERROR_NODE("OpGather: index[%lu]=%i can't exceed axis_size=%i", i, index_data[i], axis_size); + } + } + + // Eigen stores tensor in column-major + // so we iterate through dimension right to axis and the index array + // do memory copy with size of left size each time + for (int right = 0; right < right_size; ++right) + { + for (int i = 0; i < index_count; ++i) + { + std::memcpy(output_data + (right * index_count + i) * left_size, + input_data + (right * axis_size + index_data[i]) * left_size, sizeof(InEigenType) * left_size); + } + } + + return GraphNode::eval(); +} + +// template explicit instantiation +DEF_INSTANTIATE_GATHER(OpGather, AINT8); +DEF_INSTANTIATE_GATHER(OpGather, INT16); +DEF_INSTANTIATE_GATHER(OpGather, INT32); diff --git a/reference_model/src/ops/scatter_gather.h b/reference_model/src/ops/scatter_gather.h new file mode 100644 index 0000000..d9b1263 --- /dev/null +++ b/reference_model/src/ops/scatter_gather.h @@ -0,0 +1,54 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPS_SCATTER_GATHER_H +#define OPS_SCATTER_GATHER_H + +#include "graph_node.h" + +using namespace tosa; + +namespace TosaReference +{ + +// input and index can have different rank +// and infer OutRank statically +template <int InRank, int IndexRank, DType Dtype> +class OpGather : public GraphNode +{ +public: + OpGather(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpGather(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + static constexpr int OutRank = InRank - 1 + IndexRank; + using InEigenType = typename GetEigenType<Dtype>::type; + using OutEigenType = typename GetEigenType<Dtype>::type; + using TIn = Eigen::Tensor<InEigenType, InRank>; + using TIndex = Eigen::Tensor<int32_t, IndexRank>; + using TOut = Eigen::Tensor<OutEigenType, OutRank>; + +protected: + TosaAxisAttribute* attribute; + TosaReference::TensorTemplate<TIn>* in; + TosaReference::TensorTemplate<TIndex>* index; + TosaReference::TensorTemplate<TOut>* out; +}; + +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/ops/template_types.h b/reference_model/src/ops/template_types.h new file mode 100644 index 0000000..1859e03 --- /dev/null +++ b/reference_model/src/ops/template_types.h @@ -0,0 +1,277 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OP_TEMPLATE_TYPES_H +#define OP_TEMPLATE_TYPES_H + +#include "tosa_generated.h" +#include <Eigen/CXX11/Tensor> + +using namespace tosa; + +namespace TosaReference +{ +// Shorter aliase templates for common Eigen::Tensor types +template <typename T> +using ETensor0 = Eigen::Tensor<T, 0>; +template <typename T> +using ETensor1 = Eigen::Tensor<T, 1>; +template <typename T> +using ETensor2 = Eigen::Tensor<T, 2>; +template <typename T> +using ETensor3 = Eigen::Tensor<T, 3>; +template <typename T> +using ETensor4 = Eigen::Tensor<T, 4>; +template <typename T> +using ETensor5 = Eigen::Tensor<T, 5>; +template <typename T> +using ETensor6 = Eigen::Tensor<T, 6>; + +// Forward declaration +template <class T> +class TensorTemplate; + +// Shortcut to hide the TensorTemplate class. +// For example, declare Tensor1<float> to get a TensorTemplate +// with an Eigen::Tensor<float, 1> +template <typename T> +using Tensor0 = TensorTemplate<ETensor0<T>>; +template <typename T> +using Tensor1 = TensorTemplate<ETensor1<T>>; +template <typename T> +using Tensor2 = TensorTemplate<ETensor2<T>>; +template <typename T> +using Tensor3 = TensorTemplate<ETensor3<T>>; +template <typename T> +using Tensor4 = TensorTemplate<ETensor4<T>>; +template <typename T> +using Tensor5 = TensorTemplate<ETensor5<T>>; +template <typename T> +using Tensor6 = TensorTemplate<ETensor6<T>>; + +template <DType type> +struct GetEigenType; +template <> +struct GetEigenType<DType_FLOAT> +{ + using type = float; +}; +template <> +struct GetEigenType<DType_INT32> +{ + using type = int32_t; +}; +template <> +struct GetEigenType<DType_INT48> +{ + using type = int64_t; +}; +template <> +struct GetEigenType<DType_BOOL> +{ + using type = bool; +}; +template <> +struct GetEigenType<DType_AINT8> +{ + using type = int32_t; +}; +template <> +struct GetEigenType<DType_UINT8> +{ + using type = int32_t; +}; +template <> +struct GetEigenType<DType_INT4> +{ + using type = int32_t; +}; +template <> +struct GetEigenType<DType_INT8> +{ + using type = int32_t; +}; +template <> +struct GetEigenType<DType_INT16> +{ + using type = int32_t; +}; + +// Meta function to get number of bits +template <DType T> +struct GetNumBits +{ + static constexpr int32_t value = 0; +}; +template <> +struct GetNumBits<DType_BOOL> +{ + static constexpr int32_t value = 1; +}; +template <> +struct GetNumBits<DType_AINT8> +{ + static constexpr int32_t value = 8; +}; +template <> +struct GetNumBits<DType_UINT8> +{ + static constexpr int32_t value = 8; +}; +template <> +struct GetNumBits<DType_INT4> +{ + static constexpr int32_t value = 4; +}; +template <> +struct GetNumBits<DType_INT8> +{ + static constexpr int32_t value = 8; +}; +template <> +struct GetNumBits<DType_INT16> +{ + static constexpr int32_t value = 16; +}; +template <> +struct GetNumBits<DType_INT32> +{ + static constexpr int32_t value = 32; +}; +template <> +struct GetNumBits<DType_INT48> +{ + static constexpr int32_t value = 48; +}; + +// Meta function to get quantized min/max in compile time +template <DType T> +struct GetQMin +{ + static constexpr int64_t value = 0L; +}; +template <> +struct GetQMin<DType_AINT8> +{ + static constexpr int64_t value = -128L; +}; +template <> +struct GetQMin<DType_UINT8> +{ + static constexpr int64_t value = 0L; +}; +template <> +struct GetQMin<DType_INT4> +{ + static constexpr int64_t value = -8L; +}; +template <> +struct GetQMin<DType_INT8> +{ + static constexpr int64_t value = -128L; +}; +template <> +struct GetQMin<DType_INT16> +{ + static constexpr int64_t value = -32768L; +}; +template <> +struct GetQMin<DType_INT32> +{ + static constexpr int64_t value = -(1L << 31); +}; +template <> +struct GetQMin<DType_INT48> +{ + static constexpr int64_t value = -(1L << 47); +}; + +template <DType T> +struct GetQMax +{ + static constexpr int64_t value = 0L; +}; +template <> +struct GetQMax<DType_AINT8> +{ + static constexpr int64_t value = 127L; +}; +template <> +struct GetQMax<DType_UINT8> +{ + static constexpr int64_t value = 255L; +}; +template <> +struct GetQMax<DType_INT4> +{ + static constexpr int64_t value = 7L; +}; +template <> +struct GetQMax<DType_INT8> +{ + static constexpr int64_t value = 127L; +}; +template <> +struct GetQMax<DType_INT16> +{ + static constexpr int64_t value = 32767L; +}; +template <> +struct GetQMax<DType_INT32> +{ + static constexpr int64_t value = (1L << 31) - 1; +}; +template <> +struct GetQMax<DType_INT48> +{ + static constexpr int64_t value = (1L << 47) - 1; +}; + +template <DType TIn1, DType TIn2> +struct GetAccDType; +template <> +struct GetAccDType<DType_AINT8, DType_AINT8> +{ + static constexpr DType value = DType_INT32; +}; +template <> +struct GetAccDType<DType_AINT8, DType_INT4> +{ + static constexpr DType value = DType_INT32; +}; +template <> +struct GetAccDType<DType_AINT8, DType_INT8> +{ + static constexpr DType value = DType_INT32; +}; +template <> +struct GetAccDType<DType_INT16, DType_INT8> +{ + static constexpr DType value = DType_INT48; +}; +template <> +struct GetAccDType<DType_INT16, DType_INT16> +{ + static constexpr DType value = DType_INT48; +}; +template <> +struct GetAccDType<DType_FLOAT, DType_FLOAT> +{ + static constexpr DType value = DType_FLOAT; +}; + +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/ops/tensor_ops.cc b/reference_model/src/ops/tensor_ops.cc new file mode 100644 index 0000000..a735334 --- /dev/null +++ b/reference_model/src/ops/tensor_ops.cc @@ -0,0 +1,1229 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensor_ops.h" +#include "quant_util.h" +#include "template_types.h" + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +template <int Rank, DType Dtype> +OpArgMax<Rank, Dtype>::OpArgMax(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_ARGMAX, id_) +{ + setRequiredOperands(1, 1); + setRequiredRank(0, 6); + + INIT_ATTRIBUTE(Axis); +} + +template <int Rank, DType Dtype> +OpArgMax<Rank, Dtype>::~OpArgMax() +{ + if (attribute) + delete attribute; +} + +template <int Rank, DType Dtype> +int OpArgMax<Rank, Dtype>::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + output = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); + + return 0; +} + +template <int Rank, DType Dtype> +int OpArgMax<Rank, Dtype>::eval() +{ + Eigen::Tensor<DenseIndex, Rank - 1> index = this->input->getTensor().argmax(attribute->axis()); + + this->output->getTensor() = index.unaryExpr([](DenseIndex in) -> OutEigenType { return (OutEigenType)in; }); + + return GraphNode::eval(); +} + +template <DType Dtype> +OpAvgPool2d<Dtype>::OpAvgPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_AVG_POOL2D, id_) +{ + setRequiredOperands(1, 1); + setRequiredRank(4); + + INIT_ATTRIBUTE(Pool2d); + INIT_QINFO(Unary); +} + +template <DType Dtype> +OpAvgPool2d<Dtype>::~OpAvgPool2d() +{ + if (attribute) + delete attribute; +} + +template <DType Dtype> +int OpAvgPool2d<Dtype>::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + if (inputs[0]->matchType(*outputs[0])) + { + printNodeValidationError("OpAvgPool2d: input and output tensor type mismatch"); + return 1; + } + + in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); + + if (!in->hasFormat(Format_NHWC)) + { + printNodeValidationError("OpAvgPool2d: unsupported tensor format"); + return 1; + } + + if (attribute->padding().size() != 4) + { + printNodeValidationError("OpAvgPool2d: illegal size for attribute padding"); + return 1; + } + + if (attribute->kernel().size() != 2) + { + printNodeValidationError("OpAvgPool2d: illegal size for attribute kernel"); + return 1; + } + + if (attribute->stride().size() != 2) + { + printNodeValidationError("OpAvgPool2d: illegal size for attribute stride"); + return 1; + } + + return 0; +} + +template <DType Dtype> +ETensor1<int32_t> OpAvgPool2d<Dtype>::calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride) +{ + ETensor1<int32_t> result(out_size); + + int32_t total_pad = (out_size - 1) * stride + kernel_size - in_size; + total_pad = total_pad < 0 ? 0 : total_pad; + + int32_t pad_left = total_pad >> 1; + int32_t pad_right = total_pad - pad_left; + + result.setConstant(kernel_size); + + // the index left to 'left_index' and index right to 'right_index' indicates + // the input window of this output covers a pad bit + int32_t left_index = pad_left / stride; + int32_t right_index = pad_right / stride; + + // not handle ultra small activation yet + ASSERT_MSG_NODE((out_size - 1 - right_index) >= left_index, "AvgPool2d: Small activations not supported yet"); + + // minus the number of pad bit this index cover + while (left_index >= 0) + { + result(left_index) -= (pad_left - left_index * stride); + left_index--; + } + + while (right_index >= 0) + { + result(out_size - 1 - right_index) -= (pad_right - right_index * stride); + right_index--; + } + + return result; +} + +// assuming input and output tensor have same scales like tflite reference +// so no need to scale input and output +template <DType Dtype> +int OpAvgPool2d<Dtype>::eval() +{ + int in_batch = this->in->getShape()[0]; + int in_height = this->in->getShape()[1]; + int in_width = this->in->getShape()[2]; + int in_channels = this->in->getShape()[3]; + + int out_batch = this->out->getShape()[0]; + int out_height = this->out->getShape()[1]; + int out_width = this->out->getShape()[2]; + int out_channels = this->out->getShape()[3]; + + ASSERT_MSG_NODE(in_batch == out_batch, "OpAvgPool2d: tensor batch mismatch %d != %d", in_batch, out_batch); + + int padding_top = this->attribute->padding()[0]; + int padding_bottom = this->attribute->padding()[1]; + int padding_left = this->attribute->padding()[2]; + int padding_right = this->attribute->padding()[3]; + int kernel_h = this->attribute->kernel()[0]; + int kernel_w = this->attribute->kernel()[1]; + int stride_h = this->attribute->stride()[0]; + int stride_w = this->attribute->stride()[1]; + + DEBUG_INFO(OP, + "perform AvgPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], " + "stride=[%d,%d], padding=[%d,%d,%d,%d]", + in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_h, + kernel_w, stride_h, stride_w, padding_top, padding_bottom, padding_left, padding_right); + + Eigen::array<Eigen::Index, 2> im2col_input_dims; + im2col_input_dims[0] = kernel_h * kernel_w; + im2col_input_dims[1] = out_batch * out_height * out_width * out_channels; + + Eigen::array<Eigen::Index, 4> col2im_output_dims; + col2im_output_dims[0] = out_batch; + col2im_output_dims[1] = out_height; + col2im_output_dims[2] = out_width; + col2im_output_dims[3] = out_channels; + + Eigen::array<std::pair<int32_t, int32_t>, 4> padding; + padding[0] = std::make_pair(0, 0); + padding[1] = std::make_pair(padding_top, padding_bottom); + padding[2] = std::make_pair(padding_left, padding_right); + padding[3] = std::make_pair(0, 0); + + ETensor4<InEigenType> input_val = this->in->getTensor(); + if (this->qinfo) + { + input_val = input_val - (InEigenType)this->qinfo->input_zp(); + } + + ETensor4<InEigenType> input_padded = input_val.pad(padding); + + // assuming input and output have same scales + // so input and output scaling is not required + // TODO: check if this assumption TOSA made + + // extract_image_patches() output [N, KH, KW, H * W, C] + // transpose to [KH, KW, N, H * W, C] + // reshape to [KH * KW, N * H * W * C] + ETensor2<InEigenType> input_extract_patches = + input_padded.extract_image_patches(kernel_h, kernel_w, stride_h, stride_w, 1, 1, Eigen::PADDING_VALID) + .shuffle(Eigen::array<Eigen::Index, 5>{ 1, 2, 0, 3, 4 }) + .reshape(im2col_input_dims); + + // 1D result with [N * H * W * C] + ETensor1<AccEigenType> out_1d(this->out->getElementCount()); + out_1d.setZero(); + + // sum pool + for (size_t i = 0; i < this->out->getElementCount(); i++) + { + for (int32_t j = 0; j < kernel_h * kernel_w; j++) + { + out_1d(i) += (AccEigenType)input_extract_patches(j, i); + } + } + + // reshape result to [N, H, W, C] and divide with div_map + ETensor4<AccEigenType> sum = out_1d.reshape(col2im_output_dims); + + // calculate 1d height/width div_map (number of elements this pooling window covers) + // and outer product to get 2d div_map, then reshape/broadcast to [N, H, W, C] + ETensor1<int32_t> div_map_h = calculate_div_map_1d(in_height, out_height, kernel_h, stride_h); + ETensor1<int32_t> div_map_w = calculate_div_map_1d(in_width, out_width, kernel_w, stride_w); + Eigen::array<Eigen::IndexPair<Eigen::Index>, 1> contract_dims = { Eigen::IndexPair<Eigen::Index>(1, 0) }; + Eigen::array<Eigen::Index, 4> bcast{ out_batch, 1, 1, out_channels }; + + ETensor4<int32_t> div_map = + div_map_h.reshape(Eigen::array<Eigen::Index, 2>{ out_height, 1 }) + .contract(div_map_w.reshape(Eigen::array<Eigen::Index, 2>{ 1, out_width }), contract_dims) + .reshape(Eigen::array<Eigen::Index, 4>{ 1, out_height, out_width, 1 }) + .broadcast(bcast); + + if (Dtype != DType_FLOAT) + { + this->out->getTensor() = sum.binaryExpr(div_map, [](AccEigenType value, int32_t div) -> OutEigenType { + int32_t multiplier, shift; + TosaReference::QuantUtil<AccDtype>::reciprocal_scale(div, multiplier, shift); + + return (OutEigenType)TosaReference::QuantUtil<AccDtype>::apply_scale(value, multiplier, shift, false); + }); + this->out->getTensor() = this->out->getTensor() + (OutEigenType)(this->qinfo->output_zp()); + this->out->getTensor() = this->out->getTensor().cwiseMax((OutEigenType)QMin); + this->out->getTensor() = this->out->getTensor().cwiseMin((OutEigenType)QMax); + } + else + { + this->out->getTensor() = (sum / div_map.template cast<AccEigenType>()).template cast<OutEigenType>(); + } + + return GraphNode::eval(); +} + +template <DType InDtype, DType WeightDtype> +OpConv2d<InDtype, WeightDtype>::OpConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_CONV2D, id_) +{ + setRequiredOperands(3, 1); + setRequiredRank(4); + + INIT_ATTRIBUTE(Conv2d); + INIT_QINFO(Conv); +} + +template <DType InDtype, DType WeightDtype> +OpConv2d<InDtype, WeightDtype>::~OpConv2d() +{ + if (attribute) + delete attribute; + if (qinfo) + delete qinfo; +} + +template <DType InDtype, DType WeightDtype> +int OpConv2d<InDtype, WeightDtype>::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4 + if (inputs[2]->getRank() != 1) + { + printNodeValidationError("OpConv2d: bias tensor must be rank 1"); + } + + if (inputs[1]->getIsConst() == 0) + { + printNodeValidationError("OpConv2d: weight tensor is not const typed"); + } + + input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]); + bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]); + output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]); + + if (!input->hasFormat(Format_NHWC)) + { + printNodeValidationError("OpConv2d: unsupported input tensor format"); + return 1; + } + + if (!weight->hasFormat(Format_OHWI)) + { + printNodeValidationError("OpConv2d: unsupported weight tensor format"); + return 1; + } + + if (attribute->padding().size() != 4) + { + printNodeValidationError("OpConv2d: illegal size for attribute padding"); + return 1; + } + + if (attribute->stride().size() != 2) + { + printNodeValidationError("OpConv2d: illegal size for attribute stride"); + return 1; + } + + if (attribute->dilation().size() != 2) + { + printNodeValidationError("OpConv2d: illegal size for attribute dilation"); + return 1; + } + + return 0; +} + +template <DType InDtype, DType WeightDtype> +int OpConv2d<InDtype, WeightDtype>::eval() +{ + int in_batch = this->input->getShape()[0]; + int in_height = this->input->getShape()[1]; + int in_width = this->input->getShape()[2]; + int in_channels = this->input->getShape()[3]; + + int f_out_channels = this->weight->getShape()[0]; + int f_height = this->weight->getShape()[1]; + int f_width = this->weight->getShape()[2]; + int f_in_channels = this->weight->getShape()[3]; + + int b_out_channels = this->bias->getShape()[0]; + + int out_batch = this->output->getShape()[0]; + int out_height = this->output->getShape()[1]; + int out_width = this->output->getShape()[2]; + int out_channels = this->output->getShape()[3]; + + ASSERT_MSG_NODE(in_batch == out_batch, "OpConv2d: tensor batch mismatch %d != %d", in_batch, out_batch); + ASSERT_MSG_NODE(f_in_channels == in_channels, "OpConv2d: tensor input channel mismatch %d != %d", f_in_channels, + in_channels); + ASSERT_MSG_NODE(f_out_channels == out_channels, "OpConv2d: tensor output channel mismatch %d != %d", f_out_channels, + out_channels); + ASSERT_MSG_NODE(b_out_channels == out_channels, "OpConv2d: tensor output channel mismatch %d != %d", b_out_channels, + out_channels); + + int padding_top = this->attribute->padding()[0]; + int padding_bottom = this->attribute->padding()[1]; + int padding_left = this->attribute->padding()[2]; + int padding_right = this->attribute->padding()[3]; + int stride_h = this->attribute->stride()[0]; + int stride_w = this->attribute->stride()[1]; + int dilation_h = this->attribute->dilation()[0]; + int dilation_w = this->attribute->dilation()[1]; + + DEBUG_INFO(OP, + "perform OpConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], " + "stride=[%d,%d], dilation=[%d,%d], padding=[%d,%d,%d,%d]", + in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_out_channels, out_batch, + out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, padding_top, + padding_bottom, padding_left, padding_right); + + // GEMM-conv2d, left matrix is input, right matrix is weight + Eigen::array<Eigen::Index, 2> im2col_input_dims; + im2col_input_dims[0] = out_batch * out_height * out_width; + im2col_input_dims[1] = f_height * f_width * f_in_channels; + + Eigen::array<Eigen::Index, 2> im2col_weight_dims; + im2col_weight_dims[0] = f_height * f_width * f_in_channels; + im2col_weight_dims[1] = f_out_channels; + + Eigen::array<Eigen::Index, 2> bias_reshaped_dims; + bias_reshaped_dims[0] = 1; + bias_reshaped_dims[1] = b_out_channels; + + Eigen::array<Eigen::Index, 4> weight_zp_bcast_dims; + weight_zp_bcast_dims[0] = f_height; + weight_zp_bcast_dims[1] = f_width; + weight_zp_bcast_dims[2] = f_in_channels; + + Eigen::array<Eigen::Index, 2> bias_bcast_dims; + bias_bcast_dims[0] = out_batch * out_height * out_width; + bias_bcast_dims[1] = 1; + + Eigen::array<Eigen::Index, 4> col2im_output_dims; + col2im_output_dims[0] = out_batch; + col2im_output_dims[1] = out_height; + col2im_output_dims[2] = out_width; + col2im_output_dims[3] = out_channels; + + Eigen::array<Eigen::IndexPair<Eigen::Index>, 1> contract_dims = { Eigen::IndexPair<Eigen::Index>(1, 0) }; + + Eigen::array<std::pair<int32_t, int32_t>, 4> padding; + padding[0] = std::make_pair(0, 0); + padding[1] = std::make_pair(padding_top, padding_bottom); + padding[2] = std::make_pair(padding_left, padding_right); + padding[3] = std::make_pair(0, 0); + + TIn input_val = this->input->getTensor(); + TWeight weight_val = this->weight->getTensor(); + if (this->qinfo) + { + input_val = input_val - (InEigenType)this->qinfo->input_zp(); + weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp(); + } + + ETensor4<InEigenType> input_padded = input_val.pad(padding); + + // extract_image_patches() output [N, KH, KW, H * W, C] + // need to transpose to [N, H * W, KH, KW, C] + ETensor5<InEigenType> input_extract_patches = + input_padded + .extract_image_patches(f_height, f_width, stride_h, stride_w, dilation_h, dilation_w, Eigen::PADDING_VALID) + .shuffle(Eigen::array<Eigen::Index, 5>{ 0, 3, 1, 2, 4 }); + + // reshape input to [N * H * W, KH * KW * C] + ETensor2<InEigenType> im2col_input = input_extract_patches.reshape(im2col_input_dims); + + // transpose and reshape weight from [OC, H, W, IC] to [H * W * IC, OC] + ETensor2<WeightEigenType> im2col_weight = + weight_val.shuffle(Eigen::array<Eigen::Index, 4>({ 1, 2, 3, 0 })).reshape(im2col_weight_dims); + + // don't need to apply bias_multiplier ( * bias_scale and >> bias_shift) since tflite already scale it + // and reshaped from [C] to [1, C], and broadcast to [N * H * W, C] + ETensor2<AccEigenType> bias_2d = this->bias->getTensor().reshape(bias_reshaped_dims).broadcast(bias_bcast_dims); + + // output matrix is [N * H * W, C] + ETensor2<AccEigenType> contracted_result = + im2col_input.template cast<AccEigenType>().contract(im2col_weight.template cast<AccEigenType>(), contract_dims); + + // adding bias + ETensor2<AccEigenType> biased_output = contracted_result + bias_2d.template cast<AccEigenType>(); + + // reshape back to [N, H, W, C] + this->output->getTensor() = biased_output.reshape(col2im_output_dims); + + if (AccDtype == DType_INT48) + { + this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin); + this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax); + } + + return GraphNode::eval(); +} + +template <DType InDtype, DType WeightDtype> +OpDepthwiseConv2d<InDtype, WeightDtype>::OpDepthwiseConv2d(TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(Op_DEPTHWISE_CONV2D, id_) +{ + setRequiredOperands(3, 1); + setRequiredRank(4); + + INIT_ATTRIBUTE(Conv2d); + INIT_QINFO(Conv); +} + +template <DType InDtype, DType WeightDtype> +OpDepthwiseConv2d<InDtype, WeightDtype>::~OpDepthwiseConv2d() +{ + if (attribute) + delete attribute; + if (qinfo) + delete qinfo; +} + +template <DType InDtype, DType WeightDtype> +int OpDepthwiseConv2d<InDtype, WeightDtype>::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + // 'bias' checked separatedly since it doens't make sense to make required rank ranging from 1 to 4 + if (inputs[2]->getRank() != 1) + { + printNodeValidationError("OpDepthwiseConv2d: bias tensor must be rank 1"); + } + + if (inputs[1]->getIsConst() == 0) + { + printNodeValidationError("OpDepthwiseConv2d: weight tensor is not const typed"); + } + + input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]); + bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]); + output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]); + + if (!input->hasFormat(Format_NHWC)) + { + printNodeValidationError("OpDepthwiseConv2d: unsupported input tensor format"); + return 1; + } + + if (!weight->hasFormat(Format_HWIM)) + { + printNodeValidationError("OpDepthwiseConv2d: unsupported weight tensor format"); + return 1; + } + + if (attribute->padding().size() != 4) + { + printNodeValidationError("OpDepthwiseConv2d: illegal size for attribute padding"); + return 1; + } + + if (attribute->stride().size() != 2) + { + printNodeValidationError("OpDepthwiseConv2d: illegal size for attribute stride"); + return 1; + } + + if (attribute->dilation().size() != 2) + { + printNodeValidationError("OpDepthwiseConv2d: illegal size for attribute dilation"); + return 1; + } + + return 0; +} + +template <DType InDtype, DType WeightDtype> +int OpDepthwiseConv2d<InDtype, WeightDtype>::eval() +{ + int in_batch = this->input->getShape()[0]; + int in_height = this->input->getShape()[1]; + int in_width = this->input->getShape()[2]; + int in_channels = this->input->getShape()[3]; + + int f_height = this->weight->getShape()[0]; + int f_width = this->weight->getShape()[1]; + int f_in_channels = this->weight->getShape()[2]; + int f_multiplier = this->weight->getShape()[3]; + + int b_out_channels = this->bias->getShape()[0]; + + int out_batch = this->output->getShape()[0]; + int out_height = this->output->getShape()[1]; + int out_width = this->output->getShape()[2]; + int out_channels = this->output->getShape()[3]; + + ASSERT_MSG_NODE(in_batch == out_batch, "OpDepthwiseConv2d: tensor batch mismatch %d != %d", in_batch, out_batch); + ASSERT_MSG_NODE(f_in_channels == in_channels, "OpDepthwiseConv2d: tensor input channel mismatch %d != %d", + f_in_channels, in_channels); + ASSERT_MSG_NODE(in_channels * f_multiplier == out_channels, + "OpDepthwiseConv2d: tensor output channel mismatch %d != %d", in_channels * f_multiplier, + out_channels); + ASSERT_MSG_NODE(b_out_channels == out_channels, "OpDepthwiseConv2d: tensor b_out_channels mismatch %d != %d", + b_out_channels, out_channels); + + int padding_top = this->attribute->padding()[0]; + int padding_bottom = this->attribute->padding()[1]; + int padding_left = this->attribute->padding()[2]; + int padding_right = this->attribute->padding()[3]; + int stride_h = this->attribute->stride()[0]; + int stride_w = this->attribute->stride()[1]; + int dilation_h = this->attribute->dilation()[0]; + int dilation_w = this->attribute->dilation()[1]; + + DEBUG_INFO(OP, + "perform OpDepthwiseConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], " + "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], padding=[%d,%d,%d,%d]", + in_batch, in_height, in_width, in_channels, f_height, f_width, f_in_channels, f_multiplier, out_batch, + out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, padding_top, + padding_bottom, padding_left, padding_right); + + Eigen::array<std::pair<int32_t, int32_t>, 4> padding; + padding[0] = std::make_pair(0, 0); + padding[1] = std::make_pair(padding_top, padding_bottom); + padding[2] = std::make_pair(padding_left, padding_right); + padding[3] = std::make_pair(0, 0); + + TIn input_val = this->input->getTensor(); + TWeight weight_val = this->weight->getTensor(); + if (this->qinfo) + { + input_val = input_val - (InEigenType)this->qinfo->input_zp(); + weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp(); + } + + ETensor4<InEigenType> input_padded = input_val.pad(padding); + + // GEMM doesn't fit well with DepthwiseConv2d + // 1. use extract_image_patches() to handle stride/dilation/padding + // 2. perform direct convolution + + // 1. extract_image_patches() output [N, KH, KW, OH * OW, IC] + ETensor5<InEigenType> input_extract_patches = input_padded.extract_image_patches( + f_height, f_width, stride_h, stride_w, dilation_h, dilation_w, Eigen::PADDING_VALID); + + Eigen::array<Eigen::Index, 4> reshape_dim; + reshape_dim.fill(1); + reshape_dim[3] = b_out_channels; + + Eigen::array<Eigen::Index, 4> bcast; + bcast[0] = out_batch; + bcast[1] = out_height; + bcast[2] = out_width; + bcast[3] = 1; + + // initialize with bias + this->output->getTensor() = this->bias->getTensor().reshape(reshape_dim).broadcast(bcast); + + // 2. direct depthwise convolution + for (int ob = 0; ob < out_batch; ob++) + { + for (int oh = 0; oh < out_height; oh++) + { + for (int ow = 0; ow < out_width; ow++) + { + for (int ic = 0; ic < in_channels; ic++) + { + for (int cm = 0; cm < f_multiplier; cm++) + { + for (int fh = 0; fh < f_height; fh++) + { + for (int fw = 0; fw < f_width; fw++) + { + this->output->getTensor()(ob, oh, ow, ic * f_multiplier + cm) += + ((AccEigenType)input_extract_patches(ob, fh, fw, ow * out_height + oh, ic) * + (AccEigenType)weight_val(fh, fw, ic, cm)); + } + } + } + } + } + } + } + + if (AccDtype == DType_INT48) + { + this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin); + this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax); + } + + return GraphNode::eval(); +} + +template <DType InDtype, DType WeightDtype> +OpFullyConnected<InDtype, WeightDtype>::OpFullyConnected(TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(Op_FULLY_CONNECTED, id_) +{ + setRequiredOperands(3, 1); + setRequiredRank(2); + + INIT_QINFO(Conv); +} + +template <DType InDtype, DType WeightDtype> +OpFullyConnected<InDtype, WeightDtype>::~OpFullyConnected() +{ + if (qinfo) + delete qinfo; +} + +template <DType InDtype, DType WeightDtype> +int OpFullyConnected<InDtype, WeightDtype>::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]); + bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]); + + if (input->getShape()[1] != weight->getShape()[1]) + { + printNodeValidationError("OpFullyConnected operator input.shape[1] should match weight.shape[1]"); + return 1; + } + + if (weight->getShape()[0] != bias->getShape()[0]) + { + printNodeValidationError("OpFullyConnected operator bias.shape[0] should match weight.shape[0]"); + return 1; + } + + output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]); + + return 0; +} + +template <DType InDtype, DType WeightDtype> +int OpFullyConnected<InDtype, WeightDtype>::eval() +{ + typedef Eigen::Tensor<int, 1>::DimensionPair DimPair; + Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } }; + + Eigen::array<Eigen::Index, 2> weight_shuffle{ 1, 0 }; + + Eigen::array<Eigen::Index, 2> bias_reshape; + bias_reshape[0] = 1; + bias_reshape[1] = this->bias->getShape()[0]; + + Eigen::array<Eigen::Index, 2> bias_bcast; + bias_bcast[0] = this->input->getShape()[0]; + bias_bcast[1] = 1; + + TIn input_val = this->input->getTensor(); + TWeight weight_val = this->weight->getTensor().shuffle(weight_shuffle); + if (this->qinfo) + { + input_val = input_val - (InEigenType)this->qinfo->input_zp(); + weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp(); + } + + this->output->getTensor() = + input_val.template cast<AccEigenType>().contract(weight_val.template cast<AccEigenType>(), dims) + + this->bias->getTensor().reshape(bias_reshape).broadcast(bias_bcast); + + if (AccDtype == DType_INT48) + { + this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin); + this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax); + } + return GraphNode::eval(); +} + +template <DType Dtype> +OpMatMul<Dtype>::OpMatMul(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_MATMUL, id_) +{ + setRequiredOperands(2, 1); + setRequiredRank(2); + + INIT_QINFO(MatMul); +} + +template <DType Dtype> +OpMatMul<Dtype>::~OpMatMul() +{ + if (qinfo) + delete qinfo; +} + +template <DType Dtype> +int OpMatMul<Dtype>::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + a = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + b = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[1]); + + if (a->getShape()[1] != b->getShape()[0]) + { + printNodeValidationError("OpMatMul operator a.shape[1] should match b.shape[0]"); + return 1; + } + + c = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]); + + return 0; +} + +template <DType Dtype> +int OpMatMul<Dtype>::eval() +{ + typedef Eigen::Tensor<int, 1>::DimensionPair DimPair; + Eigen::array<DimPair, 1> dims{ { DimPair(1, 0) } }; + + TIn a_val = this->a->getTensor(); + TIn b_val = this->b->getTensor(); + if (this->qinfo) + { + a_val = a_val - (InEigenType)this->qinfo->a_zp(); + b_val = b_val - (InEigenType)this->qinfo->b_zp(); + } + + this->c->getTensor() = a_val.template cast<AccEigenType>().contract(b_val.template cast<AccEigenType>(), dims); + + if (AccDtype == DType_INT48) + { + this->c->getTensor() = this->c->getTensor().cwiseMax((AccEigenType)AccQMin); + this->c->getTensor() = this->c->getTensor().cwiseMin((AccEigenType)AccQMax); + } + + return GraphNode::eval(); +} + +template <DType Dtype> +OpMaxPool2d<Dtype>::OpMaxPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_MAX_POOL2D, id_) +{ + setRequiredOperands(1, 1); + setRequiredRank(4); + + INIT_ATTRIBUTE(Pool2d); +} + +template <DType Dtype> +OpMaxPool2d<Dtype>::~OpMaxPool2d() +{ + if (attribute) + delete attribute; +} + +template <DType Dtype> +int OpMaxPool2d<Dtype>::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + if (inputs[0]->matchType(*outputs[0])) + { + printNodeValidationError("OpMaxPool2d: input and output tensor type mismatch"); + return 1; + } + + in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); + + if (!in->hasFormat(Format_NHWC)) + { + printNodeValidationError("OpMaxPool2d: unsupported tensor format"); + return 1; + } + + if (attribute->padding().size() != 4) + { + printNodeValidationError("OpMaxPool2d: illegal size for attribute padding"); + return 1; + } + + if (attribute->kernel().size() != 2) + { + printNodeValidationError("OpMaxPool2d: illegal size for attribute kernel"); + return 1; + } + + if (attribute->stride().size() != 2) + { + printNodeValidationError("OpMaxPool2d: illegal size for attribute stride"); + return 1; + } + + return 0; +} + +template <DType Dtype> +int OpMaxPool2d<Dtype>::eval() +{ + int in_batch = this->in->getShape()[0]; + int in_height = this->in->getShape()[1]; + int in_width = this->in->getShape()[2]; + int in_channels = this->in->getShape()[3]; + + int out_batch = this->out->getShape()[0]; + int out_height = this->out->getShape()[1]; + int out_width = this->out->getShape()[2]; + int out_channels = this->out->getShape()[3]; + + ASSERT_MSG_NODE(in_batch == out_batch, "OpMaxPool2d: tensor batch mismatch %d != %d", in_batch, out_batch); + + int padding_top = this->attribute->padding()[0]; + int padding_bottom = this->attribute->padding()[1]; + int padding_left = this->attribute->padding()[2]; + int padding_right = this->attribute->padding()[3]; + int kernel_h = this->attribute->kernel()[0]; + int kernel_w = this->attribute->kernel()[1]; + int stride_h = this->attribute->stride()[0]; + int stride_w = this->attribute->stride()[1]; + + DEBUG_INFO(OP, + "perform MaxPool2d, input.shape=[%d,%d,%d,%d], output.shape=[%d,%d,%d,%d], kernel=[%d,%d], " + "stride=[%d,%d], padding=[%d,%d,%d,%d]", + in_batch, in_height, in_width, in_channels, out_batch, out_height, out_width, out_channels, kernel_h, + kernel_w, stride_h, stride_w, padding_top, padding_bottom, padding_left, padding_right); + + Eigen::array<Eigen::Index, 2> im2col_input_dims; + im2col_input_dims[0] = kernel_h * kernel_w; + im2col_input_dims[1] = out_batch * out_height * out_width * out_channels; + + Eigen::array<Eigen::Index, 4> col2im_output_dims; + col2im_output_dims[0] = out_batch; + col2im_output_dims[1] = out_height; + col2im_output_dims[2] = out_width; + col2im_output_dims[3] = out_channels; + + Eigen::array<std::pair<int32_t, int32_t>, 4> padding; + padding[0] = std::make_pair(0, 0); + padding[1] = std::make_pair(padding_top, padding_bottom); + padding[2] = std::make_pair(padding_left, padding_right); + padding[3] = std::make_pair(0, 0); + + ETensor4<InEigenType> input_padded = this->in->getTensor().pad(padding, std::numeric_limits<InEigenType>::lowest()); + + // extract_image_patches() output [N, KH, KW, H * W, C] + // transpose to [KH, KW, N, H * W, C] + // reshape to [KH * KW, N * H * W * C] + // + // Set the padding value to be the most negative value that can be + // represented by the datatype to ensure that any padding values will be equal + // to or smaller than the actual maximum in the KH x KW patch. + ETensor2<InEigenType> input_extract_patches = + input_padded + .extract_image_patches(kernel_h, kernel_w, stride_h, stride_w, 1, 1, Eigen::PADDING_VALID, + std::numeric_limits<InEigenType>::lowest()) + .shuffle(Eigen::array<Eigen::Index, 5>{ 1, 2, 0, 3, 4 }) + .reshape(im2col_input_dims); + + // Get the maximum of the KHxHW patches along axis 0 + Eigen::Tensor<DenseIndex, 1> tensor_argmax = input_extract_patches.argmax(0); + + // 1D result with [N * H * W * C] + ETensor1<OutEigenType> out_1d(this->out->getElementCount()); + + // index input_patches with argmax array should give the result + for (size_t i = 0; i < this->out->getElementCount(); i++) + { + out_1d(i) = (OutEigenType)input_extract_patches(tensor_argmax(i), i); + } + + // reshape result to [N, H, W, C] + this->out->getTensor() = out_1d.reshape(col2im_output_dims); + + return GraphNode::eval(); +} + +template <DType InDtype, DType OutDtype> +OpTransposeConv2d<InDtype, OutDtype>::OpTransposeConv2d(TosaAttributeBase* attribute_, + TosaQuantInfoBase* qinfo_, + uint64_t id_) + : GraphNode(Op_TRANSPOSE_CONV2D, id_) +{ + setRequiredOperands(3, 1); + setRequiredRank(4); + + INIT_ATTRIBUTE(TransposeConv2d); + INIT_QINFO(Conv); +} + +template <DType InDtype, DType OutDtype> +OpTransposeConv2d<InDtype, OutDtype>::~OpTransposeConv2d() +{ + if (attribute) + delete attribute; + if (qinfo) + delete qinfo; +} + +template <DType InDtype, DType OutDtype> +int OpTransposeConv2d<InDtype, OutDtype>::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(inputs[1]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + if (inputs[1]->getIsConst() == 0) + { + printNodeValidationError("OpTransposeConv2d: weight tensor is not const typed"); + } + + input = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + weight = dynamic_cast<TosaReference::TensorTemplate<TWeight>*>(inputs[1]); + bias = dynamic_cast<TosaReference::TensorTemplate<TBias>*>(inputs[2]); + output = dynamic_cast<TosaReference::TensorTemplate<TAcc>*>(outputs[0]); + + if (!input->hasFormat(Format_NHWC)) + { + printNodeValidationError("OpTransposeConv2d: unsupported input tensor format"); + return 1; + } + + if (!weight->hasFormat(Format_OHWI)) + { + printNodeValidationError("OpTransposeConv2d: unsupported weight tensor format"); + return 1; + } + + if (attribute->outpad().size() != 2) + { + printNodeValidationError("OpTransposeConv2d: illegal size for attribute outpad"); + return 1; + } + + if (attribute->stride().size() != 2) + { + printNodeValidationError("OpTransposeConv2d: illegal size for attribute stride"); + return 1; + } + + if (attribute->dilation().size() != 2) + { + printNodeValidationError("OpTransposeConv2d: illegal size for attribute dilation"); + return 1; + } + + if (attribute->output_shape().size() != 4) + { + printNodeValidationError("OpTransposeConv2d: illegal size for attribute output_shape"); + return 1; + } + + for (int d = 0; d < 4; d++) + { + if (attribute->output_shape()[d] != this->output->getShape()[d]) + { + printNodeValidationError("OpTransposeConv2d: illegal size for attribute output_shape"); + return 1; + } + } + + return 0; +} + +template <DType InDtype, DType OutDtype> +int OpTransposeConv2d<InDtype, OutDtype>::eval() +{ + int in_batch = this->input->getShape()[0]; + int in_height = this->input->getShape()[1]; + int in_width = this->input->getShape()[2]; + int in_channels = this->input->getShape()[3]; + + int f_out_channels = this->weight->getShape()[0]; + int f_height = this->weight->getShape()[1]; + int f_width = this->weight->getShape()[2]; + int f_in_channels = this->weight->getShape()[3]; + + int b_out_channels = this->bias->getShape()[0]; + + int out_batch = this->output->getShape()[0]; + int out_height = this->output->getShape()[1]; + int out_width = this->output->getShape()[2]; + int out_channels = this->output->getShape()[3]; + + int padding_top = this->attribute->outpad()[0]; + int padding_left = this->attribute->outpad()[1]; + int stride_h = this->attribute->stride()[0]; + int stride_w = this->attribute->stride()[1]; + int dilation_h = this->attribute->dilation()[0]; + int dilation_w = this->attribute->dilation()[1]; + + ASSERT_MSG_NODE(in_batch == out_batch, "OpTransposeConv2d: tensor batch mismatch %d != %d", in_batch, out_batch); + ASSERT_MSG_NODE(f_in_channels == in_channels, "OpTransposeConv2d: tensor input channel mismatch %d != %d", + f_in_channels, in_channels); + ASSERT_MSG_NODE(f_out_channels == out_channels, "OpTransposeConv2d: tensor output channel mismatch %d != %d", + f_out_channels, out_channels); + ASSERT_MSG_NODE(b_out_channels == out_channels, "OpDepthwiseConv2d: tensor b_out_channels mismatch %d != %d", + b_out_channels, out_channels); + + DEBUG_INFO(OP, + "perform OpTransposeConv2d, input.shape=[%d,%d,%d,%d], weight.shape=[%d,%d,%d,%d], " + "output.shape=[%d,%d,%d,%d], stride=[%d,%d], dilation=[%d,%d], padding=[%d,%d]", + in_batch, in_height, in_width, in_channels, f_height, f_width, f_out_channels, f_in_channels, out_batch, + out_height, out_width, out_channels, stride_h, stride_w, dilation_h, dilation_w, padding_top, + padding_left); + + TIn input_val = this->input->getTensor(); + TWeight weight_val = this->weight->getTensor(); + if (this->qinfo) + { + input_val = input_val - (InEigenType)this->qinfo->input_zp(); + weight_val = weight_val - (WeightEigenType)this->qinfo->weight_zp(); + } + + Eigen::array<Eigen::Index, 4> reshape_dim; + reshape_dim.fill(1); + reshape_dim[3] = b_out_channels; + + Eigen::array<Eigen::Index, 4> bcast; + bcast[0] = out_batch; + bcast[1] = out_height; + bcast[2] = out_width; + bcast[3] = 1; + + // initialize with bias + this->output->getTensor() = this->bias->getTensor().reshape(reshape_dim).broadcast(bcast); + + int out_x_origin, out_y_origin; + int out_x, out_y; + + // reference implementation from: tensorflow/tensorflow/lite/kernels/internal/reference/reference_ops.h + for (int ob = 0; ob < out_batch; ob++) + { + for (int ih = 0; ih < in_height; ih++) + { + for (int iw = 0; iw < in_width; iw++) + { + out_x_origin = iw * stride_w - padding_left; + out_y_origin = ih * stride_h - padding_top; + for (int ic = 0; ic < in_channels; ic++) + { + for (int fh = 0; fh < f_height; fh++) + { + for (int fw = 0; fw < f_width; fw++) + { + out_x = out_x_origin + fw * dilation_w; + out_y = out_y_origin + fh * dilation_h; + for (int oc = 0; oc < out_channels; oc++) + { + if ((out_x >= 0 && out_x < out_width) && (out_y >= 0 && out_y < out_height)) + { + this->output->getTensor()(ob, out_y, out_x, oc) += + ((AccEigenType)input_val(ob, ih, iw, ic) * + (AccEigenType)weight_val(oc, fh, fw, ic)); + } + } + } + } + } + } + } + } + + if (AccDtype == DType_INT48) + { + this->output->getTensor() = this->output->getTensor().cwiseMax((AccEigenType)AccQMin); + this->output->getTensor() = this->output->getTensor().cwiseMin((AccEigenType)AccQMax); + } + + return GraphNode::eval(); +} + +// template explicit instantiation +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, FLOAT); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, AINT8); +DEF_INSTANTIATE_RANK1_6_ONE_RANK_ONE_TYPE(OpArgMax, INT16); + +DEF_INSTANTIATE_ONE_TYPE(OpAvgPool2d, FLOAT) +DEF_INSTANTIATE_ONE_TYPE(OpAvgPool2d, AINT8) +DEF_INSTANTIATE_ONE_TYPE(OpAvgPool2d, INT16) + +DEF_INSTANTIATE_TWO_TYPE(OpConv2d, FLOAT, FLOAT); +DEF_INSTANTIATE_TWO_TYPE(OpConv2d, AINT8, INT4); +DEF_INSTANTIATE_TWO_TYPE(OpConv2d, AINT8, INT8); +DEF_INSTANTIATE_TWO_TYPE(OpConv2d, AINT8, AINT8); +DEF_INSTANTIATE_TWO_TYPE(OpConv2d, INT16, INT8); + +DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, FLOAT, FLOAT); +DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, AINT8, INT4); +DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, AINT8, INT8); +DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, AINT8, AINT8); +DEF_INSTANTIATE_TWO_TYPE(OpDepthwiseConv2d, INT16, INT8); + +DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, FLOAT, FLOAT); +DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, AINT8, INT4); +DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, AINT8, INT8); +DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, AINT8, AINT8); +DEF_INSTANTIATE_TWO_TYPE(OpFullyConnected, INT16, INT8); + +DEF_INSTANTIATE_ONE_TYPE(OpMatMul, AINT8); +DEF_INSTANTIATE_ONE_TYPE(OpMatMul, INT16); +DEF_INSTANTIATE_ONE_TYPE(OpMatMul, FLOAT); + +DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, FLOAT); +DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, AINT8); +DEF_INSTANTIATE_ONE_TYPE(OpMaxPool2d, INT16); + +DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, FLOAT, FLOAT); +DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, AINT8, INT4); +DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, AINT8, INT8); +DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, AINT8, AINT8); +DEF_INSTANTIATE_TWO_TYPE(OpTransposeConv2d, INT16, INT8); diff --git a/reference_model/src/ops/tensor_ops.h b/reference_model/src/ops/tensor_ops.h new file mode 100644 index 0000000..26ce84b --- /dev/null +++ b/reference_model/src/ops/tensor_ops.h @@ -0,0 +1,253 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPS_TENSOR_OPS_H +#define OPS_TENSOR_OPS_H + +#include "graph_node.h" +#include "quant_util.h" + +using namespace tosa; + +namespace TosaReference +{ + +template <int Rank, DType Dtype> +class OpArgMax : public GraphNode +{ +public: + OpArgMax(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpArgMax(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + using InEigenType = typename GetEigenType<Dtype>::type; + using OutEigenType = typename GetEigenType<DType_INT32>::type; + using TIn = Eigen::Tensor<InEigenType, Rank>; + using TOut = Eigen::Tensor<OutEigenType, Rank - 1>; + +protected: + TosaAxisAttribute* attribute; + TosaReference::TensorTemplate<TIn>* input; + TosaReference::TensorTemplate<TOut>* output; +}; + +template <DType Dtype> +class OpAvgPool2d : public GraphNode +{ +public: + OpAvgPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpAvgPool2d(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + static constexpr DType AccDtype = GetAccDType<Dtype, Dtype>::value; + using InEigenType = typename GetEigenType<Dtype>::type; + using AccEigenType = typename GetEigenType<AccDtype>::type; + using OutEigenType = typename GetEigenType<Dtype>::type; + using TIn = Eigen::Tensor<InEigenType, 4>; + using TOut = Eigen::Tensor<OutEigenType, 4>; + + static constexpr int64_t QMin = GetQMin<Dtype>::value; + static constexpr int64_t QMax = GetQMax<Dtype>::value; + +protected: + TosaReference::TensorTemplate<TIn>* in; + TosaReference::TensorTemplate<TOut>* out; + tosa::TosaPool2dAttribute* attribute; + tosa::TosaUnaryQuantInfo* qinfo; + +protected: + // return a 1D [N] tensor that describes a how many valid elements covered in the input space + ETensor1<int32_t> calculate_div_map_1d(int in_size, int out_size, int kernel_size, int stride); +}; + +template <DType InDtype, DType WeightDtype> +class OpConv2d : public GraphNode +{ +public: + OpConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpConv2d(); + + virtual int checkTensorAttributes() final; + virtual int eval() final; + + static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value; + + using InEigenType = typename GetEigenType<InDtype>::type; + using WeightEigenType = typename GetEigenType<WeightDtype>::type; + using AccEigenType = typename GetEigenType<AccDtype>::type; + using TIn = Eigen::Tensor<InEigenType, 4>; + using TWeight = Eigen::Tensor<WeightEigenType, 4>; + using TBias = Eigen::Tensor<AccEigenType, 1>; + using TAcc = Eigen::Tensor<AccEigenType, 4>; + + static constexpr int64_t AccQMin = GetQMin<AccDtype>::value; + static constexpr int64_t AccQMax = GetQMax<AccDtype>::value; + +protected: + TosaReference::TensorTemplate<TIn>* input; + TosaReference::TensorTemplate<TWeight>* weight; + TosaReference::TensorTemplate<TBias>* bias; + TosaReference::TensorTemplate<TAcc>* output; + tosa::TosaConv2dAttribute* attribute; + tosa::TosaConvQuantInfo* qinfo; +}; + +template <DType InDtype, DType WeightDtype> +class OpDepthwiseConv2d : public GraphNode +{ +public: + OpDepthwiseConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpDepthwiseConv2d(); + + virtual int checkTensorAttributes() final; + virtual int eval() final; + + static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value; + + using InEigenType = typename GetEigenType<InDtype>::type; + using WeightEigenType = typename GetEigenType<WeightDtype>::type; + using AccEigenType = typename GetEigenType<AccDtype>::type; + using TIn = Eigen::Tensor<InEigenType, 4>; + using TWeight = Eigen::Tensor<WeightEigenType, 4>; + using TBias = Eigen::Tensor<AccEigenType, 1>; + using TAcc = Eigen::Tensor<AccEigenType, 4>; + + static constexpr int64_t AccQMin = GetQMin<AccDtype>::value; + static constexpr int64_t AccQMax = GetQMax<AccDtype>::value; + +protected: + TosaReference::TensorTemplate<TIn>* input; + TosaReference::TensorTemplate<TWeight>* weight; + TosaReference::TensorTemplate<TBias>* bias; + TosaReference::TensorTemplate<TAcc>* output; + tosa::TosaConv2dAttribute* attribute; + tosa::TosaConvQuantInfo* qinfo; +}; + +template <DType InDtype, DType WeightDtype> +class OpFullyConnected : public GraphNode +{ +public: + OpFullyConnected(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpFullyConnected(); + + virtual int checkTensorAttributes() final; + virtual int eval() final; + + static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value; + using InEigenType = typename GetEigenType<InDtype>::type; + using WeightEigenType = typename GetEigenType<WeightDtype>::type; + using AccEigenType = typename GetEigenType<AccDtype>::type; + using TIn = Eigen::Tensor<InEigenType, 2>; + using TWeight = Eigen::Tensor<WeightEigenType, 2>; + using TBias = Eigen::Tensor<AccEigenType, 1>; + using TAcc = Eigen::Tensor<AccEigenType, 2>; + + static constexpr int64_t AccQMin = GetQMin<AccDtype>::value; + static constexpr int64_t AccQMax = GetQMax<AccDtype>::value; + +protected: + TosaReference::TensorTemplate<TIn>* input; + TosaReference::TensorTemplate<TWeight>* weight; + TosaReference::TensorTemplate<TBias>* bias; + TosaReference::TensorTemplate<TAcc>* output; + tosa::TosaConvQuantInfo* qinfo; +}; + +template <DType Dtype> +class OpMatMul : public GraphNode +{ +public: + OpMatMul(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpMatMul(); + + virtual int checkTensorAttributes() final; + virtual int eval() final; + + static constexpr DType AccDtype = GetAccDType<Dtype, Dtype>::value; + using InEigenType = typename GetEigenType<Dtype>::type; + using AccEigenType = typename GetEigenType<AccDtype>::type; + using TIn = Eigen::Tensor<InEigenType, 2>; + using TAcc = Eigen::Tensor<AccEigenType, 2>; + static constexpr int64_t AccQMin = GetQMin<AccDtype>::value; + static constexpr int64_t AccQMax = GetQMax<AccDtype>::value; + +protected: + TosaReference::TensorTemplate<TIn>* a; + TosaReference::TensorTemplate<TIn>* b; + TosaReference::TensorTemplate<TAcc>* c; + tosa::TosaMatMulQuantInfo* qinfo; +}; + +template <DType Dtype> +class OpMaxPool2d : public GraphNode +{ +public: + OpMaxPool2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpMaxPool2d(); + + virtual int checkTensorAttributes(); + virtual int eval(); + + using InEigenType = typename GetEigenType<Dtype>::type; + using OutEigenType = typename GetEigenType<Dtype>::type; + using TIn = Eigen::Tensor<InEigenType, 4>; + using TOut = Eigen::Tensor<OutEigenType, 4>; + +protected: + TosaReference::TensorTemplate<TIn>* in; + TosaReference::TensorTemplate<TOut>* out; + tosa::TosaPool2dAttribute* attribute; +}; + +template <DType InDtype, DType WeightDtype> +class OpTransposeConv2d : public GraphNode +{ +public: + OpTransposeConv2d(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpTransposeConv2d(); + + virtual int checkTensorAttributes() final; + virtual int eval() final; + + static constexpr DType AccDtype = GetAccDType<InDtype, WeightDtype>::value; + + using InEigenType = typename GetEigenType<InDtype>::type; + using WeightEigenType = typename GetEigenType<WeightDtype>::type; + using AccEigenType = typename GetEigenType<AccDtype>::type; + using TIn = Eigen::Tensor<InEigenType, 4>; + using TWeight = Eigen::Tensor<WeightEigenType, 4>; + using TBias = Eigen::Tensor<AccEigenType, 1>; + using TAcc = Eigen::Tensor<AccEigenType, 4>; + + static constexpr int64_t AccQMin = GetQMin<AccDtype>::value; + static constexpr int64_t AccQMax = GetQMax<AccDtype>::value; + +protected: + TosaReference::TensorTemplate<TIn>* input; + TosaReference::TensorTemplate<TWeight>* weight; + TosaReference::TensorTemplate<TBias>* bias; + TosaReference::TensorTemplate<TAcc>* output; + TosaTransposeConv2dAttribute* attribute; + TosaConvQuantInfo* qinfo; +}; + +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/ops/type_conversion.cc b/reference_model/src/ops/type_conversion.cc new file mode 100644 index 0000000..61a19f4 --- /dev/null +++ b/reference_model/src/ops/type_conversion.cc @@ -0,0 +1,299 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "type_conversion.h" +#include "quant_util.h" +#include "template_types.h" +#include <cmath> + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +template <int Rank, DType InDtype, DType OutDtype> +OpRescale<Rank, InDtype, OutDtype>::OpRescale(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_RESCALE, id_) +{ + setRequiredOperands(1, 1); + setRequiredRank(0, 6); + INIT_ATTRIBUTE(Rescale); +} + +template <int Rank, DType InDtype, DType OutDtype> +OpRescale<Rank, InDtype, OutDtype>::~OpRescale() +{ + if (attribute) + delete attribute; +} + +template <int Rank, DType InDtype, DType OutDtype> +int OpRescale<Rank, InDtype, OutDtype>::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + // output and input must be the same rank and size + if (inputs[0]->matchRankSize(*outputs[0])) + { + printNodeValidationError("OpRescale: input and output rank/size must match"); + return 1; + } + + in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); + + ASSERT_MEM(in && out); + + return 0; +} + +template <int Rank, DType InDtype, DType OutDtype> +int OpRescale<Rank, InDtype, OutDtype>::eval() +{ + int32_t input_zp = attribute->input_zp(); + int32_t output_zp = attribute->output_zp(); + std::vector<int32_t> multiplier = attribute->multiplier(); + std::vector<int32_t> shift = attribute->shift(); + //bool scale32 = attribute->scale32(); + bool double_round = attribute->double_round(); + bool per_channel = attribute->per_channel(); + + if (TosaReference::TypeChecker::is_symmetric(InDtype)) + { + if (input_zp != 0) + { + FATAL_ERROR_NODE("input tensor is symmetric type %s but zeropoint is %d instead of 0", + EnumNamesDType()[InDtype], input_zp); + } + } + + if (TosaReference::TypeChecker::is_symmetric(OutDtype)) + { + if (output_zp != 0) + { + FATAL_ERROR_NODE("output tensor is symmetric type %s but zeropoint is %d instead of 0", + EnumNamesDType()[OutDtype], output_zp); + } + } + + // reshape [d0, d1, ..., dn] into [d0 * d1 ..., dn] + Eigen::array<Eigen::Index, 2> shape_2d; + shape_2d[0] = 1; + if (Rank > 0) + { + for (int i = 0; i < Rank - 1; i++) + { + shape_2d[0] *= this->in->getShape()[i]; + } + shape_2d[1] = this->in->getShape()[Rank - 1]; + } + else + { + shape_2d[1] = 1; + } + ETensor2<InEigenType> input_reshaped = this->in->getTensor().reshape(shape_2d); + + ETensor2<OutEigenType> output_2d(shape_2d); + + // TODO: pass scale32 in when 16-bit mode implemented + if (per_channel) + { + ETensor2<InEigenType> curr_channel_slice_prescaled; + ETensor2<OutEigenType> curr_channel_slice_postscaled; + int32_t channel_multiplier, channel_shift; + Eigen::array<Eigen::Index, 2> begin, size; + size = Eigen::array<Eigen::Index, 2>({ shape_2d[0], 1 }); + for (int32_t i = 0; i < shape_2d[1]; i++) + { + begin = Eigen::array<Eigen::Index, 2>({ 0, i }); + curr_channel_slice_prescaled = input_reshaped.slice(begin, size); + channel_multiplier = multiplier[i]; + channel_shift = shift[i]; + curr_channel_slice_postscaled = + curr_channel_slice_prescaled.unaryExpr([input_zp, output_zp, channel_multiplier, channel_shift, + double_round](InEigenType in_val) -> OutEigenType { + InEigenType input_zp_shifted = in_val - (InEigenType)input_zp; + int32_t scaled = TosaReference::QuantUtil<InDtype>::apply_scale( + input_zp_shifted, channel_multiplier, channel_shift, double_round); + OutEigenType out_val = (OutEigenType)(scaled + output_zp); + out_val = std::max<OutEigenType>(out_val, QMin); + out_val = std::min<OutEigenType>(out_val, QMax); + return out_val; + }); + + for (int32_t j = 0; j < shape_2d[0]; j++) + { + output_2d(j, i) = curr_channel_slice_postscaled(j, 0); + } + } + } + else + { + int32_t tensor_multiplier = multiplier[0]; + int32_t tensor_shift = shift[0]; + output_2d = input_reshaped.unaryExpr( + [input_zp, output_zp, tensor_multiplier, tensor_shift, double_round](InEigenType in_val) -> OutEigenType { + InEigenType input_zp_shifted = in_val - (InEigenType)input_zp; + int32_t scaled = TosaReference::QuantUtil<InDtype>::apply_scale(input_zp_shifted, tensor_multiplier, + tensor_shift, double_round); + OutEigenType out_val = (OutEigenType)(scaled + output_zp); + out_val = std::max<OutEigenType>(out_val, QMin); + out_val = std::min<OutEigenType>(out_val, QMax); + return out_val; + }); + } + + // reshape [d0 * d1 ..., dn] back to [d0, d1, ..., dn] + Eigen::array<Eigen::Index, Rank> output_shape; + for (int i = 0; i < Rank; i++) + { + output_shape[i] = this->out->getShape()[i]; + } + this->out->getTensor() = output_2d.reshape(output_shape); + + return GraphNode::eval(); +} + +template <int Rank, DType InDtype, DType OutDtype> +OpCast<Rank, InDtype, OutDtype>::OpCast(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_) + : GraphNode(Op_CAST, id_) +{ + setRequiredOperands(1, 1); + setRequiredRank(0, 6); +} + +template <int Rank, DType InDtype, DType OutDtype> +OpCast<Rank, InDtype, OutDtype>::~OpCast() +{} + +template <int Rank, DType InDtype, DType OutDtype> +int OpCast<Rank, InDtype, OutDtype>::checkTensorAttributes() +{ + if (validateRequiredOperands()) + return 1; + + if (validateRequiredRank(inputs[0]) || validateRequiredRank(outputs[0])) + { + return 1; + } + + // output and input must be the same rank and size + if (inputs[0]->matchRankSize(*outputs[0])) + { + printNodeValidationError("OpCast: input and output rank/size must match"); + return 1; + } + + in = dynamic_cast<TosaReference::TensorTemplate<TIn>*>(inputs[0]); + out = dynamic_cast<TosaReference::TensorTemplate<TOut>*>(outputs[0]); + + ASSERT_MEM(in && out); + + return 0; +} + +template <int Rank, DType InDtype, DType OutDtype> +int OpCast<Rank, InDtype, OutDtype>::eval() +{ + this->out->getTensor() = this->in->getTensor().unaryExpr(cast_helper.get_fcn()); + + return GraphNode::eval(); +} + +template <DType InDtype, DType OutDtype> +CastHelper<InDtype, OutDtype>::CastHelper() +{ + fcn = [](InEigenType in) -> OutEigenType { + OutEigenType out = (OutEigenType)in; // implicit sign_extend() if sizeof(out_t) >= sizeof(in_t) + int64_t mask = (1L << OutBits) - 1; + out = out & mask; + return out; + }; +} + +template <DType InDtype> +CastHelper<InDtype, DType_BOOL>::CastHelper() +{ + fcn = [](InEigenType in) -> bool { return (in != 0) ? true : false; }; +} + +template <DType OutDtype> +CastHelper<DType_BOOL, OutDtype>::CastHelper() +{ + fcn = [](bool in) -> OutEigenType { + OutEigenType out = in ? (OutEigenType)1 : (OutEigenType)0; + return out; + }; +} + +template <DType InDtype> +CastHelper<InDtype, DType_FLOAT>::CastHelper() +{ + fcn = [](InEigenType in) -> float { + float out = (OutEigenType)in; // default cast to float is round_to_nearest_float() + return out; + }; +} + +template <DType OutDtype> +CastHelper<DType_FLOAT, OutDtype>::CastHelper() +{ + fcn = [](float in) -> OutEigenType { + OutEigenType out = std::round(in); + out = std::max<OutEigenType>(out, OutMin); + out = std::min<OutEigenType>(out, OutMax); + return out; + }; +} + +// template explicit instantiation +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, BOOL, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT8, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT16, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, BOOL); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, INT32, FLOAT); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpCast, FLOAT, INT32); + +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT16, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT32, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT16); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, INT48, INT32); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, UINT8, AINT8); +DEF_INSTANTIATE_RANK0_6_ONE_RANK_TWO_TYPE(OpRescale, AINT8, UINT8); diff --git a/reference_model/src/ops/type_conversion.h b/reference_model/src/ops/type_conversion.h new file mode 100644 index 0000000..6ec4d6d --- /dev/null +++ b/reference_model/src/ops/type_conversion.h @@ -0,0 +1,162 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef OPS_TYPE_CONVERSION_H +#define OPS_TYPE_CONVERSION_H + +#include "graph_node.h" + +using namespace tosa; + +namespace TosaReference +{ +template <int Rank, DType InDtype, DType OutDtype> +class OpRescale : public GraphNode +{ +public: + OpRescale(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpRescale(); + + virtual int checkTensorAttributes() final; + virtual int eval() final; + + using InEigenType = typename GetEigenType<InDtype>::type; + using OutEigenType = typename GetEigenType<OutDtype>::type; + using TIn = Eigen::Tensor<InEigenType, Rank>; + using TOut = Eigen::Tensor<OutEigenType, Rank>; + + static constexpr int32_t QMin = GetQMin<OutDtype>::value; + static constexpr int32_t QMax = GetQMax<OutDtype>::value; + +protected: + TosaRescaleAttribute* attribute; + TosaReference::TensorTemplate<TIn>* in; + TosaReference::TensorTemplate<TOut>* out; +}; + +template <DType InDtype, DType OutDtype> +class CastHelper +{ +public: + using InEigenType = typename GetEigenType<InDtype>::type; + using OutEigenType = typename GetEigenType<OutDtype>::type; + using FcnType = std::function<OutEigenType(InEigenType)>; + static constexpr int32_t OutBits = GetNumBits<OutDtype>::value; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <DType InDtype> +class CastHelper<InDtype, DType_BOOL> +{ +public: + using InEigenType = typename GetEigenType<InDtype>::type; + using OutEigenType = typename GetEigenType<DType_BOOL>::type; + using FcnType = std::function<OutEigenType(InEigenType)>; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <DType OutDtype> +class CastHelper<DType_BOOL, OutDtype> +{ +public: + using InEigenType = typename GetEigenType<DType_BOOL>::type; + using OutEigenType = typename GetEigenType<OutDtype>::type; + using FcnType = std::function<OutEigenType(InEigenType)>; + static constexpr int32_t OutMin = GetQMin<OutDtype>::value; + static constexpr int32_t OutMax = GetQMax<OutDtype>::value; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <DType InDtype> +class CastHelper<InDtype, DType_FLOAT> +{ +public: + using InEigenType = typename GetEigenType<InDtype>::type; + using OutEigenType = typename GetEigenType<DType_FLOAT>::type; + using FcnType = std::function<OutEigenType(InEigenType)>; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <DType OutDtype> +class CastHelper<DType_FLOAT, OutDtype> +{ +public: + using InEigenType = typename GetEigenType<DType_FLOAT>::type; + using OutEigenType = typename GetEigenType<OutDtype>::type; + using FcnType = std::function<OutEigenType(InEigenType)>; + static constexpr int32_t OutMin = GetQMin<OutDtype>::value; + static constexpr int32_t OutMax = GetQMax<OutDtype>::value; + CastHelper(); + const FcnType& get_fcn() const + { + return fcn; + } + +private: + FcnType fcn; +}; + +template <int Rank, DType InDtype, DType OutDtype> +class OpCast : public GraphNode +{ +public: + OpCast(TosaAttributeBase* attribute_, TosaQuantInfoBase* qinfo_, uint64_t id_); + virtual ~OpCast(); + + virtual int checkTensorAttributes() final; + virtual int eval() final; + + using InEigenType = typename GetEigenType<InDtype>::type; + using OutEigenType = typename GetEigenType<OutDtype>::type; + using TIn = Eigen::Tensor<InEigenType, Rank>; + using TOut = Eigen::Tensor<OutEigenType, Rank>; + +protected: + CastHelper<InDtype, OutDtype> cast_helper; + TosaReference::TensorTemplate<TIn>* in; + TosaReference::TensorTemplate<TOut>* out; +}; + +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/quant_util.h b/reference_model/src/quant_util.h new file mode 100644 index 0000000..3638b3b --- /dev/null +++ b/reference_model/src/quant_util.h @@ -0,0 +1,103 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TOSA_REFERENCE_QUANT_UTIL_H +#define TOSA_REFERENCE_QUANT_UTIL_H + +#include "arith_util.h" +#include "func_debug.h" +#include "ops/template_types.h" +#include "tosa_generated.h" + +using namespace tosa; + +namespace TosaReference +{ + +template <DType AccDType> +class QuantUtil +{ +public: + using T = typename GetEigenType<AccDType>::type; + + static void reciprocal_scale(int32_t value, + // Output + int32_t& multiplier, + int32_t& shift) + { + ASSERT_MSG(value > 0, "AvgPool2d reciprocal_scale() error: # of elements should be > 1 but is %d", value); + uint32_t value_u32 = (uint32_t)value; + int32_t k = 32 - LEADING_ZEROS_32(value_u32 - 1); // (1<<k)/2 < value <= (1<<k) + int64_t numerator = ((1L << 30) + 1) << k; + multiplier = numerator / value; // (1<<30) <= multiplier < (1<<31) + shift = 30 + k; + } + + static int32_t apply_scale(T value, int32_t multiplier, int32_t shift, bool enabled_adjusted_rounding = true) + { + if (AccDType == DType_FLOAT) + { + return value; + } + ASSERT_MSG(multiplier >= 0, "apply_scale() error: multiplier should >= 0 but is %d", multiplier); + int64_t round = (shift > 0) ? (1L << (shift - 1)) : 0; + if (enabled_adjusted_rounding) + { + if (AccDType != DType_INT48) + { + if (shift > 31 && value >= 0) + round += (1L << 30); + if (shift > 31 && value < 0) + round -= (1L << 30); + } + else + { // input data could be int16, which leads to 48 bits accumulator + ASSERT_MSG(multiplier < (1 << 15), "apply_scale() error: multiplier should <= %d in 48 bit mode", + (1 << 15)); + } + } + int64_t result = (int64_t)value * multiplier + round; + result = result >> shift; + ASSERT_MSG(result >= -(1L << 31) && result < (1L << 31), + "apply_scale() error: scaled result exceed int32 numeric range"); + return static_cast<int32_t>(result); + } +}; + +class TypeChecker +{ +public: + static bool is_integer(DType dtype) + { + if (dtype == DType_INT4 || dtype == DType_INT8 || dtype == DType_AINT8 || dtype == DType_UINT8 || + dtype == DType_INT16 || dtype == DType_INT32 || dtype == DType_INT48) + { + return true; + } + return false; + } + static bool is_symmetric(DType dtype) + { + if (dtype == DType_INT4 || dtype == DType_INT8 || dtype == DType_INT16 || dtype == DType_INT32 || + dtype == DType_INT48) + { + return true; + } + return false; + } +}; +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/subgraph_traverser.cc b/reference_model/src/subgraph_traverser.cc new file mode 100644 index 0000000..789bcae --- /dev/null +++ b/reference_model/src/subgraph_traverser.cc @@ -0,0 +1,649 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "subgraph_traverser.h" + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +SubgraphTraverser::SubgraphTraverser(TosaSerializationBasicBlock* _block, TosaSerializationHandler* _tsh) +{ + block = _block; + tsh = _tsh; + + tensors.clear(); + nodes.clear(); + nextNodeList.clear(); +} + +SubgraphTraverser::~SubgraphTraverser() +{ + nextNodeList.clear(); + + for (GraphNode* n : nodes) + { + delete n; + } + nodes.clear(); + + for (TosaReference::Tensor* t : tensors) + { + if (t->is_allocated()) + { + t->deallocate(); + } + delete t; + } + tensors.clear(); +} + +int SubgraphTraverser::getNumInputTensors() const +{ + return inputTensors.size(); +} + +TosaReference::Tensor* SubgraphTraverser::getInputTensor(const unsigned int idx) const +{ + return inputTensors[idx]; +} + +TosaReference::Tensor* SubgraphTraverser::getInputTensorByName(const std::string name) const +{ + for (auto t : inputTensors) + { + if (t->getName() == name) + { + return t; + } + } + + return nullptr; +} + +int SubgraphTraverser::getNumOutputTensors() const +{ + return outputTensors.size(); +} + +TosaReference::Tensor* SubgraphTraverser::getOutputTensor(const unsigned int idx) const +{ + return outputTensors[idx]; +} + +TosaReference::Tensor* SubgraphTraverser::getOutputTensorByName(const std::string name) const +{ + for (auto t : outputTensors) + { + if (t->getName() == name) + { + return t; + } + } + + return nullptr; +} + +int SubgraphTraverser::initializeGraph() +{ + char tensor_fullname[1000]; + int idx = 0; + for (auto op : block->GetOperators()) + { + // translated TosaSerializationOperator to GraphNode + DType in_dtype = DType_UNKNOWN, out_dtype = DType_UNKNOWN, weight_dtype = DType_UNKNOWN; + uint32_t in_rank = 0, out_rank = 0, weight_rank = 0; + for (auto name : op->GetInputTensorNames()) + { + + TosaSerializationTensor* ts = block->GetTensorByName(name); + ASSERT_MSG(ts, "SubgraphTraverser: fail to get tensor %s from TosaSerializationHandler", name.c_str()); + + if (ts->HasUsage(Usage_WEIGHT)) + { + weight_dtype = ts->GetDtype(); + weight_rank = ts->GetShape().size(); + } + else if (ts->HasUsage(Usage_INDEX)) + { + // do nothing, but this will prevent tensor's dtype/rank being wrongly used as template argument when initializing this op + } + else if (ts->HasUsage(Usage_ACTIVATION)) + { + if (ts->GetShape().size() >= in_rank) + { + in_dtype = ts->GetDtype(); + in_rank = ts->GetShape().size(); + } + } + } + + for (auto name : op->GetOutputTensorNames()) + { + + TosaSerializationTensor* ts = block->GetTensorByName(name); + ASSERT_MSG(ts, "SubgraphTraverser: fail to get tensor %s from TosaSerializationHandler", name.c_str()); + + out_dtype = ts->GetDtype(); + out_rank = ts->GetShape().size(); + } + + DEBUG_INFO(GT, "Creating operator id_%03u, %8s, %lu input tensors, %lu output tensors", idx, + EnumNamesOp()[op->GetOp()], op->GetInputTensorNames().size(), op->GetOutputTensorNames().size()); + + GraphNode* cn = OpFactory::newOp(tsh, op->GetOp(), op->GetAttribute(), op->GetQInfo(), idx, in_dtype, in_rank, + out_dtype, out_rank, weight_dtype, weight_rank); + if (!cn) + { + if (weight_dtype == DType_UNKNOWN && weight_rank == 0) + { + fprintf(g_func_debug.func_debug_file, + "OpFactory could not allocate op %8s input=(%s rank %d) -> (%s rank %d)", + EnumNamesOp()[op->GetOp()], EnumNamesDType()[in_dtype], in_rank, EnumNamesDType()[out_dtype], + out_rank); + } + else + { + fprintf(g_func_debug.func_debug_file, + "OpFactory could not allocate op %8s input=(%s rank %d), weight=(%s rank %d) -> (%s rank %d)", + EnumNamesOp()[op->GetOp()], EnumNamesDType()[in_dtype], in_rank, EnumNamesDType()[weight_dtype], + weight_rank, EnumNamesDType()[out_dtype], out_rank); + } + + for (auto ts : op->GetInputTensors()) + { + fprintf(g_func_debug.func_debug_file, "Input: %s\n", ts->GetName().c_str()); + } + + for (auto ts : op->GetOutputTensors()) + { + fprintf(g_func_debug.func_debug_file, "Output: %s\n", ts->GetName().c_str()); + } + FATAL_ERROR("Unsupported operation type or rank."); + } + + for (auto name : op->GetInputTensorNames()) + { + cn->addInputName(name); + } + + for (auto name : op->GetOutputTensorNames()) + { + cn->addOutputName(name); + } + + addNode(cn); + + // if node doesn't have any inputs (i.e. CONST) + // it should be ready for evaluation + if (op->GetInputTensorNames().empty() && !cn->getOnNextNodeList()) + { + addToNextNodeList(cn); + } + + idx++; + } + + for (auto ts : block->GetTensors()) + { + + bool is_const = false; + if (ts->HasUsage(Usage_WEIGHT)) + { + is_const = true; + } + + DEBUG_INFO(GT, "Creating tensor %s", ts->GetName().c_str()); + TosaReference::Tensor* ct = + TensorFactory::newTensor(ts->GetName(), ts->GetDtype(), ts->GetUsage(), ts->GetFormat(), ts->GetShape(), + is_const, ts->GetShape().size()); + + if (ts->GetNpyFilePtr()) + { + if (ct->allocate()) + { + FATAL_ERROR("Fail to allocate Eigen tensor %s", ct->getName().c_str()); + } + + bzero(tensor_fullname, sizeof(tensor_fullname)); + snprintf(tensor_fullname, sizeof(tensor_fullname), "%s/%s", g_func_config.subgraph_dir, + ts->GetNpyFilePtr()->c_str()); + if (ct->readFromNpyFile(tensor_fullname)) + { + FATAL_ERROR("Cannot read input data into graph tensor %s from block %s", ct->getName().c_str(), + block->GetName().c_str()); + } + } + + // update this->tensors + addTensor(ct); + } + + DEBUG_INFO(GT, "Enumerating block %s graph inputs", block->GetName().c_str()); + for (auto& input_name : block->GetInputs()) + { + TosaReference::Tensor* ct = findTensorByName(input_name); + DEBUG_INFO(GT, "input tensor name=%s", input_name.c_str()); + if (ct) + { + ct->setIsSubgraphInput(); + inputTensors.push_back(ct); + } + else + { + FATAL_ERROR("loadGraphJson: Fail to find input tensor by name %s", input_name.c_str()); + } + } + + DEBUG_INFO(GT, "Enumerating block %s graph outputs", block->GetName().c_str()); + for (auto& output_name : block->GetOutputs()) + { + TosaReference::Tensor* ct = findTensorByName(output_name); + DEBUG_INFO(GT, "output tensor name=%s\n", output_name.c_str()); + if (ct) + { + ct->setIsSubgraphOutput(); + outputTensors.push_back(ct); + } + else + { + FATAL_ERROR("loadGraphJson: Fail to find output tensor by name %s", output_name.c_str()); + } + } + + if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT)) + { + dumpNextNodeList(g_func_debug.func_debug_file); + } + + return 0; +} + +int SubgraphTraverser::isFullyEvaluated() const +{ + return nextNodeList.empty(); +} + +GraphNode* SubgraphTraverser::getNextNode() +{ + GraphNode* nextNode = nextNodeList.front(); + ASSERT_MSG(nextNode, "SubgraphTraverser::getNextNode(): called with empty next node list"); + ASSERT_MSG(nextNode->getOnNextNodeList(), + "SubgraphTraverser::getNextNode(): internal state error: node is not listed as being on next node list"); + + nextNodeList.pop_front(); + + nextNode->clearOnNextNodeList(); + return nextNode; +} + +int SubgraphTraverser::addToNextNodeList(GraphNode* nextNode) +{ + ASSERT_MSG(nextNode, "SubgraphTraverser::addToNextNodeList(): called with no node"); + ASSERT_MSG(!nextNode->getOnNextNodeList(), + "SubgraphTraverser::addToNextNodeList(): internal state error: node is already on next node list"); + + nextNode->setOnNextNodeList(); + nextNodeList.push_back(nextNode); + + return 0; +} + +int SubgraphTraverser::evaluateNextNode() +{ + if (isFullyEvaluated()) + return 0; + + GraphNode* currNode = getNextNode(); + + DEBUG_INFO(GT, "Evaluating node_%03lu, %8s, output tensor=%s", currNode->getID(), EnumNamesOp()[currNode->getOp()], + currNode->getOutputNames()[0].c_str()); + + // Sanity check for never-ending loops + if (currNode->getEvalCount() >= MAX_EVAL_COUNT && (currNode->getEvalCount() % MAX_EVAL_COUNT) == 0) + { + WARNING("Node %lu has been evaluated %d times. Loop suspected.", currNode->getID(), currNode->getEvalCount()); + } + + for (auto ct : currNode->getOutputs()) + { + if (!ct->is_allocated()) + if (ct->allocate()) + { + FATAL_ERROR("Fail to allocate Eigen tensor %s", ct->getName().c_str()); + } + } + + if (currNode->eval()) + { + FATAL_ERROR("Error evaluating node: %lu\n", currNode->getID()); + } + + // free input tensor if all of its consumers have all of their outputs ready and it's not block's output + for (auto ct : currNode->getInputs()) + { + bool in_use = false; + for (auto cn : ct->getConsumers()) + { + if (!cn->hasAllOutputsReady()) + { + in_use = true; + } + } + for (auto name : block->GetOutputs()) + { + if (name == ct->getName()) + { + in_use = true; + } + } + if (!in_use) + { + ct->deallocate(); + } + } + + // Search the output tensors of this node to see if + // there are now new ready nodes available from completing this node + for (TosaReference::Tensor* tensor : currNode->getOutputs()) + { + for (GraphNode* node : tensor->getConsumers()) + { + if (!node->getOnNextNodeList() && node->hasAllInputsReady()) + { + addToNextNodeList(node); + } + } + } + + if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT)) + { + dumpNextNodeList(g_func_debug.func_debug_file); + } + + if (g_func_config.dump_intermediates) + { + currNode->dumpNode(g_func_debug.func_debug_file); + for (auto outs : currNode->getOutputs()) + { + outs->dumpTensorParams(g_func_debug.func_debug_file); + outs->dumpTensor(g_func_debug.func_debug_file); + fprintf(g_func_debug.func_debug_file, "\n"); + } + } + + return 0; +} + +int SubgraphTraverser::dumpNextNodeList(FILE* out) const +{ + + // Dump next node list + fprintf(out, "Next node list\n"); + + if (nextNodeList.empty()) + { + fprintf(out, "<empty>\n"); + } + + for (auto gn : nextNodeList) + { + gn->dumpNode(out); + } + + fprintf(out, "Done.\n"); + return 0; +} + +int SubgraphTraverser::clearAllNodeMarkings() +{ + for (GraphNode* currNode : nodes) + { + currNode->clearNodeMarked(); + } + + return false; +} + +int SubgraphTraverser::addTensor(TosaReference::Tensor* ct) +{ + // Enforce no duplicate tensors/tensor names + // O(N), but the number of tensors is small + for (TosaReference::Tensor* currTensor : tensors) + { + if (ct == currTensor || currTensor->getName() == ct->getName()) + { + FATAL_ERROR("Error: Duplicate tensor or tensor name being added to graph: %s\n", ct->getName().c_str()); + return 1; + } + } + + tensors.push_back(ct); + + if (ct->getIsSubgraphInput()) + { + inputTensors.push_back(ct); + } + + if (ct->getIsSubgraphOutput()) + { + outputTensors.push_back(ct); + } + + return 0; +} +int SubgraphTraverser::addNode(GraphNode* newNode) +{ + // Enforce no duplicate nodes + for (GraphNode* currNode : nodes) + { + if (currNode == newNode) + { + FATAL_ERROR("Error: duplicate node being added to graph"); + return 1; + } + } + + nodes.push_back(newNode); + + return 0; +} + +TosaReference::Tensor* SubgraphTraverser::findTensorByName(const std::string& name) const +{ + for (TosaReference::Tensor* currTensor : tensors) + { + if (currTensor->getName() == name) + { + return currTensor; + } + } + + WARNING("Unable to find tensor with name: %s\n", name.c_str()); + + return nullptr; +} + +int SubgraphTraverser::linkTensorsAndNodes() +{ + // Nodes have a list of input/output tensor names + // For each node, read this list, link up the tensors with their inputs/outputs + for (GraphNode* currNode : nodes) + { + + // Link inputs/consuming nodes + for (std::string& name : currNode->getInputNames()) + { + TosaReference::Tensor* t = findTensorByName(name); + if (!t) + { + FATAL_ERROR("linkTensorsAndNodes: Cannot find tensor %s in node %lu\n", name.c_str(), + currNode->getID()); + return 1; + } + + if (currNode->addInputTensor(t)) + { + FATAL_ERROR("linkTensorsAndNodes: cannot link tensor %s to node %lu\n", name.c_str(), + currNode->getID()); + return 1; + } + + if (t->addConsumer(currNode)) + { + FATAL_ERROR("linkTensorsAndNodes: cannot link consumer node %lu to tensor %s\n", currNode->getID(), + name.c_str()); + return 1; + } + } + + // Link outputs/producing nodes + for (std::string& name : currNode->getOutputNames()) + { + TosaReference::Tensor* t = findTensorByName(name); + if (!t) + { + FATAL_ERROR("linkTensorsAndNodes: Cannot find tensor %s in node %lu\n", name.c_str(), + currNode->getID()); + return 1; + } + + if (currNode->addOutputTensor(t)) + { + FATAL_ERROR("linkTensorsAndNodes: cannot link tensor %s to node %lu\n", name.c_str(), + currNode->getID()); + return 1; + } + + if (t->setProducer(currNode)) + { + FATAL_ERROR("linkTensorsAndNodes: cannot link producer node %lu to tensor tensor %s\n", + currNode->getID(), name.c_str()); + return 1; + } + } + } + + return 0; +} + +int SubgraphTraverser::validateGraph() +{ + // Need to make sure that: + // - each tensor is actually used + // - input and output tesnsors truly are just input and just output + // Graph building already determined that each node has found its input/output tensors + + for (TosaReference::Tensor* currTensor : tensors) + { + + if (!currTensor->getProducer() && currTensor->getConsumers().empty()) + { + WARNING("Graph inconsistency: TosaReference::Tensor %s has no producers or consumers\n", + currTensor->getName().c_str()); + return 1; + } + + if (currTensor->getIsSubgraphInput()) + { + if (currTensor->getProducer() && currTensor->getProducer()->getOp() != Op_PLACEHOLDER) + { + WARNING("Graph inconsistency: TosaReference::Tensor %s is a subgraph input and has a producer\n", + currTensor->getName().c_str()); + return 1; + } + } + + // comment this check out as this is possible when graph have multiple output + // for example: + // %0 = add(%arg0, %arg1) + // %1 = mul(%arg0, %0) + // yields(%0, %1) + //if (currTensor->getIsSubgraphOutput()) { + // if (!currTensor->getConsumers().empty()) { + // WARNING ("Graph inconsistency: TosaReference::Tensor %s is a subgraph output and has a consumer\n", + // currTensor->getName().c_str()); + // return 1; + // } + //} + + if (g_func_config.tosa_profile == 0) + { + DType dtype = currTensor->getDtype(); + + // Float-point disallowed + if (dtype == DType_FLOAT) + { + WARNING("TOSA Base Inference profile selected: All floating point disabled, but %s tensor %s found\n", + EnumNamesDType()[dtype], currTensor->getName().c_str()); + return 1; + } + } + else if (g_func_config.tosa_profile == 1 || g_func_config.tosa_profile == 2) + { + // Do nothing. All FP types allowed + // Currently no implementation difference between Main Inference and Main Training modes + } + else + { + FATAL_ERROR("TOSA profile not recognized: %d", g_func_config.tosa_profile); + } + } + + for (GraphNode* currNode : nodes) + { + if (currNode->checkTensorAttributes()) + { + WARNING("TosaReference::Tensor attribute check failed"); + return 1; + } + } + + if (outputTensors.size() <= 0) + { + DEBUG_MED(GT, "Graph output tensor empty"); + return 0; + } + + return 0; +} + +int SubgraphTraverser::dumpGraph(FILE* out) const +{ + int i = 0; + + fprintf(out, "Full graph dump:\n"); + for (GraphNode* currNode : nodes) + { + fprintf(out, "Node [%d]: ", i++); + currNode->dumpNode(out); + } + + return 0; +} + +int SubgraphTraverser::evaluateAll() +{ + // evaluation loop + while (!isFullyEvaluated()) + { + if (evaluateNextNode()) + { + return 1; + } + } + + return 0; +} diff --git a/reference_model/src/subgraph_traverser.h b/reference_model/src/subgraph_traverser.h new file mode 100644 index 0000000..3f4eecf --- /dev/null +++ b/reference_model/src/subgraph_traverser.h @@ -0,0 +1,90 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef SUBGRAPH_TRAVERSER_H +#define SUBGRAPH_TRAVERSER_H + +#include "model_common.h" + +#include "graph_node.h" +#include "ops/op_factory.h" +#include "tosa_serialization_handler.h" + +namespace TosaReference +{ + +class SubgraphTraverser +{ +public: + SubgraphTraverser(TosaSerializationBasicBlock* block, TosaSerializationHandler* tsh); + ~SubgraphTraverser(); + + int initializeGraph(); + int isFullyEvaluated() const; + int evaluateNextNode(); + int evaluateAll(); + + int linkTensorsAndNodes(); + int validateGraph(); + + int dumpGraph(FILE* out) const; + int dumpNextNodeList(FILE* out) const; + int clearAllNodeMarkings(); + + int getNumInputTensors() const; + Tensor* getInputTensor(const unsigned int idx) const; + Tensor* getInputTensorByName(const std::string name) const; + int getNumOutputTensors() const; + Tensor* getOutputTensor(const unsigned int idx) const; + Tensor* getOutputTensorByName(const std::string name) const; + int addToNextNodeList(GraphNode*); + +private: + int addTensor(Tensor* ct); + int addNode(GraphNode* cn); + + Tensor* findTensorByName(const std::string& name) const; + + GraphNode* getNextNode(); + + // pointer to serialization library and corresponding basic block + TosaSerializationBasicBlock* block; + TosaSerializationHandler* tsh; + + // The definitive list of all tensors + std::vector<Tensor*> tensors; + + // The subset of tensors that are also input tensors + std::vector<Tensor*> inputTensors; + + // The subset of tensors that are also output tensors + std::vector<Tensor*> outputTensors; + + // The definitive list of all nodes in the graph + std::vector<GraphNode*> nodes; + + // The subset of node that have all of their input tensors ready, but + // have not yet been evaluated to produce their output tensors. + // With control flow, a node may appear on this list more than once during its + // lifetime, although the list itself should only contain unique nodes. + std::list<GraphNode*> nextNodeList; + + // Maximum number of times to evalute a node before + // warning. + const int MAX_EVAL_COUNT = 10000; +}; +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc new file mode 100644 index 0000000..179484e --- /dev/null +++ b/reference_model/src/tensor.cc @@ -0,0 +1,3008 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tensor.h" +#include "arith_util.h" + +using namespace TosaReference; +using namespace Eigen; +using namespace tosa; + +TosaReference::Tensor::Tensor(std::string tensorName_, + DType tensorDtype_, + const std::vector<Usage>& tensorUsage_, + const std::vector<Format>& tensorFormat_, + std::vector<int> shape_, + int isConst_) +{ + tensorName = std::string(tensorName_); + tensorDtype = tensorDtype_; + tensorUsage = std::vector<Usage>(tensorUsage_); + tensorFormat = std::vector<Format>(tensorFormat_); + shape = std::vector<int>(shape_); + isConst = isConst_; + producer = nullptr; + isValid = false; + consumers.clear(); + isSubgraphInput = false; + isSubgraphOutput = false; +} + +TosaReference::Tensor::~Tensor() +{} + +int TosaReference::Tensor::setIsSubgraphInput() +{ + isSubgraphInput = true; + return 0; +} + +int TosaReference::Tensor::setIsSubgraphOutput() +{ + isSubgraphOutput = true; + return 0; +} + +int TosaReference::Tensor::setProducer(GraphNode* node) +{ + ASSERT_MSG(node, "Tensor::setProducer: no node passed in"); + ASSERT_MSG(!producer, "Tensor::setProducer: producer node already set, tensor %s", tensorName.c_str()); + producer = node; + + return 0; +} + +int TosaReference::Tensor::addConsumer(GraphNode* node) +{ + ASSERT_MSG(node, "Tensor::addConsumer: no node passed in"); + consumers.push_back(node); + + return 0; +} + +int TosaReference::Tensor::dumpTensorParams(FILE* out) const +{ + fprintf(out, "Name: %s DType=%s Usage=%s isValid=%d Rank=%d Shape=%s\n", tensorName.c_str(), + EnumNamesDType()[getDtype()], getUsageAsString().c_str(), getIsValid(), getRank(), + getShapeAsString().c_str()); + + return 0; +} + +int TosaReference::Tensor::dumpTensorParams(std::ostream& out) const +{ + out << "Name: " << getName() << " DType=" << EnumNamesDType()[getDtype()] << " Usage=" << getUsageAsString() + << " isValid=" << getIsValid() << " Rank=" << getRank() << " Shape=" << getShapeAsString() << "\n"; + + return 0; +} + +int TosaReference::Tensor::readFromNpyFile(const char* filename) +{ + uint32_t elements = getElementCount(); + float* fdatabuf = nullptr; + int32_t* i32databuf = nullptr; + int64_t* i64databuf = nullptr; + bool* bdatabuf = nullptr; + NumpyUtilities::NPError nperror; + + switch (getDtype()) + { + case DType_FLOAT: + fdatabuf = (float*)calloc(sizeof(float), elements); + ASSERT_MEM(fdatabuf); + + nperror = NumpyUtilities::readFromNpyFile(filename, elements, fdatabuf); + break; + case DType_INT32: + case DType_AINT8: + case DType_UINT8: + case DType_INT4: + case DType_INT8: + case DType_INT16: + i32databuf = (int32_t*)calloc(sizeof(int32_t), elements); + ASSERT_MEM(i32databuf); + + nperror = NumpyUtilities::readFromNpyFile(filename, elements, i32databuf); + break; + case DType_INT48: + i64databuf = (int64_t*)calloc(sizeof(int64_t), elements); + ASSERT_MEM(i64databuf); + + nperror = NumpyUtilities::readFromNpyFile(filename, elements, i64databuf); + break; + case DType_BOOL: + bdatabuf = (bool*)calloc(sizeof(bool), elements); + ASSERT_MEM(bdatabuf); + + nperror = NumpyUtilities::readFromNpyFile(filename, elements, bdatabuf); + break; + default: + FATAL_ERROR("unsupported tensor type=%s", EnumNamesDType()[getDtype()]); + } + + switch (nperror) + { + case NumpyUtilities::NO_ERROR: + break; + case NumpyUtilities::FILE_NOT_FOUND: + FATAL_ERROR("readFromNpyFile: Cannot open file %s", filename); + case NumpyUtilities::FILE_IO_ERROR: + FATAL_ERROR("readFromNpyFile: IO error reading file: %s", filename); + case NumpyUtilities::FILE_TYPE_MISMATCH: + FATAL_ERROR("readFromNpyFile: Tensor type %s and Numpy file type mismatch for tensor %s filename %s", + EnumNamesDType()[getDtype()], getName().c_str(), filename); + case NumpyUtilities::HEADER_PARSE_ERROR: + FATAL_ERROR("Numpy header parsing error for file: %s", filename); + case NumpyUtilities::BUFFER_SIZE_MISMATCH: + FATAL_ERROR("Buffer size does not match numpy file size for tensor %s filename %s", getName().c_str(), + filename); + default: + FATAL_ERROR("Unknown error parsing Numpy file: %s", filename); + } + + switch (getDtype()) + { + case DType_FLOAT: + if (setTensorValueFloat(elements, fdatabuf)) + { + free(fdatabuf); + return 1; + } + break; + case DType_INT32: + case DType_AINT8: + case DType_UINT8: + case DType_INT4: + case DType_INT8: + case DType_INT16: + if (setTensorValueInt32(elements, i32databuf)) + { + free(i32databuf); + return 1; + } + break; + case DType_INT48: + if (setTensorValueInt64(elements, i64databuf)) + { + free(i64databuf); + return 1; + } + break; + case DType_BOOL: + if (setTensorValueBool(elements, bdatabuf)) + { + free(i32databuf); + return 1; + } + break; + default: + FATAL_ERROR("unsupported tensor type=%s", EnumNamesDType()[getDtype()]); + } + + setIsValid(); + + if (fdatabuf) + free(fdatabuf); + if (i32databuf) + free(i32databuf); + if (i64databuf) + free(i64databuf); + if (bdatabuf) + free(bdatabuf); + + return 0; +} + +int TosaReference::Tensor::writeToNpyFile(const char* filename) const +{ + float* fdatabuf = nullptr; + int32_t* i32databuf = nullptr; + int64_t* i64databuf = nullptr; + bool* bdatabuf = nullptr; + NumpyUtilities::NPError nperror; + int elements = getElementCount(); + + switch (getDtype()) + { + case DType_FLOAT: + fdatabuf = (float*)calloc(sizeof(float), elements); + ASSERT_MEM(fdatabuf); + + if (getTensorValueFloat(elements, fdatabuf)) + { + free(fdatabuf); + return 1; + } + + nperror = NumpyUtilities::writeToNpyFile(filename, shape, fdatabuf); + + free(fdatabuf); + break; + case DType_INT32: + case DType_AINT8: + case DType_UINT8: + case DType_INT4: + case DType_INT8: + case DType_INT16: + i32databuf = (int32_t*)calloc(sizeof(int32_t), elements); + ASSERT_MEM(i32databuf); + + if (getTensorValueInt32(elements, i32databuf)) + { + free(i32databuf); + return 1; + } + + nperror = NumpyUtilities::writeToNpyFile(filename, shape, i32databuf); + + free(i32databuf); + break; + case DType_INT48: + i64databuf = (int64_t*)calloc(sizeof(int64_t), elements); + ASSERT_MEM(i64databuf); + + if (getTensorValueInt64(elements, i64databuf)) + { + free(i64databuf); + return 1; + } + + nperror = NumpyUtilities::writeToNpyFile(filename, shape, i64databuf); + + free(i64databuf); + break; + case DType_BOOL: + bdatabuf = (bool*)calloc(sizeof(bool), elements); + ASSERT_MEM(bdatabuf); + + if (getTensorValueBool(elements, bdatabuf)) + { + free(bdatabuf); + return 1; + } + + nperror = NumpyUtilities::writeToNpyFile(filename, shape, bdatabuf); + + free(bdatabuf); + break; + default: + FATAL_ERROR("unsupported tensor type=%s", EnumNamesDType()[getDtype()]); + } + + switch (nperror) + { + case NumpyUtilities::NO_ERROR: + break; + case NumpyUtilities::FILE_NOT_FOUND: + FATAL_ERROR("writeToNpyFile: Cannot open output file %s", filename); + case NumpyUtilities::FILE_IO_ERROR: + FATAL_ERROR("writeToNpyFile: IO error writing file: %s", filename); + case NumpyUtilities::FILE_TYPE_MISMATCH: + FATAL_ERROR("writeToNpyFile: Tensor type and Numpy file type mismatch for tensor %s filename %s", + getName().c_str(), filename); + case NumpyUtilities::HEADER_PARSE_ERROR: + FATAL_ERROR("Numpy header parsing error for file: %s", filename); + case NumpyUtilities::BUFFER_SIZE_MISMATCH: + FATAL_ERROR("Buffer size does not match numpy file size for tensor %s filename %s", getName().c_str(), + filename); + default: + FATAL_ERROR("Unknown error writing Numpy file: %s", filename); + } + + return 0; +} + +template <class T> +int TosaReference::TensorTemplate<T>::copyValueFrom(TosaReference::Tensor* src) +{ + FATAL_ERROR("TensorTemplate<T>::copyValueFrom should not be called. " + "Implement template specialization version."); + return 0; +} + +#define DEF_CTENSOR_COPY_VALUE_FROM(RANK, TYPE) \ + template <> \ + int TosaReference::Tensor##RANK<TYPE>::copyValueFrom(TosaReference::Tensor* src) \ + { \ + TosaReference::Tensor##RANK<TYPE>* t = dynamic_cast<Tensor##RANK<TYPE>*>(src); \ + if (!t) \ + { \ + WARNING("tensor %s templated class does not match %s", src->getName().c_str(), this->getName().c_str()); \ + return 1; \ + } \ + \ + uint32_t src_rank = src->getRank(); \ + uint32_t dst_rank = this->getRank(); \ + DType src_dtype = src->getDtype(); \ + DType dst_dtype = this->getDtype(); \ + bool tensor_match = true; \ + \ + if ((src_rank != dst_rank) || (src_dtype != dst_dtype)) \ + { \ + tensor_match = false; \ + } \ + else \ + { \ + for (uint32_t i = 0; i < src_rank; i++) \ + { \ + int src_dim = src->getShape()[i]; \ + int dst_dim = this->getShape()[i]; \ + if (src_dim != dst_dim) \ + { \ + tensor_match = false; \ + } \ + } \ + } \ + \ + if (!tensor_match) \ + { \ + WARNING("source tensor %s (rank=%u, dtype=%s, shape=%s) doesn't match destination tensor %s (rank=%u, " \ + "dtype=%s, shape=%s)", \ + src->getName().c_str(), src_rank, EnumNamesDType()[src_dtype], src->getShapeAsString().c_str(), \ + this->getName().c_str(), dst_rank, EnumNamesDType()[dst_dtype], this->getShapeAsString().c_str()); \ + return 1; \ + } \ + \ + this->getTensor() = t->getTensor(); \ + return 0; \ + } + +DEF_CTENSOR_COPY_VALUE_FROM(0, float) +DEF_CTENSOR_COPY_VALUE_FROM(1, float) +DEF_CTENSOR_COPY_VALUE_FROM(2, float) +DEF_CTENSOR_COPY_VALUE_FROM(3, float) +DEF_CTENSOR_COPY_VALUE_FROM(4, float) +DEF_CTENSOR_COPY_VALUE_FROM(5, float) +DEF_CTENSOR_COPY_VALUE_FROM(6, float) +DEF_CTENSOR_COPY_VALUE_FROM(0, int32_t) +DEF_CTENSOR_COPY_VALUE_FROM(1, int32_t) +DEF_CTENSOR_COPY_VALUE_FROM(2, int32_t) +DEF_CTENSOR_COPY_VALUE_FROM(3, int32_t) +DEF_CTENSOR_COPY_VALUE_FROM(4, int32_t) +DEF_CTENSOR_COPY_VALUE_FROM(5, int32_t) +DEF_CTENSOR_COPY_VALUE_FROM(6, int32_t) +DEF_CTENSOR_COPY_VALUE_FROM(0, int64_t) +DEF_CTENSOR_COPY_VALUE_FROM(1, int64_t) +DEF_CTENSOR_COPY_VALUE_FROM(2, int64_t) +DEF_CTENSOR_COPY_VALUE_FROM(3, int64_t) +DEF_CTENSOR_COPY_VALUE_FROM(4, int64_t) +DEF_CTENSOR_COPY_VALUE_FROM(5, int64_t) +DEF_CTENSOR_COPY_VALUE_FROM(6, int64_t) +DEF_CTENSOR_COPY_VALUE_FROM(0, bool) +DEF_CTENSOR_COPY_VALUE_FROM(1, bool) +DEF_CTENSOR_COPY_VALUE_FROM(2, bool) +DEF_CTENSOR_COPY_VALUE_FROM(3, bool) +DEF_CTENSOR_COPY_VALUE_FROM(4, bool) +DEF_CTENSOR_COPY_VALUE_FROM(5, bool) +DEF_CTENSOR_COPY_VALUE_FROM(6, bool) + +#undef DEF_CTENSOR_COPY_VALUE_FROM + +template <class T> +int TosaReference::TensorTemplate<T>::setTensorValueFloat(const size_t buflen, const float* vals) +{ + FATAL_ERROR("TensorTemplate<T>::setTensorValueFloat should not be called. " + "Implement template specialization version."); + return 0; +} + +template <> +int TosaReference::Tensor0<float>::setTensorValueFloat(const size_t bufLen, const float* vals) +{ + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + (*tensor)(0) = vals[0]; + + return 0; +} + +template <> +int TosaReference::Tensor1<float>::setTensorValueFloat(const size_t bufLen, const float* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + (*tensor)(i0) = vals[idx++]; + } + + return 0; +} + +template <> +int TosaReference::Tensor2<float>::setTensorValueFloat(const size_t bufLen, const float* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + (*tensor)(i0, i1) = vals[idx++]; + } + } + + return 0; +} + +template <> +int TosaReference::Tensor3<float>::setTensorValueFloat(const size_t bufLen, const float* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + (*tensor)(i0, i1, i2) = vals[idx++]; + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor4<float>::setTensorValueFloat(const size_t bufLen, const float* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + (*tensor)(i0, i1, i2, i3) = vals[idx++]; + } + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor5<float>::setTensorValueFloat(const size_t bufLen, const float* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + for (int i4 = 0; i4 < shape[4]; i4++) + { + (*tensor)(i0, i1, i2, i3, i4) = vals[idx++]; + } + } + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor6<float>::setTensorValueFloat(const size_t bufLen, const float* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + for (int i4 = 0; i4 < shape[4]; i4++) + { + for (int i5 = 0; i5 < shape[5]; i5++) + { + (*tensor)(i0, i1, i2, i3, i4, i5) = vals[idx++]; + } + } + } + } + } + } + return 0; +} + +template <class T> +int TosaReference::TensorTemplate<T>::setTensorValueInt32(const size_t bufLen, const int32_t* vals) +{ + FATAL_ERROR("TensorTemplate<T>::setTensorValueInt32 should not be called. " + "Implement template specialization version."); + return 0; +} + +template <> +int TosaReference::Tensor0<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals) +{ + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + (*tensor)(0) = vals[0]; + + return 0; +} + +template <> +int TosaReference::Tensor1<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + (*tensor)(i0) = vals[idx++]; + } + + return 0; +} + +template <> +int TosaReference::Tensor2<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + (*tensor)(i0, i1) = vals[idx++]; + } + } + + return 0; +} + +template <> +int TosaReference::Tensor3<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + (*tensor)(i0, i1, i2) = vals[idx++]; + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor4<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + (*tensor)(i0, i1, i2, i3) = vals[idx++]; + } + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor5<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + for (int i4 = 0; i4 < shape[4]; i4++) + { + (*tensor)(i0, i1, i2, i3, i4) = vals[idx++]; + } + } + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor6<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + for (int i4 = 0; i4 < shape[4]; i4++) + { + for (int i5 = 0; i5 < shape[5]; i5++) + { + (*tensor)(i0, i1, i2, i3, i4, i5) = vals[idx++]; + } + } + } + } + } + } + return 0; +} + +template <class T> +int TosaReference::TensorTemplate<T>::setTensorValueInt64(const size_t bufLen, const int64_t* vals) +{ + FATAL_ERROR("TensorTemplate<T>::setTensorValueInt64 should not be called. " + "Implement template specialization version."); + return 0; +} + +template <> +int TosaReference::Tensor0<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals) +{ + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + (*tensor)(0) = vals[0]; + + return 0; +} + +template <> +int TosaReference::Tensor1<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + (*tensor)(i0) = vals[idx++]; + } + + return 0; +} + +template <> +int TosaReference::Tensor2<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + (*tensor)(i0, i1) = vals[idx++]; + } + } + + return 0; +} + +template <> +int TosaReference::Tensor3<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + (*tensor)(i0, i1, i2) = vals[idx++]; + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor4<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + (*tensor)(i0, i1, i2, i3) = vals[idx++]; + } + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor5<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + for (int i4 = 0; i4 < shape[4]; i4++) + { + (*tensor)(i0, i1, i2, i3, i4) = vals[idx++]; + } + } + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor6<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + for (int i4 = 0; i4 < shape[4]; i4++) + { + for (int i5 = 0; i5 < shape[5]; i5++) + { + (*tensor)(i0, i1, i2, i3, i4, i5) = vals[idx++]; + } + } + } + } + } + } + return 0; +} + +template <class T> +int TosaReference::TensorTemplate<T>::setTensorValueBool(const size_t buflen, const bool* vals) +{ + FATAL_ERROR("TensorTemplate<T>::setTensorValueBool should not be called. " + "Implement template specialization version."); + return 0; +} + +template <> +int TosaReference::Tensor0<bool>::setTensorValueBool(const size_t bufLen, const bool* vals) +{ + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + (*tensor)(0) = vals[0]; + + return 0; +} + +template <> +int TosaReference::Tensor1<bool>::setTensorValueBool(const size_t bufLen, const bool* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + (*tensor)(i0) = vals[idx++]; + } + + return 0; +} + +template <> +int TosaReference::Tensor2<bool>::setTensorValueBool(const size_t bufLen, const bool* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + (*tensor)(i0, i1) = vals[idx++]; + } + } + + return 0; +} + +template <> +int TosaReference::Tensor3<bool>::setTensorValueBool(const size_t bufLen, const bool* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + (*tensor)(i0, i1, i2) = vals[idx++]; + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor4<bool>::setTensorValueBool(const size_t bufLen, const bool* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + (*tensor)(i0, i1, i2, i3) = vals[idx++]; + } + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor5<bool>::setTensorValueBool(const size_t bufLen, const bool* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + for (int i4 = 0; i4 < shape[4]; i4++) + { + (*tensor)(i0, i1, i2, i3, i4) = vals[idx++]; + } + } + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor6<bool>::setTensorValueBool(const size_t bufLen, const bool* vals) +{ + uint32_t idx = 0; + + ASSERT_MSG(bufLen == getElementCount(), "Total elements must match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + for (int i4 = 0; i4 < shape[4]; i4++) + { + for (int i5 = 0; i5 < shape[5]; i5++) + { + (*tensor)(i0, i1, i2, i3, i4, i5) = vals[idx++]; + } + } + } + } + } + } + return 0; +} + +template <class T> +int TosaReference::TensorTemplate<T>::getTensorValueFloat(const size_t bufLen, float* vals) const +{ + FATAL_ERROR("TensorTemplate<T>::getTensorValueFloat should not be called. " + "Implement template specialization version."); + return 0; +} + +template <> +int TosaReference::Tensor0<float>::getTensorValueFloat(const size_t bufLen, float* vals) const +{ + int totalVals = 1; + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + vals[0] = (*tensor)(0); + + return 0; +} + +template <> +int TosaReference::Tensor1<float>::getTensorValueFloat(const size_t bufLen, float* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + vals[idx++] = (*tensor)(i0); + } + + return 0; +} + +template <> +int TosaReference::Tensor2<float>::getTensorValueFloat(const size_t bufLen, float* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + vals[idx++] = (*tensor)(i0, i1); + } + } + + return 0; +} + +template <> +int TosaReference::Tensor3<float>::getTensorValueFloat(const size_t bufLen, float* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + vals[idx++] = (*tensor)(i0, i1, i2); + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor4<float>::getTensorValueFloat(const size_t bufLen, float* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + vals[idx++] = (*tensor)(i0, i1, i2, i3); + } + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor5<float>::getTensorValueFloat(const size_t bufLen, float* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + for (int i4 = 0; i4 < shape[4]; i4++) + { + vals[idx++] = (*tensor)(i0, i1, i2, i3, i4); + } + } + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor6<float>::getTensorValueFloat(const size_t bufLen, float* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + for (int i4 = 0; i4 < shape[4]; i4++) + { + for (int i5 = 0; i5 < shape[5]; i5++) + { + vals[idx++] = (*tensor)(i0, i1, i2, i3, i4, i5); + } + } + } + } + } + } + return 0; +} + +template <class T> +int TosaReference::TensorTemplate<T>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const +{ + FATAL_ERROR("TensorTemplate<T>::getTensorValueInt32 should not be called. " + "Implement template specialization version."); + return 0; +} + +template <> +int TosaReference::Tensor0<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const +{ + int totalVals = 1; + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + vals[0] = (*tensor)(0); + + return 0; +} + +template <> +int TosaReference::Tensor1<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + vals[idx++] = (*tensor)(i0); + } + + return 0; +} + +template <> +int TosaReference::Tensor2<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + vals[idx++] = (*tensor)(i0, i1); + } + } + + return 0; +} + +template <> +int TosaReference::Tensor3<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + vals[idx++] = (*tensor)(i0, i1, i2); + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor4<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + vals[idx++] = (*tensor)(i0, i1, i2, i3); + } + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor5<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + for (int i4 = 0; i4 < shape[4]; i4++) + { + vals[idx++] = (*tensor)(i0, i1, i2, i3, i4); + } + } + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor6<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + for (int i4 = 0; i4 < shape[4]; i4++) + { + for (int i5 = 0; i5 < shape[5]; i5++) + { + vals[idx++] = (*tensor)(i0, i1, i2, i3, i4, i5); + } + } + } + } + } + } + return 0; +} + +template <class T> +int TosaReference::TensorTemplate<T>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const +{ + FATAL_ERROR("TensorTemplate<T>::getTensorValueInt64 should not be called. " + "Implement template specialization version."); + return 0; +} + +template <> +int TosaReference::Tensor0<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const +{ + int totalVals = 1; + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + vals[0] = (*tensor)(0); + + return 0; +} + +template <> +int TosaReference::Tensor1<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + vals[idx++] = (*tensor)(i0); + } + + return 0; +} + +template <> +int TosaReference::Tensor2<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + vals[idx++] = (*tensor)(i0, i1); + } + } + + return 0; +} + +template <> +int TosaReference::Tensor3<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + vals[idx++] = (*tensor)(i0, i1, i2); + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor4<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + vals[idx++] = (*tensor)(i0, i1, i2, i3); + } + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor5<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + for (int i4 = 0; i4 < shape[4]; i4++) + { + vals[idx++] = (*tensor)(i0, i1, i2, i3, i4); + } + } + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor6<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + for (int i4 = 0; i4 < shape[4]; i4++) + { + for (int i5 = 0; i5 < shape[5]; i5++) + { + vals[idx++] = (*tensor)(i0, i1, i2, i3, i4, i5); + } + } + } + } + } + } + return 0; +} + +template <class T> +int TosaReference::TensorTemplate<T>::getTensorValueBool(const size_t bufLen, bool* vals) const +{ + FATAL_ERROR("TensorTemplate<T>::getTensorValueBool should not be called. " + "Implement template specialization version."); + return 0; +} + +template <> +int TosaReference::Tensor0<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const +{ + int totalVals = 1; + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + vals[0] = (*tensor)(0); + + return 0; +} + +template <> +int TosaReference::Tensor1<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + vals[idx++] = (*tensor)(i0); + } + + return 0; +} + +template <> +int TosaReference::Tensor2<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + vals[idx++] = (*tensor)(i0, i1); + } + } + + return 0; +} + +template <> +int TosaReference::Tensor3<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + vals[idx++] = (*tensor)(i0, i1, i2); + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor4<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + vals[idx++] = (*tensor)(i0, i1, i2, i3); + } + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor5<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + for (int i4 = 0; i4 < shape[4]; i4++) + { + vals[idx++] = (*tensor)(i0, i1, i2, i3, i4); + } + } + } + } + } + + return 0; +} + +template <> +int TosaReference::Tensor6<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const +{ + uint32_t idx = 0; + int totalVals = 1; + + for (size_t i = 0; i < shape.size(); i++) + { + totalVals *= shape[i]; + } + + ASSERT_MSG((size_t)totalVals == bufLen, "Output buffer and tensor size do not match"); + + for (int i0 = 0; i0 < shape[0]; i0++) + { + for (int i1 = 0; i1 < shape[1]; i1++) + { + for (int i2 = 0; i2 < shape[2]; i2++) + { + for (int i3 = 0; i3 < shape[3]; i3++) + { + for (int i4 = 0; i4 < shape[4]; i4++) + { + for (int i5 = 0; i5 < shape[5]; i5++) + { + vals[idx++] = (*tensor)(i0, i1, i2, i3, i4, i5); + } + } + } + } + } + } + return 0; +} + +template <> +int TosaReference::Tensor0<float>::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor0<float>(); + + if (tensor) + return 0; + else + return 1; +} +template <> +int TosaReference::Tensor1<float>::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor1<float>(shape[0]); + if (tensor) + return 0; + else + return 1; +} +template <> +int TosaReference::Tensor2<float>::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor2<float>(shape[0], shape[1]); + if (tensor) + return 0; + else + return 1; +} + +template <> +int TosaReference::Tensor3<float>::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor3<float>(shape[0], shape[1], shape[2]); + if (tensor) + return 0; + else + return 1; +} + +template <> +int TosaReference::Tensor4<float>::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor4<float>(shape[0], shape[1], shape[2], shape[3]); + if (tensor) + return 0; + else + return 1; +} + +template <> +int TosaReference::Tensor5<float>::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor5<float>(shape[0], shape[1], shape[2], shape[3], shape[4]); + if (tensor) + return 0; + else + return 1; +} + +template <> +int TosaReference::Tensor6<float>::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor6<float>(shape[0], shape[1], shape[2], shape[3], shape[4], shape[5]); + if (tensor) + return 0; + else + return 1; +} + +template <> +int TosaReference::Tensor0<int32_t>::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor0<int32_t>(); + if (tensor) + return 0; + else + return 1; +} +template <> +int TosaReference::Tensor1<int32_t>::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor1<int32_t>(shape[0]); + if (tensor) + return 0; + else + return 1; +} +template <> +int TosaReference::Tensor2<int32_t>::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor2<int32_t>(shape[0], shape[1]); + if (tensor) + return 0; + else + return 1; +} + +template <> +int TosaReference::Tensor3<int32_t>::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor3<int32_t>(shape[0], shape[1], shape[2]); + if (tensor) + return 0; + else + return 1; +} + +template <> +int TosaReference::Tensor4<int32_t>::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor4<int32_t>(shape[0], shape[1], shape[2], shape[3]); + if (tensor) + return 0; + else + return 1; +} + +template <> +int TosaReference::Tensor5<int32_t>::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor5<int32_t>(shape[0], shape[1], shape[2], shape[3], shape[4]); + if (tensor) + return 0; + else + return 1; +} + +template <> +int TosaReference::Tensor6<int32_t>::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor6<int32_t>(shape[0], shape[1], shape[2], shape[3], shape[4], shape[5]); + if (tensor) + return 0; + else + return 1; +} + +template <> +int TosaReference::Tensor0<int64_t>::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor0<int64_t>(); + if (tensor) + return 0; + else + return 1; +} +template <> +int TosaReference::Tensor1<int64_t>::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor1<int64_t>(shape[0]); + if (tensor) + return 0; + else + return 1; +} +template <> +int TosaReference::Tensor2<int64_t>::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor2<int64_t>(shape[0], shape[1]); + if (tensor) + return 0; + else + return 1; +} + +template <> +int TosaReference::Tensor3<int64_t>::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor3<int64_t>(shape[0], shape[1], shape[2]); + if (tensor) + return 0; + else + return 1; +} + +template <> +int TosaReference::Tensor4<int64_t>::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor4<int64_t>(shape[0], shape[1], shape[2], shape[3]); + if (tensor) + return 0; + else + return 1; +} + +template <> +int TosaReference::Tensor5<int64_t>::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor5<int64_t>(shape[0], shape[1], shape[2], shape[3], shape[4]); + if (tensor) + return 0; + else + return 1; +} + +template <> +int TosaReference::Tensor6<int64_t>::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor6<int64_t>(shape[0], shape[1], shape[2], shape[3], shape[4], shape[5]); + if (tensor) + return 0; + else + return 1; +} + +template <> +int TosaReference::Tensor0<bool>::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor0<bool>(); + if (tensor) + return 0; + else + return 1; +} +template <> +int TosaReference::Tensor1<bool>::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor1<bool>(shape[0]); + if (tensor) + return 0; + else + return 1; +} +template <> +int TosaReference::Tensor2<bool>::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor2<bool>(shape[0], shape[1]); + if (tensor) + return 0; + else + return 1; +} + +template <> +int TosaReference::Tensor3<bool>::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor3<bool>(shape[0], shape[1], shape[2]); + if (tensor) + return 0; + else + return 1; +} + +template <> +int TosaReference::Tensor4<bool>::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor4<bool>(shape[0], shape[1], shape[2], shape[3]); + if (tensor) + return 0; + else + return 1; +} + +template <> +int TosaReference::Tensor5<bool>::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor5<bool>(shape[0], shape[1], shape[2], shape[3], shape[4]); + if (tensor) + return 0; + else + return 1; +} + +template <> +int TosaReference::Tensor6<bool>::allocate() +{ + ASSERT_MSG(tensor == nullptr, "Error: double allocate Eigen tensor"); + tensor = new ETensor6<bool>(shape[0], shape[1], shape[2], shape[3], shape[4], shape[5]); + if (tensor) + return 0; + else + return 1; +} + +template <> +int TosaReference::Tensor0<float>::dumpTensor(FILE* out) const +{ + char fp_fmt[FOF_STR_LEN]; + snprintf(fp_fmt, FOF_STR_LEN, "[ %%%sf ]\n", g_func_config.fp_format); + + if (tensor == nullptr) + { + fprintf(out, "<Not allocated>\n"); + return 0; + } + + fprintf(out, fp_fmt, (*tensor)(0)); + + return 0; +} + +template <> +int TosaReference::Tensor1<float>::dumpTensor(FILE* out) const +{ + char fp_fmt[FOF_STR_LEN]; + snprintf(fp_fmt, FOF_STR_LEN, " %%%sf ", g_func_config.fp_format); + + if (tensor == nullptr) + { + fprintf(out, "<Not allocated>\n"); + return 0; + } + + fprintf(out, "["); + for (int i0 = 0; i0 < shape[0]; i0++) + { + fprintf(out, fp_fmt, (*tensor)(i0)); + } + fprintf(out, "]\n"); + + return 0; +} + +template <> +int TosaReference::Tensor2<float>::dumpTensor(FILE* out) const +{ + char fp_fmt[FOF_STR_LEN]; + snprintf(fp_fmt, FOF_STR_LEN, " %%%sf ", g_func_config.fp_format); + + if (tensor == nullptr) + { + fprintf(out, "<Not allocated>\n"); + return 0; + } + + fprintf(out, "["); + for (int i0 = 0; i0 < shape[0]; i0++) + { + fprintf(out, "["); + for (int i1 = 0; i1 < shape[1]; i1++) + { + fprintf(out, fp_fmt, (*tensor)(i0, i1)); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + + return 0; +} + +template <> +int TosaReference::Tensor3<float>::dumpTensor(FILE* out) const +{ + char fp_fmt[FOF_STR_LEN]; + snprintf(fp_fmt, FOF_STR_LEN, " %%%sf ", g_func_config.fp_format); + + if (tensor == nullptr) + { + fprintf(out, "<Not allocated>\n"); + return 0; + } + + fprintf(out, "["); + for (int i0 = 0; i0 < shape[0]; i0++) + { + fprintf(out, "["); + for (int i1 = 0; i1 < shape[1]; i1++) + { + fprintf(out, "["); + for (int i2 = 0; i2 < shape[2]; i2++) + { + fprintf(out, fp_fmt, (*tensor)(i0, i1, i2)); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + + return 0; +} + +template <> +int TosaReference::Tensor4<float>::dumpTensor(FILE* out) const +{ + char fp_fmt[FOF_STR_LEN]; + snprintf(fp_fmt, FOF_STR_LEN, " %%%sf ", g_func_config.fp_format); + + if (tensor == nullptr) + { + fprintf(out, "<Not allocated>\n"); + return 0; + } + + fprintf(out, "["); + for (int i0 = 0; i0 < shape[0]; i0++) + { + fprintf(out, "["); + for (int i1 = 0; i1 < shape[1]; i1++) + { + fprintf(out, "["); + for (int i2 = 0; i2 < shape[2]; i2++) + { + fprintf(out, "["); + for (int i3 = 0; i3 < shape[3]; i3++) + { + fprintf(out, fp_fmt, (*tensor)(i0, i1, i2, i3)); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + + return 0; +} + +template <> +int TosaReference::Tensor5<float>::dumpTensor(FILE* out) const +{ + char fp_fmt[FOF_STR_LEN]; + snprintf(fp_fmt, FOF_STR_LEN, " %%%sf ", g_func_config.fp_format); + + if (tensor == nullptr) + { + fprintf(out, "<Not allocated>\n"); + return 0; + } + + fprintf(out, "["); + for (int i0 = 0; i0 < shape[0]; i0++) + { + fprintf(out, "["); + for (int i1 = 0; i1 < shape[1]; i1++) + { + fprintf(out, "["); + for (int i2 = 0; i2 < shape[2]; i2++) + { + fprintf(out, "["); + for (int i3 = 0; i3 < shape[3]; i3++) + { + fprintf(out, "["); + for (int i4 = 0; i4 < shape[4]; i4++) + { + fprintf(out, fp_fmt, (*tensor)(i0, i1, i2, i3, i4)); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + + return 0; +} + +template <> +int TosaReference::Tensor6<float>::dumpTensor(FILE* out) const +{ + char fp_fmt[FOF_STR_LEN]; + snprintf(fp_fmt, FOF_STR_LEN, " %%%sf ", g_func_config.fp_format); + + if (tensor == nullptr) + { + fprintf(out, "<Not allocated>\n"); + return 0; + } + + fprintf(out, "["); + for (int i0 = 0; i0 < shape[0]; i0++) + { + fprintf(out, "["); + for (int i1 = 0; i1 < shape[1]; i1++) + { + fprintf(out, "["); + for (int i2 = 0; i2 < shape[2]; i2++) + { + fprintf(out, "["); + for (int i3 = 0; i3 < shape[3]; i3++) + { + fprintf(out, "["); + for (int i4 = 0; i4 < shape[4]; i4++) + { + fprintf(out, "["); + for (int i5 = 0; i5 < shape[5]; i5++) + { + fprintf(out, fp_fmt, (*tensor)(i0, i1, i2, i3, i4, i5)); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + + return 0; +} + +template <> +int TosaReference::Tensor0<int64_t>::dumpTensor(FILE* out) const +{ + char i64_fmt[FOF_STR_LEN]; + snprintf(i64_fmt, FOF_STR_LEN, "[ %%ld ]\n"); + + if (tensor == nullptr) + { + fprintf(out, "<Not allocated>\n"); + return 0; + } + + fprintf(out, i64_fmt, (*tensor)(0)); + + return 0; +} + +template <> +int TosaReference::Tensor1<int64_t>::dumpTensor(FILE* out) const +{ + char i64_fmt[FOF_STR_LEN]; + snprintf(i64_fmt, FOF_STR_LEN, " %%ld "); + + if (tensor == nullptr) + { + fprintf(out, "<Not allocated>\n"); + return 0; + } + + fprintf(out, "["); + for (int i0 = 0; i0 < shape[0]; i0++) + { + fprintf(out, i64_fmt, (*tensor)(i0)); + } + fprintf(out, "]\n"); + + return 0; +} + +template <> +int TosaReference::Tensor2<int64_t>::dumpTensor(FILE* out) const +{ + char i64_fmt[FOF_STR_LEN]; + snprintf(i64_fmt, FOF_STR_LEN, " %%ld "); + + if (tensor == nullptr) + { + fprintf(out, "<Not allocated>\n"); + return 0; + } + + fprintf(out, "["); + for (int i0 = 0; i0 < shape[0]; i0++) + { + fprintf(out, "["); + for (int i1 = 0; i1 < shape[1]; i1++) + { + fprintf(out, i64_fmt, (*tensor)(i0, i1)); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + + return 0; +} + +template <> +int TosaReference::Tensor3<int64_t>::dumpTensor(FILE* out) const +{ + char i64_fmt[FOF_STR_LEN]; + snprintf(i64_fmt, FOF_STR_LEN, " %%ld "); + + if (tensor == nullptr) + { + fprintf(out, "<Not allocated>\n"); + return 0; + } + + fprintf(out, "["); + for (int i0 = 0; i0 < shape[0]; i0++) + { + fprintf(out, "["); + for (int i1 = 0; i1 < shape[1]; i1++) + { + fprintf(out, "["); + for (int i2 = 0; i2 < shape[2]; i2++) + { + fprintf(out, i64_fmt, (*tensor)(i0, i1, i2)); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + + return 0; +} + +template <> +int TosaReference::Tensor4<int64_t>::dumpTensor(FILE* out) const +{ + char i64_fmt[FOF_STR_LEN]; + snprintf(i64_fmt, FOF_STR_LEN, " %%ld "); + + if (tensor == nullptr) + { + fprintf(out, "<Not allocated>\n"); + return 0; + } + + fprintf(out, "["); + for (int i0 = 0; i0 < shape[0]; i0++) + { + fprintf(out, "["); + for (int i1 = 0; i1 < shape[1]; i1++) + { + fprintf(out, "["); + for (int i2 = 0; i2 < shape[2]; i2++) + { + fprintf(out, "["); + for (int i3 = 0; i3 < shape[3]; i3++) + { + fprintf(out, i64_fmt, (*tensor)(i0, i1, i2, i3)); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + + return 0; +} + +template <> +int TosaReference::Tensor5<int64_t>::dumpTensor(FILE* out) const +{ + char i64_fmt[FOF_STR_LEN]; + snprintf(i64_fmt, FOF_STR_LEN, " %%ld "); + + if (tensor == nullptr) + { + fprintf(out, "<Not allocated>\n"); + return 0; + } + + fprintf(out, "["); + for (int i0 = 0; i0 < shape[0]; i0++) + { + fprintf(out, "["); + for (int i1 = 0; i1 < shape[1]; i1++) + { + fprintf(out, "["); + for (int i2 = 0; i2 < shape[2]; i2++) + { + fprintf(out, "["); + for (int i3 = 0; i3 < shape[3]; i3++) + { + fprintf(out, "["); + for (int i4 = 0; i4 < shape[4]; i4++) + { + fprintf(out, i64_fmt, (*tensor)(i0, i1, i2, i3, i4)); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + + return 0; +} + +template <> +int TosaReference::Tensor6<int64_t>::dumpTensor(FILE* out) const +{ + char i64_fmt[FOF_STR_LEN]; + snprintf(i64_fmt, FOF_STR_LEN, " %%ld "); + + if (tensor == nullptr) + { + fprintf(out, "<Not allocated>\n"); + return 0; + } + + fprintf(out, "["); + for (int i0 = 0; i0 < shape[0]; i0++) + { + fprintf(out, "["); + for (int i1 = 0; i1 < shape[1]; i1++) + { + fprintf(out, "["); + for (int i2 = 0; i2 < shape[2]; i2++) + { + fprintf(out, "["); + for (int i3 = 0; i3 < shape[3]; i3++) + { + fprintf(out, "["); + for (int i4 = 0; i4 < shape[4]; i4++) + { + fprintf(out, "["); + for (int i5 = 0; i5 < shape[5]; i5++) + { + fprintf(out, i64_fmt, (*tensor)(i0, i1, i2, i3, i4, i5)); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + + return 0; +} + +template <> +int TosaReference::Tensor0<int32_t>::dumpTensor(FILE* out) const +{ + char i32_fmt[FOF_STR_LEN]; + snprintf(i32_fmt, FOF_STR_LEN, "[ %%d ]\n"); + + if (tensor == nullptr) + { + fprintf(out, "<Not allocated>\n"); + return 0; + } + + fprintf(out, i32_fmt, (*tensor)(0)); + + return 0; +} + +template <> +int TosaReference::Tensor1<int32_t>::dumpTensor(FILE* out) const +{ + char i32_fmt[FOF_STR_LEN]; + snprintf(i32_fmt, FOF_STR_LEN, " %%d "); + + if (tensor == nullptr) + { + fprintf(out, "<Not allocated>\n"); + return 0; + } + + fprintf(out, "["); + for (int i0 = 0; i0 < shape[0]; i0++) + { + fprintf(out, i32_fmt, (*tensor)(i0)); + } + fprintf(out, "]\n"); + + return 0; +} + +template <> +int TosaReference::Tensor2<int32_t>::dumpTensor(FILE* out) const +{ + char i32_fmt[FOF_STR_LEN]; + snprintf(i32_fmt, FOF_STR_LEN, " %%d "); + + if (tensor == nullptr) + { + fprintf(out, "<Not allocated>\n"); + return 0; + } + + if (tensor == nullptr) + { + fprintf(out, "<Not allocated>\n"); + return 0; + } + + fprintf(out, "["); + for (int i0 = 0; i0 < shape[0]; i0++) + { + fprintf(out, "["); + for (int i1 = 0; i1 < shape[1]; i1++) + { + fprintf(out, i32_fmt, (*tensor)(i0, i1)); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + + return 0; +} + +template <> +int TosaReference::Tensor3<int32_t>::dumpTensor(FILE* out) const +{ + char i32_fmt[FOF_STR_LEN]; + snprintf(i32_fmt, FOF_STR_LEN, " %%d "); + + if (tensor == nullptr) + { + fprintf(out, "<Not allocated>\n"); + return 0; + } + + if (tensor == nullptr) + { + fprintf(out, "<Not allocated>\n"); + return 0; + } + + fprintf(out, "["); + for (int i0 = 0; i0 < shape[0]; i0++) + { + fprintf(out, "["); + for (int i1 = 0; i1 < shape[1]; i1++) + { + fprintf(out, "["); + for (int i2 = 0; i2 < shape[2]; i2++) + { + fprintf(out, i32_fmt, (*tensor)(i0, i1, i2)); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + + return 0; +} + +template <> +int TosaReference::Tensor4<int32_t>::dumpTensor(FILE* out) const +{ + char i32_fmt[FOF_STR_LEN]; + snprintf(i32_fmt, FOF_STR_LEN, " %%d "); + + if (tensor == nullptr) + { + fprintf(out, "<Not allocated>\n"); + return 0; + } + + fprintf(out, "["); + for (int i0 = 0; i0 < shape[0]; i0++) + { + fprintf(out, "["); + for (int i1 = 0; i1 < shape[1]; i1++) + { + fprintf(out, "["); + for (int i2 = 0; i2 < shape[2]; i2++) + { + fprintf(out, "["); + for (int i3 = 0; i3 < shape[3]; i3++) + { + fprintf(out, i32_fmt, (*tensor)(i0, i1, i2, i3)); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + + return 0; +} + +template <> +int TosaReference::Tensor5<int32_t>::dumpTensor(FILE* out) const +{ + char i32_fmt[FOF_STR_LEN]; + snprintf(i32_fmt, FOF_STR_LEN, " %%d "); + + if (tensor == nullptr) + { + fprintf(out, "<Not allocated>\n"); + return 0; + } + + fprintf(out, "["); + for (int i0 = 0; i0 < shape[0]; i0++) + { + fprintf(out, "["); + for (int i1 = 0; i1 < shape[1]; i1++) + { + fprintf(out, "["); + for (int i2 = 0; i2 < shape[2]; i2++) + { + fprintf(out, "["); + for (int i3 = 0; i3 < shape[3]; i3++) + { + fprintf(out, "["); + for (int i4 = 0; i4 < shape[4]; i4++) + { + fprintf(out, i32_fmt, (*tensor)(i0, i1, i2, i3, i4)); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + + return 0; +} + +template <> +int TosaReference::Tensor6<int32_t>::dumpTensor(FILE* out) const +{ + char i32_fmt[FOF_STR_LEN]; + snprintf(i32_fmt, FOF_STR_LEN, " %%d "); + + if (tensor == nullptr) + { + fprintf(out, "<Not allocated>\n"); + return 0; + } + + fprintf(out, "["); + for (int i0 = 0; i0 < shape[0]; i0++) + { + fprintf(out, "["); + for (int i1 = 0; i1 < shape[1]; i1++) + { + fprintf(out, "["); + for (int i2 = 0; i2 < shape[2]; i2++) + { + fprintf(out, "["); + for (int i3 = 0; i3 < shape[3]; i3++) + { + fprintf(out, "["); + for (int i4 = 0; i4 < shape[4]; i4++) + { + fprintf(out, "["); + for (int i5 = 0; i5 < shape[5]; i5++) + { + fprintf(out, i32_fmt, (*tensor)(i0, i1, i2, i3, i4, i5)); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + + return 0; +} + +template <> +int TosaReference::Tensor0<bool>::dumpTensor(FILE* out) const +{ + char bool_fmt[FOF_STR_LEN]; + snprintf(bool_fmt, FOF_STR_LEN, "[ %%s ]\n"); + + if (tensor == nullptr) + { + fprintf(out, "<Not allocated>\n"); + return 0; + } + + fprintf(out, bool_fmt, bool_to_str((*tensor)(0))); + + return 0; +} + +template <> +int TosaReference::Tensor1<bool>::dumpTensor(FILE* out) const +{ + char bool_fmt[FOF_STR_LEN]; + snprintf(bool_fmt, FOF_STR_LEN, " %%s "); + + if (tensor == nullptr) + { + fprintf(out, "<Not allocated>\n"); + return 0; + } + + fprintf(out, "["); + for (int i0 = 0; i0 < shape[0]; i0++) + { + fprintf(out, bool_fmt, bool_to_str((*tensor)(i0))); + } + fprintf(out, "]\n"); + + return 0; +} + +template <> +int TosaReference::Tensor2<bool>::dumpTensor(FILE* out) const +{ + char bool_fmt[FOF_STR_LEN]; + snprintf(bool_fmt, FOF_STR_LEN, " %%s "); + + if (tensor == nullptr) + { + fprintf(out, "<Not allocated>\n"); + return 0; + } + + fprintf(out, "["); + for (int i0 = 0; i0 < shape[0]; i0++) + { + fprintf(out, "["); + for (int i1 = 0; i1 < shape[1]; i1++) + { + fprintf(out, bool_fmt, bool_to_str((*tensor)(i0, i1))); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + + return 0; +} + +template <> +int TosaReference::Tensor3<bool>::dumpTensor(FILE* out) const +{ + char bool_fmt[FOF_STR_LEN]; + snprintf(bool_fmt, FOF_STR_LEN, " %%s "); + + if (tensor == nullptr) + { + fprintf(out, "<Not allocated>\n"); + return 0; + } + + fprintf(out, "["); + for (int i0 = 0; i0 < shape[0]; i0++) + { + fprintf(out, "["); + for (int i1 = 0; i1 < shape[1]; i1++) + { + fprintf(out, "["); + for (int i2 = 0; i2 < shape[2]; i2++) + { + fprintf(out, bool_fmt, bool_to_str((*tensor)(i0, i1, i2))); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + + return 0; +} + +template <> +int TosaReference::Tensor4<bool>::dumpTensor(FILE* out) const +{ + char bool_fmt[FOF_STR_LEN]; + snprintf(bool_fmt, FOF_STR_LEN, " %%s "); + + if (tensor == nullptr) + { + fprintf(out, "<Not allocated>\n"); + return 0; + } + + fprintf(out, "["); + for (int i0 = 0; i0 < shape[0]; i0++) + { + fprintf(out, "["); + for (int i1 = 0; i1 < shape[1]; i1++) + { + fprintf(out, "["); + for (int i2 = 0; i2 < shape[2]; i2++) + { + fprintf(out, "["); + for (int i3 = 0; i3 < shape[3]; i3++) + { + fprintf(out, bool_fmt, bool_to_str((*tensor)(i0, i1, i2, i3))); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + + return 0; +} + +template <> +int TosaReference::Tensor5<bool>::dumpTensor(FILE* out) const +{ + char bool_fmt[FOF_STR_LEN]; + snprintf(bool_fmt, FOF_STR_LEN, " %%s "); + + if (tensor == nullptr) + { + fprintf(out, "<Not allocated>\n"); + return 0; + } + + fprintf(out, "["); + for (int i0 = 0; i0 < shape[0]; i0++) + { + fprintf(out, "["); + for (int i1 = 0; i1 < shape[1]; i1++) + { + fprintf(out, "["); + for (int i2 = 0; i2 < shape[2]; i2++) + { + fprintf(out, "["); + for (int i3 = 0; i3 < shape[3]; i3++) + { + fprintf(out, "["); + for (int i4 = 0; i4 < shape[4]; i4++) + { + fprintf(out, bool_fmt, bool_to_str((*tensor)(i0, i1, i2, i3, i4))); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + + return 0; +} + +template <> +int TosaReference::Tensor6<bool>::dumpTensor(FILE* out) const +{ + char bool_fmt[FOF_STR_LEN]; + snprintf(bool_fmt, FOF_STR_LEN, " %%s "); + + if (tensor == nullptr) + { + fprintf(out, "<Not allocated>\n"); + return 0; + } + + fprintf(out, "["); + for (int i0 = 0; i0 < shape[0]; i0++) + { + fprintf(out, "["); + for (int i1 = 0; i1 < shape[1]; i1++) + { + fprintf(out, "["); + for (int i2 = 0; i2 < shape[2]; i2++) + { + fprintf(out, "["); + for (int i3 = 0; i3 < shape[3]; i3++) + { + fprintf(out, "["); + for (int i4 = 0; i4 < shape[4]; i4++) + { + fprintf(out, "["); + for (int i5 = 0; i5 < shape[5]; i5++) + { + fprintf(out, bool_fmt, bool_to_str((*tensor)(i0, i1, i2, i3, i4, i5))); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + } + fprintf(out, "]\n"); + + return 0; +} + +template <class T> +int TosaReference::TensorTemplate<T>::dumpTensor(FILE* out) const +{ + return 0; +} + +// template explicit specialization +template class TosaReference::TensorTemplate<Eigen::Tensor<float, 0>>; +template class TosaReference::TensorTemplate<Eigen::Tensor<float, 1>>; +template class TosaReference::TensorTemplate<Eigen::Tensor<float, 2>>; +template class TosaReference::TensorTemplate<Eigen::Tensor<float, 3>>; +template class TosaReference::TensorTemplate<Eigen::Tensor<float, 4>>; +template class TosaReference::TensorTemplate<Eigen::Tensor<float, 5>>; +template class TosaReference::TensorTemplate<Eigen::Tensor<float, 6>>; + +template class TosaReference::TensorTemplate<Eigen::Tensor<int32_t, 0>>; +template class TosaReference::TensorTemplate<Eigen::Tensor<int32_t, 1>>; +template class TosaReference::TensorTemplate<Eigen::Tensor<int32_t, 2>>; +template class TosaReference::TensorTemplate<Eigen::Tensor<int32_t, 3>>; +template class TosaReference::TensorTemplate<Eigen::Tensor<int32_t, 4>>; +template class TosaReference::TensorTemplate<Eigen::Tensor<int32_t, 5>>; +template class TosaReference::TensorTemplate<Eigen::Tensor<int32_t, 6>>; + +template class TosaReference::TensorTemplate<Eigen::Tensor<int64_t, 0>>; +template class TosaReference::TensorTemplate<Eigen::Tensor<int64_t, 1>>; +template class TosaReference::TensorTemplate<Eigen::Tensor<int64_t, 2>>; +template class TosaReference::TensorTemplate<Eigen::Tensor<int64_t, 3>>; +template class TosaReference::TensorTemplate<Eigen::Tensor<int64_t, 4>>; +template class TosaReference::TensorTemplate<Eigen::Tensor<int64_t, 5>>; +template class TosaReference::TensorTemplate<Eigen::Tensor<int64_t, 6>>; + +template class TosaReference::TensorTemplate<Eigen::Tensor<bool, 0>>; +template class TosaReference::TensorTemplate<Eigen::Tensor<bool, 1>>; +template class TosaReference::TensorTemplate<Eigen::Tensor<bool, 2>>; +template class TosaReference::TensorTemplate<Eigen::Tensor<bool, 3>>; +template class TosaReference::TensorTemplate<Eigen::Tensor<bool, 4>>; +template class TosaReference::TensorTemplate<Eigen::Tensor<bool, 5>>; +template class TosaReference::TensorTemplate<Eigen::Tensor<bool, 6>>; diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h new file mode 100644 index 0000000..2fd37cd --- /dev/null +++ b/reference_model/src/tensor.h @@ -0,0 +1,815 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef TOSA_REFERENCE_TENSOR_H +#define TOSA_REFERENCE_TENSOR_H + +#include "model_common.h" +#include "ops/template_types.h" +#include "tosa_generated.h" +#include "tosa_serialization_handler.h" +#include <Eigen/CXX11/Tensor> +#include <list> +#include <vector> + +using namespace tosa; + +namespace TosaReference +{ +class GraphNode; + +class Tensor +{ +public: + Tensor(std::string tensorName_, + DType tensorDtype__, + const std::vector<Usage>& tensorUsage_, + const std::vector<Format>& tensorFormat_, + std::vector<int> shape_, + int isConst_); + + virtual ~Tensor(); + + int setIsSubgraphInput(); + int setIsSubgraphOutput(); + + int getIsSubgraphInput() const + { + return isSubgraphInput; + } + + int getIsSubgraphOutput() const + { + return isSubgraphOutput; + } + + int setProducer(GraphNode* node); + int addConsumer(GraphNode* node); + + int setIsValid() + { + isValid = 1; + return 0; + } + + int clearIsValid() + { + isValid = 0; + return 0; + } + + int getIsValid() const + { + return isValid; + } + + int getIsConst() const + { + return isConst; + } + + GraphNode* getProducer() + { + return producer; + } + + std::vector<GraphNode*>& getConsumers() + { + return consumers; + } + + const std::string& getName() const + { + return tensorName; + } + + const std::vector<int>& getShape() const + { + return shape; + } + + std::string getShapeAsString() const + { + std::string shape_str("["); + for (auto& dim : shape) + { + shape_str += (std::to_string(dim) + ", "); + } + shape_str.append("]"); + return shape_str; + } + + const std::vector<Usage>& getUsage() const + { + return tensorUsage; + } + + bool hasUsage(Usage usage) const + { + for (auto& usg : tensorUsage) + { + if (usg == usage) + { + return true; + } + } + return false; + } + + std::string getUsageAsString() const + { + std::string usage_str("["); + for (auto& usg : tensorUsage) + { + usage_str += (std::string(EnumNamesUsage()[usg]) + ", "); + } + usage_str.append("]"); + return usage_str; + } + + const std::vector<Format>& getFormat() const + { + return tensorFormat; + } + + bool hasFormat(Format format) const + { + for (auto& fmt : tensorFormat) + { + if (fmt == format) + { + return true; + } + } + return false; + } + + std::string getFormatAsString() const + { + std::string format_str("["); + for (auto& fmt : tensorFormat) + { + format_str += (std::string(EnumNamesFormat()[fmt]) + ", "); + } + format_str.append("]"); + return format_str; + } + + const uint32_t getElementCount() const + { + uint32_t elements = 1; + for (size_t i = 0; i < shape.size(); i++) + elements *= shape[i]; + + return elements; + } + + // Comparison of rank and type with other tensors + const int matchRank(const Tensor& ref) const + { + return (ref.shape.size() == shape.size()) ? 0 : 1; + } + + const int matchType(const Tensor& ref) const + { + return (ref.tensorDtype == tensorDtype) ? 0 : 1; + } + + const int matchRankType(const Tensor& ref) const + { + return (matchType(ref) || matchRank(ref)); + } + + const int matchRankTypeShape(const Tensor& ref, const bool broadcastOk = false) const + { + if (matchRankType(ref)) + return 1; + + for (size_t i = 0; i < shape.size(); i++) + { + if (shape[i] != ref.shape[i]) + { + if (!broadcastOk || + // For broadcasts, at least one operand must have size 1 + // if they don't both match + (broadcastOk && (shape[i] != 1 && ref.shape[i] != 1))) + { + return 1; + } + } + } + + return 0; + } + + // Sometimes we might want to match several semi-compatible types, + // so just check rank and size here + const int matchRankSize(const Tensor& ref) const + { + if (matchRank(ref)) + return 1; + + for (size_t i = 0; i < shape.size(); i++) + { + if (shape[i] != ref.shape[i]) + return 1; + } + + return 0; + } + + // Unary check to make sure rank matches + const int checkRequiredRank(const int exactRank) const + { + return (shape.size() == (size_t)exactRank) ? 0 : 1; + } + + const int checkRequiredRank(const int minRank, const int maxRank) const + { + return (shape.size() >= (size_t)minRank && shape.size() <= (size_t)maxRank) ? 0 : 1; + } + + const int getRank() const + { + return shape.size(); + } + + const DType getDtype() const + { + return tensorDtype; + } + + virtual int dumpTensor(FILE* out) const = 0; + virtual int dumpTensorParams(FILE* out) const; + virtual int dumpTensorParams(std::ostream& out) const; + + virtual int setTensorValueFloat(const size_t bufLen, const float* vals) = 0; + virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals) = 0; + virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals) = 0; + virtual int setTensorValueBool(const size_t bufLen, const bool* vals) = 0; + virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const = 0; + virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const = 0; + virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const = 0; + virtual int getTensorValueBool(const size_t bufLen, bool* ibuf) const = 0; + + virtual int readFromNpyFile(const char* filename); + virtual int writeToNpyFile(const char* filename) const; + virtual int copyValueFrom(Tensor* tensor) = 0; + + const char* bool_to_str(bool in) const + { + static const char* true_str = "true"; + static const char* false_str = "false"; + return in ? true_str : false_str; + } + + virtual int allocate() = 0; + virtual int deallocate() = 0; + virtual bool is_allocated() = 0; + +protected: + std::string tensorName; + DType tensorDtype; + std::vector<Usage> tensorUsage; + std::vector<Format> tensorFormat; + int isConst; + int isValid; + std::vector<int> shape; + int isSubgraphInput; + int isSubgraphOutput; + bool isAllocated; + + GraphNode* producer; + std::vector<GraphNode*> consumers; + + // Note: the Eigen::Tensor is not declared in Tensor + // Instead, the TensorTemplate class keeps the templated tensor + // declaration so that the graph manipulation tools are isolated + // from the templated tensor type. + // + // Operators need to be aware of the TensorTemplate<EigenTensor<type, rank>> type + // so that they can operate on the right types. +}; + +template <class T> +class TensorTemplate : public Tensor +{ +public: + TensorTemplate(std::string tensorName_, + DType tensorDtype_, + const std::vector<Usage>& tensorUsage_, + const std::vector<Format>& tensorFormat_, + std::vector<int> shape_, + int isConst_) + : Tensor(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, isConst_) + { + tensor = nullptr; + } + + virtual ~TensorTemplate() + { + deallocate(); + } + + virtual int allocate() + { + tensor = new T(); + if (tensor) + return 0; + else + return 1; + } + + virtual int deallocate() + { + if (tensor) + { + delete tensor; + } + tensor = nullptr; + return 0; + } + + virtual bool is_allocated() + { + if (tensor) + { + return true; + } + return false; + } + + T& getTensor() + { + return *tensor; + } + + virtual int dumpTensor(FILE* out) const; + + virtual int setTensorValueFloat(const size_t bufLen, const float* vals); + virtual int setTensorValueInt32(const size_t bufLen, const int32_t* vals); + virtual int setTensorValueInt64(const size_t bufLen, const int64_t* vals); + virtual int setTensorValueBool(const size_t bufLen, const bool* vals); + virtual int getTensorValueFloat(const size_t bufLen, float* fbuf) const; + virtual int getTensorValueInt32(const size_t bufLen, int32_t* ibuf) const; + virtual int getTensorValueInt64(const size_t bufLen, int64_t* ibuf) const; + virtual int getTensorValueBool(const size_t bufLen, bool* bbuf) const; + + virtual int copyValueFrom(Tensor* tensor); + +protected: + T* tensor; +}; + +// allocate() template specializations to allocate the different tensor sizes +// Let the compiler know here before the factory uses them, but define them in the .cc file. +template <> +int Tensor0<float>::allocate(); +template <> +int Tensor1<float>::allocate(); +template <> +int Tensor2<float>::allocate(); +template <> +int Tensor3<float>::allocate(); +template <> +int Tensor4<float>::allocate(); +template <> +int Tensor5<float>::allocate(); +template <> +int Tensor6<float>::allocate(); + +template <> +int Tensor0<int32_t>::allocate(); +template <> +int Tensor1<int32_t>::allocate(); +template <> +int Tensor2<int32_t>::allocate(); +template <> +int Tensor3<int32_t>::allocate(); +template <> +int Tensor4<int32_t>::allocate(); +template <> +int Tensor5<int32_t>::allocate(); +template <> +int Tensor6<int32_t>::allocate(); + +template <> +int Tensor0<int64_t>::allocate(); +template <> +int Tensor1<int64_t>::allocate(); +template <> +int Tensor2<int64_t>::allocate(); +template <> +int Tensor3<int64_t>::allocate(); +template <> +int Tensor4<int64_t>::allocate(); +template <> +int Tensor5<int64_t>::allocate(); +template <> +int Tensor6<int64_t>::allocate(); + +template <> +int Tensor0<bool>::allocate(); +template <> +int Tensor1<bool>::allocate(); +template <> +int Tensor2<bool>::allocate(); +template <> +int Tensor3<bool>::allocate(); +template <> +int Tensor4<bool>::allocate(); +template <> +int Tensor5<bool>::allocate(); +template <> +int Tensor6<bool>::allocate(); + +template <> +int Tensor0<float>::copyValueFrom(Tensor* src); +template <> +int Tensor1<float>::copyValueFrom(Tensor* src); +template <> +int Tensor2<float>::copyValueFrom(Tensor* src); +template <> +int Tensor3<float>::copyValueFrom(Tensor* src); +template <> +int Tensor4<float>::copyValueFrom(Tensor* src); +template <> +int Tensor5<float>::copyValueFrom(Tensor* src); +template <> +int Tensor6<float>::copyValueFrom(Tensor* src); + +template <> +int Tensor0<int32_t>::copyValueFrom(Tensor* src); +template <> +int Tensor1<int32_t>::copyValueFrom(Tensor* src); +template <> +int Tensor2<int32_t>::copyValueFrom(Tensor* src); +template <> +int Tensor3<int32_t>::copyValueFrom(Tensor* src); +template <> +int Tensor4<int32_t>::copyValueFrom(Tensor* src); +template <> +int Tensor5<int32_t>::copyValueFrom(Tensor* src); +template <> +int Tensor6<int32_t>::copyValueFrom(Tensor* src); + +template <> +int Tensor0<int64_t>::copyValueFrom(Tensor* src); +template <> +int Tensor1<int64_t>::copyValueFrom(Tensor* src); +template <> +int Tensor2<int64_t>::copyValueFrom(Tensor* src); +template <> +int Tensor3<int64_t>::copyValueFrom(Tensor* src); +template <> +int Tensor4<int64_t>::copyValueFrom(Tensor* src); +template <> +int Tensor5<int64_t>::copyValueFrom(Tensor* src); +template <> +int Tensor6<int64_t>::copyValueFrom(Tensor* src); + +template <> +int Tensor0<bool>::copyValueFrom(Tensor* src); +template <> +int Tensor1<bool>::copyValueFrom(Tensor* src); +template <> +int Tensor2<bool>::copyValueFrom(Tensor* src); +template <> +int Tensor3<bool>::copyValueFrom(Tensor* src); +template <> +int Tensor4<bool>::copyValueFrom(Tensor* src); +template <> +int Tensor5<bool>::copyValueFrom(Tensor* src); +template <> +int Tensor6<bool>::copyValueFrom(Tensor* src); + +template <> +int Tensor0<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals); +template <> +int Tensor1<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals); +template <> +int Tensor2<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals); +template <> +int Tensor3<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals); +template <> +int Tensor4<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals); +template <> +int Tensor5<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals); +template <> +int Tensor6<int32_t>::setTensorValueInt32(const size_t bufLen, const int32_t* vals); + +template <> +int Tensor0<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const; +template <> +int Tensor1<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const; +template <> +int Tensor2<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const; +template <> +int Tensor3<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const; +template <> +int Tensor4<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const; +template <> +int Tensor5<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const; +template <> +int Tensor6<int32_t>::getTensorValueInt32(const size_t bufLen, int32_t* vals) const; + +template <> +int Tensor0<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals); +template <> +int Tensor1<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals); +template <> +int Tensor2<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals); +template <> +int Tensor3<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals); +template <> +int Tensor4<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals); +template <> +int Tensor5<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals); +template <> +int Tensor6<int64_t>::setTensorValueInt64(const size_t bufLen, const int64_t* vals); + +template <> +int Tensor0<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const; +template <> +int Tensor1<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const; +template <> +int Tensor2<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const; +template <> +int Tensor3<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const; +template <> +int Tensor4<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const; +template <> +int Tensor5<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const; +template <> +int Tensor6<int64_t>::getTensorValueInt64(const size_t bufLen, int64_t* vals) const; + +template <> +int Tensor0<float>::setTensorValueFloat(const size_t bufLen, const float* vals); +template <> +int Tensor1<float>::setTensorValueFloat(const size_t bufLen, const float* vals); +template <> +int Tensor2<float>::setTensorValueFloat(const size_t bufLen, const float* vals); +template <> +int Tensor3<float>::setTensorValueFloat(const size_t bufLen, const float* vals); +template <> +int Tensor4<float>::setTensorValueFloat(const size_t bufLen, const float* vals); +template <> +int Tensor5<float>::setTensorValueFloat(const size_t bufLen, const float* vals); +template <> +int Tensor6<float>::setTensorValueFloat(const size_t bufLen, const float* vals); + +template <> +int Tensor0<float>::getTensorValueFloat(const size_t bufLen, float* vals) const; +template <> +int Tensor1<float>::getTensorValueFloat(const size_t bufLen, float* vals) const; +template <> +int Tensor2<float>::getTensorValueFloat(const size_t bufLen, float* vals) const; +template <> +int Tensor3<float>::getTensorValueFloat(const size_t bufLen, float* vals) const; +template <> +int Tensor4<float>::getTensorValueFloat(const size_t bufLen, float* vals) const; +template <> +int Tensor5<float>::getTensorValueFloat(const size_t bufLen, float* vals) const; +template <> +int Tensor6<float>::getTensorValueFloat(const size_t bufLen, float* vals) const; + +template <> +int Tensor0<bool>::setTensorValueBool(const size_t bufLen, const bool* vals); +template <> +int Tensor1<bool>::setTensorValueBool(const size_t bufLen, const bool* vals); +template <> +int Tensor2<bool>::setTensorValueBool(const size_t bufLen, const bool* vals); +template <> +int Tensor3<bool>::setTensorValueBool(const size_t bufLen, const bool* vals); +template <> +int Tensor4<bool>::setTensorValueBool(const size_t bufLen, const bool* vals); +template <> +int Tensor5<bool>::setTensorValueBool(const size_t bufLen, const bool* vals); +template <> +int Tensor6<bool>::setTensorValueBool(const size_t bufLen, const bool* vals); + +template <> +int Tensor0<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const; +template <> +int Tensor1<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const; +template <> +int Tensor2<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const; +template <> +int Tensor3<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const; +template <> +int Tensor4<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const; +template <> +int Tensor5<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const; +template <> +int Tensor6<bool>::getTensorValueBool(const size_t bufLen, bool* vals) const; + +// assume we only dump float type tensor now +template <> +int Tensor0<float>::dumpTensor(FILE* out) const; +template <> +int Tensor1<float>::dumpTensor(FILE* out) const; +template <> +int Tensor2<float>::dumpTensor(FILE* out) const; +template <> +int Tensor3<float>::dumpTensor(FILE* out) const; +template <> +int Tensor4<float>::dumpTensor(FILE* out) const; +template <> +int Tensor5<float>::dumpTensor(FILE* out) const; +template <> +int Tensor6<float>::dumpTensor(FILE* out) const; +template <> +int Tensor0<int32_t>::dumpTensor(FILE* out) const; +template <> +int Tensor1<int32_t>::dumpTensor(FILE* out) const; +template <> +int Tensor2<int32_t>::dumpTensor(FILE* out) const; +template <> +int Tensor3<int32_t>::dumpTensor(FILE* out) const; +template <> +int Tensor4<int32_t>::dumpTensor(FILE* out) const; +template <> +int Tensor5<int32_t>::dumpTensor(FILE* out) const; +template <> +int Tensor6<int32_t>::dumpTensor(FILE* out) const; +template <> +int Tensor0<int64_t>::dumpTensor(FILE* out) const; +template <> +int Tensor1<int64_t>::dumpTensor(FILE* out) const; +template <> +int Tensor2<int64_t>::dumpTensor(FILE* out) const; +template <> +int Tensor3<int64_t>::dumpTensor(FILE* out) const; +template <> +int Tensor4<int64_t>::dumpTensor(FILE* out) const; +template <> +int Tensor5<int64_t>::dumpTensor(FILE* out) const; +template <> +int Tensor6<int64_t>::dumpTensor(FILE* out) const; +template <> +int Tensor0<bool>::dumpTensor(FILE* out) const; +template <> +int Tensor1<bool>::dumpTensor(FILE* out) const; +template <> +int Tensor2<bool>::dumpTensor(FILE* out) const; +template <> +int Tensor3<bool>::dumpTensor(FILE* out) const; +template <> +int Tensor4<bool>::dumpTensor(FILE* out) const; +template <> +int Tensor5<bool>::dumpTensor(FILE* out) const; +template <> +int Tensor6<bool>::dumpTensor(FILE* out) const; + +class TensorFactory +{ +public: + static Tensor* newTensor(std::string tensorName_, + DType tensorDtype_, + const std::vector<Usage>& tensorUsage_, + const std::vector<Format>& tensorFormat_, + std::vector<int> shape_, + int isConst_, + const uint32_t rank) + { + switch (tensorDtype_) + { + case DType_FLOAT: + switch (rank) + { + case 0: + return new Tensor0<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, + isConst_); + case 1: + return new Tensor1<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, + isConst_); + case 2: + return new Tensor2<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, + isConst_); + case 3: + return new Tensor3<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, + isConst_); + case 4: + return new Tensor4<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, + isConst_); + case 5: + return new Tensor5<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, + isConst_); + case 6: + return new Tensor6<float>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, + isConst_); + default: + goto done; + } + case DType_INT32: + case DType_AINT8: + case DType_UINT8: + case DType_INT4: + case DType_INT8: + case DType_INT16: + switch (rank) + { + case 0: + return new Tensor0<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, + isConst_); + case 1: + return new Tensor1<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, + isConst_); + case 2: + return new Tensor2<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, + isConst_); + case 3: + return new Tensor3<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, + isConst_); + case 4: + return new Tensor4<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, + isConst_); + case 5: + return new Tensor5<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, + isConst_); + case 6: + return new Tensor6<int32_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, + isConst_); + default: + goto done; + } + case DType_INT48: + switch (rank) + { + case 0: + return new Tensor0<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, + isConst_); + case 1: + return new Tensor1<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, + isConst_); + case 2: + return new Tensor2<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, + isConst_); + case 3: + return new Tensor3<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, + isConst_); + case 4: + return new Tensor4<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, + isConst_); + case 5: + return new Tensor5<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, + isConst_); + case 6: + return new Tensor6<int64_t>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, + isConst_); + default: + goto done; + } + case DType_BOOL: + switch (rank) + { + case 0: + return new Tensor0<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, + isConst_); + case 1: + return new Tensor1<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, + isConst_); + case 2: + return new Tensor2<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, + isConst_); + case 3: + return new Tensor3<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, + isConst_); + case 4: + return new Tensor4<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, + isConst_); + case 5: + return new Tensor5<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, + isConst_); + case 6: + return new Tensor6<bool>(tensorName_, tensorDtype_, tensorUsage_, tensorFormat_, shape_, + isConst_); + default: + goto done; + } + default: + goto done; + } + + done: + FATAL_ERROR("Unsupported tensor name=%s, type=%s, rank=%d", tensorName_.c_str(), EnumNamesDType()[tensorDtype_], + rank); + } + + static Tensor* newTensor(DType type, const std::vector<int> shape); +}; +}; // namespace TosaReference + +#endif |