aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ethosu/vela/tflite_model_semantic.py2
-rw-r--r--ethosu/vela/tflite_supported_operators.py20
2 files changed, 21 insertions, 1 deletions
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py
index 51d1f072..b783bb74 100644
--- a/ethosu/vela/tflite_model_semantic.py
+++ b/ethosu/vela/tflite_model_semantic.py
@@ -160,7 +160,7 @@ class TFLiteSemantic:
valid, extra = constraint(op)
if not valid:
print(
- f"Warning: unsupported TensorFlow Lite semantics for {ext_type} '{op.name}'. Placing on CPU instead"
+ f"Warning: Unsupported TensorFlow Lite semantics for {ext_type} '{op.name}'. Placing on CPU instead"
)
print(f" - {constraint.__doc__}")
if extra:
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index a3c0dd8d..5c7fd517 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -210,6 +210,9 @@ class TFLiteSupportedOperators:
self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product)
self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product_int8)
+ # Reshape specific checks:
+ self.specific_constraints[Op.Reshape].append(TFLiteSupportedOperators.constraint_reshape_shape_constant)
+
def is_operator_supported(self, op):
ext_type = optype_to_builtintype(op.type)
if op.type not in TFLiteSupportedOperators.supported_operators:
@@ -682,3 +685,20 @@ class TFLiteSupportedOperators:
h, w = shape[hi : hi + 2]
max_prod = cls.mean_kernel_product_int8
return h * w <= max_prod, f"Product of height and width is {h * w}"
+
+ @staticmethod
+ def constraint_reshape_shape_constant(op):
+ "Shape must be constant"
+ valid = True
+ extra = []
+
+ reshape_tens = op.inputs[1]
+ if reshape_tens is not None:
+ # constant inputs have either no driving operator or a const one
+ # create a list of non-constant inputs
+ if not (len(reshape_tens.ops) == 0 or reshape_tens.ops[0].type == Op.Const):
+ valid = False
+ extra.append(reshape_tens.name)
+ extra = ", ".join(extra)
+
+ return valid, f"Op has non-const input(s): {extra}"