aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tflite_supported_operators.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tflite_supported_operators.py')
-rw-r--r--ethosu/vela/tflite_supported_operators.py38
1 files changed, 31 insertions, 7 deletions
diff --git a/ethosu/vela/tflite_supported_operators.py b/ethosu/vela/tflite_supported_operators.py
index be86e9a3..9aa174de 100644
--- a/ethosu/vela/tflite_supported_operators.py
+++ b/ethosu/vela/tflite_supported_operators.py
@@ -1,4 +1,4 @@
-# Copyright (C) 2020-2021 Arm Limited or its affiliates. All rights reserved.
+# Copyright (C) 2020-2022 Arm Limited or its affiliates. All rights reserved.
#
# SPDX-License-Identifier: Apache-2.0
#
@@ -255,6 +255,11 @@ class TFLiteSupportedOperators:
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_resize_attrs)
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_resize_half_pixel_centers)
+ # Resize Bilinear specific checks:
+ self.specific_constraints[Op.ResizeBilinear].append(
+ TFLiteSupportedOperators.constraint_resizebi_half_pixel_centers_dims
+ )
+
# Vector Product specific checks:
for op_type in TFLiteSupportedOperators.fc_vector_products:
self.specific_constraints[op_type].append(TFLiteSupportedOperators.constraint_weights_type)
@@ -602,8 +607,8 @@ class TFLiteSupportedOperators:
"""The width and height of the IFM and OFM must match one of the following criteria:
IFM W and H must both be 1
IFM must match OFM
- OFM W and H must be equal and OFM W-1 and H-1 must be 2x/4x/8x IFM W-1 and H-1, if align_corners is True
- OFM W and H must be equal and OFM W and H must be 2x/4x/8x IFM W and H, if align_corners is False"""
+ W and H scaling must be equal and OFM W-1 and H-1 must be 2x/4x/8x IFM W-1 and H-1, if align_corners is True
+ W and H scaling must be equal and OFM W and H must be 2x/4x/8x IFM W and H, if align_corners is False"""
# Easier to start with False condition as very few cases result in a supported resize
valid = False
ifm_shape = op.ifm.shape
@@ -661,11 +666,30 @@ class TFLiteSupportedOperators:
@staticmethod
def constraint_resize_half_pixel_centers(op):
- "half_pixel_centers are not supported"
- valid = True
- if op.attrs.get("half_pixel_centers", False):
+ """Half_pixel_centers are only supported for resize bilinear with IFM dtype int8 or uint8"""
+ valid = op.ifm.dtype in (DataType.int8, DataType.uint8)
+ half_pixel_centers = op.attrs.get("half_pixel_centers", False)
+ if half_pixel_centers and op.type != Op.ResizeBilinear:
+ valid = False
+ return valid, f"Op type={op.type}, ifm dtype={op.ifm.dtype} and half_pixel_centers={half_pixel_centers}"
+
+ @staticmethod
+ def constraint_resizebi_half_pixel_centers_dims(op):
+ """Half_pixel_centers for resize bilinear requires that OFM W and H is 2x IFM W and H"""
+ half_pixel_centers = op.attrs.get("half_pixel_centers", False)
+ if not half_pixel_centers:
+ valid = True
+ elif len(op.ifm.shape) >= 3:
+ ifm_h, ifm_w = op.ifm.shape[-3:-1]
+ ofm_h, ofm_w = op.ofm.shape[-3:-1]
+ valid = ofm_h / ifm_h == 2 and ofm_w / ifm_w == 2
+ else:
+ # Unexpected IFM shape
valid = False
- return valid, f"Op has half_pixel_centers set to {not valid}."
+ return (
+ valid,
+ f"Op has ifm_shape={op.ifm.shape}, ofm_shape={op.ofm.shape} and half_pixel_centers={half_pixel_centers}",
+ )
@staticmethod
def constraint_pad_shape(op):