aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/supported_operators.py
diff options
context:
space:
mode:
authorJacob Bohlin <jacob.bohlin@arm.com>2020-08-19 14:36:46 +0200
committerJacob Bohlin <jacob.bohlin@arm.com>2020-08-28 08:56:00 +0200
commit49d9212edf94ad71a00208b893d2181a33ce8648 (patch)
treef53613f96670e4a54bfe66112598e7b7d182a419 /ethosu/vela/supported_operators.py
parent99fcb89f6a6c09b5677a5323abdee5714ac5117a (diff)
downloadethos-u-vela-49d9212edf94ad71a00208b893d2181a33ce8648.tar.gz
MLBEDSW-2804: Added bias data type check
Allows int64 data type to be used as long as all values can be packed into a int40 value. Signed-off-by: Jacob Bohlin <jacob.bohlin@arm.com> Change-Id: I0e25ec482e3ea765a5fd00bcf7e212a9e65a1461
Diffstat (limited to 'ethosu/vela/supported_operators.py')
-rw-r--r--ethosu/vela/supported_operators.py23
1 files changed, 21 insertions, 2 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index f57cbee2..8ec77207 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -201,10 +201,13 @@ class SupportedOperators:
return False
# check data type
- ifm_tensor, _, weight_tensor, _ = op.get_ifm_ifm2_weights_ofm()
+ ifm_tensor, _, weight_tensor, bias_tensor, _ = op.get_ifm_ifm2_weights_biases_ofm()
if weight_tensor.element_size() > 1:
return False
+ if not self.check_bias_restrictions(bias_tensor):
+ return False
+
# check kernel size [HWIO]
dilated_weight_w = weight_tensor.shape[1] + (weight_tensor.shape[1] - 1) * (dilation_w_factor - 1)
dilated_weight_h = weight_tensor.shape[0] + (weight_tensor.shape[0] - 1) * (dilation_h_factor - 1)
@@ -307,10 +310,13 @@ class SupportedOperators:
def check_vector_product_restrictions(self, op):
# check data type
- ifm_tensor, _, weight_tensor, _ = op.get_ifm_ifm2_weights_ofm()
+ _, _, weight_tensor, bias_tensor, _ = op.get_ifm_ifm2_weights_biases_ofm()
if weight_tensor.element_size() > 1:
return False
+ if not self.check_bias_restrictions(bias_tensor):
+ return False
+
return True
def check_element_wise_restrictions(self, op):
@@ -407,3 +413,16 @@ class SupportedOperators:
return False
return True
+
+ def check_bias_restrictions(self, bias_tensor):
+ # check data type
+ if bias_tensor.dtype not in (DataType.int32, DataType.int64):
+ return False
+
+ # check if values fits in 40-bit
+ if bias_tensor.dtype == DataType.int64:
+ for value in bias_tensor.values:
+ if not (-(1 << 39) <= value < (1 << 39)):
+ return False
+
+ return True