From ba5fad356a926d5e1c6e0fe6b546a310230cc5a8 Mon Sep 17 00:00:00 2001 From: Matthew Sloyan Date: Mon, 26 Sep 2022 13:31:43 +0100 Subject: Add IModelRunner interface to TOSA Reference Model * Added IModelRunner interface using pimpl idiom, which allows a user to initialize, configure and run the model. * Added unit tests for IModelRunner. * Added doctest as third-party submodule. * Added user options to specify paths for dependencies. * Moved general func_config functions to separate utility, which removes cxxopts dependency. Signed-off-by: Matthew Sloyan Change-Id: If42f1f82cd6dadf18911a48dcd5fa579b719aff2 --- .gitmodules | 3 + CMakeLists.txt | 11 +- README.md | 35 ++- reference_model/CMakeLists.txt | 221 +++++++++++++--- reference_model/include/debug_modes.def | 20 ++ reference_model/include/debug_types.h | 57 +++++ reference_model/include/func_config.h | 41 +++ reference_model/include/func_debug.h | 245 ++++++++++++++++++ reference_model/include/graph_status.h | 25 ++ reference_model/include/model_common.h | 28 +++ reference_model/include/model_runner.h | 87 +++++++ reference_model/include/version.h | 23 ++ .../samples/model_runner_simple_sample.cpp | 97 ++++++++ reference_model/src/command_line_utils.h | 101 ++++++++ reference_model/src/debug_modes.def | 20 -- reference_model/src/debug_types.h | 57 ----- reference_model/src/func_config.cc | 108 -------- reference_model/src/func_config.h | 46 ---- reference_model/src/func_debug.h | 244 ------------------ reference_model/src/general_utils.h | 68 +++++ reference_model/src/main.cpp | 25 +- reference_model/src/model_common.h | 28 --- reference_model/src/model_runner.cc | 76 ++++++ reference_model/src/model_runner_impl.cc | 277 +++++++++++++++++++++ reference_model/src/model_runner_impl.h | 66 +++++ reference_model/src/subgraph_traverser.h | 8 +- reference_model/src/tensor.cc | 203 +++++++++++++++ reference_model/src/tensor.h | 10 + reference_model/test/model_runner_tests.cpp | 154 ++++++++++++ thirdparty/CMakeLists.txt | 11 +- thirdparty/doctest | 1 + 31 files changed, 1826 insertions(+), 570 deletions(-) create mode 100644 reference_model/include/debug_modes.def create mode 100644 reference_model/include/debug_types.h create mode 100644 reference_model/include/func_config.h create mode 100644 reference_model/include/func_debug.h create mode 100644 reference_model/include/graph_status.h create mode 100644 reference_model/include/model_common.h create mode 100644 reference_model/include/model_runner.h create mode 100644 reference_model/include/version.h create mode 100644 reference_model/samples/model_runner_simple_sample.cpp create mode 100644 reference_model/src/command_line_utils.h delete mode 100644 reference_model/src/debug_modes.def delete mode 100644 reference_model/src/debug_types.h delete mode 100644 reference_model/src/func_config.cc delete mode 100644 reference_model/src/func_config.h delete mode 100644 reference_model/src/func_debug.h create mode 100644 reference_model/src/general_utils.h delete mode 100644 reference_model/src/model_common.h create mode 100644 reference_model/src/model_runner.cc create mode 100644 reference_model/src/model_runner_impl.cc create mode 100644 reference_model/src/model_runner_impl.h create mode 100644 reference_model/test/model_runner_tests.cpp create mode 160000 thirdparty/doctest diff --git a/.gitmodules b/.gitmodules index e06268d..87ce1ef 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,3 +7,6 @@ [submodule "thirdparty/json"] path = thirdparty/json url = https://github.com/nlohmann/json.git +[submodule "thirdparty/doctest"] + path = thirdparty/doctest + url = https://github.com/doctest/doctest.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 04141aa..d3281c6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,7 +3,16 @@ cmake_minimum_required (VERSION 3.4) set(CMAKE_INSTALL_PREFIX ".") project(tosa_tools LANGUAGES CXX) -option(TOSA_TOOLS_BUILD_REFERENCE_MODEL "Enable building of Tosa Reference Model" ON) +option(TOSA_TOOLS_BUILD_REFERENCE_MODEL "Enable building of TOSA Reference Model" ON) +option(BUILD_TOSA_REFERENCE_MODEL_EXECUTABLE "Enable building of TOSA Reference Model executable" ON) +option(BUILD_TOSA_REFERENCE_MODEL_TESTS "Enable building of TOSA Reference Model unit tests" ON) +option(BUILD_MODEL_RUNNER_SAMPLE "Enable building of ModelRunner sample executable" OFF) + +# Custom path options for third party dependencies +option(SERIALIZATION_DIR "Location where the TOSA Serialization Library 'include' folder is found" Off) +option(FLATBUFFERS_DIR "Location where the FlatBuffers 'include' and 'lib' folders is found" Off) +option(EIGEN_DIR "Location where the Eigen folder is found" Off) +option(DOCTEST_DIR "Location where the doctest folder is found (If building unit tests)" Off) add_subdirectory(thirdparty) diff --git a/README.md b/README.md index 88aeaf2..0e97d10 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,7 @@ The model includes the following git submodules: * TOSA Serialization Library * JSON for Modern C++ - 3.8.0 * Eigen 3.3.7 +* doctest 2.4.9 (When building unit tests) The model is written using C++17 and has been primarily tested on Ubuntu x86_64 18.04 LTS Linux @@ -84,12 +85,18 @@ make ``` The resulting executable will be named: -`reference_model/tosa_reference_model`. CMake only needs to be re-run -if the build environment changes (e.g., new dependencies or source -files). Code changes that do not affect these build rules can be +`reference_model/tosa_reference_model`. This executable can be disabled with +`-DBUILD_TOSA_REFERENCE_MODEL_EXECUTABLE=NO`. +A static library will also be generated by default as follows: +`reference_model/libtosa_reference_model_lib.a`. +To make this a shared library (.so) add the following option +`-DBUILD_SHARED_LIBS=YES`. + +CMake only needs to be re-run if the build environment changes (e.g., new dependencies or source +files). Code changes that do not affect these build rules can be rebuilt simply using `make`. -## Usage +## Executable Usage The inputs to the *TOSA Reference Model* consist of a FlatBuffers file containing the serialized subgraph, a JSON test descriptor that describes @@ -145,7 +152,7 @@ format into output tensors specified by "ofm_file". For example, you can generate new output .npy by: ``` bash ./build/reference_model/tosa_reference_model \ - --test_desc=examples/test_add_1x4x4x4_f32/flatbuffer-tflite/desc.json + --test_desc=examples/test_add_1x4x4x4_f32/flatbuffer-tflite/desc.json \ --ofm_file=out.npy ``` @@ -170,6 +177,22 @@ may cause small differences in output for floating-point tests and differences in quantized scaling between TensorFlow Lite and the TOSA Specification may cause differences in quantized integer tests. +## ModelRunner API + +As an alternative to the executable described above, +the model_runner.h is provided which can be used to invoke the +TOSA Reference Model easily within C++. +A sample of this class and how it can used can be found in +[model_runner_simple_sample.cpp](reference_model/samples/model_runner_simple_sample.cpp). +This sample can be compiled by adding `-DBUILD_MODEL_RUNNER_SAMPLE=YES` to the CMake command +and executed by running `./build/reference_model/model_runner_sample`. + +### ModelRunner API Unit Tests +Unit test are generated by default for the ModelRunner. +This executable can be disabled by adding`-DBUILD_TOSA_REFERENCE_MODEL_TESTS=NO` to the CMake command. +This executable can be run using +`./build/reference_model/unit_tests` and requires the submodule doctest. + ## Debugging The debugging facility can be enabled by setting a debug scope and @@ -219,7 +242,7 @@ programatically compared with output of SUTs to validate them. The test infrastructure needs installing before being used. It is recommended to create a [python virtual environment](https://docs.python.org/3/library/venv.html) -and then install the TOSA Unit Test infrastruture from the root of the +and then install the TOSA Unit Test infrastructure from the root of the reference model: ``` bash diff --git a/reference_model/CMakeLists.txt b/reference_model/CMakeLists.txt index 6fdaa1c..a790968 100644 --- a/reference_model/CMakeLists.txt +++ b/reference_model/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required (VERSION 3.4) -# Copyright (c) 2020, ARM Limited. +# Copyright (c) 2020-2022, ARM Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,6 @@ cmake_minimum_required (VERSION 3.4) # See the License for the specific language governing permissions and # limitations under the License. - project(tosa_reference_model LANGUAGES CXX) set(CMAKE_CXX_STANDARD 17) @@ -26,52 +25,202 @@ else() set(CMAKE_CXX_FLAGS "-Wall -Wno-ignored-attributes") endif() -set(FLATBUFFERS_DIR "../thirdparty/serialization_lib/third_party/flatbuffers/") -set(SERIALIZATION_DIR "../thirdparty/serialization_lib/") - -set (CXX_SOURCE - src/main.cpp - src/tensor.cc - src/graph_node.cc - src/subgraph_traverser.cc - src/func_debug.cc - src/func_config.cc - src/ops/op_factory.cc - src/ops/tensor_ops.cc - src/ops/activation_funcs.cc - src/ops/ewise_binary.cc - src/ops/ewise_unary.cc - src/ops/ewise_ternary.cc - src/ops/comparison.cc - src/ops/reduction.cc - src/ops/data_layout.cc - src/ops/scatter_gather.cc - src/ops/image.cc - src/ops/type_conversion.cc - src/ops/data_nodes.cc - src/ops/custom.cc - src/ops/control_flow.cc +# If Serialization Library path is specified, look for library so it doesn't have to be built again. +# Otherwise, set the Serialization Library related paths to thirdparty directory. +if(SERIALIZATION_DIR) + find_library(SERIALIZATION_LIB + NAMES libtosa_serialization_lib.a tosa_serialization_lib + NO_DEFAULT_PATH + HINTS ${SERIALIZATION_DIR} + PATH_SUFFIXES lib) + + if(NOT SERIALIZATION_LIB) + message(FATAL_ERROR "TOSA Serialization Library location was specified but not found at: ${SERIALIZATION_LIB_DIR}") + endif() +else() + # Build from third party directory if not found. + set(SERIALIZATION_LIB tosa_serialization_lib) + set(SERIALIZATION_DIR "../thirdparty/serialization_lib/") +endif() + +# If Flatbuffers or Eigen path isn't specified, set to thirdparty directory. +if(NOT FLATBUFFERS_DIR) + set(FLATBUFFERS_DIR "../thirdparty/serialization_lib/third_party/flatbuffers/") +endif() + +if(NOT EIGEN_DIR) + set(EIGEN_DIR "../thirdparty/eigen/") +endif() + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) + +# Common sources required for TOSA Reference Model library, executable and unit tests +set(CXX_SOURCE + src/model_runner.cc + src/model_runner_impl.cc + src/tensor.cc + src/graph_node.cc + src/subgraph_traverser.cc + src/func_debug.cc + src/ops/op_factory.cc + src/ops/tensor_ops.cc + src/ops/activation_funcs.cc + src/ops/ewise_binary.cc + src/ops/ewise_unary.cc + src/ops/ewise_ternary.cc + src/ops/comparison.cc + src/ops/reduction.cc + src/ops/data_layout.cc + src/ops/scatter_gather.cc + src/ops/image.cc + src/ops/type_conversion.cc + src/ops/data_nodes.cc + src/ops/custom.cc + src/ops/control_flow.cc ) -add_executable(tosa_reference_model ${CXX_SOURCE}) +# Build TOSA Reference Model library +add_library(tosa_reference_model_lib ${CXX_SOURCE}) -target_include_directories(tosa_reference_model +target_include_directories(tosa_reference_model_lib PUBLIC $ $ PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src ${FLATBUFFERS_DIR}/include - ../thirdparty/eigen/ - ../thirdparty/eigen/unsupported/ + ${EIGEN_DIR} + ${EIGEN_DIR}/unsupported/ ${SERIALIZATION_DIR}/include ) -target_link_libraries(tosa_reference_model +target_link_libraries(tosa_reference_model_lib PRIVATE - tosa_serialization_lib - nlohmann_json::nlohmann_json - cxxopts + ${SERIALIZATION_LIB} +) + +set(PUBLIC_HEADERS) +list(APPEND PUBLIC_HEADERS + include/debug_modes.def + include/debug_types.h + include/func_config.h + include/func_debug.h + include/graph_status.h + include/model_common.h + include/model_runner.h + include/version.h +) + +set_target_properties(tosa_reference_model_lib PROPERTIES PUBLIC_HEADER "${PUBLIC_HEADERS}") + +# Build TOSA Refererence Model executable +if(BUILD_TOSA_REFERENCE_MODEL_EXECUTABLE) + set(CXX_SOURCE_EX src/main.cpp) + list(APPEND CXX_SOURCE_EX ${CXX_SOURCE}) + + add_executable(tosa_reference_model ${CXX_SOURCE_EX}) + + target_include_directories(tosa_reference_model + PUBLIC + $ + $ + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/src + ${FLATBUFFERS_DIR}/include + ${EIGEN_DIR} + ${EIGEN_DIR}/unsupported/ + ${SERIALIZATION_DIR}/include + ) + + target_link_libraries(tosa_reference_model + PRIVATE + ${SERIALIZATION_LIB} + nlohmann_json::nlohmann_json + cxxopts + ) + + install(TARGETS tosa_reference_model DESTINATION bin) +endif() + +if(BUILD_TOSA_REFERENCE_MODEL_TESTS) + # Set definition so unit tests can find examples directory. + add_definitions(-DPROJECT_ROOT=\"${CMAKE_CURRENT_SOURCE_DIR}/\") + + # Set doctest location if not specified. + if(NOT DOCTEST_DIR) + set(DOCTEST_DIR "../thirdparty/doctest/doctest") + endif() + + # Sources only required for unit tests. + set(CXX_SOURCE_TESTS + test/model_runner_tests.cpp + ${DOCTEST_DIR}/doctest.h + ) + + list(APPEND CXX_SOURCE_TESTS ${CXX_SOURCE}) + + add_executable(unit_tests ${CXX_SOURCE_TESTS}) + + target_include_directories(unit_tests + PUBLIC + $ + $ + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/src + ${FLATBUFFERS_DIR}/include + ${EIGEN_DIR} + ${EIGEN_DIR}/unsupported/ + ${SERIALIZATION_DIR}/include + ${DOCTEST_DIR} + ) + + target_link_libraries(unit_tests + PRIVATE + ${SERIALIZATION_LIB} + ) +endif() + +if(BUILD_MODEL_RUNNER_SAMPLE) + # Set definition so sample executable can find examples directory. + add_definitions(-DPROJECT_ROOT=\"${CMAKE_CURRENT_SOURCE_DIR}/\") + + # Sources only required for example executable. + set(CXX_SOURCE_SAMPLE + samples/model_runner_simple_sample.cpp + ) + + list(APPEND CXX_SOURCE_SAMPLE ${CXX_SOURCE}) + + add_executable(model_runner_sample ${CXX_SOURCE_SAMPLE}) + + target_include_directories(model_runner_sample + PUBLIC + $ + $ + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/src + ${FLATBUFFERS_DIR}/include + ${EIGEN_DIR} + ${EIGEN_DIR}/unsupported/ + ${SERIALIZATION_DIR}/include + ) + + target_link_libraries(model_runner_sample + PRIVATE + ${SERIALIZATION_LIB} + ) +endif() + +# Follow GNU packaging norms for installation directory structure. +include(GNUInstallDirs) +install( + TARGETS tosa_reference_model_lib EXPORT TosaReferenceModelLibTargets + PUBLIC_HEADER + ARCHIVE ) -install (TARGETS tosa_reference_model DESTINATION bin) +install(EXPORT TosaReferenceModelLibTargets + FILE TosaReferenceModelLibTargets.cmake + NAMESPACE TosaReference:: + DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/tosa_reference_model_lib" +) \ No newline at end of file diff --git a/reference_model/include/debug_modes.def b/reference_model/include/debug_modes.def new file mode 100644 index 0000000..51b151d --- /dev/null +++ b/reference_model/include/debug_modes.def @@ -0,0 +1,20 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Defines the debugging printing modes + +DEBUG_MODE(CONFIG,0) // Configuration parsing/initialization +DEBUG_MODE(GT,1) // Graph traverser +DEBUG_MODE(OP,2) // Operation diff --git a/reference_model/include/debug_types.h b/reference_model/include/debug_types.h new file mode 100644 index 0000000..bd93f19 --- /dev/null +++ b/reference_model/include/debug_types.h @@ -0,0 +1,57 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +/* + * Filename: src/debug_types.h + * Description: + * Defines fundamental debugger datatypes for the functional model + */ + +#ifndef DEBUG_TYPES_H_ +#define DEBUG_TYPES_H_ + +#ifdef __cplusplus +extern "C" +{ +#endif + + // Debug verbosity mask + typedef enum func_debug_verbosity_e + { + DEBUG_VERB_NONE = 0x00, + DEBUG_VERB_INFO = 0x01, // Informational debugging messages + DEBUG_VERB_IFACE = 0x02, // Interface debugging support + DEBUG_VERB_LOW = 0x04, // Low, medium, and high levels of debug printout + DEBUG_VERB_MED = 0x08, + DEBUG_VERB_HIGH = 0x10 + } func_debug_verbosity_e; + + // Generated debug modes enumeration + typedef enum func_debug_mode_e + { + DEBUG_NONE = 0x0, +#define DEBUG_MODE(NAME, BIT) DEBUG_##NAME = (1UL << BIT), +#include "debug_modes.def" +#undef DEBUG_MODE + DEBUG_ALL = 0xffffffffffffffffUL + } func_debug_mode_e; + +#define DEBUG_INST_ALL 0xffffffffffffffffUL + +#ifdef __cplusplus +} +#endif + +#endif diff --git a/reference_model/include/func_config.h b/reference_model/include/func_config.h new file mode 100644 index 0000000..41df135 --- /dev/null +++ b/reference_model/include/func_config.h @@ -0,0 +1,41 @@ + +// Copyright (c) 2020-2022, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef FUNC_CONFIG_H_ +#define FUNC_CONFIG_H_ + +#include +#include + +struct func_config_t +{ + std::string operator_fbs = "tosa.fbs"; + std::string test_desc = "desc.json"; + std::string flatbuffer_dir = ""; + std::string output_dir = ""; + std::string tosa_file = ""; + std::string ifm_name = ""; + std::string ifm_file = ""; + std::string ofm_name = ""; + std::string ofm_file = ""; + uint32_t eval = 1; + uint32_t validate_only = 0; + uint32_t output_tensors = 1; + uint32_t tosa_profile = 1; + uint32_t dump_intermediates = 0; + std::string fp_format = "0.5"; +}; + +#endif diff --git a/reference_model/include/func_debug.h b/reference_model/include/func_debug.h new file mode 100644 index 0000000..d762026 --- /dev/null +++ b/reference_model/include/func_debug.h @@ -0,0 +1,245 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef FUNC_DEBUG_H +#define FUNC_DEBUG_H + +#include "debug_types.h" +#include +#include +#include +#include +#include +#include + +void func_print_backtrace(FILE* out, int sig = SIGABRT); + +void func_enable_signal_handlers(); + +// STRINGIFY2 is needed expand expression passed to STRINGIFY +#define STRINGIFY2(s) #s +#define STRINGIFY(s) STRINGIFY2(s) + +// If TRACED_LOG is defined, add file:line to log messages +#if defined(TRACED_LOG) +#define WHERE "@" __FILE__ ":" STRINGIFY(__LINE__) +#else +#define WHERE +#endif + +#if defined(COLORIZED_LOG) +#define COL(col, fmt) "\x1b[3" col "m" fmt "\x1b[0m" +#define COL_FATAL(fmt) COL("1;41", fmt) +#define COL_WARN(fmt) COL("1;43", fmt) +#define COL_INFO(fmt) COL("2", fmt) +#define COL_IFACE(fmt) fmt +#define COL_LOW(fmt) COL("35", fmt) +#define COL_MED(fmt) COL("2;33", fmt) +#define COL_HIGH(fmt) COL("2;32", fmt) +#else +#define COL_FATAL(fmt) fmt +#define COL_WARN(fmt) fmt +#define COL_INFO(fmt) fmt +#define COL_IFACE(fmt) fmt +#define COL_LOW(fmt) fmt +#define COL_MED(fmt) fmt +#define COL_HIGH(fmt) fmt +#endif + +struct func_debug_t +{ + uint32_t func_debug_verbosity = 0; // What verbosity level is set? (bitmask) + uint64_t func_debug_mask = 0; // Which units have debugging enabled? (bitmask) + uint64_t func_debug_inst_mask = 0; // Which instances have debugging enabled (bitmask) + uint64_t inst_id = 0; // The instance id for multiple model instances + FILE* func_debug_file = stderr; // Output file + bool is_output_unbuffered = false; // should log files be opened with unbuffered I/O. + + int init_debug(uint64_t inst_id); + int fini_debug(); + int set_file(const std::string& filename); + void set_mask(const std::string& str); + void set_mask(const uint64_t mask); + void print_masks(FILE* out); + void set_verbosity(const std::string& str); + void set_verbosity(const uint32_t verb); + void set_inst_mask(const char* mask); + void set_inst_mask(const uint64_t mask); + void set_output_unbuffered(const bool is_unbuffered); + std::string get_debug_mask_help_string(); + std::string get_debug_verbosity_help_string(); +}; + +#ifndef ASSERT +#define ASSERT(COND) \ + if (!(COND)) \ + { \ + fprintf(stderr, COL_FATAL("ASSERTION AT %s:%d %s(): (%s)\n"), __FILE__, __LINE__, __func__, #COND); \ + func_print_backtrace(stderr); \ + assert(COND); \ + } +#endif + +#ifndef ASSERT_MSG +#define ASSERT_MSG(COND, fmt, ...) \ + if (!(COND)) \ + { \ + fprintf(stderr, COL_FATAL("ASSERTION AT %s:%d %s(): (%s)\n"), __FILE__, __LINE__, __func__, #COND); \ + fprintf(stderr, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \ + func_print_backtrace(stderr); \ + assert(COND); \ + } +#endif + +#ifndef REQUIRE +#define REQUIRE(COND, fmt, ...) \ + if (!(COND)) \ + { \ + fprintf(g_func_debug.func_debug_file, COL_FATAL("REQUIRE() fails AT %s:%d %s(): (%s)\n"), __FILE__, __LINE__, \ + __func__, #COND); \ + fprintf(g_func_debug.func_debug_file, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \ + this->parent_sgt->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE); \ + } +#endif + +#ifndef ERROR_IF +#define ERROR_IF(COND, fmt, ...) \ + if ((COND)) \ + { \ + if (this->parent_sgt->getGraphStatus() != GraphStatus::TOSA_UNPREDICTABLE) \ + { \ + this->parent_sgt->setGraphStatus(GraphStatus::TOSA_ERROR); \ + } \ + fprintf(g_func_debug.func_debug_file, COL_FATAL("ERROR_IF() fails AT %s:%d %s(): (%s)\n"), __FILE__, __LINE__, \ + __func__, #COND); \ + fprintf(g_func_debug.func_debug_file, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \ + this->dumpNode(g_func_debug.func_debug_file); \ + func_print_backtrace(g_func_debug.func_debug_file); \ + return 1; \ + } +#endif + +// Assertion specific to allocating memory +#ifndef ASSERT_MEM +#define ASSERT_MEM(OBJ) \ + if (!(OBJ)) \ + { \ + fprintf(stderr, COL_FATAL("ASSERTION AT %s:%d %s(): (" #OBJ "): out of memory\n"), __FILE__, __LINE__, \ + __func__); \ + func_print_backtrace(stderr); \ + assert(OBJ); \ + } +#endif + +#ifndef FATAL_ERROR +#define FATAL_ERROR(fmt, ...) \ + fprintf(stderr, COL_FATAL("FATAL ERROR AT %s:%d %s():\n"), __FILE__, __LINE__, __func__); \ + fprintf(stderr, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \ + func_print_backtrace(stderr); \ + abort(); +#endif + +void func_debug_warning( + func_debug_t* func_debug, const char* file, const char* func, const int line, const char* fmt, ...); +#ifndef WARNING +#define WARNING(...) func_debug_warning(&g_func_debug, __FILE__, __func__, __LINE__, __VA_ARGS__) +#endif + +#ifndef WARNING_STDERR +#define WARNING_STDERR(fmt, ...) \ + fprintf(stderr, COL_WARN("WARNING AT %s:%d %s():\n"), __FILE__, __LINE__, __func__); \ + fprintf(stderr, COL_WARN(fmt) "\n", ##__VA_ARGS__); +#endif + + + +// Is this debug verbosity and unit level enabled? +// Provide compiler hints that this is unlikely +// Two versions, depending on whether DEBUG_INSTANCE_EXPR is defined in a file or not +// +// For .cpp files whose units have discrete instance IDs, define DEBUG_INSTANCE_EXPR to evalute +// to the instance ID variable. The use of this define in header files is discouraged. + +#ifdef DEBUG_INSTANCE_EXPR +// Expression for whether the debugging verbosity + debugging unit is enabled for free-form printouts +#ifdef DEBUG_INSTANCE_EXPR_2 +#define DEBUG_ENABLED(VERB, LEVEL) \ + (__builtin_expect((g_func_debug.func_debug_mask == DEBUG_ALL || g_func_debug.func_debug_mask & (DEBUG_##LEVEL)) && \ + (g_func_debug.func_debug_inst_mask & (uint64_t(1) << (DEBUG_INSTANCE_EXPR))) && \ + (g_func_debug.func_debug_verbosity & (VERB)), \ + 0)) +// Debug printing macro +#define DEBUG(VERB, LEVEL, FMT, ...) \ + if (DEBUG_ENABLED(VERB, LEVEL)) \ + { \ + fprintf(g_func_debug.func_debug_file, "[%d:" #LEVEL "_%02d_%02d" WHERE "]: " FMT "\n", \ + (int)g_func_debug.inst_id, (int)(DEBUG_INSTANCE_EXPR), (int)(DEBUG_INSTANCE_EXPR_2), ##__VA_ARGS__); \ + } + +// Prints just the debugging prefix for properly marking free-form printouts +#define DEBUG_PREFIX(LEVEL) \ + fprintf(g_func_debug.func_debug_file, "[%d" #LEVEL "_%02d_%02d" WHERE "]: ", (int)g_func_debug.inst_id, \ + (int)(DEBUG_INSTANCE_EXPR), (int)(DEBUG_INSTANCE_EXPR_2)) + +#else // !DEBUG_INSTANCE_EXPR_2 + +#define DEBUG_ENABLED(VERB, LEVEL) \ + (__builtin_expect((g_func_debug.func_debug_mask == DEBUG_ALL || g_func_debug.func_debug_mask & (DEBUG_##LEVEL)) && \ + (g_func_debug.func_debug_inst_mask & (uint64_t(1) << (DEBUG_INSTANCE_EXPR))) && \ + (g_func_debug.func_debug_verbosity & (VERB)), \ + 0)) +// Debug printing macro +#define DEBUG(VERB, LEVEL, FMT, ...) \ + if (DEBUG_ENABLED(VERB, LEVEL)) \ + { \ + fprintf(g_func_debug.func_debug_file, "[%d:" #LEVEL "_%02d" WHERE "]: " FMT "\n", (int)g_func_debug.inst_id, \ + (int)(DEBUG_INSTANCE_EXPR), ##__VA_ARGS__); \ + } + +// Prints just the debugging prefix for properly marking free-form printouts +#define DEBUG_PREFIX(LEVEL) \ + fprintf(g_func_debug.func_debug_file, "[%d:" #LEVEL "_%02d" WHERE "]: ", (int)g_func_debug.inst_id, \ + (int)(DEBUG_INSTANCE_EXPR)) + +#endif // DEBUG_INSTANCE_EXPR_2 + +#else // !DEBUG_INSTANCE_EXPR + +// Expression for whether the debugging verbosity + debugging unit is enabled for free-form printouts +#define DEBUG_ENABLED(VERB, LEVEL) \ + (__builtin_expect((g_func_debug.func_debug_mask == DEBUG_ALL || g_func_debug.func_debug_mask & (DEBUG_##LEVEL)) && \ + (g_func_debug.func_debug_verbosity & (VERB)), \ + 0)) +// Debug printing macro +#define DEBUG(VERB, LEVEL, FMT, ...) \ + if (DEBUG_ENABLED(VERB, LEVEL)) \ + { \ + fprintf(g_func_debug.func_debug_file, "[%d:" #LEVEL WHERE "]: " FMT "\n", (int)g_func_debug.inst_id, \ + ##__VA_ARGS__); \ + } + +// Prints just the debugging prefix for properly marking free-form printouts +#define DEBUG_PREFIX(LEVEL) fprintf(g_func_debug.func_debug_file, "[" #LEVEL WHERE "]: ") + +#endif + +// Macros for different verbosity levels +#define DEBUG_INFO(LEVEL, FMT, ...) DEBUG(DEBUG_VERB_INFO, LEVEL, COL_INFO(FMT), ##__VA_ARGS__) +#define DEBUG_IFACE(LEVEL, FMT, ...) DEBUG(DEBUG_VERB_IFACE, LEVEL, COL_IFACE(FMT), ##__VA_ARGS__) +#define DEBUG_LOW(LEVEL, FMT, ...) DEBUG(DEBUG_VERB_LOW, LEVEL, COL_LOW(FMT), ##__VA_ARGS__) +#define DEBUG_MED(LEVEL, FMT, ...) DEBUG(DEBUG_VERB_MED, LEVEL, COL_MED(FMT), ##__VA_ARGS__) +#define DEBUG_HIGH(LEVEL, FMT, ...) DEBUG(DEBUG_VERB_HIGH, LEVEL, COL_HIGH(FMT), ##__VA_ARGS__) + +#endif diff --git a/reference_model/include/graph_status.h b/reference_model/include/graph_status.h new file mode 100644 index 0000000..f3be004 --- /dev/null +++ b/reference_model/include/graph_status.h @@ -0,0 +1,25 @@ +// Copyright (c) 2022, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GRAPH_STATUS_H +#define GRAPH_STATUS_H + +enum class GraphStatus : int +{ + TOSA_VALID = 0, + TOSA_UNPREDICTABLE = 1, + TOSA_ERROR = 2, +}; + +#endif // GRAPH_STATUS_H diff --git a/reference_model/include/model_common.h b/reference_model/include/model_common.h new file mode 100644 index 0000000..d6dab6d --- /dev/null +++ b/reference_model/include/model_common.h @@ -0,0 +1,28 @@ + +// Copyright (c) 2020, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MODEL_COMMON_H +#define MODEL_COMMON_H + +#include +#include + +#include "func_config.h" +#include "func_debug.h" + +extern func_config_t g_func_config; +extern func_debug_t g_func_debug; + +#endif diff --git a/reference_model/include/model_runner.h b/reference_model/include/model_runner.h new file mode 100644 index 0000000..4629467 --- /dev/null +++ b/reference_model/include/model_runner.h @@ -0,0 +1,87 @@ + +// Copyright (c) 2022, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MODEL_RUNNER_H_ +#define MODEL_RUNNER_H_ + +#include "model_common.h" +#include "graph_status.h" + +#include "tosa_serialization_handler.h" + +namespace TosaReference +{ + +class ModelRunnerImpl; + +/* + * This interface allows a user to initialize, run and get the output from a model. + * See model_runner_simple_sample.cpp for example on how this interface can be used. + */ +class IModelRunner +{ +public: + IModelRunner(); + IModelRunner(const func_config_t& func_config, const func_debug_t& func_debug); + + ~IModelRunner(); + + /* + * Functional and debug configurations can also be set. + * See func_config.h and func_debug.h for possible options. + */ + void setFuncConfig(func_config_t& func_config); + void setFuncDebug(func_debug_t& func_debug); + + /* + * Initialize the model. + * The TosaSerializationHandler is validated and then converted to a SubgraphTraverser internally. + * This SubgraphTraverser is initialized, allocated and validated. + * The status of the graph will be returned upon completion. + */ + GraphStatus initialize(tosa::TosaSerializationHandler& serialization_handler); + + /* + * Run the model using the internal SubgraphTraverser created during initialization. + * If validate_only is specified run() will simply return the graph status. + * Otherwise, the graph will be run and the output tensors will be generated if the graph is valid. + * The output tensors can then be retrieved with getOutput(). + * NOTE: initialize() must be called before run(). Also, setInput() must be called for all inputs in the model. + */ + GraphStatus run(); + + /* + * Set the input tensors for the model. + * The input_name much match the input tensor name in the model. + * NOTE: setInput() must be called for each input tensor before run() is called. + */ + template + int setInput(std::string input_name, std::vector vals); + + /* + * Retrieve the output tensors from the graph after running. + * The output_name much match the output tensor name in the model. + * NOTE: run() must be called before outputs are retrieved. + */ + template + std::vector getOutput(std::string output_name); + +private: + std::unique_ptr model_runner_impl; +}; + +}; // namespace TosaReference + +#endif diff --git a/reference_model/include/version.h b/reference_model/include/version.h new file mode 100644 index 0000000..cd1d598 --- /dev/null +++ b/reference_model/include/version.h @@ -0,0 +1,23 @@ +// Copyright (c) 2022, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef VERSION_H +#define VERSION_H + +#define TOSA_REFERENCE_MODEL_VERSION_MAJOR 0 +#define TOSA_REFERENCE_MODEL_VERSION_MINOR 41 +#define TOSA_REFERENCE_MODEL_VERSION_PATCH 0 +#define TOSA_REFERENCE_MODEL_VERSION_DRAFT true + +#endif //VERSION_H diff --git a/reference_model/samples/model_runner_simple_sample.cpp b/reference_model/samples/model_runner_simple_sample.cpp new file mode 100644 index 0000000..2eebca6 --- /dev/null +++ b/reference_model/samples/model_runner_simple_sample.cpp @@ -0,0 +1,97 @@ + +// Copyright (c) 2022, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "general_utils.h" +#include "model_runner.h" + +int main() +{ + using namespace TosaReference; + + std::string test_root(std::string(PROJECT_ROOT) + "../examples/test_add_1x4x4x4_f32/"); + std::string tosa_model_file(test_root + "flatbuffer-tflite/test_add_1x4x4x4_f32.tosa"); + std::string input0_file(test_root + "placeholder_0.npy"); + std::string input1_file(test_root + "placeholder_1.npy"); + std::string expected_output_file(test_root + "tflite_result.npy"); + + std::vector input_names = { "TosaInput_0", "TosaInput_1" }; + std::string output_name = "TosaOutput_0"; + + std::vector input0_shape = { 1, 4, 4, 1 }; + std::vector input1_shape = { 1, 4, 4, 4 }; + std::vector output_shape = { 1, 4, 4, 4 }; + + std::vector> inputs(input_names.size()); + std::vector expected_outputs = { }; + std::vector actual_outputs = { }; + + // Read in inputs and expected outputs for sample purposes. + inputs[0] = readFromNpyFile(input0_file.c_str(), input0_shape); + inputs[1] = readFromNpyFile(input1_file.c_str(), input1_shape); + expected_outputs = readFromNpyFile(expected_output_file.c_str(), output_shape); + + tosa::TosaSerializationHandler handler; + tosa::tosa_err_t error = handler.LoadFileTosaFlatbuffer(tosa_model_file.c_str()); + if(error != tosa::TOSA_OK) + { + WARNING("An error occurred while loading the model from file."); + return 1; + } + GraphStatus status; + + // Initialize the ModelRunner with configurations. + IModelRunner runner; + status = runner.initialize(handler); + if(status != GraphStatus::TOSA_VALID) + { + WARNING("An error occurred while initializing."); + return 1; + } + + // Set the model inputs using the input names and input data. + runner.setInput(input_names[0], inputs[0]); + runner.setInput(input_names[1], inputs[1]); + + // Run the ModelRunner using test inputs. + status = runner.run(); + if(status != GraphStatus::TOSA_VALID) + { + WARNING("An error occurred when running the model."); + return 1; + } + + // Get the outputs from the model. + actual_outputs = runner.getOutput(output_name); + + // Compare the actual output to the expected output. + bool if_accurate = true; + for (size_t i = 0; i < expected_outputs.size(); ++i) + { + if(actual_outputs[i] != expected_outputs[i]) + { + WARNING("Actual output (%f) doesn't match expected output (%f)."); + if_accurate = false; + } + } + + if(!if_accurate) + { + WARNING("There were mismatches in actual vs expected output, see above output for more details."); + return 1; + } + + printf("The model ran successfully without errors and matched the expected output.\n"); + return 0; +} \ No newline at end of file diff --git a/reference_model/src/command_line_utils.h b/reference_model/src/command_line_utils.h new file mode 100644 index 0000000..1bd1639 --- /dev/null +++ b/reference_model/src/command_line_utils.h @@ -0,0 +1,101 @@ + +// Copyright (c) 2020-2022, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef COMMAND_LINE_UTILS_H_ +#define COMMAND_LINE_UTILS_H_ + +#include "func_config.h" +#include "func_debug.h" + +#include +#include + +// Read the command line arguments +int func_model_parse_cmd_line( + func_config_t& func_config, func_debug_t& func_debug, int argc, char** argv, const char* version) +{ + try + { + cxxopts::Options options("tosa_reference_model", "The TOSA reference model"); + + // clang-format off + options.add_options() + ("operator_fbs", "Flat buffer schema file", cxxopts::value(func_config.operator_fbs), "") + ("test_desc", "Json test descriptor", cxxopts::value(func_config.test_desc), "") + ("flatbuffer_dir", "Flatbuffer directory to load. If not specified, it will be overwritten by dirname(test_desc)", + cxxopts::value(func_config.flatbuffer_dir)) + ("output_dir", "Output directory to write. If not specified, it will be overwritten by dirname(test_desc)", + cxxopts::value(func_config.output_dir)) + ("tosa_file", "Flatbuffer file. Support .json or .tosa. Specifying this will overwrite the one initialized by --test_desc.", + cxxopts::value(func_config.tosa_file)) + ("ifm_name", "Input tensor name. Comma(,) separated. Specifying this will overwrite the one initialized by --test_desc.", + cxxopts::value(func_config.ifm_name)) + ("ifm_file", "Input tensor numpy Comma(,) separated. file to initialize with placeholder. Specifying this will overwrite the one initialized by --test_desc.", + cxxopts::value(func_config.ifm_file)) + ("ofm_name", "Output tensor name. Comma(,) seperated. Specifying this will overwrite the one initialized by --test_desc.", + cxxopts::value(func_config.ofm_name)) + ("ofm_file", "Output tensor numpy file to be generated. Comma(,) seperated. Specifying this will overwrite the one initialized by --test_desc.", + cxxopts::value(func_config.ofm_file)) + ("eval", "Evaluate the network (0/1)", cxxopts::value(func_config.eval)) + ("fp_format", "Floating-point number dump format string (printf-style format, e.g. 0.5)", + cxxopts::value(func_config.fp_format)) + ("validate_only", "Validate the network, but do not read inputs or evaluate (0/1)", + cxxopts::value(func_config.validate_only)) + ("output_tensors", "Output tensors to a file (0/1)", cxxopts::value(func_config.output_tensors)) + ("tosa_profile", "Set TOSA profile (0 = Base Inference, 1 = Main Inference, 2 = Main Training)", + cxxopts::value(func_config.tosa_profile)) + ("dump_intermediates", "Dump intermediate tensors (0/1)", cxxopts::value(func_config.dump_intermediates)) + ("v,version", "print model version") + ("i,input_tensor_file", "specify input tensor files", cxxopts::value>()) + ("l,loglevel", func_debug.get_debug_verbosity_help_string(), cxxopts::value()) + ("o,logfile", "output log file", cxxopts::value()) + ("d,debugmask", func_debug.get_debug_mask_help_string(), cxxopts::value>()) + ("h,help", "print help"); + // clang-format on + + auto result = options.parse(argc, argv); + if (result.count("help")) { + std::cout << options.help() << std::endl; + return 1; + } + if (result.count("debugmask")) { + auto& v = result["debugmask"].as>(); + for (const std::string& s : v) + func_debug.set_mask(s); + } + if (result.count("loglevel")) { + const std::string& levelstr = result["loglevel"].as(); + func_debug.set_verbosity(levelstr); + } + if (result.count("logfile")) { + func_debug.set_file(result["logfile"].as()); + } + if (result.count("input_tensor_file")) { + func_config.ifm_name = result["input_tensor_file"].as(); + } + if (result.count("version")) { + std::cout << "Model version " << version << std::endl; + } + } + catch(const std::exception& e) + { + std::cerr << e.what() << '\n'; + return 1; + } + + return 0; +} + +#endif diff --git a/reference_model/src/debug_modes.def b/reference_model/src/debug_modes.def deleted file mode 100644 index 51b151d..0000000 --- a/reference_model/src/debug_modes.def +++ /dev/null @@ -1,20 +0,0 @@ - -// Copyright (c) 2020, ARM Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Defines the debugging printing modes - -DEBUG_MODE(CONFIG,0) // Configuration parsing/initialization -DEBUG_MODE(GT,1) // Graph traverser -DEBUG_MODE(OP,2) // Operation diff --git a/reference_model/src/debug_types.h b/reference_model/src/debug_types.h deleted file mode 100644 index bd93f19..0000000 --- a/reference_model/src/debug_types.h +++ /dev/null @@ -1,57 +0,0 @@ - -// Copyright (c) 2020, ARM Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -/* - * Filename: src/debug_types.h - * Description: - * Defines fundamental debugger datatypes for the functional model - */ - -#ifndef DEBUG_TYPES_H_ -#define DEBUG_TYPES_H_ - -#ifdef __cplusplus -extern "C" -{ -#endif - - // Debug verbosity mask - typedef enum func_debug_verbosity_e - { - DEBUG_VERB_NONE = 0x00, - DEBUG_VERB_INFO = 0x01, // Informational debugging messages - DEBUG_VERB_IFACE = 0x02, // Interface debugging support - DEBUG_VERB_LOW = 0x04, // Low, medium, and high levels of debug printout - DEBUG_VERB_MED = 0x08, - DEBUG_VERB_HIGH = 0x10 - } func_debug_verbosity_e; - - // Generated debug modes enumeration - typedef enum func_debug_mode_e - { - DEBUG_NONE = 0x0, -#define DEBUG_MODE(NAME, BIT) DEBUG_##NAME = (1UL << BIT), -#include "debug_modes.def" -#undef DEBUG_MODE - DEBUG_ALL = 0xffffffffffffffffUL - } func_debug_mode_e; - -#define DEBUG_INST_ALL 0xffffffffffffffffUL - -#ifdef __cplusplus -} -#endif - -#endif diff --git a/reference_model/src/func_config.cc b/reference_model/src/func_config.cc deleted file mode 100644 index 6bd809e..0000000 --- a/reference_model/src/func_config.cc +++ /dev/null @@ -1,108 +0,0 @@ - -// Copyright (c) 2020-2022, ARM Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include "func_config.h" -#include "func_debug.h" - -// Read the command line arguments -int func_model_parse_cmd_line( - func_config_t& func_config, func_debug_t& func_debug, int argc, char** argv, const char* version) -{ - try - { - cxxopts::Options options("tosa_reference_model", "The TOSA reference model"); - - // clang-format off - options.add_options() - ("operator_fbs", "Flat buffer schema file", cxxopts::value(func_config.operator_fbs), "") - ("test_desc", "Json test descriptor", cxxopts::value(func_config.test_desc), "") - ("flatbuffer_dir", "Flatbuffer directory to load. If not specified, it will be overwritten by dirname(test_desc)", - cxxopts::value(func_config.flatbuffer_dir)) - ("output_dir", "Output directory to write. If not specified, it will be overwritten by dirname(test_desc)", - cxxopts::value(func_config.output_dir)) - ("tosa_file", "Flatbuffer file. Support .json or .tosa. Specifying this will overwrite the one initialized by --test_desc.", - cxxopts::value(func_config.tosa_file)) - ("ifm_name", "Input tensor name. Comma(,) separated. Specifying this will overwrite the one initialized by --test_desc.", - cxxopts::value(func_config.ifm_name)) - ("ifm_file", "Input tensor numpy Comma(,) separated. file to initialize with placeholder. Specifying this will overwrite the one initialized by --test_desc.", - cxxopts::value(func_config.ifm_file)) - ("ofm_name", "Output tensor name. Comma(,) seperated. Specifying this will overwrite the one initialized by --test_desc.", - cxxopts::value(func_config.ofm_name)) - ("ofm_file", "Output tensor numpy file to be generated. Comma(,) seperated. Specifying this will overwrite the one initialized by --test_desc.", - cxxopts::value(func_config.ofm_file)) - ("eval", "Evaluate the network (0/1)", cxxopts::value(func_config.eval)) - ("fp_format", "Floating-point number dump format string (printf-style format, e.g. 0.5)", - cxxopts::value(func_config.fp_format)) - ("validate_only", "Validate the network, but do not read inputs or evaluate (0/1)", - cxxopts::value(func_config.validate_only)) - ("output_tensors", "Output tensors to a file (0/1)", cxxopts::value(func_config.output_tensors)) - ("tosa_profile", "Set TOSA profile (0 = Base Inference, 1 = Main Inference, 2 = Main Training)", - cxxopts::value(func_config.tosa_profile)) - ("dump_intermediates", "Dump intermediate tensors (0/1)", cxxopts::value(func_config.dump_intermediates)) - ("v,version", "print model version") - ("i,input_tensor_file", "specify input tensor files", cxxopts::value>()) - ("l,loglevel", func_debug.get_debug_verbosity_help_string(), cxxopts::value()) - ("o,logfile", "output log file", cxxopts::value()) - ("d,debugmask", func_debug.get_debug_mask_help_string(), cxxopts::value>()) - ("h,help", "print help"); - // clang-format on - - auto result = options.parse(argc, argv); - if (result.count("help")) { - std::cout << options.help() << std::endl; - return 1; - } - if (result.count("debugmask")) { - auto& v = result["debugmask"].as>(); - for (const std::string& s : v) - func_debug.set_mask(s); - } - if (result.count("loglevel")) { - const std::string& levelstr = result["loglevel"].as(); - func_debug.set_verbosity(levelstr); - } - if (result.count("logfile")) { - func_debug.set_file(result["logfile"].as()); - } - if (result.count("input_tensor_file")) { - func_config.ifm_name = result["input_tensor_file"].as(); - } - if (result.count("version")) { - std::cout << "Model version " << version << std::endl; - } - } - catch(const std::exception& e) - { - std::cerr << e.what() << '\n'; - return 1; - } - - return 0; -} - -void func_model_print_help() -{ - -} diff --git a/reference_model/src/func_config.h b/reference_model/src/func_config.h deleted file mode 100644 index 49f03e9..0000000 --- a/reference_model/src/func_config.h +++ /dev/null @@ -1,46 +0,0 @@ - -// Copyright (c) 2020-2022, ARM Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef FUNC_CONFIG_H_ -#define FUNC_CONFIG_H_ - -struct func_config_t -{ - std::string operator_fbs = "tosa.fbs"; - std::string test_desc = "desc.json"; - std::string flatbuffer_dir = ""; - std::string output_dir = ""; - std::string tosa_file = ""; - std::string ifm_name = ""; - std::string ifm_file = ""; - std::string ofm_name = ""; - std::string ofm_file = ""; - uint32_t eval = 1; - uint32_t validate_only = 0; - uint32_t output_tensors = 1; - uint32_t tosa_profile = 1; - uint32_t dump_intermediates = 0; - std::string fp_format = "0.5"; -}; - -// Forward declaration -struct func_debug_t; - -int func_model_parse_cmd_line( - func_config_t& func_config, func_debug_t& func_debug, int argc, char** argv, const char* version); -int func_model_parse_flat_config_file(func_config_t*, const char* filename); -void func_model_print_help(); - -#endif diff --git a/reference_model/src/func_debug.h b/reference_model/src/func_debug.h deleted file mode 100644 index ee89935..0000000 --- a/reference_model/src/func_debug.h +++ /dev/null @@ -1,244 +0,0 @@ - -// Copyright (c) 2020, ARM Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef FUNC_DEBUG_H -#define FUNC_DEBUG_H - -#include "debug_types.h" -#include -#include -#include -#include -#include - -void func_print_backtrace(FILE* out, int sig = SIGABRT); - -void func_enable_signal_handlers(); - -// STRINGIFY2 is needed expand expression passed to STRINGIFY -#define STRINGIFY2(s) #s -#define STRINGIFY(s) STRINGIFY2(s) - -// If TRACED_LOG is defined, add file:line to log messages -#if defined(TRACED_LOG) -#define WHERE "@" __FILE__ ":" STRINGIFY(__LINE__) -#else -#define WHERE -#endif - -#if defined(COLORIZED_LOG) -#define COL(col, fmt) "\x1b[3" col "m" fmt "\x1b[0m" -#define COL_FATAL(fmt) COL("1;41", fmt) -#define COL_WARN(fmt) COL("1;43", fmt) -#define COL_INFO(fmt) COL("2", fmt) -#define COL_IFACE(fmt) fmt -#define COL_LOW(fmt) COL("35", fmt) -#define COL_MED(fmt) COL("2;33", fmt) -#define COL_HIGH(fmt) COL("2;32", fmt) -#else -#define COL_FATAL(fmt) fmt -#define COL_WARN(fmt) fmt -#define COL_INFO(fmt) fmt -#define COL_IFACE(fmt) fmt -#define COL_LOW(fmt) fmt -#define COL_MED(fmt) fmt -#define COL_HIGH(fmt) fmt -#endif - -struct func_debug_t -{ - uint32_t func_debug_verbosity = 0; // What verbosity level is set? (bitmask) - uint64_t func_debug_mask = 0; // Which units have debugging enabled? (bitmask) - uint64_t func_debug_inst_mask = 0; // Which instances have debugging enabled (bitmask) - uint64_t inst_id = 0; // The instance id for multiple model instances - FILE* func_debug_file = stderr; // Output file - bool is_output_unbuffered = false; // should log files be opened with unbuffered I/O. - - int init_debug(uint64_t inst_id); - int fini_debug(); - int set_file(const std::string& filename); - void set_mask(const std::string& str); - void set_mask(const uint64_t mask); - void print_masks(FILE* out); - void set_verbosity(const std::string& str); - void set_verbosity(const uint32_t verb); - void set_inst_mask(const char* mask); - void set_inst_mask(const uint64_t mask); - void set_output_unbuffered(const bool is_unbuffered); - std::string get_debug_mask_help_string(); - std::string get_debug_verbosity_help_string(); -}; - -#ifndef ASSERT -#define ASSERT(COND) \ - if (!(COND)) \ - { \ - fprintf(stderr, COL_FATAL("ASSERTION AT %s:%d %s(): (%s)\n"), __FILE__, __LINE__, __func__, #COND); \ - func_print_backtrace(stderr); \ - assert(COND); \ - } -#endif - -#ifndef ASSERT_MSG -#define ASSERT_MSG(COND, fmt, ...) \ - if (!(COND)) \ - { \ - fprintf(stderr, COL_FATAL("ASSERTION AT %s:%d %s(): (%s)\n"), __FILE__, __LINE__, __func__, #COND); \ - fprintf(stderr, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \ - func_print_backtrace(stderr); \ - assert(COND); \ - } -#endif - -#ifndef REQUIRE -#define REQUIRE(COND, fmt, ...) \ - if (!(COND)) \ - { \ - fprintf(g_func_debug.func_debug_file, COL_FATAL("REQUIRE() fails AT %s:%d %s(): (%s)\n"), __FILE__, __LINE__, \ - __func__, #COND); \ - fprintf(g_func_debug.func_debug_file, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \ - this->parent_sgt->setGraphStatus(GraphStatus::TOSA_UNPREDICTABLE); \ - } -#endif - -#ifndef ERROR_IF -#define ERROR_IF(COND, fmt, ...) \ - if ((COND)) \ - { \ - if (this->parent_sgt->getGraphStatus() != GraphStatus::TOSA_UNPREDICTABLE) \ - { \ - this->parent_sgt->setGraphStatus(GraphStatus::TOSA_ERROR); \ - } \ - fprintf(g_func_debug.func_debug_file, COL_FATAL("ERROR_IF() fails AT %s:%d %s(): (%s)\n"), __FILE__, __LINE__, \ - __func__, #COND); \ - fprintf(g_func_debug.func_debug_file, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \ - this->dumpNode(g_func_debug.func_debug_file); \ - func_print_backtrace(g_func_debug.func_debug_file); \ - return 1; \ - } -#endif - -// Assertion specific to allocating memory -#ifndef ASSERT_MEM -#define ASSERT_MEM(OBJ) \ - if (!(OBJ)) \ - { \ - fprintf(stderr, COL_FATAL("ASSERTION AT %s:%d %s(): (" #OBJ "): out of memory\n"), __FILE__, __LINE__, \ - __func__); \ - func_print_backtrace(stderr); \ - assert(OBJ); \ - } -#endif - -#ifndef FATAL_ERROR -#define FATAL_ERROR(fmt, ...) \ - fprintf(stderr, COL_FATAL("FATAL ERROR AT %s:%d %s():\n"), __FILE__, __LINE__, __func__); \ - fprintf(stderr, COL_FATAL(fmt) "\n", ##__VA_ARGS__); \ - func_print_backtrace(stderr); \ - abort(); -#endif - -void func_debug_warning( - func_debug_t* func_debug, const char* file, const char* func, const int line, const char* fmt, ...); -#ifndef WARNING -#define WARNING(...) func_debug_warning(&g_func_debug, __FILE__, __func__, __LINE__, __VA_ARGS__) -#endif - -#ifndef WARNING_STDERR -#define WARNING_STDERR(fmt, ...) \ - fprintf(stderr, COL_WARN("WARNING AT %s:%d %s():\n"), __FILE__, __LINE__, __func__); \ - fprintf(stderr, COL_WARN(fmt) "\n", ##__VA_ARGS__); -#endif - - - -// Is this debug verbosity and unit level enabled? -// Provide compiler hints that this is unlikely -// Two versions, depending on whether DEBUG_INSTANCE_EXPR is defined in a file or not -// -// For .cpp files whose units have discrete instance IDs, define DEBUG_INSTANCE_EXPR to evalute -// to the instance ID variable. The use of this define in header files is discouraged. - -#ifdef DEBUG_INSTANCE_EXPR -// Expression for whether the debugging verbosity + debugging unit is enabled for free-form printouts -#ifdef DEBUG_INSTANCE_EXPR_2 -#define DEBUG_ENABLED(VERB, LEVEL) \ - (__builtin_expect((g_func_debug.func_debug_mask == DEBUG_ALL || g_func_debug.func_debug_mask & (DEBUG_##LEVEL)) && \ - (g_func_debug.func_debug_inst_mask & (uint64_t(1) << (DEBUG_INSTANCE_EXPR))) && \ - (g_func_debug.func_debug_verbosity & (VERB)), \ - 0)) -// Debug printing macro -#define DEBUG(VERB, LEVEL, FMT, ...) \ - if (DEBUG_ENABLED(VERB, LEVEL)) \ - { \ - fprintf(g_func_debug.func_debug_file, "[%d:" #LEVEL "_%02d_%02d" WHERE "]: " FMT "\n", \ - (int)g_func_debug.inst_id, (int)(DEBUG_INSTANCE_EXPR), (int)(DEBUG_INSTANCE_EXPR_2), ##__VA_ARGS__); \ - } - -// Prints just the debugging prefix for properly marking free-form printouts -#define DEBUG_PREFIX(LEVEL) \ - fprintf(g_func_debug.func_debug_file, "[%d" #LEVEL "_%02d_%02d" WHERE "]: ", (int)g_func_debug.inst_id, \ - (int)(DEBUG_INSTANCE_EXPR), (int)(DEBUG_INSTANCE_EXPR_2)) - -#else // !DEBUG_INSTANCE_EXPR_2 - -#define DEBUG_ENABLED(VERB, LEVEL) \ - (__builtin_expect((g_func_debug.func_debug_mask == DEBUG_ALL || g_func_debug.func_debug_mask & (DEBUG_##LEVEL)) && \ - (g_func_debug.func_debug_inst_mask & (uint64_t(1) << (DEBUG_INSTANCE_EXPR))) && \ - (g_func_debug.func_debug_verbosity & (VERB)), \ - 0)) -// Debug printing macro -#define DEBUG(VERB, LEVEL, FMT, ...) \ - if (DEBUG_ENABLED(VERB, LEVEL)) \ - { \ - fprintf(g_func_debug.func_debug_file, "[%d:" #LEVEL "_%02d" WHERE "]: " FMT "\n", (int)g_func_debug.inst_id, \ - (int)(DEBUG_INSTANCE_EXPR), ##__VA_ARGS__); \ - } - -// Prints just the debugging prefix for properly marking free-form printouts -#define DEBUG_PREFIX(LEVEL) \ - fprintf(g_func_debug.func_debug_file, "[%d:" #LEVEL "_%02d" WHERE "]: ", (int)g_func_debug.inst_id, \ - (int)(DEBUG_INSTANCE_EXPR)) - -#endif // DEBUG_INSTANCE_EXPR_2 - -#else // !DEBUG_INSTANCE_EXPR - -// Expression for whether the debugging verbosity + debugging unit is enabled for free-form printouts -#define DEBUG_ENABLED(VERB, LEVEL) \ - (__builtin_expect((g_func_debug.func_debug_mask == DEBUG_ALL || g_func_debug.func_debug_mask & (DEBUG_##LEVEL)) && \ - (g_func_debug.func_debug_verbosity & (VERB)), \ - 0)) -// Debug printing macro -#define DEBUG(VERB, LEVEL, FMT, ...) \ - if (DEBUG_ENABLED(VERB, LEVEL)) \ - { \ - fprintf(g_func_debug.func_debug_file, "[%d:" #LEVEL WHERE "]: " FMT "\n", (int)g_func_debug.inst_id, \ - ##__VA_ARGS__); \ - } - -// Prints just the debugging prefix for properly marking free-form printouts -#define DEBUG_PREFIX(LEVEL) fprintf(g_func_debug.func_debug_file, "[" #LEVEL WHERE "]: ") - -#endif - -// Macros for different verbosity levels -#define DEBUG_INFO(LEVEL, FMT, ...) DEBUG(DEBUG_VERB_INFO, LEVEL, COL_INFO(FMT), ##__VA_ARGS__) -#define DEBUG_IFACE(LEVEL, FMT, ...) DEBUG(DEBUG_VERB_IFACE, LEVEL, COL_IFACE(FMT), ##__VA_ARGS__) -#define DEBUG_LOW(LEVEL, FMT, ...) DEBUG(DEBUG_VERB_LOW, LEVEL, COL_LOW(FMT), ##__VA_ARGS__) -#define DEBUG_MED(LEVEL, FMT, ...) DEBUG(DEBUG_VERB_MED, LEVEL, COL_MED(FMT), ##__VA_ARGS__) -#define DEBUG_HIGH(LEVEL, FMT, ...) DEBUG(DEBUG_VERB_HIGH, LEVEL, COL_HIGH(FMT), ##__VA_ARGS__) - -#endif diff --git a/reference_model/src/general_utils.h b/reference_model/src/general_utils.h new file mode 100644 index 0000000..12f831e --- /dev/null +++ b/reference_model/src/general_utils.h @@ -0,0 +1,68 @@ + +// Copyright (c) 2022, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GENERAL_UTILS_H_ +#define GENERAL_UTILS_H_ + +#include "func_debug.h" + +#include "numpy_utils.h" + +namespace TosaReference +{ + +const uint32_t getElementCount(std::vector& shape) +{ + uint32_t elements = 1; + for (size_t i = 0; i < shape.size(); i++) + { + elements *= shape[i]; + } + + return elements; +} + +template +std::vector readFromNpyFile(const char* filename, std::vector& shape) +{ + uint32_t elements = getElementCount(shape); + std::vector data(elements, 0); + + NumpyUtilities::NPError nperror = NumpyUtilities::readFromNpyFile(filename, elements, data.data()); + + switch (nperror) + { + case NumpyUtilities::NO_ERROR: + break; + case NumpyUtilities::FILE_NOT_FOUND: + FATAL_ERROR("readFromNpyFile: Cannot open file %s", filename); + case NumpyUtilities::FILE_IO_ERROR: + FATAL_ERROR("readFromNpyFile: IO error reading file: %s", filename); + case NumpyUtilities::FILE_TYPE_MISMATCH: + FATAL_ERROR("readFromNpyFile: Tensor type and Numpy file type mismatch for filename %s", filename); + case NumpyUtilities::HEADER_PARSE_ERROR: + FATAL_ERROR("Numpy header parsing error for file: %s", filename); + case NumpyUtilities::BUFFER_SIZE_MISMATCH: + FATAL_ERROR("Buffer size does not match numpy file size for filename %s", filename); + default: + FATAL_ERROR("Unknown error parsing Numpy file: %s", filename); + } + + return data; +} + +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp index 643a351..776fbf3 100644 --- a/reference_model/src/main.cpp +++ b/reference_model/src/main.cpp @@ -13,31 +13,24 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include +#include "model_runner.h" +#include "version.h" -#include "model_common.h" +#include "command_line_utils.h" #include "ops/op_factory.h" #include "subgraph_traverser.h" #include "tosa_serialization_handler.h" -#include -#include #include +#include +#include +#include #include -#define MODEL_VERSION_MAJOR 0 -#define MODEL_VERSION_MINOR 41 -#define MODEL_VERSION_PATCH 0 -#define MODEL_VERSION_DRAFT true - using namespace TosaReference; using namespace tosa; using json = nlohmann::json; -// Global instantiation of configuration and debug objects -func_config_t g_func_config; -func_debug_t g_func_debug; - int initTestDesc(json& test_desc); int readInputTensors(SubgraphTraverser& gt, json test_desc); int writeFinalTensors(SubgraphTraverser& gt, json test_desc); @@ -45,7 +38,10 @@ int loadGraph(TosaSerializationHandler& tsh, json test_desc); int main(int argc, char** argv) { - TosaVersion model_version(MODEL_VERSION_MAJOR, MODEL_VERSION_MINOR, MODEL_VERSION_PATCH, MODEL_VERSION_DRAFT); + TosaVersion model_version(TOSA_REFERENCE_MODEL_VERSION_MAJOR, + TOSA_REFERENCE_MODEL_VERSION_MINOR, + TOSA_REFERENCE_MODEL_VERSION_PATCH, + TOSA_REFERENCE_MODEL_VERSION_DRAFT); // Initialize configuration and debug subsystems g_func_debug.init_debug(0); @@ -203,7 +199,6 @@ int loadGraph(TosaSerializationHandler& tsh, json test_desc) if (strlen(graph_fullname) <= 2) { - func_model_print_help(); FATAL_ERROR("Missing required argument: Check \"tosa_file\" in .json specified by -Ctosa_desc="); } diff --git a/reference_model/src/model_common.h b/reference_model/src/model_common.h deleted file mode 100644 index d6dab6d..0000000 --- a/reference_model/src/model_common.h +++ /dev/null @@ -1,28 +0,0 @@ - -// Copyright (c) 2020, ARM Limited. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifndef MODEL_COMMON_H -#define MODEL_COMMON_H - -#include -#include - -#include "func_config.h" -#include "func_debug.h" - -extern func_config_t g_func_config; -extern func_debug_t g_func_debug; - -#endif diff --git a/reference_model/src/model_runner.cc b/reference_model/src/model_runner.cc new file mode 100644 index 0000000..2395a85 --- /dev/null +++ b/reference_model/src/model_runner.cc @@ -0,0 +1,76 @@ + +// Copyright (c) 2022, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "model_runner_impl.h" + +using namespace TosaReference; + +// Global instantiation of configuration and debug objects +func_config_t g_func_config; +func_debug_t g_func_debug; + +IModelRunner::IModelRunner() : model_runner_impl(new ModelRunnerImpl()) +{} + +IModelRunner::IModelRunner(const func_config_t& func_config, + const func_debug_t& func_debug) + : model_runner_impl(new ModelRunnerImpl(func_config, func_debug)) +{} + +IModelRunner::~IModelRunner() +{} + +void IModelRunner::setFuncConfig(func_config_t& func_config) +{ + model_runner_impl->setFuncConfig(func_config); +} + +void IModelRunner::setFuncDebug(func_debug_t& func_debug) +{ + model_runner_impl->setFuncDebug(func_debug); +} + +GraphStatus IModelRunner::initialize(tosa::TosaSerializationHandler& serialization_handler) +{ + return model_runner_impl->initialize(serialization_handler); +} + +GraphStatus IModelRunner::run() +{ + return model_runner_impl->run(); +} + +template +int IModelRunner::setInput(std::string input_name, std::vector vals) +{ + return model_runner_impl->setInput(input_name, vals); +} + +template +std::vector IModelRunner::getOutput(std::string output_name) +{ + return model_runner_impl->getOutput(output_name); +} + +// Template explicit specialization +template int IModelRunner::setInput(std::string input_name, std::vector vals); +template int IModelRunner::setInput(std::string input_name, std::vector vals); +template int IModelRunner::setInput(std::string input_name, std::vector vals); +template int IModelRunner::setInput(std::string input_name, std::vector vals); + +template std::vector IModelRunner::getOutput(std::string output_name); +template std::vector IModelRunner::getOutput(std::string output_name); +template std::vector IModelRunner::getOutput(std::string output_name); +template std::vector IModelRunner::getOutput(std::string output_name); \ No newline at end of file diff --git a/reference_model/src/model_runner_impl.cc b/reference_model/src/model_runner_impl.cc new file mode 100644 index 0000000..e0fdc49 --- /dev/null +++ b/reference_model/src/model_runner_impl.cc @@ -0,0 +1,277 @@ + +// Copyright (c) 2022, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "model_runner_impl.h" + +using namespace TosaReference; + +ModelRunnerImpl::ModelRunnerImpl() +{} + +ModelRunnerImpl::ModelRunnerImpl(const func_config_t& func_config, + const func_debug_t& func_debug) +{ + g_func_config = func_config; + g_func_debug = func_debug; +} + +ModelRunnerImpl::~ModelRunnerImpl() +{ + g_func_debug.fini_debug(); + delete _main_gt; +}; + +void ModelRunnerImpl::setFuncConfig(func_config_t& func_config) +{ + g_func_config = func_config; +} +void ModelRunnerImpl::setFuncDebug(func_debug_t& func_debug) +{ + g_func_debug = func_debug; +} + +GraphStatus ModelRunnerImpl::initialize(TosaSerializationHandler& serialization_handler) +{ + validateTosaVersion(serialization_handler); + + // Make nullptr in case ModelRunnerImpl is being initialized again with a different graph. + _main_gt = nullptr; + _main_gt = new SubgraphTraverser(serialization_handler.GetMainBlock(), &serialization_handler); + + if(_main_gt == nullptr) + { + WARNING("An error occurred when generating main graph traverser."); + return GraphStatus::TOSA_ERROR; + } + + if (_main_gt->initializeGraph()) + { + WARNING("Unable to initialize main graph traverser."); + return _main_gt->getGraphStatus(); + } + + if (_main_gt->linkTensorsAndNodes()) + { + WARNING("Failed to link tensors and nodes"); + return _main_gt->getGraphStatus(); + } + + if (_main_gt->validateGraph()) + { + WARNING("Failed to validate graph."); + return _main_gt->getGraphStatus(); + } + + if (_main_gt->allocateTensor()) + { + WARNING("Failed to allocate tensor."); + return _main_gt->getGraphStatus(); + } + + return _main_gt->getGraphStatus(); +} + +GraphStatus ModelRunnerImpl::run() +{ + if (_main_gt == nullptr) + { + FATAL_ERROR("ModelRunnerImpl hasn't been initialized, please invoke initialize() before run()"); + } + + if (g_func_config.validate_only) + { + goto done; + } + + // Validate the number of inputs matches the + if (static_cast(_main_gt->getNumInputTensors()) != n_input_tensors) + { + FATAL_ERROR("The number of inputs (%d) does not equal the number of inputs in the model (%d). " + "setInput() must be called for each input.", + n_input_tensors, _main_gt->getNumInputTensors()); + } + + if (g_func_config.eval) + { + // evaluateAll() returns 1 if graph evaluation is forced to be terminated earlier. + if (_main_gt->evaluateAll()) + { + ASSERT_MSG(_main_gt->getGraphStatus() != GraphStatus::TOSA_VALID, + "Upon evaluateAll() returning 1, graph can not be VALID."); + } + else + { + ASSERT_MSG(_main_gt->getGraphStatus() == GraphStatus::TOSA_VALID || + _main_gt->getGraphStatus() == GraphStatus::TOSA_UNPREDICTABLE, + "Upon evaluateAll() returning 0, graph can only be VALID/UNPREDICTABLE."); + } + + // Only generate output tensor if graph is valid. + if (_main_gt->getGraphStatus() == GraphStatus::TOSA_VALID) + { + // Make sure output tensor is evaluated and show its value + int num_output_tensors = _main_gt->getNumOutputTensors(); + bool all_output_valid = true; + for (int i = 0; i < num_output_tensors; i++) + { + const Tensor* ct = _main_gt->getOutputTensor(i); + ASSERT_MEM(ct); + if (!ct->getIsValid()) + { + ct->dumpTensorParams(g_func_debug.func_debug_file); + if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT)) + { + ct->dumpTensor(g_func_debug.func_debug_file); + } + all_output_valid = false; + } + } + if (!all_output_valid) + { + _main_gt->dumpGraph(g_func_debug.func_debug_file); + FATAL_ERROR( + "SubgraphTraverser \"main\" error: Output tensors are not all valid at the end of evaluation."); + } + } + } + +done: + // Print status if not valid and do cleanup. + checkGraphStatus(*_main_gt); + g_func_debug.fini_debug(); + + return _main_gt->getGraphStatus(); +} + +template +int ModelRunnerImpl::setInput(std::string input_name, std::vector vals) +{ + if (_main_gt == nullptr) + { + FATAL_ERROR("ModelRunner hasn't been initialized, please invoke initialize() before setInput()"); + } + + Tensor* tensor; + tensor = _main_gt->getInputTensorByName(input_name); + + if (!tensor) + { + WARNING("Unable to find input tensor %s", input_name.c_str()); + return 1; + } + + if (!tensor->is_allocated()) + { + WARNING("Tensor %s is not allocated before being initialized", tensor->getName().c_str()); + return 1; + } + + if (tensor->readfromVector(vals)) + { + WARNING("Unable to convert input tensor %s to Tensor", tensor->getName().c_str()); + return 1; + } + + // Push ready consumers to the next node list + for (auto gn : tensor->getConsumers()) + { + if (gn->hasAllInputsReady() && !gn->getOnNextNodeList()) + { + _main_gt->addToNextNodeList(gn); + } + } + + n_input_tensors++; + return 0; +} + +template +std::vector ModelRunnerImpl::getOutput(std::string output_name) +{ + if (_main_gt == nullptr) + { + FATAL_ERROR("ModelRunner hasn't been initialized, please invoke initialize() and run() before getOutput()"); + } + + Tensor* tensor; + tensor = _main_gt->getOutputTensorByName(output_name); + + if (!tensor) + { + WARNING("Unable to find output tensor %s", output_name.c_str()); + return std::vector(); + } + + std::vector outputs(tensor->getElementCount(), 0); + + if (tensor->writeToVector(outputs)) + { + WARNING("Unable to convert output tensor %s to vector", tensor->getName().c_str()); + return std::vector(); + } + + return outputs; +} + +void ModelRunnerImpl::validateTosaVersion(TosaSerializationHandler& serialization_handler) +{ + TosaVersion model_version(TOSA_REFERENCE_MODEL_VERSION_MAJOR, + TOSA_REFERENCE_MODEL_VERSION_MINOR, + TOSA_REFERENCE_MODEL_VERSION_PATCH, + TOSA_REFERENCE_MODEL_VERSION_DRAFT); + + TosaVersion::compat_t is_compat = model_version.is_compatible(serialization_handler.GetVersion()); + switch (is_compat) + { + case TosaVersion::compat_t::COMPLETELY_COMPATIBLE: + break; + case TosaVersion::compat_t::PARTIALLY_COMPATIBLE: + WARNING("Reference model version %s is partially compatible with serializer version %s.", + model_version.to_string().c_str(), serialization_handler.GetVersion().to_string().c_str()); + break; + case TosaVersion::compat_t::NOT_COMPATIBLE: + FATAL_ERROR("Reference model version %s is not compatible with serializer version %s.", + model_version.to_string().c_str(), serialization_handler.GetVersion().to_string().c_str()); + } +} + +void ModelRunnerImpl::checkGraphStatus(SubgraphTraverser& main_gt) +{ + switch (main_gt.getGraphStatus()) + { + case GraphStatus::TOSA_VALID: + // Result is valid. + break; + case GraphStatus::TOSA_UNPREDICTABLE: + WARNING("Graph result: UNPREDICTABLE."); + break; + case GraphStatus::TOSA_ERROR: + WARNING("Graph result: ERROR."); + break; + default: + WARNING("Unknown graph status code=%d.", (int)main_gt.getGraphStatus()); + } +} + +// Template explicit specialization +template int ModelRunnerImpl::setInput(std::string input_name, std::vector vals); +template int ModelRunnerImpl::setInput(std::string input_name, std::vector vals); +template int ModelRunnerImpl::setInput(std::string input_name, std::vector vals); +template int ModelRunnerImpl::setInput(std::string input_name, std::vector vals); + +template std::vector ModelRunnerImpl::getOutput(std::string output_name); +template std::vector ModelRunnerImpl::getOutput(std::string output_name); +template std::vector ModelRunnerImpl::getOutput(std::string output_name); +template std::vector ModelRunnerImpl::getOutput(std::string output_name); \ No newline at end of file diff --git a/reference_model/src/model_runner_impl.h b/reference_model/src/model_runner_impl.h new file mode 100644 index 0000000..7a91bfe --- /dev/null +++ b/reference_model/src/model_runner_impl.h @@ -0,0 +1,66 @@ + +// Copyright (c) 2022, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MODEL_RUNNER_IMPL_H_ +#define MODEL_RUNNER_IMPL_H_ + +#include "model_runner.h" +#include "graph_status.h" +#include "version.h" + +#include "ops/op_factory.h" +#include "subgraph_traverser.h" +#include "tosa_serialization_handler.h" + +namespace TosaReference +{ + +/* + * This is a private implementation of the IModelRunner class. + * See documented IModelRunner for usage. + */ +class ModelRunnerImpl +{ +public: + ModelRunnerImpl(); + ModelRunnerImpl(const func_config_t& func_config, const func_debug_t& func_debug); + + ~ModelRunnerImpl(); + + void setFuncConfig(func_config_t& func_config); + void setFuncDebug(func_debug_t& func_debug); + + GraphStatus initialize(TosaSerializationHandler& serialization_handler); + GraphStatus run(); + + template + int setInput(std::string input_name, std::vector vals); + + template + std::vector getOutput(std::string output_name); + +private: + SubgraphTraverser* _main_gt = nullptr; + + // Used to determine if all input tensors have been set correctly. + uint32_t n_input_tensors = 0; + + void validateTosaVersion(TosaSerializationHandler& serialization_handler); + void checkGraphStatus(SubgraphTraverser& main_gt); +}; + +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/subgraph_traverser.h b/reference_model/src/subgraph_traverser.h index 8c66d73..7940ee4 100644 --- a/reference_model/src/subgraph_traverser.h +++ b/reference_model/src/subgraph_traverser.h @@ -16,6 +16,7 @@ #ifndef SUBGRAPH_TRAVERSER_H #define SUBGRAPH_TRAVERSER_H +#include "graph_status.h" #include "graph_node.h" #include "model_common.h" #include "ops/op_factory.h" @@ -26,13 +27,6 @@ namespace TosaReference { -enum class GraphStatus : int -{ - TOSA_VALID = 0, - TOSA_UNPREDICTABLE = 1, - TOSA_ERROR = 2, -}; - class SubgraphTraverser { public: diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc index 90aee05..7cbeb13 100644 --- a/reference_model/src/tensor.cc +++ b/reference_model/src/tensor.cc @@ -382,6 +382,209 @@ DEF_CTENSOR_COPY_VALUE_FROM(6, bool) #undef DEF_CTENSOR_COPY_VALUE_FROM +int TosaReference::Tensor::readfromVector(const std::vector& vals) +{ + uint32_t elements = getElementCount(); + switch (getDtype()) + { + case DType_FLOAT: + if (vals.size() != elements) + { + WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", + vals.size(), elements); + return -1; + } + + setTensorValueFloat(elements, vals.data()); + break; + default: + WARNING("The input type (float) doesn't match the data type assigned to the tensor (%s).", + EnumNameDType(getDtype())); + return -2; + } + setIsValid(); + return 0; +} + +int TosaReference::Tensor::readfromVector(const std::vector& vals) +{ + uint32_t elements = getElementCount(); + switch (getDtype()) + { + case DType_INT32: + case DType_UINT8: + case DType_INT4: + case DType_INT8: + case DType_INT16: + case DType_UINT16: + if (vals.size() != elements) + { + WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", + vals.size(), elements); + return -1; + } + + setTensorValueInt32(elements, vals.data()); + break; + default: + WARNING("The input type doesn't match the data type assigned to the tensor (%s).", + EnumNameDType(getDtype())); + return -2; + } + setIsValid(); + return 0; +} + +int TosaReference::Tensor::readfromVector(const std::vector& vals) +{ + uint32_t elements = getElementCount(); + switch (getDtype()) + { + case DType_INT48: + if (vals.size() != elements) + { + WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", + vals.size(), elements); + return -1; + } + + setTensorValueInt64(elements, vals.data()); + break; + default: + WARNING("The input type doesn't match the data type assigned to the tensor (%s).", + EnumNameDType(getDtype())); + return -2; + } + setIsValid(); + return 0; +} + +int TosaReference::Tensor::readfromVector(const std::vector& vals) +{ + uint32_t elements = getElementCount(); + + switch (getDtype()) + { + case DType_BOOL: + if (vals.size() != elements) + { + WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", + vals.size(), elements); + return -1; + } + + setTensorValueBool(elements, reinterpret_cast(vals.data())); + break; + default: + WARNING("The input type (bool) doesn't match the data type assigned to the tensor (%s).", + EnumNameDType(getDtype())); + return -2; + } + setIsValid(); + return 0; +} + +int TosaReference::Tensor::writeToVector(std::vector& vals) +{ + uint32_t elements = getElementCount(); + + switch (getDtype()) + { + case DType_FLOAT: + if (vals.size() != elements) + { + WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", + vals.size(), elements); + return -1; + } + + getTensorValueFloat(elements, vals.data()); + break; + default: + WARNING("The output type (float) doesn't match the data type assigned to the tensor (%s).", + EnumNameDType(getDtype())); + return -2; + } + return 0; +} + +int TosaReference::Tensor::writeToVector(std::vector& vals) +{ + uint32_t elements = getElementCount(); + + switch (getDtype()) + { + case DType_INT32: + case DType_UINT8: + case DType_INT4: + case DType_INT8: + case DType_INT16: + case DType_UINT16: + if (vals.size() != elements) + { + WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", + vals.size(), elements); + return -1; + } + + getTensorValueInt32(elements, vals.data()); + break; + default: + WARNING("The output type doesn't match the data type assigned to the tensor (%s).", + EnumNameDType(getDtype())); + return -2; + } + return 0; +} + +int TosaReference::Tensor::writeToVector(std::vector& vals) +{ + uint32_t elements = getElementCount(); + + switch (getDtype()) + { + case tosa::DType_INT48: + if (vals.size() != elements) + { + WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", + vals.size(), elements); + return -1; + } + + getTensorValueInt64(elements, vals.data()); + break; + default: + WARNING("The output type doesn't match the data type assigned to the tensor (%s).", + EnumNameDType(getDtype())); + return -2; + } + return 0; +} + +int TosaReference::Tensor::writeToVector(std::vector& vals) +{ + uint32_t elements = getElementCount(); + + switch (getDtype()) + { + case tosa::DType_BOOL: + if (vals.size() != elements) + { + WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", + vals.size(), elements); + return -1; + } + + getTensorValueBool(elements, reinterpret_cast(vals.data())); + break; + default: + WARNING("The output type (bool) doesn't match the data type assigned to the tensor (%s).", + EnumNameDType(getDtype())); + return -2; + } + return 0; +} + template int TosaReference::TensorTemplate::setTensorValueFloat(const size_t buflen, const float* vals) { diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h index ede42a9..6b7d5f1 100644 --- a/reference_model/src/tensor.h +++ b/reference_model/src/tensor.h @@ -228,6 +228,16 @@ public: virtual int writeToNpyFile(const char* filename) const; virtual int copyValueFrom(Tensor* tensor) = 0; + virtual int readfromVector(const std::vector& vals); + virtual int readfromVector(const std::vector& vals); + virtual int readfromVector(const std::vector& vals); + virtual int readfromVector(const std::vector& vals); + + virtual int writeToVector(std::vector& vals); + virtual int writeToVector(std::vector& vals); + virtual int writeToVector(std::vector& vals); + virtual int writeToVector(std::vector& vals); + const char* bool_to_str(bool in) const { static const char* true_str = "true"; diff --git a/reference_model/test/model_runner_tests.cpp b/reference_model/test/model_runner_tests.cpp new file mode 100644 index 0000000..cc295d9 --- /dev/null +++ b/reference_model/test/model_runner_tests.cpp @@ -0,0 +1,154 @@ + +// Copyright (c) 2022, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN +#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN +#endif + +#include "model_runner.h" +#include "general_utils.h" + +#include "doctest.h" + +using namespace TosaReference; +using namespace tosa; + +template +void compareOutput(std::vector& tensor1, std::vector& tensor2, size_t size) +{ + for (size_t i = 0; i < size; ++i) + { + CHECK((tensor1[i] == tensor2[i])); + } +} + +TEST_SUITE("model_runner") +{ + +TEST_CASE("simple_add_f32_test") +{ + std::string test_root(std::string(PROJECT_ROOT) + "../examples/test_add_1x4x4x4_f32/"); + std::string tosa_model_file(test_root + "flatbuffer-tflite/test_add_1x4x4x4_f32.tosa"); + std::string input0_file(test_root + "placeholder_0.npy"); + std::string input1_file(test_root + "placeholder_1.npy"); + std::string expected_output_file(test_root + "tflite_result.npy"); + + std::vector input_names = { "TosaInput_0", "TosaInput_1" }; + std::string output_name = "TosaOutput_0"; + + std::vector input0_shape = { 1, 4, 4, 1 }; + std::vector input1_shape = { 1, 4, 4, 4 }; + std::vector output_shape = { 1, 4, 4, 4 }; + + std::vector> inputs(input_names.size()); + std::vector actual_outputs = { }; + std::vector expected_outputs = { }; + + // Read in inputs and expected outputs. + inputs[0] = readFromNpyFile(input0_file.c_str(), input0_shape); + inputs[1] = readFromNpyFile(input1_file.c_str(), input1_shape); + expected_outputs = readFromNpyFile(expected_output_file.c_str(), output_shape); + + TosaSerializationHandler handler; + tosa_err_t error = handler.LoadFileTosaFlatbuffer(tosa_model_file.c_str()); + CHECK((error == tosa::TOSA_OK)); + + GraphStatus status; + + // Initialize the ModelRunner with configurations. + IModelRunner runner; + status = runner.initialize(handler); + CHECK((status == GraphStatus::TOSA_VALID)); + + runner.setInput(input_names[0], inputs[0]); + runner.setInput(input_names[1], inputs[1]); + + // Run the ModelRunner using test inputs. + status = runner.run(); + CHECK((status == GraphStatus::TOSA_VALID)); + + actual_outputs = runner.getOutput(output_name); + CHECK(!actual_outputs.empty()); + + compareOutput(expected_outputs, actual_outputs, expected_outputs.size()); +} + +TEST_CASE("conv2d_f32_test") +{ + std::string test_root(std::string(PROJECT_ROOT) + "../examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/"); + std::string tosa_model_file(test_root + "flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa"); + std::string input_file(test_root + "placeholder_0.npy"); + std::string expected_output_file(test_root + "tflite_result.npy"); + + std::string input_name = "TosaInput_0"; + std::string output_name = "TosaOutput_0"; + + std::vector input_shape = { 1, 32, 32, 8 }; + std::vector output_shape = { 1, 32, 32, 16 }; + + // Read in inputs and expected outputs. + std::vector inputs = readFromNpyFile(input_file.c_str(), input_shape); + std::vector expected_outputs = readFromNpyFile(expected_output_file.c_str(), output_shape); + + TosaSerializationHandler handler; + tosa_err_t error = handler.LoadFileTosaFlatbuffer(tosa_model_file.c_str()); + CHECK((error == tosa::TOSA_OK)); + + GraphStatus status; + + // Initialize the ModelRunner with configurations. + IModelRunner runner; + status = runner.initialize(handler); + CHECK((status == GraphStatus::TOSA_VALID)); + + runner.setInput(input_name, inputs); + + // Run the ModelRunner using test inputs. + status = runner.run(); + CHECK((status == GraphStatus::TOSA_VALID)); + + std::vector actual_outputs = runner.getOutput(output_name); + CHECK(!actual_outputs.empty()); + + compareOutput(expected_outputs, actual_outputs, expected_outputs.size()); +} + +TEST_CASE("conv2d_f32_validate_only_test") +{ + std::string test_root(std::string(PROJECT_ROOT) + "../examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/"); + std::string tosa_model_file(test_root + "flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa"); + + TosaSerializationHandler handler; + tosa_err_t error = handler.LoadFileTosaFlatbuffer(tosa_model_file.c_str()); + CHECK((error == tosa::TOSA_OK)); + + GraphStatus status; + func_debug_t funcDebug; + + func_config_t funcConfig; + funcConfig.validate_only = 1; + + // Initialize the ModelRunner with configurations. + IModelRunner runner = IModelRunner(funcConfig, funcDebug); + runner.setFuncConfig(funcConfig); + status = runner.initialize(handler); + CHECK((status == GraphStatus::TOSA_VALID)); + + // Run the ModelRunner using no inputs, as validate_only is specified run() should still work. + status = runner.run(); + CHECK((status == GraphStatus::TOSA_VALID)); +} + +} diff --git a/thirdparty/CMakeLists.txt b/thirdparty/CMakeLists.txt index abcc52c..aa6c43d 100644 --- a/thirdparty/CMakeLists.txt +++ b/thirdparty/CMakeLists.txt @@ -12,6 +12,13 @@ set(CMAKE_INSTALL_PREFIX "./thirdparty" CACHE PATH "..." FORCE) project(thirdparty LANGUAGES CXX) -add_subdirectory(cxxopts) add_subdirectory(serialization_lib EXCLUDE_FROM_ALL) -add_subdirectory(json EXCLUDE_FROM_ALL) + +if(BUILD_TOSA_REFERENCE_MODEL_EXECUTABLE) + add_subdirectory(cxxopts) + add_subdirectory(json EXCLUDE_FROM_ALL) +endif() + +if(BUILD_TOSA_REFERENCE_MODEL_TESTS) + add_subdirectory(doctest EXCLUDE_FROM_ALL) +endif() diff --git a/thirdparty/doctest b/thirdparty/doctest new file mode 160000 index 0000000..86892fc --- /dev/null +++ b/thirdparty/doctest @@ -0,0 +1 @@ +Subproject commit 86892fc480f80fb57d9a3926cb506c0e974489d8 -- cgit v1.2.1