diff options
author | Grant Watson <grant.watson@arm.com> | 2023-08-28 16:34:28 +0100 |
---|---|---|
committer | Eric Kunze <eric.kunze@arm.com> | 2023-09-05 19:03:49 +0000 |
commit | e70d93175afb3bc053d264a02be2156a7354039c (patch) | |
tree | ce510859a642f617210d0190308a532fb7458df3 /scripts | |
parent | 7935972b1ae427b597ce2e817c5071c44d7ba56e (diff) | |
download | reference_model-e70d93175afb3bc053d264a02be2156a7354039c.tar.gz |
Pass func_config to individual operator API
Updates the generate_api.py script and associated
templates to allow func_config and debug_config
to be passed when running individual operators
on the API.
This will allow us, for example, to set precise_mode
and abs_mode when running individual operators.
Signed-off-by: Grant Watson <grant.watson@arm.com>
Change-Id: Ia3e7ffc146f876daa307558433177c68285843b7
Diffstat (limited to 'scripts')
-rw-r--r-- | scripts/operator_api/generate_api.py | 7 | ||||
-rw-r--r-- | scripts/operator_api/templates/operators_cc.j2 | 25 | ||||
-rw-r--r-- | scripts/operator_api/templates/operators_h.j2 | 8 |
3 files changed, 32 insertions, 8 deletions
diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py index 499eadb..f1cb6e0 100644 --- a/scripts/operator_api/generate_api.py +++ b/scripts/operator_api/generate_api.py @@ -1,5 +1,5 @@ """Generate extended reference model API with eager operator execution entrypoints""" -# Copyright (c) 2021-2022, ARM Limited. +# Copyright (c) 2021-2023, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import copy import os @@ -190,7 +190,7 @@ def getOperators(tosaXml): Return a list of TOSA operators as defined by tosa.xml. """ operators = [] - ignoreOps = ["while_loop", "cond_if", "const", "custom", "fft2d", "rfft2d", "erf"] + ignoreOps = ["while_loop", "cond_if", "const", "custom", "fft2d", "rfft2d"] opsXml = tosaXml.getElementsByTagName("operator") allSerializeArgs = getSerializeArgs() for opXml in opsXml: @@ -241,6 +241,9 @@ def getTosaArgs(opXml): argType = xmlArg.getAttribute("type") argShape = xmlArg.getAttribute("shape") argCategory = xmlArg.getAttribute("category") + # FullyConnected workaround + if (argName == "weight" or argName == "bias") and (argCategory == "attribute"): + argCategory = "input" # Update argument type if argType[-1:] == "*": argType = argType[:-1] diff --git a/scripts/operator_api/templates/operators_cc.j2 b/scripts/operator_api/templates/operators_cc.j2 index 37a0af6..a8f1c24 100644 --- a/scripts/operator_api/templates/operators_cc.j2 +++ b/scripts/operator_api/templates/operators_cc.j2 @@ -45,12 +45,28 @@ tosa::DType translate_client_datatype(tosa_datatype_t type) { switch (type) { + case tosa_datatype_bf16_t: + return tosa::DType::DType_BF16; + case tosa_datatype_bool_t: + return tosa::DType::DType_BOOL; case tosa_datatype_fp16_t: return tosa::DType::DType_FP16; case tosa_datatype_fp32_t: return tosa::DType::DType_FP32; - case tosa_datatype_bool_t: - return tosa::DType::DType_BOOL; + case tosa_datatype_int16_t: + return tosa::DType::DType_INT16; + case tosa_datatype_int32_t: + return tosa::DType::DType_INT32; + case tosa_datatype_int48_t: + return tosa::DType::DType_INT48; + case tosa_datatype_int4_t: + return tosa::DType::DType_INT4; + case tosa_datatype_int8_t: + return tosa::DType::DType_INT8; + case tosa_datatype_uint16_t: + return tosa::DType::DType_UINT16; + case tosa_datatype_uint8_t: + return tosa::DType::DType_UINT8; default: return tosa::DType::DType_UNKNOWN; } @@ -83,7 +99,8 @@ extern "C" {%- for arg in operator.arguments: -%} {% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}} {% if loop.index < operator.arguments|length %},{% endif %} - {%- endfor -%} + {%- endfor -%}, const func_config_t &func_config, + const func_debug_t &func_debug ) { // Create operator attributes @@ -157,7 +174,7 @@ extern "C" }); // Setup model - TosaReference::ModelRunnerImpl runner; + TosaReference::ModelRunnerImpl runner(func_config, func_debug); TOSA_RETURN_ON_GRAPH_STATUS_ERROR(runner.initialize(block)); {% for input in operator.inputs: -%} TOSA_RETURN_ON_ERROR(runner.setInput({{input}}->GetName(), client_{{input}}.data, client_{{input}}.size)); diff --git a/scripts/operator_api/templates/operators_h.j2 b/scripts/operator_api/templates/operators_h.j2 index 803b76a..042d7a5 100644 --- a/scripts/operator_api/templates/operators_h.j2 +++ b/scripts/operator_api/templates/operators_h.j2 @@ -1,5 +1,5 @@ -// Copyright (c) 2022, ARM Limited. +// Copyright (c) 2022-2023, ARM Limited. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -19,6 +19,9 @@ #ifndef OPERATORS_H_ #define OPERATORS_H_ +#include "func_config.h" +#include "func_debug.h" + #include <stddef.h> #include <stdint.h> @@ -64,7 +67,8 @@ extern "C" { {%- for arg in operator.arguments: -%} {% if arg.type != "tosa_tensor_t" -%}const {% endif -%}{{arg.type}} client_{{arg.name}}{{arg.shape}} {% if loop.index < operator.arguments|length %},{% endif %} - {%- endfor -%}); + {%- endfor -%}, const func_config_t &func_config = func_config_t{}, + const func_debug_t &func_debug = func_debug_t{}); {% endfor %} #ifdef __cplusplus |