diff options
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r-- | ethosu/vela/tflite_supported_operators.py | 38 |
1 files changed, 31 insertions, 7 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py index be86e9a3..9aa174de 100644 --- a/ethosu/vela/tflite_supported_operators.py +++ b/ethosu/vela/tflite_supported_operators.py @@ -1,4 +1,4 @@ -# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved. +# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved. # # SPDX-License-Identifier: Apache-2.0 # @@ -255,6 +255,11 @@ class TFLiteSupportedOperators: self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_resize_attrs) self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_resize_half_pixel_centers) + # Resize Bilinear specific checks: + self.specific_constraints[Op.ResizeBilinear].append( + TFLiteSupportedOperators.constraint_resizebi_half_pixel_centers_dims + ) + # Vector Product specific checks: for op_type in TFLiteSupportedOperators.fc_vector_products: self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_weights_type) @@ -602,8 +607,8 @@ class TFLiteSupportedOperators: """The width and height of the IFM and OFM must match one of the following criteria: IFM W and H must both be 1 IFM must match OFM - OFM W and H must be equal and OFM W-1 and H-1 must be 2x/4x/8x IFM W-1 and H-1, if align_corners is True - OFM W and H must be equal and OFM W and H must be 2x/4x/8x IFM W and H, if align_corners is False""" + W and H scaling must be equal and OFM W-1 and H-1 must be 2x/4x/8x IFM W-1 and H-1, if align_corners is True + W and H scaling must be equal and OFM W and H must be 2x/4x/8x IFM W and H, if align_corners is False""" # Easier to start with False condition as very few cases result in a supported resize valid = False ifm_shape = op.ifm.shape @@ -661,11 +666,30 @@ class TFLiteSupportedOperators: @staticmethod def constraint_resize_half_pixel_centers(op): - "half_pixel_centers are not supported" - valid = True - if op.attrs.get("half_pixel_centers", False): + """Half_pixel_centers are only supported for resize bilinear with IFM dtype int8 or uint8""" + valid = op.ifm.dtype in (DataType.int8, DataType.uint8) + half_pixel_centers = op.attrs.get("half_pixel_centers", False) + if half_pixel_centers and op.type != Op.ResizeBilinear: + valid = False + return valid, f"Op type={op.type}, ifm dtype={op.ifm.dtype} and half_pixel_centers={half_pixel_centers}" + + @staticmethod + def constraint_resizebi_half_pixel_centers_dims(op): + """Half_pixel_centers for resize bilinear requires that OFM W and H is 2x IFM W and H""" + half_pixel_centers = op.attrs.get("half_pixel_centers", False) + if not half_pixel_centers: + valid = True + elif len(op.ifm.shape) >= 3: + ifm_h, ifm_w = op.ifm.shape[-3:-1] + ofm_h, ofm_w = op.ofm.shape[-3:-1] + valid = ofm_h / ifm_h == 2 and ofm_w / ifm_w == 2 + else: + # Unexpected IFM shape valid = False - return valid, f"Op has half_pixel_centers set to {not valid}." + return ( + valid, + f"Op has ifm_shape={op.ifm.shape}, ofm_shape={op.ofm.shape} and half_pixel_centers={half_pixel_centers}", + ) @staticmethod def constraint_pad_shape(op): |