diff options
-rw-r--r-- | ethosu/vela/tflite_model_semantic.py | 2 | ||||
-rw-r--r-- | ethosu/vela/tflite_supported_operators.py | 20 |
2 files changed, 21 insertions, 1 deletions
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py index 51d1f072..b783bb74 100644 --- a/ethosu/vela/tflite_model_semantic.py +++ b/ethosu/vela/tflite_model_semantic.py @@ -160,7 +160,7 @@ class TFLiteSemantic: valid, extra = constraint(op) if not valid: print( - f"Warning: unsupported TensorFlow Lite semantics for {ext_type} '{op.name}'. Placing on CPU instead" + f"Warning: Unsupported TensorFlow Lite semantics for {ext_type} '{op.name}'. Placing on CPU instead" ) print(f" - {constraint.__doc__}") if extra: diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py index a3c0dd8d..5c7fd517 100644 --- a/ethosu/vela/tflite_supported_operators.py +++ b/ethosu/vela/tflite_supported_operators.py @@ -210,6 +210,9 @@ class TFLiteSupportedOperators: self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product) self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product_int8) + # Reshape specific checks: + self.specific_constraints[Op.Reshape].append(TFLiteSupportedOperators.constraint_reshape_shape_constant) + def is_operator_supported(self, op): ext_type = optype_to_builtintype(op.type) if op.type not in TFLiteSupportedOperators.supported_operators: @@ -682,3 +685,20 @@ class TFLiteSupportedOperators: h, w = shape[hi : hi + 2] max_prod = cls.mean_kernel_product_int8 return h * w <= max_prod, f"Product of height and width is {h * w}" + + @staticmethod + def constraint_reshape_shape_constant(op): + "Shape must be constant" + valid = True + extra = [] + + reshape_tens = op.inputs[1] + if reshape_tens is not None: + # constant inputs have either no driving operator or a const one + # create a list of non-constant inputs + if not (len(reshape_tens.ops) == 0 or reshape_tens.ops[0].type == Op.Const): + valid = False + extra.append(reshape_tens.name) + extra = ", ".join(extra) + + return valid, f"Op has non-const input(s): {extra}" |