diff options
Diffstat (limited to 'ethosu/vela')
-rw-r--r-- | ethosu/vela/tflite_mapping.py | 4 | ||||
-rw-r--r-- | ethosu/vela/tflite_reader.py | 4 |
2 files changed, 8 insertions, 0 deletions
diff --git a/ethosu/vela/tflite_mapping.py b/ethosu/vela/tflite_mapping.py index 4873ecc2..06097cd0 100644 --- a/ethosu/vela/tflite_mapping.py +++ b/ethosu/vela/tflite_mapping.py @@ -54,6 +54,7 @@ from .tflite import GatherNdOptions from .tflite import GatherOptions from .tflite import GreaterEqualOptions from .tflite import GreaterOptions +from .tflite import HardSwishOptions from .tflite import IfOptions from .tflite import L2NormOptions from .tflite import LeakyReluOptions @@ -258,6 +259,8 @@ builtin_options_map = { BuiltinOptions.MatrixSetDiagOptions: MatrixSetDiagOptions.MatrixSetDiagOptions, BuiltinOptions.DensifyOptions: DensifyOptions.DensifyOptions, BuiltinOptions.DepthToSpaceOptions: DepthToSpaceOptions.DepthToSpaceOptions, + BuiltinOptions.HardSwishOptions: HardSwishOptions.HardSwishOptions, + BuiltinOptions.IfOptions: IfOptions.IfOptions, BuiltinOptions.NonMaxSuppressionV4Options: NonMaxSuppressionV4Options.NonMaxSuppressionV4Options, BuiltinOptions.NonMaxSuppressionV5Options: NonMaxSuppressionV5Options.NonMaxSuppressionV5Options, @@ -622,6 +625,7 @@ builtin_operator_map = { BuiltinOperator.MATRIX_DIAG: ("MatrixDiag", None), BuiltinOperator.QUANTIZE: ("Quantize", None), BuiltinOperator.MATRIX_SET_DIAG: ("MatrixSetDiag", None), + BuiltinOperator.HARD_SWISH: ("HardSwish", OptionsSerializer("HardSwishOptions")), BuiltinOperator.IF: ("If", OptionsSerializer("IfOptions", ("then_subgraph_index", "else_subgraph_index"))), BuiltinOperator.WHILE: ("While", OptionsSerializer("WhileOptions", ("cond_subgraph_index", "body_subgraph_index"))), BuiltinOperator.NON_MAX_SUPPRESSION_V4: ("NonMaxSuppressionV4", OptionsSerializer("NonMaxSuppressionV4Options")), diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py index 5ab90f04..84c4c3c2 100644 --- a/ethosu/vela/tflite_reader.py +++ b/ethosu/vela/tflite_reader.py @@ -19,6 +19,7 @@ import os.path import numpy as np +from .errors import InputFileError from .errors import UnsupportedFeatureError from .nn_graph import Graph from .nn_graph import Subgraph @@ -228,6 +229,9 @@ class TFLiteGraph: def parse_operator_code(self, code): c = code.BuiltinCode() + if not c in builtin_operator_map: + msg = "The input file contains operator code {} which is currently not supported".format(c) + raise InputFileError(self.name, msg) op_type, ser = builtin_operator_map[c] if c == BuiltinOperator.CUSTOM: op_type += decode_str(code.CustomCode()) |