From 530992a3943eb21e12f6d0e638940d7df27a9f51 Mon Sep 17 00:00:00 2001 From: Patrik Gustavsson Date: Wed, 30 Sep 2020 13:26:59 +0200 Subject: MLBEDSW-3001 Fix Min Max OPs not properly checked Min and max operations was not passed through the checking of elementwize OPs in the supported operator checking. Changed so they are passed through this check as well. Signed-off-by: Patrik Gustavsson Change-Id: I358a121de33882802415d97d9ed5dbee53233f77 --- ethosu/vela/supported_operators.py | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) (limited to 'ethosu') diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py index eec1b900..73a4f282 100644 --- a/ethosu/vela/supported_operators.py +++ b/ethosu/vela/supported_operators.py @@ -121,9 +121,6 @@ class SupportedOperators: self.supported_operator_restrictions.update( {op: self.check_memory_only_restrictions for op in self.memory_only_ops} ) - self.supported_operator_restrictions.update( - {op: self.check_quantization_restrictions_binary_elem_wise for op in self.binary_elem_wise_min_max_ops} - ) self.supported_operator_restrictions.update({op: self.check_activation_ops for op in self.activation_ops}) def is_operator_supported(self, op): @@ -201,10 +198,12 @@ class SupportedOperators: # check inf values for tens in op.get_ifm_ifm2_weights_ofm(): - if (tens is not None) and ( - tens.quantization is not None) and ( - tens.quantization.scale_f32 is not None) and ( - np.isinf(tens.quantization.scale_f32).any()): + if ( + (tens is not None) + and (tens.quantization is not None) + and (tens.quantization.scale_f32 is not None) + and (np.isinf(tens.quantization.scale_f32).any()) + ): print("Warning:", op.type, "has inf valued tensor(s), placing on CPU") return False @@ -398,6 +397,11 @@ class SupportedOperators: if ifm_tensor.shape != ofm_tensor.shape and ifm2_tensor.shape != ofm_tensor.shape: return False + if op.type in self.binary_elem_wise_min_max_ops and not self.check_quantization_restrictions_binary_elem_wise( + op + ): + return False + return True def check_memory_only_restrictions(self, op): @@ -470,7 +474,16 @@ class SupportedOperators: return False for i in range(ofm_dims): if i != axis and ifm.shape[i] != ofm.shape[i]: - print("Warning:", op.type, "invalid ifm:", ifm.name, ifm.shape, "mismatch in dimension", i, ", placing on CPU") + print( + "Warning:", + op.type, + "invalid ifm:", + ifm.name, + ifm.shape, + "mismatch in dimension", + i, + ", placing on CPU", + ) return False return True -- cgit v1.2.1