aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatthew Sloyan <matthew.sloyan@arm.com>2022-09-26 13:31:43 +0100
committerMatthew Sloyan <matthew.sloyan@arm.com>2022-10-07 16:02:07 +0100
commitba5fad356a926d5e1c6e0fe6b546a310230cc5a8 (patch)
tree10e3e127c091da90d591253fd55e8566c0b61e7a
parenta0848c6edbf37034e280a670bdd2f990fdf796da (diff)
downloadreference_model-ba5fad356a926d5e1c6e0fe6b546a310230cc5a8.tar.gz
Add IModelRunner interface to TOSA Reference Model
* Added IModelRunner interface using pimpl idiom, which allows a user to initialize, configure and run the model. * Added unit tests for IModelRunner. * Added doctest as third-party submodule. * Added user options to specify paths for dependencies. * Moved general func_config functions to separate utility, which removes cxxopts dependency. Signed-off-by: Matthew Sloyan <matthew.sloyan@arm.com> Change-Id: If42f1f82cd6dadf18911a48dcd5fa579b719aff2
-rw-r--r--.gitmodules3
-rw-r--r--CMakeLists.txt11
-rw-r--r--README.md35
-rw-r--r--reference_model/CMakeLists.txt221
-rw-r--r--reference_model/include/debug_modes.def (renamed from reference_model/src/debug_modes.def)0
-rw-r--r--reference_model/include/debug_types.h (renamed from reference_model/src/debug_types.h)0
-rw-r--r--reference_model/include/func_config.h (renamed from reference_model/src/func_config.h)11
-rw-r--r--reference_model/include/func_debug.h (renamed from reference_model/src/func_debug.h)1
-rw-r--r--reference_model/include/graph_status.h25
-rw-r--r--reference_model/include/model_common.h (renamed from reference_model/src/model_common.h)0
-rw-r--r--reference_model/include/model_runner.h87
-rw-r--r--reference_model/include/version.h23
-rw-r--r--reference_model/samples/model_runner_simple_sample.cpp97
-rw-r--r--reference_model/src/command_line_utils.h (renamed from reference_model/src/func_config.cc)19
-rw-r--r--reference_model/src/general_utils.h68
-rw-r--r--reference_model/src/main.cpp25
-rw-r--r--reference_model/src/model_runner.cc76
-rw-r--r--reference_model/src/model_runner_impl.cc277
-rw-r--r--reference_model/src/model_runner_impl.h66
-rw-r--r--reference_model/src/subgraph_traverser.h8
-rw-r--r--reference_model/src/tensor.cc203
-rw-r--r--reference_model/src/tensor.h10
-rw-r--r--reference_model/test/model_runner_tests.cpp154
-rw-r--r--thirdparty/CMakeLists.txt11
m---------thirdparty/doctest0
25 files changed, 1343 insertions, 88 deletions
diff --git a/.gitmodules b/.gitmodules
index e06268d..87ce1ef 100644
--- a/.gitmodules
+++ b/.gitmodules
@@ -7,3 +7,6 @@
[submodule "thirdparty/json"]
path = thirdparty/json
url = https://github.com/nlohmann/json.git
+[submodule "thirdparty/doctest"]
+ path = thirdparty/doctest
+ url = https://github.com/doctest/doctest.git
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 04141aa..d3281c6 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -3,7 +3,16 @@ cmake_minimum_required (VERSION 3.4)
set(CMAKE_INSTALL_PREFIX ".")
project(tosa_tools LANGUAGES CXX)
-option(TOSA_TOOLS_BUILD_REFERENCE_MODEL "Enable building of Tosa Reference Model" ON)
+option(TOSA_TOOLS_BUILD_REFERENCE_MODEL "Enable building of TOSA Reference Model" ON)
+option(BUILD_TOSA_REFERENCE_MODEL_EXECUTABLE "Enable building of TOSA Reference Model executable" ON)
+option(BUILD_TOSA_REFERENCE_MODEL_TESTS "Enable building of TOSA Reference Model unit tests" ON)
+option(BUILD_MODEL_RUNNER_SAMPLE "Enable building of ModelRunner sample executable" OFF)
+
+# Custom path options for third party dependencies
+option(SERIALIZATION_DIR "Location where the TOSA Serialization Library 'include' folder is found" Off)
+option(FLATBUFFERS_DIR "Location where the FlatBuffers 'include' and 'lib' folders is found" Off)
+option(EIGEN_DIR "Location where the Eigen folder is found" Off)
+option(DOCTEST_DIR "Location where the doctest folder is found (If building unit tests)" Off)
add_subdirectory(thirdparty)
diff --git a/README.md b/README.md
index 88aeaf2..0e97d10 100644
--- a/README.md
+++ b/README.md
@@ -33,6 +33,7 @@ The model includes the following git submodules:
* TOSA Serialization Library
* JSON for Modern C++ - 3.8.0
* Eigen 3.3.7
+* doctest 2.4.9 (When building unit tests)
The model is written using
C++17 and has been primarily tested on Ubuntu x86_64 18.04 LTS Linux
@@ -84,12 +85,18 @@ make
```
The resulting executable will be named:
-`reference_model/tosa_reference_model`. CMake only needs to be re-run
-if the build environment changes (e.g., new dependencies or source
-files). Code changes that do not affect these build rules can be
+`reference_model/tosa_reference_model`. This executable can be disabled with
+`-DBUILD_TOSA_REFERENCE_MODEL_EXECUTABLE=NO`.
+A static library will also be generated by default as follows:
+`reference_model/libtosa_reference_model_lib.a`.
+To make this a shared library (.so) add the following option
+`-DBUILD_SHARED_LIBS=YES`.
+
+CMake only needs to be re-run if the build environment changes (e.g., new dependencies or source
+files). Code changes that do not affect these build rules can be
rebuilt simply using `make`.
-## Usage
+## Executable Usage
The inputs to the *TOSA Reference Model* consist of a FlatBuffers file
containing the serialized subgraph, a JSON test descriptor that describes
@@ -145,7 +152,7 @@ format into output tensors specified by "ofm_file".
For example, you can generate new output .npy by:
``` bash
./build/reference_model/tosa_reference_model \
- --test_desc=examples/test_add_1x4x4x4_f32/flatbuffer-tflite/desc.json
+ --test_desc=examples/test_add_1x4x4x4_f32/flatbuffer-tflite/desc.json \
--ofm_file=out.npy
```
@@ -170,6 +177,22 @@ may cause small differences in output for floating-point tests and
differences in quantized scaling between TensorFlow Lite and the TOSA
Specification may cause differences in quantized integer tests.
+## ModelRunner API
+
+As an alternative to the executable described above,
+the model_runner.h is provided which can be used to invoke the
+TOSA Reference Model easily within C++.
+A sample of this class and how it can used can be found in
+[model_runner_simple_sample.cpp](reference_model/samples/model_runner_simple_sample.cpp).
+This sample can be compiled by adding `-DBUILD_MODEL_RUNNER_SAMPLE=YES` to the CMake command
+and executed by running `./build/reference_model/model_runner_sample`.
+
+### ModelRunner API Unit Tests
+Unit test are generated by default for the ModelRunner.
+This executable can be disabled by adding`-DBUILD_TOSA_REFERENCE_MODEL_TESTS=NO` to the CMake command.
+This executable can be run using
+`./build/reference_model/unit_tests` and requires the submodule doctest.
+
## Debugging
The debugging facility can be enabled by setting a debug scope and
@@ -219,7 +242,7 @@ programatically compared with output of SUTs to validate them.
The test infrastructure needs installing before being used. It is recommended
to create a [python virtual environment](https://docs.python.org/3/library/venv.html)
-and then install the TOSA Unit Test infrastruture from the root of the
+and then install the TOSA Unit Test infrastructure from the root of the
reference model:
``` bash
diff --git a/reference_model/CMakeLists.txt b/reference_model/CMakeLists.txt
index 6fdaa1c..a790968 100644
--- a/reference_model/CMakeLists.txt
+++ b/reference_model/CMakeLists.txt
@@ -1,6 +1,6 @@
cmake_minimum_required (VERSION 3.4)
-# Copyright (c) 2020, ARM Limited.
+# Copyright (c) 2020-2022, ARM Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -14,7 +14,6 @@ cmake_minimum_required (VERSION 3.4)
# See the License for the specific language governing permissions and
# limitations under the License.
-
project(tosa_reference_model LANGUAGES CXX)
set(CMAKE_CXX_STANDARD 17)
@@ -26,52 +25,202 @@ else()
set(CMAKE_CXX_FLAGS "-Wall -Wno-ignored-attributes")
endif()
-set(FLATBUFFERS_DIR "../thirdparty/serialization_lib/third_party/flatbuffers/")
-set(SERIALIZATION_DIR "../thirdparty/serialization_lib/")
-
-set (CXX_SOURCE
- src/main.cpp
- src/tensor.cc
- src/graph_node.cc
- src/subgraph_traverser.cc
- src/func_debug.cc
- src/func_config.cc
- src/ops/op_factory.cc
- src/ops/tensor_ops.cc
- src/ops/activation_funcs.cc
- src/ops/ewise_binary.cc
- src/ops/ewise_unary.cc
- src/ops/ewise_ternary.cc
- src/ops/comparison.cc
- src/ops/reduction.cc
- src/ops/data_layout.cc
- src/ops/scatter_gather.cc
- src/ops/image.cc
- src/ops/type_conversion.cc
- src/ops/data_nodes.cc
- src/ops/custom.cc
- src/ops/control_flow.cc
+# If Serialization Library path is specified, look for library so it doesn't have to be built again.
+# Otherwise, set the Serialization Library related paths to thirdparty directory.
+if(SERIALIZATION_DIR)
+ find_library(SERIALIZATION_LIB
+ NAMES libtosa_serialization_lib.a tosa_serialization_lib
+ NO_DEFAULT_PATH
+ HINTS ${SERIALIZATION_DIR}
+ PATH_SUFFIXES lib)
+
+ if(NOT SERIALIZATION_LIB)
+ message(FATAL_ERROR "TOSA Serialization Library location was specified but not found at: ${SERIALIZATION_LIB_DIR}")
+ endif()
+else()
+ # Build from third party directory if not found.
+ set(SERIALIZATION_LIB tosa_serialization_lib)
+ set(SERIALIZATION_DIR "../thirdparty/serialization_lib/")
+endif()
+
+# If Flatbuffers or Eigen path isn't specified, set to thirdparty directory.
+if(NOT FLATBUFFERS_DIR)
+ set(FLATBUFFERS_DIR "../thirdparty/serialization_lib/third_party/flatbuffers/")
+endif()
+
+if(NOT EIGEN_DIR)
+ set(EIGEN_DIR "../thirdparty/eigen/")
+endif()
+
+include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
+
+# Common sources required for TOSA Reference Model library, executable and unit tests
+set(CXX_SOURCE
+ src/model_runner.cc
+ src/model_runner_impl.cc
+ src/tensor.cc
+ src/graph_node.cc
+ src/subgraph_traverser.cc
+ src/func_debug.cc
+ src/ops/op_factory.cc
+ src/ops/tensor_ops.cc
+ src/ops/activation_funcs.cc
+ src/ops/ewise_binary.cc
+ src/ops/ewise_unary.cc
+ src/ops/ewise_ternary.cc
+ src/ops/comparison.cc
+ src/ops/reduction.cc
+ src/ops/data_layout.cc
+ src/ops/scatter_gather.cc
+ src/ops/image.cc
+ src/ops/type_conversion.cc
+ src/ops/data_nodes.cc
+ src/ops/custom.cc
+ src/ops/control_flow.cc
)
-add_executable(tosa_reference_model ${CXX_SOURCE})
+# Build TOSA Reference Model library
+add_library(tosa_reference_model_lib ${CXX_SOURCE})
-target_include_directories(tosa_reference_model
+target_include_directories(tosa_reference_model_lib
PUBLIC
$<INSTALL_INTERFACE:include>
$<BUILD_INTERFACE:${CMAKE_CURRENT_SRC_DIR}/include>
PRIVATE
${CMAKE_CURRENT_SOURCE_DIR}/src
${FLATBUFFERS_DIR}/include
- ../thirdparty/eigen/
- ../thirdparty/eigen/unsupported/
+ ${EIGEN_DIR}
+ ${EIGEN_DIR}/unsupported/
${SERIALIZATION_DIR}/include
)
-target_link_libraries(tosa_reference_model
+target_link_libraries(tosa_reference_model_lib
PRIVATE
- tosa_serialization_lib
- nlohmann_json::nlohmann_json
- cxxopts
+ ${SERIALIZATION_LIB}
+)
+
+set(PUBLIC_HEADERS)
+list(APPEND PUBLIC_HEADERS
+ include/debug_modes.def
+ include/debug_types.h
+ include/func_config.h
+ include/func_debug.h
+ include/graph_status.h
+ include/model_common.h
+ include/model_runner.h
+ include/version.h
+)
+
+set_target_properties(tosa_reference_model_lib PROPERTIES PUBLIC_HEADER "${PUBLIC_HEADERS}")
+
+# Build TOSA Refererence Model executable
+if(BUILD_TOSA_REFERENCE_MODEL_EXECUTABLE)
+ set(CXX_SOURCE_EX src/main.cpp)
+ list(APPEND CXX_SOURCE_EX ${CXX_SOURCE})
+
+ add_executable(tosa_reference_model ${CXX_SOURCE_EX})
+
+ target_include_directories(tosa_reference_model
+ PUBLIC
+ $<INSTALL_INTERFACE:include>
+ $<BUILD_INTERFACE:${CMAKE_CURRENT_SRC_DIR}/include>
+ PRIVATE
+ ${CMAKE_CURRENT_SOURCE_DIR}/src
+ ${FLATBUFFERS_DIR}/include
+ ${EIGEN_DIR}
+ ${EIGEN_DIR}/unsupported/
+ ${SERIALIZATION_DIR}/include
+ )
+
+ target_link_libraries(tosa_reference_model
+ PRIVATE
+ ${SERIALIZATION_LIB}
+ nlohmann_json::nlohmann_json
+ cxxopts
+ )
+
+ install(TARGETS tosa_reference_model DESTINATION bin)
+endif()
+
+if(BUILD_TOSA_REFERENCE_MODEL_TESTS)
+ # Set definition so unit tests can find examples directory.
+ add_definitions(-DPROJECT_ROOT=\"${CMAKE_CURRENT_SOURCE_DIR}/\")
+
+ # Set doctest location if not specified.
+ if(NOT DOCTEST_DIR)
+ set(DOCTEST_DIR "../thirdparty/doctest/doctest")
+ endif()
+
+ # Sources only required for unit tests.
+ set(CXX_SOURCE_TESTS
+ test/model_runner_tests.cpp
+ ${DOCTEST_DIR}/doctest.h
+ )
+
+ list(APPEND CXX_SOURCE_TESTS ${CXX_SOURCE})
+
+ add_executable(unit_tests ${CXX_SOURCE_TESTS})
+
+ target_include_directories(unit_tests
+ PUBLIC
+ $<INSTALL_INTERFACE:include>
+ $<BUILD_INTERFACE:${CMAKE_CURRENT_SRC_DIR}/include>
+ PRIVATE
+ ${CMAKE_CURRENT_SOURCE_DIR}/src
+ ${FLATBUFFERS_DIR}/include
+ ${EIGEN_DIR}
+ ${EIGEN_DIR}/unsupported/
+ ${SERIALIZATION_DIR}/include
+ ${DOCTEST_DIR}
+ )
+
+ target_link_libraries(unit_tests
+ PRIVATE
+ ${SERIALIZATION_LIB}
+ )
+endif()
+
+if(BUILD_MODEL_RUNNER_SAMPLE)
+ # Set definition so sample executable can find examples directory.
+ add_definitions(-DPROJECT_ROOT=\"${CMAKE_CURRENT_SOURCE_DIR}/\")
+
+ # Sources only required for example executable.
+ set(CXX_SOURCE_SAMPLE
+ samples/model_runner_simple_sample.cpp
+ )
+
+ list(APPEND CXX_SOURCE_SAMPLE ${CXX_SOURCE})
+
+ add_executable(model_runner_sample ${CXX_SOURCE_SAMPLE})
+
+ target_include_directories(model_runner_sample
+ PUBLIC
+ $<INSTALL_INTERFACE:include>
+ $<BUILD_INTERFACE:${CMAKE_CURRENT_SRC_DIR}/include>
+ PRIVATE
+ ${CMAKE_CURRENT_SOURCE_DIR}/src
+ ${FLATBUFFERS_DIR}/include
+ ${EIGEN_DIR}
+ ${EIGEN_DIR}/unsupported/
+ ${SERIALIZATION_DIR}/include
+ )
+
+ target_link_libraries(model_runner_sample
+ PRIVATE
+ ${SERIALIZATION_LIB}
+ )
+endif()
+
+# Follow GNU packaging norms for installation directory structure.
+include(GNUInstallDirs)
+install(
+ TARGETS tosa_reference_model_lib EXPORT TosaReferenceModelLibTargets
+ PUBLIC_HEADER
+ ARCHIVE
)
-install (TARGETS tosa_reference_model DESTINATION bin)
+install(EXPORT TosaReferenceModelLibTargets
+ FILE TosaReferenceModelLibTargets.cmake
+ NAMESPACE TosaReference::
+ DESTINATION "${CMAKE_INSTALL_LIBDIR}/cmake/tosa_reference_model_lib"
+) \ No newline at end of file
diff --git a/reference_model/src/debug_modes.def b/reference_model/include/debug_modes.def
index 51b151d..51b151d 100644
--- a/reference_model/src/debug_modes.def
+++ b/reference_model/include/debug_modes.def
diff --git a/reference_model/src/debug_types.h b/reference_model/include/debug_types.h
index bd93f19..bd93f19 100644
--- a/reference_model/src/debug_types.h
+++ b/reference_model/include/debug_types.h
diff --git a/reference_model/src/func_config.h b/reference_model/include/func_config.h
index 49f03e9..41df135 100644
--- a/reference_model/src/func_config.h
+++ b/reference_model/include/func_config.h
@@ -16,6 +16,9 @@
#ifndef FUNC_CONFIG_H_
#define FUNC_CONFIG_H_
+#include <iostream>
+#include <stdio.h>
+
struct func_config_t
{
std::string operator_fbs = "tosa.fbs";
@@ -35,12 +38,4 @@ struct func_config_t
std::string fp_format = "0.5";
};
-// Forward declaration
-struct func_debug_t;
-
-int func_model_parse_cmd_line(
- func_config_t& func_config, func_debug_t& func_debug, int argc, char** argv, const char* version);
-int func_model_parse_flat_config_file(func_config_t*, const char* filename);
-void func_model_print_help();
-
#endif
diff --git a/reference_model/src/func_debug.h b/reference_model/include/func_debug.h
index ee89935..d762026 100644
--- a/reference_model/src/func_debug.h
+++ b/reference_model/include/func_debug.h
@@ -21,6 +21,7 @@
#include <cinttypes>
#include <signal.h>
#include <stdio.h>
+#include <string>
#include <vector>
void func_print_backtrace(FILE* out, int sig = SIGABRT);
diff --git a/reference_model/include/graph_status.h b/reference_model/include/graph_status.h
new file mode 100644
index 0000000..f3be004
--- /dev/null
+++ b/reference_model/include/graph_status.h
@@ -0,0 +1,25 @@
+// Copyright (c) 2022, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GRAPH_STATUS_H
+#define GRAPH_STATUS_H
+
+enum class GraphStatus : int
+{
+ TOSA_VALID = 0,
+ TOSA_UNPREDICTABLE = 1,
+ TOSA_ERROR = 2,
+};
+
+#endif // GRAPH_STATUS_H
diff --git a/reference_model/src/model_common.h b/reference_model/include/model_common.h
index d6dab6d..d6dab6d 100644
--- a/reference_model/src/model_common.h
+++ b/reference_model/include/model_common.h
diff --git a/reference_model/include/model_runner.h b/reference_model/include/model_runner.h
new file mode 100644
index 0000000..4629467
--- /dev/null
+++ b/reference_model/include/model_runner.h
@@ -0,0 +1,87 @@
+
+// Copyright (c) 2022, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef MODEL_RUNNER_H_
+#define MODEL_RUNNER_H_
+
+#include "model_common.h"
+#include "graph_status.h"
+
+#include "tosa_serialization_handler.h"
+
+namespace TosaReference
+{
+
+class ModelRunnerImpl;
+
+/*
+ * This interface allows a user to initialize, run and get the output from a model.
+ * See model_runner_simple_sample.cpp for example on how this interface can be used.
+ */
+class IModelRunner
+{
+public:
+ IModelRunner();
+ IModelRunner(const func_config_t& func_config, const func_debug_t& func_debug);
+
+ ~IModelRunner();
+
+ /*
+ * Functional and debug configurations can also be set.
+ * See func_config.h and func_debug.h for possible options.
+ */
+ void setFuncConfig(func_config_t& func_config);
+ void setFuncDebug(func_debug_t& func_debug);
+
+ /*
+ * Initialize the model.
+ * The TosaSerializationHandler is validated and then converted to a SubgraphTraverser internally.
+ * This SubgraphTraverser is initialized, allocated and validated.
+ * The status of the graph will be returned upon completion.
+ */
+ GraphStatus initialize(tosa::TosaSerializationHandler& serialization_handler);
+
+ /*
+ * Run the model using the internal SubgraphTraverser created during initialization.
+ * If validate_only is specified run() will simply return the graph status.
+ * Otherwise, the graph will be run and the output tensors will be generated if the graph is valid.
+ * The output tensors can then be retrieved with getOutput().
+ * NOTE: initialize() must be called before run(). Also, setInput() must be called for all inputs in the model.
+ */
+ GraphStatus run();
+
+ /*
+ * Set the input tensors for the model.
+ * The input_name much match the input tensor name in the model.
+ * NOTE: setInput() must be called for each input tensor before run() is called.
+ */
+ template <typename T>
+ int setInput(std::string input_name, std::vector<T> vals);
+
+ /*
+ * Retrieve the output tensors from the graph after running.
+ * The output_name much match the output tensor name in the model.
+ * NOTE: run() must be called before outputs are retrieved.
+ */
+ template <typename T>
+ std::vector<T> getOutput(std::string output_name);
+
+private:
+ std::unique_ptr<ModelRunnerImpl> model_runner_impl;
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/include/version.h b/reference_model/include/version.h
new file mode 100644
index 0000000..cd1d598
--- /dev/null
+++ b/reference_model/include/version.h
@@ -0,0 +1,23 @@
+// Copyright (c) 2022, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef VERSION_H
+#define VERSION_H
+
+#define TOSA_REFERENCE_MODEL_VERSION_MAJOR 0
+#define TOSA_REFERENCE_MODEL_VERSION_MINOR 41
+#define TOSA_REFERENCE_MODEL_VERSION_PATCH 0
+#define TOSA_REFERENCE_MODEL_VERSION_DRAFT true
+
+#endif //VERSION_H
diff --git a/reference_model/samples/model_runner_simple_sample.cpp b/reference_model/samples/model_runner_simple_sample.cpp
new file mode 100644
index 0000000..2eebca6
--- /dev/null
+++ b/reference_model/samples/model_runner_simple_sample.cpp
@@ -0,0 +1,97 @@
+
+// Copyright (c) 2022, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "general_utils.h"
+#include "model_runner.h"
+
+int main()
+{
+ using namespace TosaReference;
+
+ std::string test_root(std::string(PROJECT_ROOT) + "../examples/test_add_1x4x4x4_f32/");
+ std::string tosa_model_file(test_root + "flatbuffer-tflite/test_add_1x4x4x4_f32.tosa");
+ std::string input0_file(test_root + "placeholder_0.npy");
+ std::string input1_file(test_root + "placeholder_1.npy");
+ std::string expected_output_file(test_root + "tflite_result.npy");
+
+ std::vector<std::string> input_names = { "TosaInput_0", "TosaInput_1" };
+ std::string output_name = "TosaOutput_0";
+
+ std::vector<int32_t> input0_shape = { 1, 4, 4, 1 };
+ std::vector<int32_t> input1_shape = { 1, 4, 4, 4 };
+ std::vector<int32_t> output_shape = { 1, 4, 4, 4 };
+
+ std::vector<std::vector<float>> inputs(input_names.size());
+ std::vector<float> expected_outputs = { };
+ std::vector<float> actual_outputs = { };
+
+ // Read in inputs and expected outputs for sample purposes.
+ inputs[0] = readFromNpyFile<float>(input0_file.c_str(), input0_shape);
+ inputs[1] = readFromNpyFile<float>(input1_file.c_str(), input1_shape);
+ expected_outputs = readFromNpyFile<float>(expected_output_file.c_str(), output_shape);
+
+ tosa::TosaSerializationHandler handler;
+ tosa::tosa_err_t error = handler.LoadFileTosaFlatbuffer(tosa_model_file.c_str());
+ if(error != tosa::TOSA_OK)
+ {
+ WARNING("An error occurred while loading the model from file.");
+ return 1;
+ }
+ GraphStatus status;
+
+ // Initialize the ModelRunner with configurations.
+ IModelRunner runner;
+ status = runner.initialize(handler);
+ if(status != GraphStatus::TOSA_VALID)
+ {
+ WARNING("An error occurred while initializing.");
+ return 1;
+ }
+
+ // Set the model inputs using the input names and input data.
+ runner.setInput(input_names[0], inputs[0]);
+ runner.setInput(input_names[1], inputs[1]);
+
+ // Run the ModelRunner using test inputs.
+ status = runner.run();
+ if(status != GraphStatus::TOSA_VALID)
+ {
+ WARNING("An error occurred when running the model.");
+ return 1;
+ }
+
+ // Get the outputs from the model.
+ actual_outputs = runner.getOutput<float>(output_name);
+
+ // Compare the actual output to the expected output.
+ bool if_accurate = true;
+ for (size_t i = 0; i < expected_outputs.size(); ++i)
+ {
+ if(actual_outputs[i] != expected_outputs[i])
+ {
+ WARNING("Actual output (%f) doesn't match expected output (%f).");
+ if_accurate = false;
+ }
+ }
+
+ if(!if_accurate)
+ {
+ WARNING("There were mismatches in actual vs expected output, see above output for more details.");
+ return 1;
+ }
+
+ printf("The model ran successfully without errors and matched the expected output.\n");
+ return 0;
+} \ No newline at end of file
diff --git a/reference_model/src/func_config.cc b/reference_model/src/command_line_utils.h
index 6bd809e..1bd1639 100644
--- a/reference_model/src/func_config.cc
+++ b/reference_model/src/command_line_utils.h
@@ -13,19 +13,15 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include <cxxopts.hpp>
-#include <ctype.h>
-#include <signal.h>
-#include <stdarg.h>
-#include <stdint.h>
-#include <stdio.h>
-#include <stdlib.h>
-#include <string.h>
-#include <sys/types.h>
+#ifndef COMMAND_LINE_UTILS_H_
+#define COMMAND_LINE_UTILS_H_
#include "func_config.h"
#include "func_debug.h"
+#include <stdint.h>
+#include <cxxopts.hpp>
+
// Read the command line arguments
int func_model_parse_cmd_line(
func_config_t& func_config, func_debug_t& func_debug, int argc, char** argv, const char* version)
@@ -102,7 +98,4 @@ int func_model_parse_cmd_line(
return 0;
}
-void func_model_print_help()
-{
-
-}
+#endif
diff --git a/reference_model/src/general_utils.h b/reference_model/src/general_utils.h
new file mode 100644
index 0000000..12f831e
--- /dev/null
+++ b/reference_model/src/general_utils.h
@@ -0,0 +1,68 @@
+
+// Copyright (c) 2022, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef GENERAL_UTILS_H_
+#define GENERAL_UTILS_H_
+
+#include "func_debug.h"
+
+#include "numpy_utils.h"
+
+namespace TosaReference
+{
+
+const uint32_t getElementCount(std::vector<int32_t>& shape)
+{
+ uint32_t elements = 1;
+ for (size_t i = 0; i < shape.size(); i++)
+ {
+ elements *= shape[i];
+ }
+
+ return elements;
+}
+
+template <typename T>
+std::vector<T> readFromNpyFile(const char* filename, std::vector<int32_t>& shape)
+{
+ uint32_t elements = getElementCount(shape);
+ std::vector<T> data(elements, 0);
+
+ NumpyUtilities::NPError nperror = NumpyUtilities::readFromNpyFile(filename, elements, data.data());
+
+ switch (nperror)
+ {
+ case NumpyUtilities::NO_ERROR:
+ break;
+ case NumpyUtilities::FILE_NOT_FOUND:
+ FATAL_ERROR("readFromNpyFile: Cannot open file %s", filename);
+ case NumpyUtilities::FILE_IO_ERROR:
+ FATAL_ERROR("readFromNpyFile: IO error reading file: %s", filename);
+ case NumpyUtilities::FILE_TYPE_MISMATCH:
+ FATAL_ERROR("readFromNpyFile: Tensor type and Numpy file type mismatch for filename %s", filename);
+ case NumpyUtilities::HEADER_PARSE_ERROR:
+ FATAL_ERROR("Numpy header parsing error for file: %s", filename);
+ case NumpyUtilities::BUFFER_SIZE_MISMATCH:
+ FATAL_ERROR("Buffer size does not match numpy file size for filename %s", filename);
+ default:
+ FATAL_ERROR("Unknown error parsing Numpy file: %s", filename);
+ }
+
+ return data;
+}
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/main.cpp b/reference_model/src/main.cpp
index 643a351..776fbf3 100644
--- a/reference_model/src/main.cpp
+++ b/reference_model/src/main.cpp
@@ -13,31 +13,24 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-#include <stdio.h>
+#include "model_runner.h"
+#include "version.h"
-#include "model_common.h"
+#include "command_line_utils.h"
#include "ops/op_factory.h"
#include "subgraph_traverser.h"
#include "tosa_serialization_handler.h"
-#include <Eigen/CXX11/Tensor>
-#include <iostream>
#include <fstream>
+#include <iostream>
+#include <stdio.h>
+#include <Eigen/CXX11/Tensor>
#include <nlohmann/json.hpp>
-#define MODEL_VERSION_MAJOR 0
-#define MODEL_VERSION_MINOR 41
-#define MODEL_VERSION_PATCH 0
-#define MODEL_VERSION_DRAFT true
-
using namespace TosaReference;
using namespace tosa;
using json = nlohmann::json;
-// Global instantiation of configuration and debug objects
-func_config_t g_func_config;
-func_debug_t g_func_debug;
-
int initTestDesc(json& test_desc);
int readInputTensors(SubgraphTraverser& gt, json test_desc);
int writeFinalTensors(SubgraphTraverser& gt, json test_desc);
@@ -45,7 +38,10 @@ int loadGraph(TosaSerializationHandler& tsh, json test_desc);
int main(int argc, char** argv)
{
- TosaVersion model_version(MODEL_VERSION_MAJOR, MODEL_VERSION_MINOR, MODEL_VERSION_PATCH, MODEL_VERSION_DRAFT);
+ TosaVersion model_version(TOSA_REFERENCE_MODEL_VERSION_MAJOR,
+ TOSA_REFERENCE_MODEL_VERSION_MINOR,
+ TOSA_REFERENCE_MODEL_VERSION_PATCH,
+ TOSA_REFERENCE_MODEL_VERSION_DRAFT);
// Initialize configuration and debug subsystems
g_func_debug.init_debug(0);
@@ -203,7 +199,6 @@ int loadGraph(TosaSerializationHandler& tsh, json test_desc)
if (strlen(graph_fullname) <= 2)
{
- func_model_print_help();
FATAL_ERROR("Missing required argument: Check \"tosa_file\" in .json specified by -Ctosa_desc=");
}
diff --git a/reference_model/src/model_runner.cc b/reference_model/src/model_runner.cc
new file mode 100644
index 0000000..2395a85
--- /dev/null
+++ b/reference_model/src/model_runner.cc
@@ -0,0 +1,76 @@
+
+// Copyright (c) 2022, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "model_runner_impl.h"
+
+using namespace TosaReference;
+
+// Global instantiation of configuration and debug objects
+func_config_t g_func_config;
+func_debug_t g_func_debug;
+
+IModelRunner::IModelRunner() : model_runner_impl(new ModelRunnerImpl())
+{}
+
+IModelRunner::IModelRunner(const func_config_t& func_config,
+ const func_debug_t& func_debug)
+ : model_runner_impl(new ModelRunnerImpl(func_config, func_debug))
+{}
+
+IModelRunner::~IModelRunner()
+{}
+
+void IModelRunner::setFuncConfig(func_config_t& func_config)
+{
+ model_runner_impl->setFuncConfig(func_config);
+}
+
+void IModelRunner::setFuncDebug(func_debug_t& func_debug)
+{
+ model_runner_impl->setFuncDebug(func_debug);
+}
+
+GraphStatus IModelRunner::initialize(tosa::TosaSerializationHandler& serialization_handler)
+{
+ return model_runner_impl->initialize(serialization_handler);
+}
+
+GraphStatus IModelRunner::run()
+{
+ return model_runner_impl->run();
+}
+
+template <typename T>
+int IModelRunner::setInput(std::string input_name, std::vector<T> vals)
+{
+ return model_runner_impl->setInput<T>(input_name, vals);
+}
+
+template <typename T>
+std::vector<T> IModelRunner::getOutput(std::string output_name)
+{
+ return model_runner_impl->getOutput<T>(output_name);
+}
+
+// Template explicit specialization
+template int IModelRunner::setInput<float>(std::string input_name, std::vector<float> vals);
+template int IModelRunner::setInput<int32_t>(std::string input_name, std::vector<int32_t> vals);
+template int IModelRunner::setInput<int64_t>(std::string input_name, std::vector<int64_t> vals);
+template int IModelRunner::setInput<unsigned char>(std::string input_name, std::vector<unsigned char> vals);
+
+template std::vector<float> IModelRunner::getOutput<float>(std::string output_name);
+template std::vector<int32_t> IModelRunner::getOutput<int32_t>(std::string output_name);
+template std::vector<int64_t> IModelRunner::getOutput<int64_t>(std::string output_name);
+template std::vector<unsigned char> IModelRunner::getOutput<unsigned char>(std::string output_name); \ No newline at end of file
diff --git a/reference_model/src/model_runner_impl.cc b/reference_model/src/model_runner_impl.cc
new file mode 100644
index 0000000..e0fdc49
--- /dev/null
+++ b/reference_model/src/model_runner_impl.cc
@@ -0,0 +1,277 @@
+
+// Copyright (c) 2022, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include "model_runner_impl.h"
+
+using namespace TosaReference;
+
+ModelRunnerImpl::ModelRunnerImpl()
+{}
+
+ModelRunnerImpl::ModelRunnerImpl(const func_config_t& func_config,
+ const func_debug_t& func_debug)
+{
+ g_func_config = func_config;
+ g_func_debug = func_debug;
+}
+
+ModelRunnerImpl::~ModelRunnerImpl()
+{
+ g_func_debug.fini_debug();
+ delete _main_gt;
+};
+
+void ModelRunnerImpl::setFuncConfig(func_config_t& func_config)
+{
+ g_func_config = func_config;
+}
+void ModelRunnerImpl::setFuncDebug(func_debug_t& func_debug)
+{
+ g_func_debug = func_debug;
+}
+
+GraphStatus ModelRunnerImpl::initialize(TosaSerializationHandler& serialization_handler)
+{
+ validateTosaVersion(serialization_handler);
+
+ // Make nullptr in case ModelRunnerImpl is being initialized again with a different graph.
+ _main_gt = nullptr;
+ _main_gt = new SubgraphTraverser(serialization_handler.GetMainBlock(), &serialization_handler);
+
+ if(_main_gt == nullptr)
+ {
+ WARNING("An error occurred when generating main graph traverser.");
+ return GraphStatus::TOSA_ERROR;
+ }
+
+ if (_main_gt->initializeGraph())
+ {
+ WARNING("Unable to initialize main graph traverser.");
+ return _main_gt->getGraphStatus();
+ }
+
+ if (_main_gt->linkTensorsAndNodes())
+ {
+ WARNING("Failed to link tensors and nodes");
+ return _main_gt->getGraphStatus();
+ }
+
+ if (_main_gt->validateGraph())
+ {
+ WARNING("Failed to validate graph.");
+ return _main_gt->getGraphStatus();
+ }
+
+ if (_main_gt->allocateTensor())
+ {
+ WARNING("Failed to allocate tensor.");
+ return _main_gt->getGraphStatus();
+ }
+
+ return _main_gt->getGraphStatus();
+}
+
+GraphStatus ModelRunnerImpl::run()
+{
+ if (_main_gt == nullptr)
+ {
+ FATAL_ERROR("ModelRunnerImpl hasn't been initialized, please invoke initialize() before run()");
+ }
+
+ if (g_func_config.validate_only)
+ {
+ goto done;
+ }
+
+ // Validate the number of inputs matches the
+ if (static_cast<uint32_t>(_main_gt->getNumInputTensors()) != n_input_tensors)
+ {
+ FATAL_ERROR("The number of inputs (%d) does not equal the number of inputs in the model (%d). "
+ "setInput() must be called for each input.",
+ n_input_tensors, _main_gt->getNumInputTensors());
+ }
+
+ if (g_func_config.eval)
+ {
+ // evaluateAll() returns 1 if graph evaluation is forced to be terminated earlier.
+ if (_main_gt->evaluateAll())
+ {
+ ASSERT_MSG(_main_gt->getGraphStatus() != GraphStatus::TOSA_VALID,
+ "Upon evaluateAll() returning 1, graph can not be VALID.");
+ }
+ else
+ {
+ ASSERT_MSG(_main_gt->getGraphStatus() == GraphStatus::TOSA_VALID ||
+ _main_gt->getGraphStatus() == GraphStatus::TOSA_UNPREDICTABLE,
+ "Upon evaluateAll() returning 0, graph can only be VALID/UNPREDICTABLE.");
+ }
+
+ // Only generate output tensor if graph is valid.
+ if (_main_gt->getGraphStatus() == GraphStatus::TOSA_VALID)
+ {
+ // Make sure output tensor is evaluated and show its value
+ int num_output_tensors = _main_gt->getNumOutputTensors();
+ bool all_output_valid = true;
+ for (int i = 0; i < num_output_tensors; i++)
+ {
+ const Tensor* ct = _main_gt->getOutputTensor(i);
+ ASSERT_MEM(ct);
+ if (!ct->getIsValid())
+ {
+ ct->dumpTensorParams(g_func_debug.func_debug_file);
+ if (DEBUG_ENABLED(DEBUG_VERB_HIGH, GT))
+ {
+ ct->dumpTensor(g_func_debug.func_debug_file);
+ }
+ all_output_valid = false;
+ }
+ }
+ if (!all_output_valid)
+ {
+ _main_gt->dumpGraph(g_func_debug.func_debug_file);
+ FATAL_ERROR(
+ "SubgraphTraverser \"main\" error: Output tensors are not all valid at the end of evaluation.");
+ }
+ }
+ }
+
+done:
+ // Print status if not valid and do cleanup.
+ checkGraphStatus(*_main_gt);
+ g_func_debug.fini_debug();
+
+ return _main_gt->getGraphStatus();
+}
+
+template <typename T>
+int ModelRunnerImpl::setInput(std::string input_name, std::vector<T> vals)
+{
+ if (_main_gt == nullptr)
+ {
+ FATAL_ERROR("ModelRunner hasn't been initialized, please invoke initialize() before setInput()");
+ }
+
+ Tensor* tensor;
+ tensor = _main_gt->getInputTensorByName(input_name);
+
+ if (!tensor)
+ {
+ WARNING("Unable to find input tensor %s", input_name.c_str());
+ return 1;
+ }
+
+ if (!tensor->is_allocated())
+ {
+ WARNING("Tensor %s is not allocated before being initialized", tensor->getName().c_str());
+ return 1;
+ }
+
+ if (tensor->readfromVector(vals))
+ {
+ WARNING("Unable to convert input tensor %s to Tensor", tensor->getName().c_str());
+ return 1;
+ }
+
+ // Push ready consumers to the next node list
+ for (auto gn : tensor->getConsumers())
+ {
+ if (gn->hasAllInputsReady() && !gn->getOnNextNodeList())
+ {
+ _main_gt->addToNextNodeList(gn);
+ }
+ }
+
+ n_input_tensors++;
+ return 0;
+}
+
+template <typename T>
+std::vector<T> ModelRunnerImpl::getOutput(std::string output_name)
+{
+ if (_main_gt == nullptr)
+ {
+ FATAL_ERROR("ModelRunner hasn't been initialized, please invoke initialize() and run() before getOutput()");
+ }
+
+ Tensor* tensor;
+ tensor = _main_gt->getOutputTensorByName(output_name);
+
+ if (!tensor)
+ {
+ WARNING("Unable to find output tensor %s", output_name.c_str());
+ return std::vector<T>();
+ }
+
+ std::vector<T> outputs(tensor->getElementCount(), 0);
+
+ if (tensor->writeToVector(outputs))
+ {
+ WARNING("Unable to convert output tensor %s to vector", tensor->getName().c_str());
+ return std::vector<T>();
+ }
+
+ return outputs;
+}
+
+void ModelRunnerImpl::validateTosaVersion(TosaSerializationHandler& serialization_handler)
+{
+ TosaVersion model_version(TOSA_REFERENCE_MODEL_VERSION_MAJOR,
+ TOSA_REFERENCE_MODEL_VERSION_MINOR,
+ TOSA_REFERENCE_MODEL_VERSION_PATCH,
+ TOSA_REFERENCE_MODEL_VERSION_DRAFT);
+
+ TosaVersion::compat_t is_compat = model_version.is_compatible(serialization_handler.GetVersion());
+ switch (is_compat)
+ {
+ case TosaVersion::compat_t::COMPLETELY_COMPATIBLE:
+ break;
+ case TosaVersion::compat_t::PARTIALLY_COMPATIBLE:
+ WARNING("Reference model version %s is partially compatible with serializer version %s.",
+ model_version.to_string().c_str(), serialization_handler.GetVersion().to_string().c_str());
+ break;
+ case TosaVersion::compat_t::NOT_COMPATIBLE:
+ FATAL_ERROR("Reference model version %s is not compatible with serializer version %s.",
+ model_version.to_string().c_str(), serialization_handler.GetVersion().to_string().c_str());
+ }
+}
+
+void ModelRunnerImpl::checkGraphStatus(SubgraphTraverser& main_gt)
+{
+ switch (main_gt.getGraphStatus())
+ {
+ case GraphStatus::TOSA_VALID:
+ // Result is valid.
+ break;
+ case GraphStatus::TOSA_UNPREDICTABLE:
+ WARNING("Graph result: UNPREDICTABLE.");
+ break;
+ case GraphStatus::TOSA_ERROR:
+ WARNING("Graph result: ERROR.");
+ break;
+ default:
+ WARNING("Unknown graph status code=%d.", (int)main_gt.getGraphStatus());
+ }
+}
+
+// Template explicit specialization
+template int ModelRunnerImpl::setInput<float>(std::string input_name, std::vector<float> vals);
+template int ModelRunnerImpl::setInput<int32_t>(std::string input_name, std::vector<int32_t> vals);
+template int ModelRunnerImpl::setInput<int64_t>(std::string input_name, std::vector<int64_t> vals);
+template int ModelRunnerImpl::setInput<unsigned char>(std::string input_name, std::vector<unsigned char> vals);
+
+template std::vector<float> ModelRunnerImpl::getOutput<float>(std::string output_name);
+template std::vector<int32_t> ModelRunnerImpl::getOutput<int32_t>(std::string output_name);
+template std::vector<int64_t> ModelRunnerImpl::getOutput<int64_t>(std::string output_name);
+template std::vector<unsigned char> ModelRunnerImpl::getOutput<unsigned char>(std::string output_name); \ No newline at end of file
diff --git a/reference_model/src/model_runner_impl.h b/reference_model/src/model_runner_impl.h
new file mode 100644
index 0000000..7a91bfe
--- /dev/null
+++ b/reference_model/src/model_runner_impl.h
@@ -0,0 +1,66 @@
+
+// Copyright (c) 2022, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef MODEL_RUNNER_IMPL_H_
+#define MODEL_RUNNER_IMPL_H_
+
+#include "model_runner.h"
+#include "graph_status.h"
+#include "version.h"
+
+#include "ops/op_factory.h"
+#include "subgraph_traverser.h"
+#include "tosa_serialization_handler.h"
+
+namespace TosaReference
+{
+
+/*
+ * This is a private implementation of the IModelRunner class.
+ * See documented IModelRunner for usage.
+ */
+class ModelRunnerImpl
+{
+public:
+ ModelRunnerImpl();
+ ModelRunnerImpl(const func_config_t& func_config, const func_debug_t& func_debug);
+
+ ~ModelRunnerImpl();
+
+ void setFuncConfig(func_config_t& func_config);
+ void setFuncDebug(func_debug_t& func_debug);
+
+ GraphStatus initialize(TosaSerializationHandler& serialization_handler);
+ GraphStatus run();
+
+ template <typename T>
+ int setInput(std::string input_name, std::vector<T> vals);
+
+ template <typename T>
+ std::vector<T> getOutput(std::string output_name);
+
+private:
+ SubgraphTraverser* _main_gt = nullptr;
+
+ // Used to determine if all input tensors have been set correctly.
+ uint32_t n_input_tensors = 0;
+
+ void validateTosaVersion(TosaSerializationHandler& serialization_handler);
+ void checkGraphStatus(SubgraphTraverser& main_gt);
+};
+
+}; // namespace TosaReference
+
+#endif
diff --git a/reference_model/src/subgraph_traverser.h b/reference_model/src/subgraph_traverser.h
index 8c66d73..7940ee4 100644
--- a/reference_model/src/subgraph_traverser.h
+++ b/reference_model/src/subgraph_traverser.h
@@ -16,6 +16,7 @@
#ifndef SUBGRAPH_TRAVERSER_H
#define SUBGRAPH_TRAVERSER_H
+#include "graph_status.h"
#include "graph_node.h"
#include "model_common.h"
#include "ops/op_factory.h"
@@ -26,13 +27,6 @@
namespace TosaReference
{
-enum class GraphStatus : int
-{
- TOSA_VALID = 0,
- TOSA_UNPREDICTABLE = 1,
- TOSA_ERROR = 2,
-};
-
class SubgraphTraverser
{
public:
diff --git a/reference_model/src/tensor.cc b/reference_model/src/tensor.cc
index 90aee05..7cbeb13 100644
--- a/reference_model/src/tensor.cc
+++ b/reference_model/src/tensor.cc
@@ -382,6 +382,209 @@ DEF_CTENSOR_COPY_VALUE_FROM(6, bool)
#undef DEF_CTENSOR_COPY_VALUE_FROM
+int TosaReference::Tensor::readfromVector(const std::vector<float>& vals)
+{
+ uint32_t elements = getElementCount();
+ switch (getDtype())
+ {
+ case DType_FLOAT:
+ if (vals.size() != elements)
+ {
+ WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
+ vals.size(), elements);
+ return -1;
+ }
+
+ setTensorValueFloat(elements, vals.data());
+ break;
+ default:
+ WARNING("The input type (float) doesn't match the data type assigned to the tensor (%s).",
+ EnumNameDType(getDtype()));
+ return -2;
+ }
+ setIsValid();
+ return 0;
+}
+
+int TosaReference::Tensor::readfromVector(const std::vector<int32_t>& vals)
+{
+ uint32_t elements = getElementCount();
+ switch (getDtype())
+ {
+ case DType_INT32:
+ case DType_UINT8:
+ case DType_INT4:
+ case DType_INT8:
+ case DType_INT16:
+ case DType_UINT16:
+ if (vals.size() != elements)
+ {
+ WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
+ vals.size(), elements);
+ return -1;
+ }
+
+ setTensorValueInt32(elements, vals.data());
+ break;
+ default:
+ WARNING("The input type doesn't match the data type assigned to the tensor (%s).",
+ EnumNameDType(getDtype()));
+ return -2;
+ }
+ setIsValid();
+ return 0;
+}
+
+int TosaReference::Tensor::readfromVector(const std::vector<int64_t>& vals)
+{
+ uint32_t elements = getElementCount();
+ switch (getDtype())
+ {
+ case DType_INT48:
+ if (vals.size() != elements)
+ {
+ WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
+ vals.size(), elements);
+ return -1;
+ }
+
+ setTensorValueInt64(elements, vals.data());
+ break;
+ default:
+ WARNING("The input type doesn't match the data type assigned to the tensor (%s).",
+ EnumNameDType(getDtype()));
+ return -2;
+ }
+ setIsValid();
+ return 0;
+}
+
+int TosaReference::Tensor::readfromVector(const std::vector<unsigned char>& vals)
+{
+ uint32_t elements = getElementCount();
+
+ switch (getDtype())
+ {
+ case DType_BOOL:
+ if (vals.size() != elements)
+ {
+ WARNING("The input size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
+ vals.size(), elements);
+ return -1;
+ }
+
+ setTensorValueBool(elements, reinterpret_cast<const bool*>(vals.data()));
+ break;
+ default:
+ WARNING("The input type (bool) doesn't match the data type assigned to the tensor (%s).",
+ EnumNameDType(getDtype()));
+ return -2;
+ }
+ setIsValid();
+ return 0;
+}
+
+int TosaReference::Tensor::writeToVector(std::vector<float>& vals)
+{
+ uint32_t elements = getElementCount();
+
+ switch (getDtype())
+ {
+ case DType_FLOAT:
+ if (vals.size() != elements)
+ {
+ WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
+ vals.size(), elements);
+ return -1;
+ }
+
+ getTensorValueFloat(elements, vals.data());
+ break;
+ default:
+ WARNING("The output type (float) doesn't match the data type assigned to the tensor (%s).",
+ EnumNameDType(getDtype()));
+ return -2;
+ }
+ return 0;
+}
+
+int TosaReference::Tensor::writeToVector(std::vector<int32_t>& vals)
+{
+ uint32_t elements = getElementCount();
+
+ switch (getDtype())
+ {
+ case DType_INT32:
+ case DType_UINT8:
+ case DType_INT4:
+ case DType_INT8:
+ case DType_INT16:
+ case DType_UINT16:
+ if (vals.size() != elements)
+ {
+ WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
+ vals.size(), elements);
+ return -1;
+ }
+
+ getTensorValueInt32(elements, vals.data());
+ break;
+ default:
+ WARNING("The output type doesn't match the data type assigned to the tensor (%s).",
+ EnumNameDType(getDtype()));
+ return -2;
+ }
+ return 0;
+}
+
+int TosaReference::Tensor::writeToVector(std::vector<int64_t>& vals)
+{
+ uint32_t elements = getElementCount();
+
+ switch (getDtype())
+ {
+ case tosa::DType_INT48:
+ if (vals.size() != elements)
+ {
+ WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
+ vals.size(), elements);
+ return -1;
+ }
+
+ getTensorValueInt64(elements, vals.data());
+ break;
+ default:
+ WARNING("The output type doesn't match the data type assigned to the tensor (%s).",
+ EnumNameDType(getDtype()));
+ return -2;
+ }
+ return 0;
+}
+
+int TosaReference::Tensor::writeToVector(std::vector<unsigned char>& vals)
+{
+ uint32_t elements = getElementCount();
+
+ switch (getDtype())
+ {
+ case tosa::DType_BOOL:
+ if (vals.size() != elements)
+ {
+ WARNING("The output size (%ld) doesn't match the number of elements (%d) assigned to the tensor.",
+ vals.size(), elements);
+ return -1;
+ }
+
+ getTensorValueBool(elements, reinterpret_cast<bool*>(vals.data()));
+ break;
+ default:
+ WARNING("The output type (bool) doesn't match the data type assigned to the tensor (%s).",
+ EnumNameDType(getDtype()));
+ return -2;
+ }
+ return 0;
+}
+
template <class T>
int TosaReference::TensorTemplate<T>::setTensorValueFloat(const size_t buflen, const float* vals)
{
diff --git a/reference_model/src/tensor.h b/reference_model/src/tensor.h
index ede42a9..6b7d5f1 100644
--- a/reference_model/src/tensor.h
+++ b/reference_model/src/tensor.h
@@ -228,6 +228,16 @@ public:
virtual int writeToNpyFile(const char* filename) const;
virtual int copyValueFrom(Tensor* tensor) = 0;
+ virtual int readfromVector(const std::vector<float>& vals);
+ virtual int readfromVector(const std::vector<int32_t>& vals);
+ virtual int readfromVector(const std::vector<int64_t>& vals);
+ virtual int readfromVector(const std::vector<unsigned char>& vals);
+
+ virtual int writeToVector(std::vector<float>& vals);
+ virtual int writeToVector(std::vector<int32_t>& vals);
+ virtual int writeToVector(std::vector<int64_t>& vals);
+ virtual int writeToVector(std::vector<unsigned char>& vals);
+
const char* bool_to_str(bool in) const
{
static const char* true_str = "true";
diff --git a/reference_model/test/model_runner_tests.cpp b/reference_model/test/model_runner_tests.cpp
new file mode 100644
index 0000000..cc295d9
--- /dev/null
+++ b/reference_model/test/model_runner_tests.cpp
@@ -0,0 +1,154 @@
+
+// Copyright (c) 2022, ARM Limited.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#ifndef DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN
+#define DOCTEST_CONFIG_IMPLEMENT_WITH_MAIN
+#endif
+
+#include "model_runner.h"
+#include "general_utils.h"
+
+#include "doctest.h"
+
+using namespace TosaReference;
+using namespace tosa;
+
+template <typename T>
+void compareOutput(std::vector<T>& tensor1, std::vector<T>& tensor2, size_t size)
+{
+ for (size_t i = 0; i < size; ++i)
+ {
+ CHECK((tensor1[i] == tensor2[i]));
+ }
+}
+
+TEST_SUITE("model_runner")
+{
+
+TEST_CASE("simple_add_f32_test")
+{
+ std::string test_root(std::string(PROJECT_ROOT) + "../examples/test_add_1x4x4x4_f32/");
+ std::string tosa_model_file(test_root + "flatbuffer-tflite/test_add_1x4x4x4_f32.tosa");
+ std::string input0_file(test_root + "placeholder_0.npy");
+ std::string input1_file(test_root + "placeholder_1.npy");
+ std::string expected_output_file(test_root + "tflite_result.npy");
+
+ std::vector<std::string> input_names = { "TosaInput_0", "TosaInput_1" };
+ std::string output_name = "TosaOutput_0";
+
+ std::vector<int32_t> input0_shape = { 1, 4, 4, 1 };
+ std::vector<int32_t> input1_shape = { 1, 4, 4, 4 };
+ std::vector<int32_t> output_shape = { 1, 4, 4, 4 };
+
+ std::vector<std::vector<float>> inputs(input_names.size());
+ std::vector<float> actual_outputs = { };
+ std::vector<float> expected_outputs = { };
+
+ // Read in inputs and expected outputs.
+ inputs[0] = readFromNpyFile<float>(input0_file.c_str(), input0_shape);
+ inputs[1] = readFromNpyFile<float>(input1_file.c_str(), input1_shape);
+ expected_outputs = readFromNpyFile<float>(expected_output_file.c_str(), output_shape);
+
+ TosaSerializationHandler handler;
+ tosa_err_t error = handler.LoadFileTosaFlatbuffer(tosa_model_file.c_str());
+ CHECK((error == tosa::TOSA_OK));
+
+ GraphStatus status;
+
+ // Initialize the ModelRunner with configurations.
+ IModelRunner runner;
+ status = runner.initialize(handler);
+ CHECK((status == GraphStatus::TOSA_VALID));
+
+ runner.setInput(input_names[0], inputs[0]);
+ runner.setInput(input_names[1], inputs[1]);
+
+ // Run the ModelRunner using test inputs.
+ status = runner.run();
+ CHECK((status == GraphStatus::TOSA_VALID));
+
+ actual_outputs = runner.getOutput<float>(output_name);
+ CHECK(!actual_outputs.empty());
+
+ compareOutput(expected_outputs, actual_outputs, expected_outputs.size());
+}
+
+TEST_CASE("conv2d_f32_test")
+{
+ std::string test_root(std::string(PROJECT_ROOT) + "../examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/");
+ std::string tosa_model_file(test_root + "flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa");
+ std::string input_file(test_root + "placeholder_0.npy");
+ std::string expected_output_file(test_root + "tflite_result.npy");
+
+ std::string input_name = "TosaInput_0";
+ std::string output_name = "TosaOutput_0";
+
+ std::vector<int32_t> input_shape = { 1, 32, 32, 8 };
+ std::vector<int32_t> output_shape = { 1, 32, 32, 16 };
+
+ // Read in inputs and expected outputs.
+ std::vector<float> inputs = readFromNpyFile<float>(input_file.c_str(), input_shape);
+ std::vector<float> expected_outputs = readFromNpyFile<float>(expected_output_file.c_str(), output_shape);
+
+ TosaSerializationHandler handler;
+ tosa_err_t error = handler.LoadFileTosaFlatbuffer(tosa_model_file.c_str());
+ CHECK((error == tosa::TOSA_OK));
+
+ GraphStatus status;
+
+ // Initialize the ModelRunner with configurations.
+ IModelRunner runner;
+ status = runner.initialize(handler);
+ CHECK((status == GraphStatus::TOSA_VALID));
+
+ runner.setInput(input_name, inputs);
+
+ // Run the ModelRunner using test inputs.
+ status = runner.run();
+ CHECK((status == GraphStatus::TOSA_VALID));
+
+ std::vector<float> actual_outputs = runner.getOutput<float>(output_name);
+ CHECK(!actual_outputs.empty());
+
+ compareOutput(expected_outputs, actual_outputs, expected_outputs.size());
+}
+
+TEST_CASE("conv2d_f32_validate_only_test")
+{
+ std::string test_root(std::string(PROJECT_ROOT) + "../examples/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11/");
+ std::string tosa_model_file(test_root + "flatbuffer-tflite/test_conv2d_1x1_1x32x32x8_f32_st11_padSAME_dilat11.tosa");
+
+ TosaSerializationHandler handler;
+ tosa_err_t error = handler.LoadFileTosaFlatbuffer(tosa_model_file.c_str());
+ CHECK((error == tosa::TOSA_OK));
+
+ GraphStatus status;
+ func_debug_t funcDebug;
+
+ func_config_t funcConfig;
+ funcConfig.validate_only = 1;
+
+ // Initialize the ModelRunner with configurations.
+ IModelRunner runner = IModelRunner(funcConfig, funcDebug);
+ runner.setFuncConfig(funcConfig);
+ status = runner.initialize(handler);
+ CHECK((status == GraphStatus::TOSA_VALID));
+
+ // Run the ModelRunner using no inputs, as validate_only is specified run() should still work.
+ status = runner.run();
+ CHECK((status == GraphStatus::TOSA_VALID));
+}
+
+}
diff --git a/thirdparty/CMakeLists.txt b/thirdparty/CMakeLists.txt
index abcc52c..aa6c43d 100644
--- a/thirdparty/CMakeLists.txt
+++ b/thirdparty/CMakeLists.txt
@@ -12,6 +12,13 @@ set(CMAKE_INSTALL_PREFIX "./thirdparty" CACHE PATH "..." FORCE)
project(thirdparty LANGUAGES CXX)
-add_subdirectory(cxxopts)
add_subdirectory(serialization_lib EXCLUDE_FROM_ALL)
-add_subdirectory(json EXCLUDE_FROM_ALL)
+
+if(BUILD_TOSA_REFERENCE_MODEL_EXECUTABLE)
+ add_subdirectory(cxxopts)
+ add_subdirectory(json EXCLUDE_FROM_ALL)
+endif()
+
+if(BUILD_TOSA_REFERENCE_MODEL_TESTS)
+ add_subdirectory(doctest EXCLUDE_FROM_ALL)
+endif()
diff --git a/thirdparty/doctest b/thirdparty/doctest
new file mode 160000
+Subproject 86892fc480f80fb57d9a3926cb506c0e974489d