From 678645b7b7b788543e08fe0767cce784cb92a1f9 Mon Sep 17 00:00:00 2001 From: Louis Verhaard Date: Mon, 15 Jun 2020 15:22:47 +0200 Subject: MLBEDSW-2436: Support for HardSwish operator - Added support for HardSwish (placed on CPU) - Improved error reporting for unknown operator codes in input file Signed-off-by: Louis Verhaard Change-Id: I1d1c7b9d786288d7098450cdad2b67fc0759378b --- ethosu/vela/tflite_mapping.py | 4 ++++ ethosu/vela/tflite_reader.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/ethosu/vela/tflite_mapping.py b/ethosu/vela/tflite_mapping.py index 4873ecc2..06097cd0 100644 --- a/ethosu/vela/tflite_mapping.py +++ b/ethosu/vela/tflite_mapping.py @@ -54,6 +54,7 @@ from .tflite import GatherNdOptions from .tflite import GatherOptions from .tflite import GreaterEqualOptions from .tflite import GreaterOptions +from .tflite import HardSwishOptions from .tflite import IfOptions from .tflite import L2NormOptions from .tflite import LeakyReluOptions @@ -258,6 +259,8 @@ builtin_options_map = { BuiltinOptions.MatrixSetDiagOptions: MatrixSetDiagOptions.MatrixSetDiagOptions, BuiltinOptions.DensifyOptions: DensifyOptions.DensifyOptions, BuiltinOptions.DepthToSpaceOptions: DepthToSpaceOptions.DepthToSpaceOptions, + BuiltinOptions.HardSwishOptions: HardSwishOptions.HardSwishOptions, + BuiltinOptions.IfOptions: IfOptions.IfOptions, BuiltinOptions.NonMaxSuppressionV4Options: NonMaxSuppressionV4Options.NonMaxSuppressionV4Options, BuiltinOptions.NonMaxSuppressionV5Options: NonMaxSuppressionV5Options.NonMaxSuppressionV5Options, @@ -622,6 +625,7 @@ builtin_operator_map = { BuiltinOperator.MATRIX_DIAG: ("MatrixDiag", None), BuiltinOperator.QUANTIZE: ("Quantize", None), BuiltinOperator.MATRIX_SET_DIAG: ("MatrixSetDiag", None), + BuiltinOperator.HARD_SWISH: ("HardSwish", OptionsSerializer("HardSwishOptions")), BuiltinOperator.IF: ("If", OptionsSerializer("IfOptions", ("then_subgraph_index", "else_subgraph_index"))), BuiltinOperator.WHILE: ("While", OptionsSerializer("WhileOptions", ("cond_subgraph_index", "body_subgraph_index"))), BuiltinOperator.NON_MAX_SUPPRESSION_V4: ("NonMaxSuppressionV4", OptionsSerializer("NonMaxSuppressionV4Options")), diff --git a/ethosu/vela/tflite_reader.py b/ethosu/vela/tflite_reader.py index 5ab90f04..84c4c3c2 100644 --- a/ethosu/vela/tflite_reader.py +++ b/ethosu/vela/tflite_reader.py @@ -19,6 +19,7 @@ import os.path import numpy as np +from .errors import InputFileError from .errors import UnsupportedFeatureError from .nn_graph import Graph from .nn_graph import Subgraph @@ -228,6 +229,9 @@ class TFLiteGraph: def parse_operator_code(self, code): c = code.BuiltinCode() + if not c in builtin_operator_map: + msg = "The input file contains operator code {} which is currently not supported".format(c) + raise InputFileError(self.name, msg) op_type, ser = builtin_operator_map[c] if c == BuiltinOperator.CUSTOM: op_type += decode_str(code.CustomCode()) -- cgit v1.2.1