summaryrefslogtreecommitdiff
path: root/source/application/api/use_case/inference_runner
diff options
context:
space:
mode:
Diffstat (limited to 'source/application/api/use_case/inference_runner')
-rw-r--r--source/application/api/use_case/inference_runner/CMakeLists.txt4
-rw-r--r--source/application/api/use_case/inference_runner/include/MicroMutableAllOpsResolver.hpp119
-rw-r--r--source/application/api/use_case/inference_runner/include/TestModel.hpp21
-rw-r--r--source/application/api/use_case/inference_runner/src/MicroMutableAllOpsResolver.cc128
-rw-r--r--source/application/api/use_case/inference_runner/src/TestModel.cc11
5 files changed, 160 insertions, 123 deletions
diff --git a/source/application/api/use_case/inference_runner/CMakeLists.txt b/source/application/api/use_case/inference_runner/CMakeLists.txt
index a27ce63..e4754e6 100644
--- a/source/application/api/use_case/inference_runner/CMakeLists.txt
+++ b/source/application/api/use_case/inference_runner/CMakeLists.txt
@@ -1,5 +1,5 @@
#----------------------------------------------------------------------------
-# SPDX-FileCopyrightText: Copyright 2022 Arm Limited and/or its affiliates <open-source-office@arm.com>
+# SPDX-FileCopyrightText: Copyright 2022-2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -25,7 +25,7 @@ project(${INFERENCE_RUNNER_API_TARGET}
LANGUAGES C CXX)
# Create static library
-add_library(${INFERENCE_RUNNER_API_TARGET} STATIC src/TestModel.cc)
+add_library(${INFERENCE_RUNNER_API_TARGET} STATIC src/TestModel.cc src/MicroMutableAllOpsResolver.cc)
target_include_directories(${INFERENCE_RUNNER_API_TARGET} PUBLIC include)
diff --git a/source/application/api/use_case/inference_runner/include/MicroMutableAllOpsResolver.hpp b/source/application/api/use_case/inference_runner/include/MicroMutableAllOpsResolver.hpp
index 96ac28d..67a7c9e 100644
--- a/source/application/api/use_case/inference_runner/include/MicroMutableAllOpsResolver.hpp
+++ b/source/application/api/use_case/inference_runner/include/MicroMutableAllOpsResolver.hpp
@@ -14,122 +14,25 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-#ifndef INF_RUNNER_MICRO_MUTABLE_ALLOPS_RESOLVER_HPP
-#define INF_RUNNER_MICRO_MUTABLE_ALLOPS_RESOLVER_HPP
+#ifndef INF_RUNNER_MICRO_MUTABLE_ALL_OPS_RESOLVER_HPP
+#define INF_RUNNER_MICRO_MUTABLE_ALL_OPS_RESOLVER_HPP
#include <tensorflow/lite/micro/micro_mutable_op_resolver.h>
-constexpr int kNumberOperators = 97;
-
namespace arm {
namespace app {
- /* Create our own AllOpsResolver by adding all Ops to MicroMutableOpResolver. */
- inline tflite::MicroMutableOpResolver<kNumberOperators> CreateAllOpsResolver() {
- tflite::MicroMutableOpResolver<kNumberOperators> mutableAllOpResolver;
-
- mutableAllOpResolver.AddAbs();
- mutableAllOpResolver.AddAdd();
- mutableAllOpResolver.AddAddN();
- mutableAllOpResolver.AddArgMax();
- mutableAllOpResolver.AddArgMin();
- mutableAllOpResolver.AddAssignVariable();
- mutableAllOpResolver.AddAveragePool2D();
- mutableAllOpResolver.AddBatchToSpaceNd();
- mutableAllOpResolver.AddBroadcastArgs();
- mutableAllOpResolver.AddBroadcastTo();
- mutableAllOpResolver.AddCallOnce();
- mutableAllOpResolver.AddCast();
- mutableAllOpResolver.AddCeil();
- mutableAllOpResolver.AddCircularBuffer();
- mutableAllOpResolver.AddConcatenation();
- mutableAllOpResolver.AddConv2D();
- mutableAllOpResolver.AddCos();
- mutableAllOpResolver.AddCumSum();
- mutableAllOpResolver.AddDepthToSpace();
- mutableAllOpResolver.AddDepthwiseConv2D();
- mutableAllOpResolver.AddDequantize();
- mutableAllOpResolver.AddDetectionPostprocess();
- mutableAllOpResolver.AddDiv();
- mutableAllOpResolver.AddElu();
- mutableAllOpResolver.AddEqual();
- mutableAllOpResolver.AddEthosU();
- mutableAllOpResolver.AddExp();
- mutableAllOpResolver.AddExpandDims();
- mutableAllOpResolver.AddFill();
- mutableAllOpResolver.AddFloor();
- mutableAllOpResolver.AddFloorDiv();
- mutableAllOpResolver.AddFloorMod();
- mutableAllOpResolver.AddFullyConnected();
- mutableAllOpResolver.AddGather();
- mutableAllOpResolver.AddGatherNd();
- mutableAllOpResolver.AddGreater();
- mutableAllOpResolver.AddGreaterEqual();
- mutableAllOpResolver.AddHardSwish();
- mutableAllOpResolver.AddIf();
- mutableAllOpResolver.AddL2Normalization();
- mutableAllOpResolver.AddL2Pool2D();
- mutableAllOpResolver.AddLeakyRelu();
- mutableAllOpResolver.AddLess();
- mutableAllOpResolver.AddLessEqual();
- mutableAllOpResolver.AddLog();
- mutableAllOpResolver.AddLogicalAnd();
- mutableAllOpResolver.AddLogicalNot();
- mutableAllOpResolver.AddLogicalOr();
- mutableAllOpResolver.AddLogistic();
- mutableAllOpResolver.AddLogSoftmax();
- mutableAllOpResolver.AddMaxPool2D();
- mutableAllOpResolver.AddMaximum();
- mutableAllOpResolver.AddMean();
- mutableAllOpResolver.AddMinimum();
- mutableAllOpResolver.AddMirrorPad();
- mutableAllOpResolver.AddMul();
- mutableAllOpResolver.AddNeg();
- mutableAllOpResolver.AddNotEqual();
- mutableAllOpResolver.AddPack();
- mutableAllOpResolver.AddPad();
- mutableAllOpResolver.AddPadV2();
- mutableAllOpResolver.AddPrelu();
- mutableAllOpResolver.AddQuantize();
- mutableAllOpResolver.AddReadVariable();
- mutableAllOpResolver.AddReduceMax();
- mutableAllOpResolver.AddRelu();
- mutableAllOpResolver.AddRelu6();
- mutableAllOpResolver.AddReshape();
- mutableAllOpResolver.AddResizeBilinear();
- mutableAllOpResolver.AddResizeNearestNeighbor();
- mutableAllOpResolver.AddRound();
- mutableAllOpResolver.AddRsqrt();
- mutableAllOpResolver.AddSelectV2();
- mutableAllOpResolver.AddShape();
- mutableAllOpResolver.AddSin();
- mutableAllOpResolver.AddSlice();
- mutableAllOpResolver.AddSoftmax();
- mutableAllOpResolver.AddSpaceToBatchNd();
- mutableAllOpResolver.AddSpaceToDepth();
- mutableAllOpResolver.AddSplit();
- mutableAllOpResolver.AddSplitV();
- mutableAllOpResolver.AddSqrt();
- mutableAllOpResolver.AddSquare();
- mutableAllOpResolver.AddSquaredDifference();
- mutableAllOpResolver.AddSqueeze();
- mutableAllOpResolver.AddStridedSlice();
- mutableAllOpResolver.AddSub();
- mutableAllOpResolver.AddSum();
- mutableAllOpResolver.AddSvdf();
- mutableAllOpResolver.AddTanh();
- mutableAllOpResolver.AddTranspose();
- mutableAllOpResolver.AddTransposeConv();
- mutableAllOpResolver.AddUnidirectionalSequenceLSTM();
- mutableAllOpResolver.AddUnpack();
- mutableAllOpResolver.AddVarHandle();
- mutableAllOpResolver.AddWhile();
- mutableAllOpResolver.AddZerosLike();
+ /* Maximum number of individual operations that can be enlisted. */
+ constexpr int kNumberOperators = 97;
- return mutableAllOpResolver;
- }
+ /** An Op resolver containing all ops is no longer supplied with TFLite Micro
+ * so we create our own instead for the generic inference runner.
+ *
+ * @return MicroMutableOpResolver containing all TFLite Micro Ops registered.
+ */
+ tflite::MicroMutableOpResolver<kNumberOperators> CreateAllOpsResolver();
} /* namespace app */
} /* namespace arm */
-#endif /* INF_RUNNER_MICRO_MUTABLE_ALLOPS_RESOLVER_HPP */
+#endif /* INF_RUNNER_MICRO_MUTABLE_ALL_OPS_RESOLVER_HPP */
diff --git a/source/application/api/use_case/inference_runner/include/TestModel.hpp b/source/application/api/use_case/inference_runner/include/TestModel.hpp
index 455e244..4fbbfc0 100644
--- a/source/application/api/use_case/inference_runner/include/TestModel.hpp
+++ b/source/application/api/use_case/inference_runner/include/TestModel.hpp
@@ -23,20 +23,19 @@
namespace arm {
namespace app {
- class TestModel : public Model {
+ class TestModel : public Model {
- protected:
- /** @brief Gets the reference to op resolver interface class. */
- const tflite::MicroMutableOpResolver<kNumberOperators>& GetOpResolver() override;
+ protected:
+ /** @brief Gets the reference to op resolver interface class. */
+ const tflite::MicroOpResolver& GetOpResolver() override;
- /** @brief Adds operations to the op resolver instance, not needed as using AllOpsResolver. */
- bool EnlistOperations() override {return false;}
+ /** @brief Adds operations to the op resolver instance. */
+ bool EnlistOperations() override;
- private:
-
- /* No need to define individual ops at the cost of extra memory. */
- tflite::MicroMutableOpResolver<kNumberOperators> m_opResolver = CreateAllOpsResolver();
- };
+ private:
+ /* A mutable op resolver instance including every operation for Inference runner. */
+ tflite::MicroMutableOpResolver<kNumberOperators> m_opResolver;
+ };
} /* namespace app */
} /* namespace arm */
diff --git a/source/application/api/use_case/inference_runner/src/MicroMutableAllOpsResolver.cc b/source/application/api/use_case/inference_runner/src/MicroMutableAllOpsResolver.cc
new file mode 100644
index 0000000..ed50912
--- /dev/null
+++ b/source/application/api/use_case/inference_runner/src/MicroMutableAllOpsResolver.cc
@@ -0,0 +1,128 @@
+/*
+ * SPDX-FileCopyrightText: Copyright 2023 Arm Limited and/or its affiliates <open-source-office@arm.com>
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "MicroMutableAllOpsResolver.hpp"
+
+namespace arm {
+namespace app {
+ /* Create our own AllOpsResolver by adding all Ops to MicroMutableOpResolver. */
+ tflite::MicroMutableOpResolver<kNumberOperators> CreateAllOpsResolver()
+ {
+ tflite::MicroMutableOpResolver<kNumberOperators> mutableAllOpResolver;
+
+ mutableAllOpResolver.AddAbs();
+ mutableAllOpResolver.AddAdd();
+ mutableAllOpResolver.AddAddN();
+ mutableAllOpResolver.AddArgMax();
+ mutableAllOpResolver.AddArgMin();
+ mutableAllOpResolver.AddAssignVariable();
+ mutableAllOpResolver.AddAveragePool2D();
+ mutableAllOpResolver.AddBatchToSpaceNd();
+ mutableAllOpResolver.AddBroadcastArgs();
+ mutableAllOpResolver.AddBroadcastTo();
+ mutableAllOpResolver.AddCallOnce();
+ mutableAllOpResolver.AddCast();
+ mutableAllOpResolver.AddCeil();
+ mutableAllOpResolver.AddCircularBuffer();
+ mutableAllOpResolver.AddConcatenation();
+ mutableAllOpResolver.AddConv2D();
+ mutableAllOpResolver.AddCos();
+ mutableAllOpResolver.AddCumSum();
+ mutableAllOpResolver.AddDepthToSpace();
+ mutableAllOpResolver.AddDepthwiseConv2D();
+ mutableAllOpResolver.AddDequantize();
+ mutableAllOpResolver.AddDetectionPostprocess();
+ mutableAllOpResolver.AddDiv();
+ mutableAllOpResolver.AddElu();
+ mutableAllOpResolver.AddEqual();
+ mutableAllOpResolver.AddEthosU();
+ mutableAllOpResolver.AddExp();
+ mutableAllOpResolver.AddExpandDims();
+ mutableAllOpResolver.AddFill();
+ mutableAllOpResolver.AddFloor();
+ mutableAllOpResolver.AddFloorDiv();
+ mutableAllOpResolver.AddFloorMod();
+ mutableAllOpResolver.AddFullyConnected();
+ mutableAllOpResolver.AddGather();
+ mutableAllOpResolver.AddGatherNd();
+ mutableAllOpResolver.AddGreater();
+ mutableAllOpResolver.AddGreaterEqual();
+ mutableAllOpResolver.AddHardSwish();
+ mutableAllOpResolver.AddIf();
+ mutableAllOpResolver.AddL2Normalization();
+ mutableAllOpResolver.AddL2Pool2D();
+ mutableAllOpResolver.AddLeakyRelu();
+ mutableAllOpResolver.AddLess();
+ mutableAllOpResolver.AddLessEqual();
+ mutableAllOpResolver.AddLog();
+ mutableAllOpResolver.AddLogicalAnd();
+ mutableAllOpResolver.AddLogicalNot();
+ mutableAllOpResolver.AddLogicalOr();
+ mutableAllOpResolver.AddLogistic();
+ mutableAllOpResolver.AddLogSoftmax();
+ mutableAllOpResolver.AddMaxPool2D();
+ mutableAllOpResolver.AddMaximum();
+ mutableAllOpResolver.AddMean();
+ mutableAllOpResolver.AddMinimum();
+ mutableAllOpResolver.AddMirrorPad();
+ mutableAllOpResolver.AddMul();
+ mutableAllOpResolver.AddNeg();
+ mutableAllOpResolver.AddNotEqual();
+ mutableAllOpResolver.AddPack();
+ mutableAllOpResolver.AddPad();
+ mutableAllOpResolver.AddPadV2();
+ mutableAllOpResolver.AddPrelu();
+ mutableAllOpResolver.AddQuantize();
+ mutableAllOpResolver.AddReadVariable();
+ mutableAllOpResolver.AddReduceMax();
+ mutableAllOpResolver.AddRelu();
+ mutableAllOpResolver.AddRelu6();
+ mutableAllOpResolver.AddReshape();
+ mutableAllOpResolver.AddResizeBilinear();
+ mutableAllOpResolver.AddResizeNearestNeighbor();
+ mutableAllOpResolver.AddRound();
+ mutableAllOpResolver.AddRsqrt();
+ mutableAllOpResolver.AddSelectV2();
+ mutableAllOpResolver.AddShape();
+ mutableAllOpResolver.AddSin();
+ mutableAllOpResolver.AddSlice();
+ mutableAllOpResolver.AddSoftmax();
+ mutableAllOpResolver.AddSpaceToBatchNd();
+ mutableAllOpResolver.AddSpaceToDepth();
+ mutableAllOpResolver.AddSplit();
+ mutableAllOpResolver.AddSplitV();
+ mutableAllOpResolver.AddSqrt();
+ mutableAllOpResolver.AddSquare();
+ mutableAllOpResolver.AddSquaredDifference();
+ mutableAllOpResolver.AddSqueeze();
+ mutableAllOpResolver.AddStridedSlice();
+ mutableAllOpResolver.AddSub();
+ mutableAllOpResolver.AddSum();
+ mutableAllOpResolver.AddSvdf();
+ mutableAllOpResolver.AddTanh();
+ mutableAllOpResolver.AddTranspose();
+ mutableAllOpResolver.AddTransposeConv();
+ mutableAllOpResolver.AddUnidirectionalSequenceLSTM();
+ mutableAllOpResolver.AddUnpack();
+ mutableAllOpResolver.AddVarHandle();
+ mutableAllOpResolver.AddWhile();
+ mutableAllOpResolver.AddZerosLike();
+ return mutableAllOpResolver;
+ }
+
+} /* namespace app */
+} /* namespace arm */
diff --git a/source/application/api/use_case/inference_runner/src/TestModel.cc b/source/application/api/use_case/inference_runner/src/TestModel.cc
index c69a98d..94f17ef 100644
--- a/source/application/api/use_case/inference_runner/src/TestModel.cc
+++ b/source/application/api/use_case/inference_runner/src/TestModel.cc
@@ -16,8 +16,15 @@
*/
#include "TestModel.hpp"
#include "log_macros.h"
+#include "MicroMutableAllOpsResolver.hpp"
-const tflite::MicroMutableOpResolver<kNumberOperators>& arm::app::TestModel::GetOpResolver()
+const tflite::MicroOpResolver& arm::app::TestModel::GetOpResolver()
{
- return this->m_opResolver;
+ return this->m_opResolver;
+}
+
+bool arm::app::TestModel::EnlistOperations()
+{
+ this->m_opResolver = CreateAllOpsResolver();
+ return true;
}