aboutsummaryrefslogtreecommitdiff
path: root/scripts/operator_api/generate_api.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/operator_api/generate_api.py')
-rw-r--r--scripts/operator_api/generate_api.py7
1 files changed, 5 insertions, 2 deletions
diff --git a/scripts/operator_api/generate_api.py b/scripts/operator_api/generate_api.py
index 499eadb..f1cb6e0 100644
--- a/scripts/operator_api/generate_api.py
+++ b/scripts/operator_api/generate_api.py
@@ -1,5 +1,5 @@
"""Generate extended reference model API with eager operator execution entrypoints"""
-# Copyright (c) 2021-2022, ARM Limited.
+# Copyright (c) 2021-2023, ARM Limited.
# SPDX-License-Identifier: Apache-2.0
import copy
import os
@@ -190,7 +190,7 @@ def getOperators(tosaXml):
Return a list of TOSA operators as defined by tosa.xml.
"""
operators = []
- ignoreOps = ["while_loop", "cond_if", "const", "custom", "fft2d", "rfft2d", "erf"]
+ ignoreOps = ["while_loop", "cond_if", "const", "custom", "fft2d", "rfft2d"]
opsXml = tosaXml.getElementsByTagName("operator")
allSerializeArgs = getSerializeArgs()
for opXml in opsXml:
@@ -241,6 +241,9 @@ def getTosaArgs(opXml):
argType = xmlArg.getAttribute("type")
argShape = xmlArg.getAttribute("shape")
argCategory = xmlArg.getAttribute("category")
+ # FullyConnected workaround
+ if (argName == "weight" or argName == "bias") and (argCategory == "attribute"):
+ argCategory = "input"
# Update argument type
if argType[-1:] == "*":
argType = argType[:-1]