aboutsummaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
authorGrant Watson <grant.watson@arm.com>2023-08-28 16:34:28 +0100
committerEric Kunze <eric.kunze@arm.com>2023-09-05 19:03:49 +0000
commite70d93175afb3bc053d264a02be2156a7354039c (patch)
treece510859a642f617210d0190308a532fb7458df3 /scripts
parent7935972b1ae427b597ce2e817c5071c44d7ba56e (diff)
downloadreference_model-e70d93175afb3bc053d264a02be2156a7354039c.tar.gz
Pass func_config to individual operator API
Updates the generate_api.py script and associated templates to allow func_config and debug_config to be passed when running individual operators on the API. This will allow us, for example, to set precise_mode and abs_mode when running individual operators. Signed-off-by: Grant Watson <grant.watson@arm.com> Change-Id: Ia3e7ffc146f876daa307558433177c68285843b7
Diffstat (limited to 'scripts')
-rw-r--r--scripts/operator_api/generate_api.py7
-rw-r--r--scripts/operator_api/templates/operators_cc.j225
-rw-r--r--scripts/operator_api/templates/operators_h.j28
3 files changed, 32 insertions, 8 deletions
diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py
index 499eadb..f1cb6e0 100644
--- a/scripts/operator_api/generate_api.py
+++ b/scripts/operator_api/generate_api.py
@@ -1,5 +1,5 @@
"""Generate extended reference model API with eager operator execution entrypoints"""
-# Copyright (c) 2021-2022, ARM Limited.
+# Copyright (c) 2021-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import copy
import os
@@ -190,7 +190,7 @@ def getOperators(tosaXml):
Return a list of TOSA operators as defined by tosa.xml.
"""
operators = []
- ignoreOps = ["while_loop", "cond_if", "const", "custom", "fft2d", "rfft2d", "erf"]
+ ignoreOps = ["while_loop", "cond_if", "const", "custom", "fft2d", "rfft2d"]
opsXml = tosaXml.getElementsByTagName("operator")
allSerializeArgs = getSerializeArgs()
for opXml in opsXml:
@@ -241,6 +241,9 @@ def getTosaArgs(opXml):
argType = xmlArg.getAttribute("type")
argShape = xmlArg.getAttribute("shape")
argCategory = xmlArg.getAttribute("category")
+ # FullyConnected workaround
+ if (argName == "weight" or argName == "bias") and (argCategory == "attribute"):
+ argCategory = "input"
# Update argument type
if argType[-1:] == "*":
argType = argType[:-1]
diff --git a/scripts/operator_api/templates/operators_cc.j2 b/scripts/operator_api/templates/operators_cc.j2
index 37a0af6..a8f1c24 100644
--- a/scripts/operator_api/templates/operators_cc.j2
+++ b/scripts/operator_api/templates/operators_cc.j2
@@ -45,12 +45,28 @@ tosa::DType translate_client_datatype(tosa_datatype_t type)
{
switch (type)
{
+ case tosa_datatype_bf16_t:
+ return tosa::DType::DType_BF16;
+ case tosa_datatype_bool_t:
+ return tosa::DType::DType_BOOL;
case tosa_datatype_fp16_t:
return tosa::DType::DType_FP16;
case tosa_datatype_fp32_t:
return tosa::DType::DType_FP32;
- case tosa_datatype_bool_t:
- return tosa::DType::DType_BOOL;
+ case tosa_datatype_int16_t:
+ return tosa::DType::DType_INT16;
+ case tosa_datatype_int32_t:
+ return tosa::DType::DType_INT32;
+ case tosa_datatype_int48_t:
+ return tosa::DType::DType_INT48;
+ case tosa_datatype_int4_t:
+ return tosa::DType::DType_INT4;
+ case tosa_datatype_int8_t:
+ return tosa::DType::DType_INT8;
+ case tosa_datatype_uint16_t:
+ return tosa::DType::DType_UINT16;
+ case tosa_datatype_uint8_t:
+ return tosa::DType::DType_UINT8;
default:
return tosa::DType::DType_UNKNOWN;
}
@@ -83,7 +99,8 @@ extern "C"
{%- for arg in operator.arguments: -%}
{% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}}
{% if loop.index < operator.arguments|length %},{% endif %}
- {%- endfor -%}
+ {%- endfor -%}, const func_config_t &func_config,
+ const func_debug_t &func_debug
)
{
// Create operator attributes
@@ -157,7 +174,7 @@ extern "C"
});
// Setup model
- TosaReference::ModelRunnerImpl runner;
+ TosaReference::ModelRunnerImpl runner(func_config, func_debug);
TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block));
{% for input in operator.inputs: -%}
TOSA_RETURN_ON_ERROR(runner.setInput({{input}}->GetName(), client_{{input}}.data, client_{{input}}.size));
diff --git a/scripts/operator_api/templates/operators_h.j2 b/scripts/operator_api/templates/operators_h.j2
index 803b76a..042d7a5 100644
--- a/scripts/operator_api/templates/operators_h.j2
+++ b/scripts/operator_api/templates/operators_h.j2
@@ -1,5 +1,5 @@
-// Copyright (c) 2022, ARM Limited.
+// Copyright (c) 2022-2023, ARM Limited.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
@@ -19,6 +19,9 @@
#ifndef OPERATORS_H_
#define OPERATORS_H_
+#include "func_config.h"
+#include "func_debug.h"
+
#include <stddef.h>
#include <stdint.h>
@@ -64,7 +67,8 @@ extern "C" {
{%- for arg in operator.arguments: -%}
{% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}}
{% if loop.index < operator.arguments|length %},{% endif %}
- {%- endfor -%});
+ {%- endfor -%}, const func_config_t &func_config = func_config_t{},
+ const func_debug_t &func_debug = func_debug_t{});
{% endfor %}
#ifdef __cplusplus