aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_model_semantic.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_model_semantic.py')
-rw-r--r--ethosu/vela/tflite_model_semantic.py11
1 files changed, 11 insertions, 0 deletions
diff --git a/ethosu/vela/tflite_model_semantic.py b/ethosu/vela/tflite_model_semantic.py
index 62648914..ea7ef4a3 100644
--- a/ethosu/vela/tflite_model_semantic.py
+++ b/ethosu/vela/tflite_model_semantic.py
@@ -184,6 +184,7 @@ class TFLiteSemantic:
# Pad specific checks:
self.specific_constraints[Op.Pad].append(TFLiteSemantic.constraint_pad_input_count)
self.specific_constraints[Op.Pad].append(TFLiteSemantic.constraint_pad_constant)
+ self.specific_constraints[Op.Pad].append(TFLiteSemantic.constraint_pad_output_shape)
# HardSwish specific checks:
self.specific_constraints[Op.HardSwish].append(TFLiteSemantic.constraint_input_8bit)
@@ -586,6 +587,16 @@ class TFLiteSemantic:
return valid, f"Op has non-constant padding tensor: {op.inputs[1].values}"
@staticmethod
+ def constraint_pad_output_shape(op):
+ "Shape of output tensor must equal to size of input tensor plus padding"
+ input_shape = op.inputs[0].shape
+ expected_output_shape = op.outputs[0].shape
+ pad_tensor = op.inputs[1].values
+ actual_output_shape = input_shape + pad_tensor.T[0] + pad_tensor.T[1]
+ valid = np.array_equal(actual_output_shape, expected_output_shape)
+ return valid, f"Op has wrong output tensor shape: {expected_output_shape}, has shape: {actual_output_shape}"
+
+ @staticmethod
def constraint_stridedslice_inputs_const(op):
"Begin, End and Stride Input tensors must be constant"
valid = True