diff options
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r-- | ethosu/vela/tflite_supported_operators.py | 31 |
1 files changed, 26 insertions, 5 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py index 48813fe..ad61fca 100644 --- a/ethosu/vela/tflite_supported_operators.py +++ b/ethosu/vela/tflite_supported_operators.py @@ -106,10 +106,17 @@ class TFLiteSupportedOperators: ) ) binary_elem_wise_main_ops = binary_elem_wise_min_max_ops | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops + elem_wise_main_ops = binary_elem_wise_main_ops | unary_elem_wise_main_ops | set((Op.SquaredDifference,)) - pad_ops = set((Op.Pad,)) + pad_ops = set( + ( + Op.Pad, + Op.MirrorPad, + ) + ) + supported_int32_tensor_ops = ( - set((Op.ReduceSum, Op.CLZ, Op.Shape, Op.ArgMax, Op.Transpose)) + set((Op.ReduceSum, Op.CLZ, Op.Shape, Op.ArgMax, Op.Transpose, Op.MirrorPad)) | binary_elem_wise_add_mul_sub | binary_elem_wise_shift_ops ) @@ -312,9 +319,13 @@ class TFLiteSupportedOperators: self.specific_constraints[Op.StridedSlice].append(TFLiteSupportedOperators.constraint_stridedslice_offset_false) # Pad specific checks: - self.specific_constraints[Op.Pad].append(TFLiteSupportedOperators.constraint_pad_shape) - self.specific_constraints[Op.Pad].append(TFLiteSupportedOperators.constraint_padding_dimensions) - self.specific_constraints[Op.Pad].append(TFLiteSupportedOperators.constraint_pad_type) + for op_type in TFLiteSupportedOperators.pad_ops: + self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_pad_shape) + self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_padding_dimensions) + self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_pad_type) + + # Mirror pad specific checks: + self.specific_constraints[Op.MirrorPad].append(TFLiteSupportedOperators.constraint_mirror_pad_padding_values) # Mean specific checks: self.specific_constraints[Op.Mean].append(TFLiteSupportedOperators.constraint_mean_height_width_product) @@ -818,6 +829,16 @@ class TFLiteSupportedOperators: return valid, f"First dimension padding: {pad_tensor[0,:]}, last dimension padding: {pad_tensor[-1,:]}" @staticmethod + def constraint_mirror_pad_padding_values(op): + "The number of pad values for each direction must not be larger than the ifm size in that dimension" + pad_tensor = op.inputs[1].values + ifm_shape = op.inputs[0].shape + for dim_padding, ifm_dim_shape in enumerate(pad_tensor, ifm_shape): + if any(dim_padding > ifm_dim_shape): + valid = False + return valid, f"IFM shape: {ifm_shape}, number of padding values per dimension: {pad_tensor}" + + @staticmethod def constraint_stridedslice_stride_values(op): "All Strides values must be 1" strides = op.inputs[3] |