diff options
25 files changed, 1343 insertions, 88 deletions
diff --git a/.gitmodules b/.gitmodules index e06268d..87ce1ef 100644 --- a/.gitmodules +++ b/.gitmodules @@ -7,3 +7,6 @@ [submodule "thirdparty/json"] path = thirdparty/json url = https://github.com/nlohmann/json.git +[submodule "thirdparty/doctest"] + path = thirdparty/doctest + url = https://github.com/doctest/doctest.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 04141aa..d3281c6 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -3,7 +3,16 @@ cmake_minimum_required (VERSION 3.4) set(CMAKE_INSTALL_PREFIX ".") project(tosa_tools LANGUAGES CXX) -option(TOSA_TOOLS_BUILD_REFERENCE_MODEL "Enable building of Tosa Reference Model" ON) +option(TOSA_TOOLS_BUILD_REFERENCE_MODEL "Enable building of TOSA Reference Model" ON) +option(BUILD_TOSA_REFERENCE_MODEL_EXECUTABLE "Enable building of TOSA Reference Model executable" ON) +option(BUILD_TOSA_REFERENCE_MODEL_TESTS "Enable building of TOSA Reference Model unit tests" ON) +option(BUILD_MODEL_RUNNER_SAMPLE "Enable building of ModelRunner sample executable" OFF) + +# Custom path options for third party dependencies +option(SERIALIZATION_DIR "Location where the TOSA Serialization Library 'include' folder is found" Off) +option(FLATBUFFERS_DIR "Location where the FlatBuffers 'include' and 'lib' folders is found" Off) +option(EIGEN_DIR "Location where the Eigen folder is found" Off) +option(DOCTEST_DIR "Location where the doctest folder is found (If building unit tests)" Off) add_subdirectory(thirdparty) @@ -33,6 +33,7 @@ The model includes the following git submodules: * TOSA Serialization Library * JSON for Modern C++ - 3.8.0 * Eigen 3.3.7 +* doctest 2.4.9 (When building unit tests) The model is written using C++17 and has been primarily tested on Ubuntu x86_64 18.04 LTS Linux @@ -84,12 +85,18 @@ make ``` The resulting executable will be named: -`reference_model/tosa_reference_model`. CMake only needs to be re-run -if the build environment changes (e.g., new dependencies or source -files). Code changes that do not affect these build rules can be +`reference_model/tosa_reference_model`. This executable can be disabled with +`-DBUILD_TOSA_REFERENCE_MODEL_EXECUTABLE=NO`. +A static library will also be generated by default as follows: +`reference_model/libtosa_reference_model_lib.a`. +To make this a shared library (.so) add the following option +`-DBUILD_SHARED_LIBS=YES`. + +CMake only needs to be re-run if the build environment changes (e.g., new dependencies or source +files). Code changes that do not affect these build rules can be rebuilt simply using `make`. -## Usage +## Executable Usage The inputs to the *TOSA Reference Model* consist of a FlatBuffers file containing the serialized subgraph, a JSON test descriptor that describes @@ -145,7 +152,7 @@ format into output tensors specified by "ofm_file". For example, you can generate new output .npy by: ``` bash ./build/reference_model/tosa_reference_model \ - --test_desc=examples/test_add_1x4x4x4_f32/flatbuffer-tflite/desc.json + --test_desc=examples/test_add_1x4x4x4_f32/flatbuffer-tflite/desc.json \ --ofm_file=out.npy ``` @@ -170,6 +177,22 @@ may cause small differences in output for floating-point tests and differences in quantized scaling between TensorFlow Lite and the TOSA Specification may cause differences in quantized integer tests. +## ModelRunner API + +As an alternative to the executable described above, +the model_runner.h is provided which can be used to invoke the +TOSA Reference Model easily within C++. +A sample of this class and how it can used can be found in +[model_runner_simple_sample.cpp](reference_model/samples/model_runner_simple_sample.cpp). +This sample can be compiled by adding `-DBUILD_MODEL_RUNNER_SAMPLE=YES` to the CMake command +and executed by running `./build/reference_model/model_runner_sample`. + +### ModelRunner API Unit Tests +Unit test are generated by default for the ModelRunner. +This executable can be disabled by adding`-DBUILD_TOSA_REFERENCE_MODEL_TESTS=NO` to the CMake command. +This executable can be run using +`./build/reference_model/unit_tests` and requires the submodule doctest. + ## Debugging The debugging facility can be enabled by setting a debug scope and @@ -219,7 +242,7 @@ programatically compared with output of SUTs to validate them. The test infrastructure needs installing before being used. It is recommended to create a [python virtual environment](https://docs.python.org/3/library/venv.html) -and then install the TOSA Unit Test infrastruture from the root of the +and then install the TOSA Unit Test infrastructure from the root of the reference model: ``` bash diff --git a/reference_model/CMakeLists.txt b/reference_model/CMakeLists.txt index 6fdaa1c..a790968 100644 --- a/reference_model/CMakeLists.txt +++ b/reference_model/CMakeLists.txt @@ -1,6 +1,6 @@ cmake_minimum_required (VERSION 3.4) -# Copyright (c) 2020, ARM Limited. +# Copyright (c) 2020-2022, ARM Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,7 +14,6 @@ cmake_minimum_required (VERSION 3.4) # See the License for the specific language governing permissions and # limitations under the License. - project(tosa_reference_model LANGUAGES CXX) set(CMAKE_CXX_STANDARD 17) @@ -26,52 +25,202 @@ else() set(CMAKE_CXX_FLAGS "-Wall -Wno-ignored-attributes") endif() -set(FLATBUFFERS_DIR "../thirdparty/serialization_lib/third_party/flatbuffers/") -set(SERIALIZATION_DIR "../thirdparty/serialization_lib/") - -set (CXX_SOURCE - src/main.cpp - src/tensor.cc - src/graph_node.cc - src/subgraph_traverser.cc - src/func_debug.cc - src/func_config.cc - src/ops/op_factory.cc - src/ops/tensor_ops.cc - src/ops/activation_funcs.cc - src/ops/ewise_binary.cc - src/ops/ewise_unary.cc - src/ops/ewise_ternary.cc - src/ops/comparison.cc - src/ops/reduction.cc - src/ops/data_layout.cc - src/ops/scatter_gather.cc - src/ops/image.cc - src/ops/type_conversion.cc - src/ops/data_nodes.cc - src/ops/custom.cc - src/ops/control_flow.cc +# If Serialization Library path is specified, look for library so it doesn't have to be built again. +# Otherwise, set the Serialization Library related paths to thirdparty directory. +if(SERIALIZATION_DIR) + find_library(SERIALIZATION_LIB + NAMES libtosa_serialization_lib.a tosa_serialization_lib + NO_DEFAULT_PATH + HINTS ${SERIALIZATION_DIR} + PATH_SUFFIXES lib) + + if(NOT SERIALIZATION_LIB) + message(FATAL_ERROR "TOSA Serialization Library location was specified but not found at: ${SERIALIZATION_LIB_DIR}") + endif() +else() + # Build from third party directory if not found. + set(SERIALIZATION_LIB tosa_serialization_lib) + set(SERIALIZATION_DIR "../thirdparty/serialization_lib/") +endif() + +# If Flatbuffers or Eigen path isn't specified, set to thirdparty directory. +if(NOT FLATBUFFERS_DIR) + set(FLATBUFFERS_DIR "../thirdparty/serialization_lib/third_party/flatbuffers/") +endif() + +if(NOT EIGEN_DIR) + set(EIGEN_DIR "../thirdparty/eigen/") +endif() + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include) + +# Common sources required for TOSA Reference Model library, executable and unit tests +set(CXX_SOURCE + src/model_runner.cc + src/model_runner_impl.cc + src/tensor.cc + src/graph_node.cc + src/subgraph_traverser.cc + src/func_debug.cc + src/ops/op_factory.cc + src/ops/tensor_ops.cc + src/ops/activation_funcs.cc + src/ops/ewise_binary.cc + src/ops/ewise_unary.cc + src/ops/ewise_ternary.cc + src/ops/comparison.cc + src/ops/reduction.cc + src/ops/data_layout.cc + src/ops/scatter_gather.cc + src/ops/image.cc + src/ops/type_conversion.cc + src/ops/data_nodes.cc + src/ops/custom.cc + src/ops/control_flow.cc ) -add_executable(tosa_reference_model ${CXX_SOURCE}) +# Build TOSA Reference Model library +add_library(tosa_reference_model_lib ${CXX_SOURCE}) -target_include_directories(tosa_reference_model +target_include_directories(tosa_reference_model_lib PUBLIC $<INSTALL_INTERFACE:include> $<BUILD_INTERFACE:${CMAKE_CURRENT_SRC_DIR}/include> PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src ${FLATBUFFERS_DIR}/include - ../thirdparty/eigen/ - ../thirdparty/eigen/unsupported/ + ${EIGEN_DIR} + ${EIGEN_DIR}/unsupported/ ${SERIALIZATION_DIR}/include ) -target_link_libraries(tosa_reference_model +target_link_libraries(tosa_reference_model_lib PRIVATE - tosa_serialization_lib - nlohmann_json::nlohmann_json - cxxopts + ${SERIALIZATION_LIB} +) + +set(PUBLIC_HEADERS) +list(APPEND PUBLIC_HEADERS + include/debug_modes.def + include/debug_types.h + include/func_config.h + include/func_debug.h + include/graph_status.h + include/model_common.h + include/model_runner.h + include/version.h +) + +set_target_properties(tosa_reference_model_lib PROPERTIES PUBLIC_HEADER "${PUBLIC_HEADERS}") + +# Build TOSA Refererence Model executable +if(BUILD_TOSA_REFERENCE_MODEL_EXECUTABLE) + set(CXX_SOURCE_EX src/main.cpp) + list(APPEND CXX_SOURCE_EX ${CXX_SOURCE}) + + add_executable(tosa_reference_model ${CXX_SOURCE_EX}) + + target_include_directories(tosa_reference_model + PUBLIC + $<INSTALL_INTERFACE:include> + $<BUILD_INTERFACE:${CMAKE_CURRENT_SRC_DIR}/include> + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/src + ${FLATBUFFERS_DIR}/include + ${EIGEN_DIR} + ${EIGEN_DIR}/unsupported/ + ${SERIALIZATION_DIR}/include + ) + + target_link_libraries(tosa_reference_model + PRIVATE + ${SERIALIZATION_LIB} + nlohmann_json::nlohmann_json + cxxopts + ) + + install(TARGETS tosa_reference_model DESTINATION bin) +endif() + +if(BUILD_TOSA_REFERENCE_MODEL_TESTS) + # Set definition so unit tests can find examples directory. + add_definitions(-DPROJECT_ROOT=\"${CMAKE_CURRENT_SOURCE_DIR}/\") + + # Set doctest location if not specified. + if(NOT DOCTEST_DIR) + set(DOCTEST_DIR "../thirdparty/doctest/doctest") + endif() + + # Sources only required for unit tests. + set(CXX_SOURCE_TESTS + test/model_runner_tests.cpp + ${DOCTEST_DIR}/doctest.h + ) + + list(APPEND CXX_SOURCE_TESTS ${CXX_SOURCE}) + + add_executable(unit_tests ${CXX_SOURCE_TESTS}) + + target_include_directories(unit_tests + PUBLIC + $<INSTALL_INTERFACE:include> + $<BUILD_INTERFACE:${CMAKE_CURRENT_SRC_DIR}/include> + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/src + ${FLATBUFFERS_DIR}/include + ${EIGEN_DIR} + ${EIGEN_DIR}/unsupported/ + ${SERIALIZATION_DIR}/include + ${DOCTEST_DIR} + ) + + target_link_libraries(unit_tests + PRIVATE + ${SERIALIZATION_LIB} + ) +endif() + +if(BUILD_MODEL_RUNNER_SAMPLE) + # Set definition so sample executable can find examples directory. + add_definitions(-DPROJECT_ROOT=\"${CMAKE_CURRENT_SOURCE_DIR}/\") + + # Sources only required for example executable. + set(CXX_SOURCE_SAMPLE + samples/model_runner_simple_sample.cpp + ) + + list(APPEND CXX_SOURCE_SAMPLE ${CXX_SOURCE}) + + add_executable(model_runner_sample ${CXX_SOURCE_SAMPLE}) + + target_include_directories(model_runner_sample + PUBLIC + $<INSTALL_INTERFACE:include> + $<BUILD_INTERFACE:${CMAKE_CURRENT_SRC_DIR}/include> + PRIVATE + ${CMAKE_CURRENT_SOURCE_DIR}/src + ${FLATBUFFERS_DIR}/include + ${EIGEN_DIR} + ${EIGEN_DIR}/unsupported/ + ${SERIALIZATION_DIR}/include + ) + + target_link_libraries(model_runner_sample + PRIVATE + ${SERIALIZATION_LIB} + ) +endif() + +# Follow GNU packaging norms for installation directory structure. +include(GNUInstallDirs) +install( + TARGETS tosa_reference_model_lib EXPORT TosaReferenceModelLibTargets + PUBLIC_HEADER + ARCHIVE ) -install (TARGETS tosa_reference_model DESTINATION bin) +install(EXPORT TosaReferenceModelLibTargets + FILE TosaReferenceModelLibTargets.cmake + NAMESPACE TosaReference:: + DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/tosa_reference_model_lib" +)
\ No newline at end of file diff --git a/reference_model/src/debug_modes.def b/reference_model/include/debug_modes.def index 51b151d..51b151d 100644 --- a/reference_model/src/debug_modes.def +++ b/reference_model/include/debug_modes.def diff --git a/reference_model/src/debug_types.h b/reference_model/include/debug_types.h index bd93f19..bd93f19 100644 --- a/reference_model/src/debug_types.h +++ b/reference_model/include/debug_types.h diff --git a/reference_model/src/func_config.h b/reference_model/include/func_config.h index 49f03e9..41df135 100644 --- a/reference_model/src/func_config.h +++ b/reference_model/include/func_config.h @@ -16,6 +16,9 @@ #ifndef FUNC_CONFIG_H_ #define FUNC_CONFIG_H_ +#include <iostream> +#include <stdio.h> + struct func_config_t { std::string operator_fbs = "tosa.fbs"; @@ -35,12 +38,4 @@ struct func_config_t std::string fp_format = "0.5"; }; -// Forward declaration -struct func_debug_t; - -int func_model_parse_cmd_line( - func_config_t& func_config, func_debug_t& func_debug, int argc, char** argv, const char* version); -int func_model_parse_flat_config_file(func_config_t*, const char* filename); -void func_model_print_help(); - #endif diff --git a/reference_model/src/func_debug.h b/reference_model/include/func_debug.h index ee89935..d762026 100644 --- a/reference_model/src/func_debug.h +++ b/reference_model/include/func_debug.h @@ -21,6 +21,7 @@ #include <cinttypes> #include <signal.h> #include <stdio.h> +#include <string> #include <vector> void func_print_backtrace(FILE* out, int sig = SIGABRT); diff --git a/reference_model/include/graph_status.h b/reference_model/include/graph_status.h new file mode 100644 index 0000000..f3be004 --- /dev/null +++ b/reference_model/include/graph_status.h @@ -0,0 +1,25 @@ +// Copyright (c) 2022, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GRAPH_STATUS_H +#define GRAPH_STATUS_H + +enum class GraphStatus : int +{ + TOSA_VALID = 0, + TOSA_UNPREDICTABLE = 1, + TOSA_ERROR = 2, +}; + +#endif // GRAPH_STATUS_H diff --git a/reference_model/src/model_common.h b/reference_model/include/model_common.h index d6dab6d..d6dab6d 100644 --- a/reference_model/src/model_common.h +++ b/reference_model/include/model_common.h diff --git a/reference_model/include/model_runner.h b/reference_model/include/model_runner.h new file mode 100644 index 0000000..4629467 --- /dev/null +++ b/reference_model/include/model_runner.h @@ -0,0 +1,87 @@ + +// Copyright (c) 2022, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MODEL_RUNNER_H_ +#define MODEL_RUNNER_H_ + +#include "model_common.h" +#include "graph_status.h" + +#include "tosa_serialization_handler.h" + +namespace TosaReference +{ + +class ModelRunnerImpl; + +/* + * This interface allows a user to initialize, run and get the output from a model. + * See model_runner_simple_sample.cpp for example on how this interface can be used. + */ +class IModelRunner +{ +public: + IModelRunner(); + IModelRunner(const func_config_t& func_config, const func_debug_t& func_debug); + + ~IModelRunner(); + + /* + * Functional and debug configurations can also be set. + * See func_config.h and func_debug.h for possible options. + */ + void setFuncConfig(func_config_t& func_config); + void setFuncDebug(func_debug_t& func_debug); + + /* + * Initialize the model. + * The TosaSerializationHandler is validated and then converted to a SubgraphTraverser internally. + * This SubgraphTraverser is initialized, allocated and validated. + * The status of the graph will be returned upon completion. + */ + GraphStatus initialize(tosa::TosaSerializationHandler& serialization_handler); + + /* + * Run the model using the internal SubgraphTraverser created during initialization. + * If validate_only is specified run() will simply return the graph status. + * Otherwise, the graph will be run and the output tensors will be generated if the graph is valid. + * The output tensors can then be retrieved with getOutput(). + * NOTE: initialize() must be called before run(). Also, setInput() must be called for all inputs in the model. + */ + GraphStatus run(); + + /* + * Set the input tensors for the model. + * The input_name much match the input tensor name in the model. + * NOTE: setInput() must be called for each input tensor before run() is called. + */ + template <typename T> + int setInput(std::string input_name, std::vector<T> vals); + + /* + * Retrieve the output tensors from the graph after running. + * The output_name much match the output tensor name in the model. + * NOTE: run() must be called before outputs are retrieved. + */ + template <typename T> + std::vector<T> getOutput(std::string output_name); + +private: + std::unique_ptr<ModelRunnerImpl> model_runner_impl; +}; + +}; // namespace TosaReference + +#endif diff --git a/reference_model/include/version.h b/reference_model/include/version.h new file mode 100644 index 0000000..cd1d598 --- /dev/null +++ b/reference_model/include/version.h @@ -0,0 +1,23 @@ +// Copyright (c) 2022, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef VERSION_H +#define VERSION_H + +#define TOSA_REFERENCE_MODEL_VERSION_MAJOR 0 +#define TOSA_REFERENCE_MODEL_VERSION_MINOR 41 +#define TOSA_REFERENCE_MODEL_VERSION_PATCH 0 +#define TOSA_REFERENCE_MODEL_VERSION_DRAFT true + +#endif //VERSION_H diff --git a/reference_model/samples/model_runner_simple_sample.cpp b/reference_model/samples/model_runner_simple_sample.cpp new file mode 100644 index 0000000..2eebca6 --- /dev/null +++ b/reference_model/samples/model_runner_simple_sample.cpp @@ -0,0 +1,97 @@ + +// Copyright (c) 2022, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "general_utils.h" +#include "model_runner.h" + +int main() +{ + using namespace TosaReference; + + std::string test_root(std::string(PROJECT_ROOT) + "../examples/test_add_1x4x4x4_f32/"); + std::string tosa_model_file(test_root + "flatbuffer-tflite/test_add_1x4x4x4_f32.tosa"); + std::string input0_file(test_root + "placeholder_0.npy"); + std::string input1_file(test_root + "placeholder_1.npy"); + std::string expected_output_file(test_root + "tflite_result.npy"); + + std::vector<std::string> input_names = { "TosaInput_0", "TosaInput_1" }; + std::string output_name = "TosaOutput_0"; + + std::vector<int32_t> input0_shape = { 1, 4, 4, 1 }; + std::vector<int32_t> input1_shape = { 1, 4, 4, 4 }; + std::vector<int32_t> output_shape = { 1, 4, 4, 4 }; + + std::vector<std::vector<float>> inputs(input_names.size()); + std::vector<float> expected_outputs = { }; + std::vector<float> actual_outputs = { }; + + // Read in inputs and expected outputs for sample purposes. + inputs[0] = readFromNpyFile<float>(input0_file.c_str(), input0_shape); + inputs[1] = readFromNpyFile<float>(input1_file.c_str(), input1_shape); + expected_outputs = readFromNpyFile<float>(expected_output_file.c_str(), output_shape); + + tosa::TosaSerializationHandler handler; + tosa::tosa_err_t error = handler.LoadFileTosaFlatbuffer(tosa_model_file.c_str()); + if(error != tosa::TOSA_OK) + { + WARNING("An error occurred while loading the model from file."); + return 1; + } + GraphStatus status; + + // Initialize the ModelRunner with configurations. + IModelRunner runner; + status = runner.initialize(handler); + if(status != GraphStatus::TOSA_VALID) + { + WARNING("An error occurred while initializing."); + return 1; + } + + // Set the model inputs using the input names and input data. + runner.setInput(input_names[0], inputs[0]); + runner.setInput(input_names[1], inputs[1]); + + // Run the ModelRunner using test inputs. + status = runner.run(); + if(status != GraphStatus::TOSA_VALID) + { + WARNING("An error occurred when running the model."); + return 1; + } + + // Get the outputs from the model. + actual_outputs = runner.getOutput<float>(output_name); + + // Compare the actual output to the expected output. + bool if_accurate = true; + for (size_t i = 0; i < expected_outputs.size(); ++i) + { + if(actual_outputs[i] != expected_outputs[i]) + { + WARNING("Actual output (%f) doesn't match expected output (%f)."); + if_accurate = false; + } + } + + if(!if_accurate) + { + WARNING("There were mismatches in actual vs expected output, see above output for more details."); + return 1; + } + + printf("The model ran successfully without errors and matched the expected output.\n"); + return 0; +}
\ No newline at end of file diff --git a/reference_model/src/func_config.cc b/reference_model/src/command_line_utils.h index 6bd809e..1bd1639 100644 --- a/reference_model/src/func_config.cc +++ b/reference_model/src/command_line_utils.h @@ -13,19 +13,15 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include <cxxopts.hpp> -#include <ctype.h> -#include <signal.h> -#include <stdarg.h> -#include <stdint.h> -#include <stdio.h> -#include <stdlib.h> -#include <string.h> -#include <sys/types.h> +#ifndef COMMAND_LINE_UTILS_H_ +#define COMMAND_LINE_UTILS_H_ #include "func_config.h" #include "func_debug.h" +#include <stdint.h> +#include <cxxopts.hpp> + // Read the command line arguments int func_model_parse_cmd_line( func_config_t& func_config, func_debug_t& func_debug, int argc, char** argv, const char* version) @@ -102,7 +98,4 @@ int func_model_parse_cmd_line( return 0; } -void func_model_print_help() -{ - -} +#endif diff --git a/reference_model/src/general_utils.h b/reference_model/src/general_utils.h new file mode 100644 index 0000000..12f831e --- /dev/null +++ b/reference_model/src/general_utils.h @@ -0,0 +1,68 @@ + +// Copyright (c) 2022, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef GENERAL_UTILS_H_ +#define GENERAL_UTILS_H_ + +#include "func_debug.h" + +#include "numpy_utils.h" + +namespace TosaReference +{ + +const uint32_t getElementCount(std::vector<int32_t>& shape) +{ + uint32_t elements = 1; + for (size_t i = 0; i < shape.size(); i++) + { + elements *= shape[i]; + } + + return elements; +} + +template <typename T> +std::vector<T> readFromNpyFile(const char* filename, std::vector<int32_t>& shape) +{ + uint32_t elements = getElementCount(shape); + std::vector<T> data(elements, 0); + + NumpyUtilities::NPError nperror = NumpyUtilities::readFromNpyFile(filename, elements, data.data()); + + switch (nperror) + { + case NumpyUtilities::NO_ERROR: + break; + case NumpyUtilities::FILE_NOT_FOUND: + FATAL_ERROR("readFromNpyFile: Cannot open file %s", filename); + case NumpyUtilities::FILE_IO_ERROR: + FATAL_ERROR("readFromNpyFile: IO error reading file: %s", filename); + case NumpyUtilities::FILE_TYPE_MISMATCH: + FATAL_ERROR("readFromNpyFile: Tensor type and Numpy file type mismatch for filename %s", filename); + case NumpyUtilities::HEADER_PARSE_ERROR: + FATAL_ERROR("Numpy header parsing error for file: %s", filename); + case NumpyUtilities::BUFFER_SIZE_MISMATCH: + FATAL_ERROR("Buffer size does not match numpy file size for filename %s", filename); + default: + FATAL_ERROR("Unknown error parsing Numpy file: %s", filename); + } + + return data; +} + +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp index 643a351..776fbf3 100644 --- a/reference_model/src/main.cpp +++ b/reference_model/src/main.cpp @@ -13,31 +13,24 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include <stdio.h> +#include "model_runner.h" +#include "version.h" -#include "model_common.h" +#include "command_line_utils.h" #include "ops/op_factory.h" #include "subgraph_traverser.h" #include "tosa_serialization_handler.h" -#include <Eigen/CXX11/Tensor> -#include <iostream> #include <fstream> +#include <iostream> +#include <stdio.h> +#include <Eigen/CXX11/Tensor> #include <nlohmann/json.hpp> -#define MODEL_VERSION_MAJOR 0 -#define MODEL_VERSION_MINOR 41 -#define MODEL_VERSION_PATCH 0 -#define MODEL_VERSION_DRAFT true - using namespace TosaReference; using namespace tosa; using json = nlohmann::json; -// Global instantiation of configuration and debug objects -func_config_t g_func_config; -func_debug_t g_func_debug; - int initTestDesc(json& test_desc); int readInputTensors(SubgraphTraverser& gt, json test_desc); int writeFinalTensors(SubgraphTraverser& gt, json test_desc); @@ -45,7 +38,10 @@ int loadGraph(TosaSerializationHandler& tsh, json test_desc); int main(int argc, char** argv) { - TosaVersion model_version(MODEL_VERSION_MAJOR, MODEL_VERSION_MINOR, MODEL_VERSION_PATCH, MODEL_VERSION_DRAFT); + TosaVersion model_version(TOSA_REFERENCE_MODEL_VERSION_MAJOR, + TOSA_REFERENCE_MODEL_VERSION_MINOR, + TOSA_REFERENCE_MODEL_VERSION_PATCH, + TOSA_REFERENCE_MODEL_VERSION_DRAFT); // Initialize configuration and debug subsystems g_func_debug.init_debug(0); @@ -203,7 +199,6 @@ int loadGraph(TosaSerializationHandler& tsh, json test_desc) if (strlen(graph_fullname) <= 2) { - func_model_print_help(); FATAL_ERROR("Missing required argument: Check \"tosa_file\" in .json specified by -Ctosa_desc="); } diff --git a/reference_model/src/model_runner.cc b/reference_model/src/model_runner.cc new file mode 100644 index 0000000..2395a85 --- /dev/null +++ b/reference_model/src/model_runner.cc @@ -0,0 +1,76 @@ + +// Copyright (c) 2022, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "model_runner_impl.h" + +using namespace TosaReference; + +// Global instantiation of configuration and debug objects +func_config_t g_func_config; +func_debug_t g_func_debug; + +IModelRunner::IModelRunner() : model_runner_impl(new ModelRunnerImpl()) +{} + +IModelRunner::IModelRunner(const func_config_t& func_config, + const func_debug_t& func_debug) + : model_runner_impl(new ModelRunnerImpl(func_config, func_debug)) +{} + +IModelRunner::~IModelRunner() +{} + +void IModelRunner::setFuncConfig(func_config_t& func_config) +{ + model_runner_impl->setFuncConfig(func_config); +} + +void IModelRunner::setFuncDebug(func_debug_t& func_debug) +{ + model_runner_impl->setFuncDebug(func_debug); +} + +GraphStatus IModelRunner::initialize(tosa::TosaSerializationHandler& serialization_handler) +{ + return model_runner_impl->initialize(serialization_handler); +} + +GraphStatus IModelRunner::run() +{ + return model_runner_impl->run(); +} + +template <typename T> +int IModelRunner::setInput(std::string input_name, std::vector<T> vals) +{ + return model_runner_impl->setInput<T>(input_name, vals); +} + +template <typename T> +std::vector<T> IModelRunner::getOutput(std::string output_name) +{ + return model_runner_impl->getOutput<T>(output_name); +} + +// Template explicit specialization +template int IModelRunner::setInput<float>(std::string input_name, std::vector<float> vals); +template int IModelRunner::setInput<int32_t>(std::string input_name, std::vector<int32_t> vals); +template int IModelRunner::setInput<int64_t>(std::string input_name, std::vector<int64_t> vals); +template int IModelRunner::setInput<unsigned char>(std::string input_name, std::vector<unsigned char> vals); + +template std::vector<float> IModelRunner::getOutput<float>(std::string output_name); +template std::vector<int32_t> IModelRunner::getOutput<int32_t>(std::string output_name); +template std::vector<int64_t> IModelRunner::getOutput<int64_t>(std::string output_name); +template std::vector<unsigned char> IModelRunner::getOutput<unsigned char>(std::string output_name);
\ No newline at end of file diff --git a/reference_model/src/model_runner_impl.cc b/reference_model/src/model_runner_impl.cc new file mode 100644 index 0000000..e0fdc49 --- /dev/null +++ b/reference_model/src/model_runner_impl.cc @@ -0,0 +1,277 @@ + +// Copyright (c) 2022, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "model_runner_impl.h" + +using namespace TosaReference; + +ModelRunnerImpl::ModelRunnerImpl() +{} + +ModelRunnerImpl::ModelRunnerImpl(const func_config_t& func_config, + const func_debug_t& func_debug) +{ + g_func_config = func_config; + g_func_debug = func_debug; +} + +ModelRunnerImpl::~ModelRunnerImpl() +{ + g_func_debug.fini_debug(); + delete _main_gt; +}; + +void ModelRunnerImpl::setFuncConfig(func_config_t& func_config) +{ + g_func_config = func_config; +} +void ModelRunnerImpl::setFuncDebug(func_debug_t& func_debug) +{ + g_func_debug = func_debug; +} + +GraphStatus ModelRunnerImpl::initialize(TosaSerializationHandler& serialization_handler) +{ + validateTosaVersion(serialization_handler); + + // Make nullptr in case ModelRunnerImpl is being initialized again with a different graph. + _main_gt = nullptr; + _main_gt = new SubgraphTraverser(serialization_handler.GetMainBlock(), &serialization_handler); + + if(_main_gt == nullptr) + { + WARNING("An error occurred when generating main graph traverser."); + return GraphStatus::TOSA_ERROR; + } + + if (_main_gt->initializeGraph()) + { + WARNING("Unable to initialize main graph traverser."); + return _main_gt->getGraphStatus(); + } + + if (_main_gt->linkTensorsAndNodes()) + { + WARNING("Failed to link tensors and nodes"); + return _main_gt->getGraphStatus(); + } + + if (_main_gt->validateGraph()) + { + WARNING("Failed to validate graph."); + return _main_gt->getGraphStatus(); + } + + if (_main_gt->allocateTensor()) + { + WARNING("Failed to allocate tensor."); + return _main_gt->getGraphStatus(); + } + + return _main_gt->getGraphStatus(); +} + +GraphStatus ModelRunnerImpl::run() +{ + if (_main_gt == nullptr) + { + FATAL_ERROR("ModelRunnerImpl hasn't been initialized, please invoke initialize() before run()"); + } + + if (g_func_config.validate_only) + { + goto done; + } + + // Validate the number of inputs matches the + if (static_cast<uint32_t>(_main_gt->getNumInputTensors()) != n_input_tensors) + { + FATAL_ERROR("The number of inputs (%d) does not equal the number of inputs in the model (%d). " + "setInput() must be called for each input.", + n_input_tensors, _main_gt->getNumInputTensors()); + } + + if (g_func_config.eval) + { + // evaluateAll() returns 1 if graph evaluation is forced to be terminated earlier. + if (_main_gt->evaluateAll()) + { + ASSERT_MSG(_main_gt->getGraphStatus() != GraphStatus::TOSA_VALID, + "Upon evaluateAll() returning 1, graph can not be VALID."); + } + else + { + ASSERT_MSG(_main_gt->getGraphStatus() == GraphStatus::TOSA_VALID || + _main_gt->getGraphStatus() == GraphStatus::TOSA_UNPREDICTABLE, + "Upon evaluateAll() returning 0, graph can only be VALID/UNPREDICTABLE."); + } + + // Only generate output tensor if graph is valid. + if (_main_gt->getGraphStatus() == GraphStatus::TOSA_VALID) + { + // Make sure output tensor is evaluated and show its value + int num_output_tensors = _main_gt->getNumOutputTensors(); + bool all_output_valid = true; + for (int i = 0; i < num_output_tensors; i++) + { + const Tensor* ct = _main_gt->getOutputTensor(i); + ASSERT_MEM(ct); + if (!ct->getIsValid()) + { + ct->dumpTensorParams(g_func_debug.func_debug_file); + if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT)) + { + ct->dumpTensor(g_func_debug.func_debug_file); + } + all_output_valid = false; + } + } + if (!all_output_valid) + { + _main_gt->dumpGraph(g_func_debug.func_debug_file); + FATAL_ERROR( + "SubgraphTraverser \"main\" error: Output tensors are not all valid at the end of evaluation."); + } + } + } + +done: + // Print status if not valid and do cleanup. + checkGraphStatus(*_main_gt); + g_func_debug.fini_debug(); + + return _main_gt->getGraphStatus(); +} + +template <typename T> +int ModelRunnerImpl::setInput(std::string input_name, std::vector<T> vals) +{ + if (_main_gt == nullptr) + { + FATAL_ERROR("ModelRunner hasn't been initialized, please invoke initialize() before setInput()"); + } + + Tensor* tensor; + tensor = _main_gt->getInputTensorByName(input_name); + + if (!tensor) + { + WARNING("Unable to find input tensor %s", input_name.c_str()); + return 1; + } + + if (!tensor->is_allocated()) + { + WARNING("Tensor %s is not allocated before being initialized", tensor->getName().c_str()); + return 1; + } + + if (tensor->readfromVector(vals)) + { + WARNING("Unable to convert input tensor %s to Tensor", tensor->getName().c_str()); + return 1; + } + + // Push ready consumers to the next node list + for (auto gn : tensor->getConsumers()) + { + if (gn->hasAllInputsReady() && !gn->getOnNextNodeList()) + { + _main_gt->addToNextNodeList(gn); + } + } + + n_input_tensors++; + return 0; +} + +template <typename T> +std::vector<T> ModelRunnerImpl::getOutput(std::string output_name) +{ + if (_main_gt == nullptr) + { + FATAL_ERROR("ModelRunner hasn't been initialized, please invoke initialize() and run() before getOutput()"); + } + + Tensor* tensor; + tensor = _main_gt->getOutputTensorByName(output_name); + + if (!tensor) + { + WARNING("Unable to find output tensor %s", output_name.c_str()); + return std::vector<T>(); + } + + std::vector<T> outputs(tensor->getElementCount(), 0); + + if (tensor->writeToVector(outputs)) + { + WARNING("Unable to convert output tensor %s to vector", tensor->getName().c_str()); + return std::vector<T>(); + } + + return outputs; +} + +void ModelRunnerImpl::validateTosaVersion(TosaSerializationHandler& serialization_handler) +{ + TosaVersion model_version(TOSA_REFERENCE_MODEL_VERSION_MAJOR, + TOSA_REFERENCE_MODEL_VERSION_MINOR, + TOSA_REFERENCE_MODEL_VERSION_PATCH, + TOSA_REFERENCE_MODEL_VERSION_DRAFT); + + TosaVersion::compat_t is_compat = model_version.is_compatible(serialization_handler.GetVersion()); + switch (is_compat) + { + case TosaVersion::compat_t::COMPLETELY_COMPATIBLE: + break; + case TosaVersion::compat_t::PARTIALLY_COMPATIBLE: + WARNING("Reference model version %s is partially compatible with serializer version %s.", + model_version.to_string().c_str(), serialization_handler.GetVersion().to_string().c_str()); + break; + case TosaVersion::compat_t::NOT_COMPATIBLE: + FATAL_ERROR("Reference model version %s is not compatible with serializer version %s.", + model_version.to_string().c_str(), serialization_handler.GetVersion().to_string().c_str()); + } +} + +void ModelRunnerImpl::checkGraphStatus(SubgraphTraverser& main_gt) +{ + switch (main_gt.getGraphStatus()) + { + case GraphStatus::TOSA_VALID: + // Result is valid. + break; + case GraphStatus::TOSA_UNPREDICTABLE: + WARNING("Graph result: UNPREDICTABLE."); + break; + case GraphStatus::TOSA_ERROR: + WARNING("Graph result: ERROR."); + break; + default: + WARNING("Unknown graph status code=%d.", (int)main_gt.getGraphStatus()); + } +} + +// Template explicit specialization +template int ModelRunnerImpl::setInput<float>(std::string input_name, std::vector<float> vals); +template int ModelRunnerImpl::setInput<int32_t>(std::string input_name, std::vector<int32_t> vals); +template int ModelRunnerImpl::setInput<int64_t>(std::string input_name, std::vector<int64_t> vals); +template int ModelRunnerImpl::setInput<unsigned char>(std::string input_name, std::vector<unsigned char> vals); + +template std::vector<float> ModelRunnerImpl::getOutput<float>(std::string output_name); +template std::vector<int32_t> ModelRunnerImpl::getOutput<int32_t>(std::string output_name); +template std::vector<int64_t> ModelRunnerImpl::getOutput<int64_t>(std::string output_name); +template std::vector<unsigned char> ModelRunnerImpl::getOutput<unsigned char>(std::string output_name);
\ No newline at end of file diff --git a/reference_model/src/model_runner_impl.h b/reference_model/src/model_runner_impl.h new file mode 100644 index 0000000..7a91bfe --- /dev/null +++ b/reference_model/src/model_runner_impl.h @@ -0,0 +1,66 @@ + +// Copyright (c) 2022, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef MODEL_RUNNER_IMPL_H_ +#define MODEL_RUNNER_IMPL_H_ + +#include "model_runner.h" +#include "graph_status.h" +#include "version.h" + +#include "ops/op_factory.h" +#include "subgraph_traverser.h" +#include "tosa_serialization_handler.h" + +namespace TosaReference +{ + +/* + * This is a private implementation of the IModelRunner class. + * See documented IModelRunner for usage. + */ +class ModelRunnerImpl +{ +public: + ModelRunnerImpl(); + ModelRunnerImpl(const func_config_t& func_config, const func_debug_t& func_debug); + + ~ModelRunnerImpl(); + + void setFuncConfig(func_config_t& func_config); + void setFuncDebug(func_debug_t& func_debug); + + GraphStatus initialize(TosaSerializationHandler& serialization_handler); + GraphStatus run(); + + template <typename T> + int setInput(std::string input_name, std::vector<T> vals); + + template <typename T> + std::vector<T> getOutput(std::string output_name); + +private: + SubgraphTraverser* _main_gt = nullptr; + + // Used to determine if all input tensors have been set correctly. + uint32_t n_input_tensors = 0; + + void validateTosaVersion(TosaSerializationHandler& serialization_handler); + void checkGraphStatus(SubgraphTraverser& main_gt); +}; + +}; // namespace TosaReference + +#endif diff --git a/reference_model/src/subgraph_traverser.h b/reference_model/src/subgraph_traverser.h index 8c66d73..7940ee4 100644 --- a/reference_model/src/subgraph_traverser.h +++ b/reference_model/src/subgraph_traverser.h @@ -16,6 +16,7 @@ #ifndef SUBGRAPH_TRAVERSER_H #define SUBGRAPH_TRAVERSER_H +#include "graph_status.h" #include "graph_node.h" #include "model_common.h" #include "ops/op_factory.h" @@ -26,13 +27,6 @@ namespace TosaReference { -enum class GraphStatus : int -{ - TOSA_VALID = 0, - TOSA_UNPREDICTABLE = 1, - TOSA_ERROR = 2, -}; - class SubgraphTraverser { public: diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc index 90aee05..7cbeb13 100644 --- a/reference_model/src/tensor.cc +++ b/reference_model/src/tensor.cc @@ -382,6 +382,209 @@ DEF_CTENSOR_COPY_VALUE_FROM(6, bool) #undef DEF_CTENSOR_COPY_VALUE_FROM +int TosaReference::Tensor::readfromVector(const std::vector<float>& vals) +{ + uint32_t elements = getElementCount(); + switch (getDtype()) + { + case DType_FLOAT: + if (vals.size() != elements) + { + WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", + vals.size(), elements); + return -1; + } + + setTensorValueFloat(elements, vals.data()); + break; + default: + WARNING("The input type (float) doesn't match the data type assigned to the tensor (%s).", + EnumNameDType(getDtype())); + return -2; + } + setIsValid(); + return 0; +} + +int TosaReference::Tensor::readfromVector(const std::vector<int32_t>& vals) +{ + uint32_t elements = getElementCount(); + switch (getDtype()) + { + case DType_INT32: + case DType_UINT8: + case DType_INT4: + case DType_INT8: + case DType_INT16: + case DType_UINT16: + if (vals.size() != elements) + { + WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", + vals.size(), elements); + return -1; + } + + setTensorValueInt32(elements, vals.data()); + break; + default: + WARNING("The input type doesn't match the data type assigned to the tensor (%s).", + EnumNameDType(getDtype())); + return -2; + } + setIsValid(); + return 0; +} + +int TosaReference::Tensor::readfromVector(const std::vector<int64_t>& vals) +{ + uint32_t elements = getElementCount(); + switch (getDtype()) + { + case DType_INT48: + if (vals.size() != elements) + { + WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", + vals.size(), elements); + return -1; + } + + setTensorValueInt64(elements, vals.data()); + break; + default: + WARNING("The input type doesn't match the data type assigned to the tensor (%s).", + EnumNameDType(getDtype())); + return -2; + } + setIsValid(); + return 0; +} + +int TosaReference::Tensor::readfromVector(const std::vector<unsigned char>& vals) +{ + uint32_t elements = getElementCount(); + + switch (getDtype()) + { + case DType_BOOL: + if (vals.size() != elements) + { + WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", + vals.size(), elements); + return -1; + } + + setTensorValueBool(elements, reinterpret_cast<const bool*>(vals.data())); + break; + default: + WARNING("The input type (bool) doesn't match the data type assigned to the tensor (%s).", + EnumNameDType(getDtype())); + return -2; + } + setIsValid(); + return 0; +} + +int TosaReference::Tensor::writeToVector(std::vector<float>& vals) +{ + uint32_t elements = getElementCount(); + + switch (getDtype()) + { + case DType_FLOAT: + if (vals.size() != elements) + { + WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", + vals.size(), elements); + return -1; + } + + getTensorValueFloat(elements, vals.data()); + break; + default: + WARNING("The output type (float) doesn't match the data type assigned to the tensor (%s).", + EnumNameDType(getDtype())); + return -2; + } + return 0; +} + +int TosaReference::Tensor::writeToVector(std::vector<int32_t>& vals) +{ + uint32_t elements = getElementCount(); + + switch (getDtype()) + { + case DType_INT32: + case DType_UINT8: + case DType_INT4: + case DType_INT8: + case DType_INT16: + case DType_UINT16: + if (vals.size() != elements) + { + WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", + vals.size(), elements); + return -1; + } + + getTensorValueInt32(elements, vals.data()); + break; + default: + WARNING("The output type doesn't match the data type assigned to the tensor (%s).", + EnumNameDType(getDtype())); + return -2; + } + return 0; +} + +int TosaReference::Tensor::writeToVector(std::vector<int64_t>& vals) +{ + uint32_t elements = getElementCount(); + + switch (getDtype()) + { + case tosa::DType_INT48: + if (vals.size() != elements) + { + WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", + vals.size(), elements); + return -1; + } + + getTensorValueInt64(elements, vals.data()); + break; + default: + WARNING("The output type doesn't match the data type assigned to the tensor (%s).", + EnumNameDType(getDtype())); + return -2; + } + return 0; +} + +int TosaReference::Tensor::writeToVector(std::vector<unsigned char>& vals) +{ + uint32_t elements = getElementCount(); + + switch (getDtype()) + { + case tosa::DType_BOOL: + if (vals.size() != elements) + { + WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.", + vals.size(), elements); + return -1; + } + + getTensorValueBool(elements, reinterpret_cast<bool*>(vals.data())); + break; + default: + WARNING("The output type (bool) doesn't match the data type assigned to the tensor (%s).", + EnumNameDType(getDtype())); + return -2; + } + return 0; +} + template <class T> int TosaReference::TensorTemplate<T>::setTensorValueFloat(const size_t buflen, const float* vals) { diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h index ede42a9..6b7d5f1 100644 --- a/reference_model/src/tensor.h +++ b/reference_model/src/tensor.h @@ -228,6 +228,16 @@ public: virtual int writeToNpyFile(const char* filename) const; virtual int copyValueFrom(Tensor* tensor) = 0; + virtual int readfromVector(const std::vector<float>& vals); + virtual int readfromVector(const std::vector<int32_t>& vals); + virtual int readfromVector(const std::vector<int64_t>& vals); + virtual int readfromVector(const std::vector<unsigned char>& vals); + + virtual int writeToVector(std::vector<float>& vals); + virtual int writeToVector(std::vector<int32_t>& vals); + virtual int writeToVector(std::vector<int64_t>& vals); + virtual int writeToVector(std::vector<unsigned char>& vals); + const char* bool_to_str(bool in) const { static const char* true_str = "true"; diff --git a/reference_model/test/model_runner_tests.cpp b/reference_model/test/model_runner_tests.cpp new file mode 100644 index 0000000..cc295d9 --- /dev/null +++ b/reference_model/test/model_runner_tests.cpp @@ -0,0 +1,154 @@ + +// Copyright (c) 2022, ARM Limited. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN +#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN +#endif + +#include "model_runner.h" +#include "general_utils.h" + +#include "doctest.h" + +using namespace TosaReference; +using namespace tosa; + +template <typename T> +void compareOutput(std::vector<T>& tensor1, std::vector<T>& tensor2, size_t size) +{ + for (size_t i = 0; i < size; ++i) + { + CHECK((tensor1[i] == tensor2[i])); + } +} + +TEST_SUITE("model_runner") +{ + +TEST_CASE("simple_add_f32_test") +{ + std::string test_root(std::string(PROJECT_ROOT) + "../examples/test_add_1x4x4x4_f32/"); + std::string tosa_model_file(test_root + "flatbuffer-tflite/test_add_1x4x4x4_f32.tosa"); + std::string input0_file(test_root + "placeholder_0.npy"); + std::string input1_file(test_root + "placeholder_1.npy"); + std::string expected_output_file(test_root + "tflite_result.npy"); + + std::vector<std::string> input_names = { "TosaInput_0", "TosaInput_1" }; + std::string output_name = "TosaOutput_0"; + + std::vector<int32_t> input0_shape = { 1, 4, 4, 1 }; + std::vector<int32_t> input1_shape = { 1, 4, 4, 4 }; + std::vector<int32_t> output_shape = { 1, 4, 4, 4 }; + + std::vector<std::vector<float>> inputs(input_names.size()); + std::vector<float> actual_outputs = { }; + std::vector<float> expected_outputs = { }; + + // Read in inputs and expected outputs. + inputs[0] = readFromNpyFile<float>(input0_file.c_str(), input0_shape); + inputs[1] = readFromNpyFile<float>(input1_file.c_str(), input1_shape); + expected_outputs = readFromNpyFile<float>(expected_output_file.c_str(), output_shape); + + TosaSerializationHandler handler; + tosa_err_t error = handler.LoadFileTosaFlatbuffer(tosa_model_file.c_str()); + CHECK((error == tosa::TOSA_OK)); + + GraphStatus status; + + // Initialize the ModelRunner with configurations. + IModelRunner runner; + status = runner.initialize(handler); + CHECK((status == GraphStatus::TOSA_VALID)); + + runner.setInput(input_names[0], inputs[0]); + runner.setInput(input_names[1], inputs[1]); + + // Run the ModelRunner using test inputs. + status = runner.run(); + CHECK((status == GraphStatus::TOSA_VALID)); + + actual_outputs = runner.getOutput<float>(output_name); + CHECK(!actual_outputs.empty()); + + compareOutput(expected_outputs, actual_outputs, expected_outputs.size()); +} + +TEST_CASE("conv2d_f32_test") +{ + std::string test_root(std::string(PROJECT_ROOT) + "../examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/"); + std::string tosa_model_file(test_root + "flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa"); + std::string input_file(test_root + "placeholder_0.npy"); + std::string expected_output_file(test_root + "tflite_result.npy"); + + std::string input_name = "TosaInput_0"; + std::string output_name = "TosaOutput_0"; + + std::vector<int32_t> input_shape = { 1, 32, 32, 8 }; + std::vector<int32_t> output_shape = { 1, 32, 32, 16 }; + + // Read in inputs and expected outputs. + std::vector<float> inputs = readFromNpyFile<float>(input_file.c_str(), input_shape); + std::vector<float> expected_outputs = readFromNpyFile<float>(expected_output_file.c_str(), output_shape); + + TosaSerializationHandler handler; + tosa_err_t error = handler.LoadFileTosaFlatbuffer(tosa_model_file.c_str()); + CHECK((error == tosa::TOSA_OK)); + + GraphStatus status; + + // Initialize the ModelRunner with configurations. + IModelRunner runner; + status = runner.initialize(handler); + CHECK((status == GraphStatus::TOSA_VALID)); + + runner.setInput(input_name, inputs); + + // Run the ModelRunner using test inputs. + status = runner.run(); + CHECK((status == GraphStatus::TOSA_VALID)); + + std::vector<float> actual_outputs = runner.getOutput<float>(output_name); + CHECK(!actual_outputs.empty()); + + compareOutput(expected_outputs, actual_outputs, expected_outputs.size()); +} + +TEST_CASE("conv2d_f32_validate_only_test") +{ + std::string test_root(std::string(PROJECT_ROOT) + "../examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/"); + std::string tosa_model_file(test_root + "flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa"); + + TosaSerializationHandler handler; + tosa_err_t error = handler.LoadFileTosaFlatbuffer(tosa_model_file.c_str()); + CHECK((error == tosa::TOSA_OK)); + + GraphStatus status; + func_debug_t funcDebug; + + func_config_t funcConfig; + funcConfig.validate_only = 1; + + // Initialize the ModelRunner with configurations. + IModelRunner runner = IModelRunner(funcConfig, funcDebug); + runner.setFuncConfig(funcConfig); + status = runner.initialize(handler); + CHECK((status == GraphStatus::TOSA_VALID)); + + // Run the ModelRunner using no inputs, as validate_only is specified run() should still work. + status = runner.run(); + CHECK((status == GraphStatus::TOSA_VALID)); +} + +} diff --git a/thirdparty/CMakeLists.txt b/thirdparty/CMakeLists.txt index abcc52c..aa6c43d 100644 --- a/thirdparty/CMakeLists.txt +++ b/thirdparty/CMakeLists.txt @@ -12,6 +12,13 @@ set(CMAKE_INSTALL_PREFIX "./thirdparty" CACHE PATH "..." FORCE) project(thirdparty LANGUAGES CXX) -add_subdirectory(cxxopts) add_subdirectory(serialization_lib EXCLUDE_FROM_ALL) -add_subdirectory(json EXCLUDE_FROM_ALL) + +if(BUILD_TOSA_REFERENCE_MODEL_EXECUTABLE) + add_subdirectory(cxxopts) + add_subdirectory(json EXCLUDE_FROM_ALL) +endif() + +if(BUILD_TOSA_REFERENCE_MODEL_TESTS) + add_subdirectory(doctest EXCLUDE_FROM_ALL) +endif() diff --git a/thirdparty/doctest b/thirdparty/doctest new file mode 160000 +Subproject 86892fc480f80fb57d9a3926cb506c0e974489d |