From 3584a9cfdf0bcf0e75d38b78ec39e5b083947e19 Mon Sep 17 00:00:00 2001 From: Tim Hall Date: Thu, 18 Nov 2021 22:05:17 +0000 Subject: MLBEDSW-3602: Output mismatch on some mobilenet_v1 int8 and int16 - The failing tests contain operations with dynamic tensors which are not supported and therefore they should be placed on the CPU. However, a bug in the removal of RESHAPEs which contain a dynamic shape prevented this happening. - This change adds a check to make sure that RESHAPE ops with a dynamic shape tensor are not removed and instead are placed on the CPU. Signed-off-by: Tim Hall Change-Id: I2d7481f7f80f99a0f01df100d956933777e6875a --- ethosu/vela/tflite_model_semantic.py | 2 +- ethosu/vela/tflite_supported_operators.py | 20 ++++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py index 51d1f072..b783bb74 100644 --- a/ethosu/vela/tflite_model_semantic.py +++ b/ethosu/vela/tflite_model_semantic.py @@ -160,7 +160,7 @@ class TFLiteSemantic: valid, extra = constraint(op) if not valid: print( - f"Warning: unsupported TensorFlow Lite semantics for {ext_type} '{op.name}'. Placing on CPU instead" + f"Warning: Unsupported TensorFlow Lite semantics for {ext_type} '{op.name}'. Placing on CPU instead" ) print(f" - {constraint.__doc__}") if extra: diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py index a3c0dd8d..5c7fd517 100644 --- a/ethosu/vela/tflite_supported_operators.py +++ b/ethosu/vela/tflite_supported_operators.py @@ -210,6 +210,9 @@ class TFLiteSupportedOperators: self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product) self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product_int8) + # Reshape specific checks: + self.specific_constraints[Op.Reshape].append(TFLiteSupportedOperators.constraint_reshape_shape_constant) + def is_operator_supported(self, op): ext_type = optype_to_builtintype(op.type) if op.type not in TFLiteSupportedOperators.supported_operators: @@ -682,3 +685,20 @@ class TFLiteSupportedOperators: h, w = shape[hi : hi + 2] max_prod = cls.mean_kernel_product_int8 return h * w <= max_prod, f"Product of height and width is {h * w}" + + @staticmethod + def constraint_reshape_shape_constant(op): + "Shape must be constant" + valid = True + extra = [] + + reshape_tens = op.inputs[1] + if reshape_tens is not None: + # constant inputs have either no driving operator or a const one + # create a list of non-constant inputs + if not (len(reshape_tens.ops) == 0 or reshape_tens.ops[0].type == Op.Const): + valid = False + extra.append(reshape_tens.name) + extra = ", ".join(extra) + + return valid, f"Op has non-const input(s): {extra}" -- cgit v1.2.1