diff options
Diffstat (limited to 'scripts/operator_api/generate_api.py')
-rw-r--r-- | scripts/operator_api/generate_api.py | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py index 499eadb..f1cb6e0 100644 --- a/scripts/operator_api/generate_api.py +++ b/scripts/operator_api/generate_api.py @@ -1,5 +1,5 @@ """Generate extended reference model API with eager operator execution entrypoints""" -# Copyright (c) 2021-2022, ARM Limited. +# Copyright (c) 2021-2023, ARM Limited. # SPDX-License-Identifier: Apache-2.0 import copy import os @@ -190,7 +190,7 @@ def getOperators(tosaXml): Return a list of TOSA operators as defined by tosa.xml. """ operators = [] - ignoreOps = ["while_loop", "cond_if", "const", "custom", "fft2d", "rfft2d", "erf"] + ignoreOps = ["while_loop", "cond_if", "const", "custom", "fft2d", "rfft2d"] opsXml = tosaXml.getElementsByTagName("operator") allSerializeArgs = getSerializeArgs() for opXml in opsXml: @@ -241,6 +241,9 @@ def getTosaArgs(opXml): argType = xmlArg.getAttribute("type") argShape = xmlArg.getAttribute("shape") argCategory = xmlArg.getAttribute("category") + # FullyConnected workaround + if (argName == "weight" or argName == "bias") and (argCategory == "attribute"): + argCategory = "input" # Update argument type if argType[-1:] == "*": argType = argType[:-1] |