diff options
author | Matthew Sloyan <matthew.sloyan@arm.com> | 2022-09-26 13:31:43 +0100 |
---|---|---|
committer | Matthew Sloyan <matthew.sloyan@arm.com> | 2022-10-07 16:02:07 +0100 |
commit | ba5fad356a926d5e1c6e0fe6b546a310230cc5a8 (patch) | |
tree | 10e3e127c091da90d591253fd55e8566c0b61e7a /reference_model/src | |
parent | a0848c6edbf37034e280a670bdd2f990fdf796da (diff) | |
download | reference_model-ba5fad356a926d5e1c6e0fe6b546a310230cc5a8.tar.gz |
Add IModelRunner interface to TOSA Reference Model
* Added IModelRunner interface using pimpl idiom, which allows a user to
initialize, configure and run the model.
* Added unit tests for IModelRunner.
* Added doctest as third-party submodule.
* Added user options to specify paths for dependencies.
* Moved general func_config functions to separate utility, which removes
cxxopts dependency.
Signed-off-by: Matthew Sloyan <matthew.sloyan@arm.com>
Change-Id: If42f1f82cd6dadf18911a48dcd5fa579b719aff2
Diffstat (limited to 'reference_model/src')
-rw-r--r-- | reference_model/src/command_line_utils.h (renamed from reference_model/src/func_config.cc) | 19 | ||||
-rw-r--r-- | reference_model/src/debug_modes.def | 20 | ||||
-rw-r--r-- | reference_model/src/debug_types.h | 57 | ||||
-rw-r--r-- | reference_model/src/func_config.h | 46 | ||||
-rw-r--r-- | reference_model/src/func_debug.h | 244 | ||||
-rw-r--r-- | reference_model/src/general_utils.h | 68 | ||||
-rw-r--r-- | reference_model/src/main.cpp | 25 | ||||
-rw-r--r-- | reference_model/src/model_common.h | 28 | ||||
-rw-r--r-- | reference_model/src/model_runner.cc | 76 | ||||
-rw-r--r-- | reference_model/src/model_runner_impl.cc | 277 | ||||
-rw-r--r-- | reference_model/src/model_runner_impl.h | 66 | ||||
-rw-r--r-- | reference_model/src/subgraph_traverser.h | 8 | ||||
-rw-r--r-- | reference_model/src/tensor.cc | 203 | ||||
-rw-r--r-- | reference_model/src/tensor.h | 10 |
14 files changed, 717 insertions, 430 deletions
diff --git a/reference_model/src/func_config.cc b/reference_model/src/command_line_utils.h index 6bd809e..1bd1639 100644 --- a/reference_model/src/func_config.cc +++ b/reference_model/src/command_line_utils.h @@ -13,19 +13,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include <cxxopts.hpp> -#include <ctype.h> -#include <signal.h> -#include <stdarg.h> -#include <stdint.h> -#include <stdio.h> -#include <stdlib.h> -#include <string.h> -#include <sys/types.h> +#ifndef COMMAND_LINE_UTILS_H_ +#define COMMAND_LINE_UTILS_H_ #include "func_config.h" #include "func_debug.h" +#include <stdint.h> +#include <cxxopts.hpp> + // Read the command line arguments int func_model_parse_cmd_line( func_config_t& func_config, func_debug_t& func_debug, int argc, char** argv, const char* version) @@ -102,7 +98,4 @@ int func_model_parse_cmd_line( return 0; } -void func_model_print_help() -{ - -} +#endif diff --git a/reference_model/src/debug_modes.def b/reference_model/src/debug_modes.def deleted file mode 100644 index 51b151d..0000000 --- a/reference_model/src/debug_modes.def +++ /dev/null @@ -1,20 +0,0 @@ - -// Copyright (c) 2020, ARM Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Defines the debugging printing modes - -DEBUG_MODE(CONFIG,0) // Configuration parsing/initialization -DEBUG_MODE(GT,1) // Graph traverser -DEBUG_MODE(OP,2) // Operation diff --git a/reference_model/src/debug_types.h b/reference_model/src/debug_types.h deleted file mode 100644 index bd93f19..0000000 --- a/reference_model/src/debug_types.h +++ /dev/null @@ -1,57 +0,0 @@ - -// Copyright (c) 2020, ARM Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* - * Filename: src/debug_types.h - * Description: - * Defines fundamental debugger datatypes for the functional model - */ - -#ifndef DEBUG_TYPES_H_ -#define DEBUG_TYPES_H_ - -#ifdef __cplusplus -extern "C" -{ -#endif - - // Debug verbosity mask - typedef enum func_debug_verbosity_e - { - DEBUG_VERB_NONE = 0x00, - DEBUG_VERB_INFO = 0x01, // Informational debugging messages - DEBUG_VERB_IFACE = 0x02, // Interface debugging support - DEBUG_VERB_LOW = 0x04, // Low, medium, and high levels of debug printout - DEBUG_VERB_MED = 0x08, - DEBUG_VERB_HIGH = 0x10 - } func_debug_verbosity_e; - - // Generated debug modes enumeration - typedef enum func_debug_mode_e - { - DEBUG_NONE = 0x0, -#define DEBUG_MODE(NAME, BIT) DEBUG_##NAME = (1UL << BIT), -#include "debug_modes.def" -#undef DEBUG_MODE - DEBUG_ALL = 0xffffffffffffffffUL - } func_debug_mode_e; - -#define DEBUG_INST_ALL 0xffffffffffffffffUL - -#ifdef __cplusplus -} -#endif - -#endif diff --git a/reference_model/src/func_config.h b/reference_model/src/func_config.h deleted file mode 100644 index 49f03e9..0000000 --- a/reference_model/src/func_config.h +++ /dev/null @@ -1,46 +0,0 @@ - -// Copyright (c) 2020-2022, ARM Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef FUNC_CONFIG_H_ -#define FUNC_CONFIG_H_ - -struct func_config_t -{ - std::string operator_fbs = "tosa.fbs"; - std::string test_desc = "desc.json"; - std::string flatbuffer_dir = ""; - std::string output_dir = ""; - std::string tosa_file = ""; - std::string ifm_name = ""; - std::string ifm_file = ""; - std::string ofm_name = ""; - std::string ofm_file = ""; - uint32_t eval = 1; - uint32_t validate_only = 0; - uint32_t output_tensors = 1; - uint32_t tosa_profile = 1; - uint32_t dump_intermediates = 0; - std::string fp_format = "0.5"; -}; - -// Forward declaration -struct func_debug_t; - -int func_model_parse_cmd_line( - func_config_t& func_config, func_debug_t& func_debug, int argc, char** argv, const char* version); -int func_model_parse_flat_config_file(func_config_t*, const char* filename); -void func_model_print_help(); - -#endif diff --git a/reference_model/src/func_debug.h b/reference_model/src/func_debug.h deleted file mode 100644 index ee89935..0000000 --- a/reference_model/src/func_debug.h +++ /dev/null @@ -1,244 +0,0 @@ - -// Copyright (c) 2020, ARM Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef FUNC_DEBUG_H -#define FUNC_DEBUG_H - -#include "debug_types.h" -#include <assert.h> -#include <cinttypes> -#include <signal.h> -#include <stdio.h> -#include <vector> - -void func_print_backtrace(FILE* out, int sig = SIGABRT); - -void func_enable_signal_handlers(); - -// STRINGIFY2 is needed expand expression passed to STRINGIFY -#define STRINGIFY2(s) #s -#define STRINGIFY(s) STRINGIFY2(s) - -// If TRACED_LOG is defined, add file:line to log messages -#if defined(TRACED_LOG) -#define WHERE "@" __FILE__ ":" STRINGIFY(__LINE__) -#else -#define WHERE -#endif - -#if defined(COLORIZED_LOG) -#define COL(col, fmt) "\x1b[3" col "m" fmt "\x1b[0m" -#define COL_FATAL(fmt) COL("1;41", fmt) -#define COL_WARN(fmt) COL("1;43", fmt) -#define COL_INFO(fmt) COL("2", fmt) -#define COL_IFACE(fmt) fmt -#define COL_LOW(fmt) COL("35", fmt) -#define COL_MED(fmt) COL("2;33", fmt) -#define COL_HIGH(fmt) COL("2;32", fmt) -#else -#define COL_FATAL(fmt) fmt -#define COL_WARN(fmt) fmt -#define COL_INFO(fmt) fmt -#define COL_IFACE(fmt) fmt -#define COL_LOW(fmt) fmt -#define COL_MED(fmt) fmt -#define COL_HIGH(fmt) fmt -#endif - -struct func_debug_t -{ - uint32_t func_debug_verbosity = 0; // What verbosity level is set? (bitmask) - uint64_t func_debug_mask = 0; // Which units have debugging enabled? (bitmask) - uint64_t func_debug_inst_mask = 0; // Which instances have debugging enabled (bitmask) - uint64_t inst_id = 0; // The instance id for multiple model instances - FILE* func_debug_file = stderr; // Output file - bool is_output_unbuffered = false; // should log files be opened with unbuffered I/O. - - int init_debug(uint64_t inst_id); - int fini_debug(); - int set_file(const std::string& filename); - void set_mask(const std::string& str); - void set_mask(const uint64_t mask); - void print_masks(FILE* out); - void set_verbosity(const std::string& str); - void set_verbosity(const uint32_t verb); - void set_inst_mask(const char* mask); - void set_inst_mask(const uint64_t mask); - void set_output_unbuffered(const bool is_unbuffered); - std::string get_debug_mask_help_string(); - std::string get_debug_verbosity_help_string(); -}; - -#ifndef ASSERT -#define ASSERT(COND) \ - if (!(COND)) \ - { \ - fprintf(stderr, COL_FATAL("ASSERTION AT %s:%d %s(): (%s)\n"), __FILE__, __LINE__, __func__, #COND); \ - func_print_backtrace(stderr); \ - assert(COND); \ - } -#endif - -#ifndef ASSERT_MSG -#define ASSERT_MSG(COND, fmt, ...) \ - if (!(COND)) \ - { \ - fprintf(stderr, COL_FATAL("ASSERTION AT %s:%d %s(): (%s)\n"), __FILE__, __LINE__, __func__, #COND); \ - fprintf(stderr, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \ - func_print_backtrace(stderr); \ - assert(COND); \ - } -#endif - -#ifndef REQUIRE -#define REQUIRE(COND, fmt, ...) \ - if (!(COND)) \ - { \ - fprintf(g_func_debug.func_debug_file, COL_FATAL("REQUIRE() fails AT %s:%d %s(): (%s)\n"), __FILE__, __LINE__, \ - __func__, #COND); \ - fprintf(g_func_debug.func_debug_file, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \ - this->parent_sgt->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE); \ - } -#endif - -#ifndef ERROR_IF -#define ERROR_IF(COND, fmt, ...) \ - if ((COND)) \ - { \ - if (this->parent_sgt->getGraphStatus() != GraphStatus::TOSA_UNPREDICTABLE) \ - { \ - this->parent_sgt->setGraphStatus(GraphStatus::TOSA_ERROR); \ - } \ - fprintf(g_func_debug.func_debug_file, COL_FATAL("ERROR_IF() fails AT %s:%d %s(): (%s)\n"), __FILE__, __LINE__, \ - __func__, #COND); \ - fprintf(g_func_debug.func_debug_file, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \ - this->dumpNode(g_func_debug.func_debug_file); \ - func_print_backtrace(g_func_debug.func_debug_file); \ - return 1; \ - } -#endif - -// Assertion specific to allocating memory -#ifndef ASSERT_MEM -#define ASSERT_MEM(OBJ) \ - if (!(OBJ)) \ - { \ - fprintf(stderr, COL_FATAL("ASSERTION AT %s:%d %s(): (" #OBJ "): out of memory\n"), __FILE__, __LINE__, \ - __func__); \ - func_print_backtrace(stderr); \ - assert(OBJ); \ - } -#endif - -#ifndef FATAL_ERROR -#define FATAL_ERROR(fmt, ...) \ - fprintf(stderr, COL_FATAL("FATAL ERROR AT %s:%d %s():\n"), __FILE__, __LINE__, __func__); \ - fprintf(stderr, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \ - func_print_backtrace(stderr); \ - abort(); -#endif - -void func_debug_warning( - func_debug_t* func_debug, const char* file, const char* func, const int line, const char* fmt, ...); -#ifndef WARNING -#define WARNING(...) func_debug_warning(&g_func_debug, __FILE__, __func__, __LINE__, __VA_ARGS__) -#endif - -#ifndef WARNING_STDERR -#define WARNING_STDERR(fmt, ...) \ - fprintf(stderr, COL_WARN("WARNING AT %s:%d %s():\n"), __FILE__, __LINE__, __func__); \ - fprintf(stderr, COL_WARN(fmt) "\n", ##__VA_ARGS__); -#endif - - - -// Is this debug verbosity and unit level enabled? -// Provide compiler hints that this is unlikely -// Two versions, depending on whether DEBUG_INSTANCE_EXPR is defined in a file or not -// -// For .cpp files whose units have discrete instance IDs, define DEBUG_INSTANCE_EXPR to evalute -// to the instance ID variable. The use of this define in header files is discouraged. - -#ifdef DEBUG_INSTANCE_EXPR -// Expression for whether the debugging verbosity + debugging unit is enabled for free-form printouts -#ifdef DEBUG_INSTANCE_EXPR_2 -#define DEBUG_ENABLED(VERB, LEVEL) \ - (__builtin_expect((g_func_debug.func_debug_mask == DEBUG_ALL || g_func_debug.func_debug_mask & (DEBUG_##LEVEL)) && \ - (g_func_debug.func_debug_inst_mask & (uint64_t(1) << (DEBUG_INSTANCE_EXPR))) && \ - (g_func_debug.func_debug_verbosity & (VERB)), \ - 0)) -// Debug printing macro -#define DEBUG(VERB, LEVEL, FMT, ...) \ - if (DEBUG_ENABLED(VERB, LEVEL)) \ - { \ - fprintf(g_func_debug.func_debug_file, "[%d:" #LEVEL "_%02d_%02d" WHERE "]: " FMT "\n", \ - (int)g_func_debug.inst_id, (int)(DEBUG_INSTANCE_EXPR), (int)(DEBUG_INSTANCE_EXPR_2), ##__VA_ARGS__); \ - } - -// Prints just the debugging prefix for properly marking free-form printouts -#define DEBUG_PREFIX(LEVEL) \ - fprintf(g_func_debug.func_debug_file, "[%d" #LEVEL "_%02d_%02d" WHERE "]: ", (int)g_func_debug.inst_id, \ - (int)(DEBUG_INSTANCE_EXPR), (int)(DEBUG_INSTANCE_EXPR_2)) - -#else // !DEBUG_INSTANCE_EXPR_2 - -#define DEBUG_ENABLED(VERB, LEVEL) \ - (__builtin_expect((g_func_debug.func_debug_mask == DEBUG_ALL || g_func_debug.func_debug_mask & (DEBUG_##LEVEL)) && \ - (g_func_debug.func_debug_inst_mask & (uint64_t(1) << (DEBUG_INSTANCE_EXPR))) && \ - (g_func_debug.func_debug_verbosity & (VERB)), \ - 0)) -// Debug printing macro -#define DEBUG(VERB, LEVEL, FMT, ...) \ - if (DEBUG_ENABLED(VERB, LEVEL)) \ - { \ - fprintf(g_func_debug.func_debug_file, "[%d:" #LEVEL "_%02d" WHERE "]: " FMT "\n", (int)g_func_debug.inst_id, \ - (int)(DEBUG_INSTANCE_EXPR), ##__VA_ARGS__); \ - } - -// Prints just the debugging prefix for properly marking free-form printouts -#define DEBUG_PREFIX(LEVEL) \ - fprintf(g_func_debug.func_debug_file, "[%d:" #LEVEL "_%02d" WHERE "]: ", (int)g_func_debug.inst_id, \ - (int)(DEBUG_INSTANCE_EXPR)) - -#endif // DEBUG_INSTANCE_EXPR_2 - -#else // !DEBUG_INSTANCE_EXPR - -// Expression for whether the debugging verbosity + debugging unit is enabled for free-form printouts -#define DEBUG_ENABLED(VERB, LEVEL) \ - (__builtin_expect((g_func_debug.func_debug_mask == DEBUG_ALL || g_func_debug.func_debug_mask & (DEBUG_##LEVEL)) && \ - (g_func_debug.func_debug_verbosity & (VERB)), \ - 0)) -// Debug printing macro -#define DEBUG(VERB, LEVEL, FMT, ...) \ - if (DEBUG_ENABLED(VERB, LEVEL)) \ - { \ - fprintf(g_func_debug.func_debug_file, "[%d:" #LEVEL WHERE "]: " FMT "\n", (int)g_func_debug.inst_id, \ - ##__VA_ARGS__); \ - } - -// Prints just the debugging prefix for properly marking free-form printouts -#define DEBUG_PREFIX(LEVEL) fprintf(g_func_debug.func_debug_file, "[" #LEVEL WHERE "]: ") - -#endif - -// Macros for different verbosity levels -#define DEBUG_INFO(LEVEL, FMT, ...) DEBUG(DEBUG_VERB_INFO, LEVEL, COL_INFO(FMT), ##__VA_ARGS__) -#define DEBUG_IFACE(LEVEL, FMT, ...) DEBUG(DEBUG_VERB_IFACE, LEVEL, COL_IFACE(FMT), ##__VA_ARGS__) -#define DEBUG_LOW(LEVEL, FMT, ...) DEBUG(DEBUG_VERB_LOW, LEVEL, COL_LOW(FMT), ##__VA_ARGS__) -#define DEBUG_MED(LEVEL, FMT, ...) DEBUG(DEBUG_VERB_MED, LEVEL, COL_MED(FMT), ##__VA_ARGS__) -#define DEBUG_HIGH(LEVEL, FMT, ...) DEBUG(DEBUG_VERB_HIGH, LEVEL, COL_HIGH(FMT), ##__VA_ARGS__) - -#endif diff --git a/reference_model/src/general_utils.h b/reference_model/src/general_utils.h new file mode 100644 index 0000000..12f831e --- /dev/null +++ b/reference_model/src/general_utils.h @@ -0,0 +1,68 @@ + +// Copyright (c) 2022, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GENERAL_UTILS_H_ +#define GENERAL_UTILS_H_ + +#include "func_debug.h" + +#include "numpy_utils.h" + +namespace TosaReference +{ + +const uint32_t getElementCount(std::vector<int32_t>& shape) +{ + uint32_t elements = 1; + for (size_t i = 0; i < shape.size(); i++) + { + elements *= shape[i]; + } + + return elements; +} + +template <typename T> +std::vector<T> readFromNpyFile(const char* filename, std::vector<int32_t>& shape) +{ + uint32_t elements = getElementCount(shape); + std::vector<T> data(elements, 0); + + NumpyUtilities::NPError nperror = NumpyUtilities::readFromNpyFile(filename, elements, data.data()); + + switch (nperror) + { + case NumpyUtilities::NO_ERROR: + break; + case NumpyUtilities::FILE_NOT_FOUND: + FATAL_ERROR("readFromNpyFile: Cannot open file %s", filename); + case NumpyUtilities::FILE_IO_ERROR: + FATAL_ERROR("readFromNpyFile: IO error reading file: %s", filename); + case NumpyUtilities::FILE_TYPE_MISMATCH: + FATAL_ERROR("readFromNpyFile: Tensor type and Numpy file type mismatch for filename %s", filename); + case NumpyUtilities::HEADER_PARSE_ERROR: + FATAL_ERROR("Numpy header parsing error for file: %s", filename); + case NumpyUtilities::BUFFER_SIZE_MISMATCH: + FATAL_ERROR("Buffer size does not match numpy file size for filename %s", filename); + default: + FATAL_ERROR("Unknown error parsing Numpy file: %s", filename); + } + + return data; +} + +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp index 643a351..776fbf3 100644 --- a/reference_model/src/main.cpp +++ b/reference_model/src/main.cpp @@ -13,31 +13,24 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include <stdio.h> +#include "model_runner.h" +#include "version.h" -#include "model_common.h" +#include "command_line_utils.h" #include "ops/op_factory.h" #include "subgraph_traverser.h" #include "tosa_serialization_handler.h" -#include <Eigen/CXX11/Tensor> -#include <iostream> #include <fstream> +#include <iostream> +#include <stdio.h> +#include <Eigen/CXX11/Tensor> #include <nlohmann/json.hpp> -#define MODEL_VERSION_MAJOR 0 -#define MODEL_VERSION_MINOR 41 -#define MODEL_VERSION_PATCH 0 -#define MODEL_VERSION_DRAFT true - using namespace TosaReference; using namespace tosa; using json = nlohmann::json; -// Global instantiation of configuration and debug objects -func_config_t g_func_config; -func_debug_t g_func_debug; - int initTestDesc(json& test_desc); int readInputTensors(SubgraphTraverser& gt, json test_desc); int writeFinalTensors(SubgraphTraverser& gt, json test_desc); @@ -45,7 +38,10 @@ int loadGraph(TosaSerializationHandler& tsh, json test_desc); int main(int argc, char** argv) { - TosaVersion model_version(MODEL_VERSION_MAJOR, MODEL_VERSION_MINOR, MODEL_VERSION_PATCH, MODEL_VERSION_DRAFT); + TosaVersion model_version(TOSA_REFERENCE_MODEL_VERSION_MAJOR, + TOSA_REFERENCE_MODEL_VERSION_MINOR, + TOSA_REFERENCE_MODEL_VERSION_PATCH, + TOSA_REFERENCE_MODEL_VERSION_DRAFT); // Initialize configuration and debug subsystems g_func_debug.init_debug(0); @@ -203,7 +199,6 @@ int loadGraph(TosaSerializationHandler& tsh, json test_desc) if (strlen(graph_fullname) <= 2) { - func_model_print_help(); FATAL_ERROR("Missing required argument: Check \"tosa_file\" in .json specified by -Ctosa_desc="); } diff --git a/reference_model/src/model_common.h b/reference_model/src/model_common.h deleted file mode 100644 index d6dab6d..0000000 --- a/reference_model/src/model_common.h +++ /dev/null @@ -1,28 +0,0 @@ - -// Copyright (c) 2020, ARM Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef MODEL_COMMON_H -#define MODEL_COMMON_H - -#include <iostream> -#include <stdio.h> - -#include "func_config.h" -#include "func_debug.h" - -extern func_config_t g_func_config; -extern func_debug_t g_func_debug; - -#endif diff --git a/reference_model/src/model_runner.cc b/reference_model/src/model_runner.cc new file mode 100644 index 0000000..2395a85 --- /dev/null +++ b/reference_model/src/model_runner.cc @@ -0,0 +1,76 @@ + +// Copyright (c) 2022, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "model_runner_impl.h" + +using namespace TosaReference; + +// Global instantiation of configuration and debug objects +func_config_t g_func_config; +func_debug_t g_func_debug; + +IModelRunner::IModelRunner() : model_runner_impl(new ModelRunnerImpl()) +{} + +IModelRunner::IModelRunner(const func_config_t& func_config, + const func_debug_t& func_debug) + : model_runner_impl(new ModelRunnerImpl(func_config, func_debug)) +{} + +IModelRunner::~IModelRunner() +{} + +void IModelRunner::setFuncConfig(func_config_t& func_config) +{ + model_runner_impl->setFuncConfig(func_config); +} + +void IModelRunner::setFuncDebug(func_debug_t& func_debug) +{ + model_runner_impl->setFuncDebug(func_debug); +} + +GraphStatus IModelRunner::initialize(tosa::TosaSerializationHandler& serialization_handler) +{ + return model_runner_impl->initialize(serialization_handler); +} + +GraphStatus IModelRunner::run() +{ + return model_runner_impl->run(); +} + +template <typename T> +int IModelRunner::setInput(std::string input_name, std::vector<T> vals) +{ + return model_runner_impl->setInput<T>(input_name, vals); +} + +template <typename T> +std::vector<T> IModelRunner::getOutput(std::string output_name) +{ + return model_runner_impl->getOutput<T>(output_name); +} + +// Template explicit specialization +template int IModelRunner::setInput<float>(std::string input_name, std::vector<float> vals); +template int IModelRunner::setInput<int32_t>(std::string input_name, std::vector<int32_t> vals); +template int IModelRunner::setInput<int64_t>(std::string input_name, std::vector<int64_t> vals); +template int IModelRunner::setInput<unsigned char>(std::string input_name, std::vector<unsigned char> vals); + +template std::vector<float> IModelRunner::getOutput<float>(std::string output_name); +template std::vector<int32_t> IModelRunner::getOutput<int32_t>(std::string output_name); +template std::vector<int64_t> IModelRunner::getOutput<int64_t>(std::string output_name); +template std::vector<unsigned char> IModelRunner::getOutput<unsigned char>(std::string output_name);
\ No newline at end of file diff --git a/reference_model/src/model_runner_impl.cc b/reference_model/src/model_runner_impl.cc new file mode 100644 index 0000000..e0fdc49 --- /dev/null +++ b/reference_model/src/model_runner_impl.cc @@ -0,0 +1,277 @@ + +// Copyright (c) 2022, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "model_runner_impl.h" + +using namespace TosaReference; + +ModelRunnerImpl::ModelRunnerImpl() +{} + +ModelRunnerImpl::ModelRunnerImpl(const func_config_t& func_config, + const func_debug_t& func_debug) +{ + g_func_config = func_config; + g_func_debug = func_debug; +} + +ModelRunnerImpl::~ModelRunnerImpl() +{ + g_func_debug.fini_debug(); + delete _main_gt; +}; + +void ModelRunnerImpl::setFuncConfig(func_config_t& func_config) +{ + g_func_config = func_config; +} +void ModelRunnerImpl::setFuncDebug(func_debug_t& func_debug) +{ + g_func_debug = func_debug; +} + +GraphStatus ModelRunnerImpl::initialize(TosaSerializationHandler& serialization_handler) +{ + validateTosaVersion(serialization_handler); + + // Make nullptr in case ModelRunnerImpl is being initialized again with a different graph. + _main_gt = nullptr; + _main_gt = new SubgraphTraverser(serialization_handler.GetMainBlock(), &serialization_handler); + + if(_main_gt == nullptr) + { + WARNING("An error occurred when generating main graph traverser."); + return GraphStatus::TOSA_ERROR; + } + + if (_main_gt->initializeGraph()) + { + WARNING("Unable to initialize main graph traverser."); + return _main_gt->getGraphStatus(); + } + + if (_main_gt->linkTensorsAndNodes()) + { + WARNING("Failed to link tensors and nodes"); + return _main_gt->getGraphStatus(); + } + + if (_main_gt->validateGraph()) + { + WARNING("Failed to validate graph."); + return _main_gt->getGraphStatus(); + } + + if (_main_gt->allocateTensor()) + { + WARNING("Failed to allocate tensor."); + return _main_gt->getGraphStatus(); + } + + return _main_gt->getGraphStatus(); +} + +GraphStatus ModelRunnerImpl::run() +{ + if (_main_gt == nullptr) + { + FATAL_ERROR("ModelRunnerImpl hasn't been initialized, please invoke initialize() before run()"); + } + + if (g_func_config.validate_only) + { + goto done; + } + + // Validate the number of inputs matches the + if (static_cast<uint32_t>(_main_gt->getNumInputTensors()) != n_input_tensors) + { + FATAL_ERROR("The number of inputs (%d) does not equal the number of inputs in the model (%d). " + "setInput() must be called for each input.", + n_input_tensors, _main_gt->getNumInputTensors()); + } + + if (g_func_config.eval) + { + // evaluateAll() returns 1 if graph evaluation is forced to be terminated earlier. + if (_main_gt->evaluateAll()) + { + ASSERT_MSG(_main_gt->getGraphStatus() != GraphStatus::TOSA_VALID, + "Upon evaluateAll() returning 1, graph can not be VALID."); + } + else + { + ASSERT_MSG(_main_gt->getGraphStatus() == GraphStatus::TOSA_VALID || + _main_gt->getGraphStatus() == GraphStatus::TOSA_UNPREDICTABLE, + "Upon evaluateAll() returning 0, graph can only be VALID/UNPREDICTABLE."); + } + + // Only generate output tensor if graph is valid. + if (_main_gt->getGraphStatus() == GraphStatus::TOSA_VALID) + { + // Make sure output tensor is evaluated and show its value + int num_output_tensors = _main_gt->getNumOutputTensors(); + bool all_output_valid = true; + for (int i = 0; i < num_output_tensors; i++) + { + const Tensor* ct = _main_gt->getOutputTensor(i); + ASSERT_MEM(ct); + if (!ct->getIsValid()) + { + ct->dumpTensorParams(g_func_debug.func_debug_file); + if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT)) + { + ct->dumpTensor(g_func_debug.func_debug_file); + } + all_output_valid = false; + } + } + if (!all_output_valid) + { + _main_gt->dumpGraph(g_func_debug.func_debug_file); + FATAL_ERROR( + "SubgraphTraverser \"main\" error: Output tensors are not all valid at the end of evaluation."); + } + } + } + +done: + // Print status if not valid and do cleanup. + checkGraphStatus(*_main_gt); + g_func_debug.fini_debug(); + + return _main_gt->getGraphStatus(); +} + +template <typename T> +int ModelRunnerImpl::setInput(std::string input_name, std::vector<T> vals) +{ + if (_main_gt == nullptr) + { + FATAL_ERROR("ModelRunner hasn't been initialized, please invoke initialize() before setInput()"); + } + + Tensor* tensor; + tensor = _main_gt->getInputTensorByName(input_name); + + if (!tensor) + { + WARNING("Unable to find input tensor %s", input_name.c_str()); + return 1; + } + + if (!tensor->is_allocated()) + { + WARNING("Tensor %s is not allocated before being initialized", tensor->getName().c_str()); + return 1; + } + + if (tensor->readfromVector(vals)) + { + WARNING("Unable to convert input tensor %s to Tensor", tensor->getName().c_str()); + return 1; + } + + // Push ready consumers to the next node list + for (auto gn : tensor->getConsumers()) + { + if (gn->hasAllInputsReady() && !gn->getOnNextNodeList()) + { + _main_gt->addToNextNodeList(gn); + } + } + + n_input_tensors++; + return 0; +} + +template <typename T> +std::vector<T> ModelRunnerImpl::getOutput(std::string output_name) +{ + if (_main_gt == nullptr) + { + FATAL_ERROR("ModelRunner hasn't been initialized, please invoke initialize() and run() before getOutput()"); + } + + Tensor* tensor; + tensor = _main_gt->getOutputTensorByName(output_name); + + if (!tensor) + { + WARNING("Unable to find output tensor %s", output_name.c_str()); + return std::vector<T>(); + } + + std::vector<T> outputs(tensor->getElementCount(), 0); + + if (tensor->writeToVector(outputs)) + { + WARNING("Unable to convert output tensor %s to vector", tensor->getName().c_str()); + return std::vector<T>(); + } + + return outputs; +} + +void ModelRunnerImpl::validateTosaVersion(TosaSerializationHandler& serialization_handler) +{ + TosaVersion model_version(TOSA_REFERENCE_MODEL_VERSION_MAJOR, + TOSA_REFERENCE_MODEL_VERSION_MINOR, + TOSA_REFERENCE_MODEL_VERSION_PATCH, + TOSA_REFERENCE_MODEL_VERSION_DRAFT); + + TosaVersion::compat_t is_compat = model_version.is_compatible(serialization_handler.GetVersion()); + switch (is_compat) + { + case TosaVersion::compat_t::COMPLETELY_COMPATIBLE: + break; + case TosaVersion::compat_t::PARTIALLY_COMPATIBLE: + WARNING("Reference model version %s is partially compatible with serializer version %s.", + model_version.to_string().c_str(), serialization_handler.GetVersion().to_string().c_str()); + break; + case TosaVersion::compat_t::NOT_COMPATIBLE: + FATAL_ERROR("Reference model version %s is not compatible with serializer version %s.", + model_version.to_string().c_str(), serialization_handler.GetVersion().to_string().c_str()); + } +} + +void ModelRunnerImpl::checkGraphStatus(SubgraphTraverser& main_gt) +{ + switch (main_gt.getGraphStatus()) + { + case GraphStatus::TOSA_VALID: + // Result is valid. + break; + case GraphStatus::TOSA_UNPREDICTABLE: + WARNING("Graph result: UNPREDICTABLE."); + break; + case GraphStatus::TOSA_ERROR: + WARNING("Graph result: ERROR."); + break; + default: + WARNING("Unknown graph status code=%d.", (int)main_gt.getGraphStatus()); + } +} + +// Template explicit specialization +template int ModelRunnerImpl::setInput<float>(std::string input_name, std::vector<float> vals); +template int ModelRunnerImpl::setInput<int32_t>(std::string input_name, std::vector<int32_t> vals); +template int ModelRunnerImpl::setInput<int64_t>(std::string input_name, std::vector<int64_t> vals); +template int ModelRunnerImpl::setInput<unsigned char>(std::string input_name, std::vector<unsigned char> vals); + +template std::vector<float> ModelRunnerImpl::getOutput<float>(std::string output_name); +template std::vector<int32_t> ModelRunnerImpl::getOutput<int32_t>(std::string output_name); +template std::vector<int64_t> ModelRunnerImpl::getOutput<int64_t>(std::string output_name); +template std::vector<unsigned char> ModelRunnerImpl::getOutput<unsigned char>(std::string output_name);
\ No newline at end of file diff --git a/reference_model/src/model_runner_impl.h b/reference_model/src/model_runner_impl.h new file mode 100644 index 0000000..7a91bfe --- /dev/null +++ b/reference_model/src/model_runner_impl.h @@ -0,0 +1,66 @@ + +// Copyright (c) 2022, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MODEL_RUNNER_IMPL_H_ +#define MODEL_RUNNER_IMPL_H_ + +#include "model_runner.h" +#include "graph_status.h" +#include "version.h" + +#include "ops/op_factory.h" +#include "subgraph_traverser.h" +#include "tosa_serialization_handler.h" + +namespace TosaReference +{ + +/* + * This is a private implementation of the IModelRunner class. + * See documented IModelRunner for usage. + */ +class ModelRunnerImpl +{ +public: + ModelRunnerImpl(); + ModelRunnerImpl(const func_config_t& func_config, const func_debug_t& func_debug); + + ~ModelRunnerImpl(); + + void setFuncConfig(func_config_t& func_config); + void setFuncDebug(func_debug_t& func_debug); + + GraphStatus initialize(TosaSerializationHandler& serialization_handler); + GraphStatus run(); + + template <typename T> + int setInput(std::string input_name, std::vector<T> vals); + + template <typename T> + std::vector<T> getOutput(std::string output_name); + +private: + SubgraphTraverser* _main_gt = nullptr; + + // Used to determine if all input tensors have been set correctly. + uint32_t n_input_tensors = 0; + + void validateTosaVersion(TosaSerializationHandler& serialization_handler); + void checkGraphStatus(SubgraphTraverser& main_gt); +}; + +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/subgraph_traverser.h b/reference_model/src/subgraph_traverser.h index 8c66d73..7940ee4 100644 --- a/reference_model/src/subgraph_traverser.h +++ b/reference_model/src/subgraph_traverser.h @@ -16,6 +16,7 @@ #ifndef SUBGRAPH_TRAVERSER_H #define SUBGRAPH_TRAVERSER_H +#include "graph_status.h" #include "graph_node.h" #include "model_common.h" #include "ops/op_factory.h" @@ -26,13 +27,6 @@ namespace TosaReference { -enum class GraphStatus : int -{ - TOSA_VALID = 0, - TOSA_UNPREDICTABLE = 1, - TOSA_ERROR = 2, -}; - class SubgraphTraverser { public: diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc index 90aee05..7cbeb13 100644 --- a/reference_model/src/tensor.cc +++ b/reference_model/src/tensor.cc @@ -382,6 +382,209 @@ DEF_CTENSOR_COPY_VALUE_FROM(6, bool) #undef DEF_CTENSOR_COPY_VALUE_FROM +int TosaReference::Tensor::readfromVector(const std::vector<float>& vals) +{ + uint32_t elements = getElementCount(); + switch (getDtype()) + { + case DType_FLOAT: + if (vals.size() != elements) + { + WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", + vals.size(), elements); + return -1; + } + + setTensorValueFloat(elements, vals.data()); + break; + default: + WARNING("The input type (float) doesn't match the data type assigned to the tensor (%s).", + EnumNameDType(getDtype())); + return -2; + } + setIsValid(); + return 0; +} + +int TosaReference::Tensor::readfromVector(const std::vector<int32_t>& vals) +{ + uint32_t elements = getElementCount(); + switch (getDtype()) + { + case DType_INT32: + case DType_UINT8: + case DType_INT4: + case DType_INT8: + case DType_INT16: + case DType_UINT16: + if (vals.size() != elements) + { + WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", + vals.size(), elements); + return -1; + } + + setTensorValueInt32(elements, vals.data()); + break; + default: + WARNING("The input type doesn't match the data type assigned to the tensor (%s).", + EnumNameDType(getDtype())); + return -2; + } + setIsValid(); + return 0; +} + +int TosaReference::Tensor::readfromVector(const std::vector<int64_t>& vals) +{ + uint32_t elements = getElementCount(); + switch (getDtype()) + { + case DType_INT48: + if (vals.size() != elements) + { + WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", + vals.size(), elements); + return -1; + } + + setTensorValueInt64(elements, vals.data()); + break; + default: + WARNING("The input type doesn't match the data type assigned to the tensor (%s).", + EnumNameDType(getDtype())); + return -2; + } + setIsValid(); + return 0; +} + +int TosaReference::Tensor::readfromVector(const std::vector<unsigned char>& vals) +{ + uint32_t elements = getElementCount(); + + switch (getDtype()) + { + case DType_BOOL: + if (vals.size() != elements) + { + WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", + vals.size(), elements); + return -1; + } + + setTensorValueBool(elements, reinterpret_cast<const bool*>(vals.data())); + break; + default: + WARNING("The input type (bool) doesn't match the data type assigned to the tensor (%s).", + EnumNameDType(getDtype())); + return -2; + } + setIsValid(); + return 0; +} + +int TosaReference::Tensor::writeToVector(std::vector<float>& vals) +{ + uint32_t elements = getElementCount(); + + switch (getDtype()) + { + case DType_FLOAT: + if (vals.size() != elements) + { + WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", + vals.size(), elements); + return -1; + } + + getTensorValueFloat(elements, vals.data()); + break; + default: + WARNING("The output type (float) doesn't match the data type assigned to the tensor (%s).", + EnumNameDType(getDtype())); + return -2; + } + return 0; +} + +int TosaReference::Tensor::writeToVector(std::vector<int32_t>& vals) +{ + uint32_t elements = getElementCount(); + + switch (getDtype()) + { + case DType_INT32: + case DType_UINT8: + case DType_INT4: + case DType_INT8: + case DType_INT16: + case DType_UINT16: + if (vals.size() != elements) + { + WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", + vals.size(), elements); + return -1; + } + + getTensorValueInt32(elements, vals.data()); + break; + default: + WARNING("The output type doesn't match the data type assigned to the tensor (%s).", + EnumNameDType(getDtype())); + return -2; + } + return 0; +} + +int TosaReference::Tensor::writeToVector(std::vector<int64_t>& vals) +{ + uint32_t elements = getElementCount(); + + switch (getDtype()) + { + case tosa::DType_INT48: + if (vals.size() != elements) + { + WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", + vals.size(), elements); + return -1; + } + + getTensorValueInt64(elements, vals.data()); + break; + default: + WARNING("The output type doesn't match the data type assigned to the tensor (%s).", + EnumNameDType(getDtype())); + return -2; + } + return 0; +} + +int TosaReference::Tensor::writeToVector(std::vector<unsigned char>& vals) +{ + uint32_t elements = getElementCount(); + + switch (getDtype()) + { + case tosa::DType_BOOL: + if (vals.size() != elements) + { + WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", + vals.size(), elements); + return -1; + } + + getTensorValueBool(elements, reinterpret_cast<bool*>(vals.data())); + break; + default: + WARNING("The output type (bool) doesn't match the data type assigned to the tensor (%s).", + EnumNameDType(getDtype())); + return -2; + } + return 0; +} + template <class T> int TosaReference::TensorTemplate<T>::setTensorValueFloat(const size_t buflen, const float* vals) { diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h index ede42a9..6b7d5f1 100644 --- a/reference_model/src/tensor.h +++ b/reference_model/src/tensor.h @@ -228,6 +228,16 @@ public: virtual int writeToNpyFile(const char* filename) const; virtual int copyValueFrom(Tensor* tensor) = 0; + virtual int readfromVector(const std::vector<float>& vals); + virtual int readfromVector(const std::vector<int32_t>& vals); + virtual int readfromVector(const std::vector<int64_t>& vals); + virtual int readfromVector(const std::vector<unsigned char>& vals); + + virtual int writeToVector(std::vector<float>& vals); + virtual int writeToVector(std::vector<int32_t>& vals); + virtual int writeToVector(std::vector<int64_t>& vals); + virtual int writeToVector(std::vector<unsigned char>& vals); + const char* bool_to_str(bool in) const { static const char* true_str = "true"; |