diff options
author | Fredrik Svedberg <fredrik.svedberg@arm.com> | 2020-05-25 16:32:00 +0200 |
---|---|---|
committer | Tim Hall <tim.hall@arm.com> | 2020-06-18 17:53:52 +0100 |
commit | 388e9c230898385df59e6175aa45012e5864c09a (patch) | |
tree | ed7091ec07ba57275b09b306ad0e7dd6c6e7bef3 /ethosu/vela | |
parent | 86841e7dfb2df70c7959f0fccdd2fe1b878a98e2 (diff) | |
download | ethos-u-vela-388e9c230898385df59e6175aa45012e5864c09a.tar.gz |
[MLBEDSW-1996] Update supported operator checks
Updated supported operator checks according to latest requirements.
Change-Id: I79708d8039e464e39818d3c09e61f3f533e96f3d
Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
Diffstat (limited to 'ethosu/vela')
-rw-r--r-- | ethosu/vela/supported_operators.py | 42 |
1 files changed, 29 insertions, 13 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py index 574b3a49..ce3fa609 100644 --- a/ethosu/vela/supported_operators.py +++ b/ethosu/vela/supported_operators.py @@ -15,7 +15,7 @@ # limitations under the License. # Description: # The SupportedOperators class which is a collection of all supported operators and parameter checks. -from .data_type import BaseType +from .data_type import BaseType, DataType class SupportedOperators: @@ -45,9 +45,9 @@ class SupportedOperators: | set(("ResizeBilinear",)) ) self.unary_elem_wise_main_ops = set(("LeakyRelu", "Abs")) - self.binary_elem_wise_main_ops = set( + self.binary_elem_wise_min_max_ops = set(("Minimum", "Maximum")) + self.binary_elem_wise_add_mul_sub = set( ( - # binary element-wise "AddAct", "MulAct", "SubAct", @@ -57,10 +57,9 @@ class SupportedOperators: "Mul", "Add", "Sub", - "Minimum", - "Maximum", ) ) + self.binary_elem_wise_main_ops = self.binary_elem_wise_min_max_ops | self.binary_elem_wise_add_mul_sub self.elem_wise_main_ops = self.binary_elem_wise_main_ops | self.unary_elem_wise_main_ops self.activation_ops = set( ("QuantizedRelu", "QuantizedRelu1", "QuantizedRelu6", "Relu", "Relu6", "ReluN1To1", "Sigmoid", "Tanh") @@ -124,7 +123,7 @@ class SupportedOperators: for t in tensors: if not (t.dtype.type & BaseType.Int): return False - if t.element_size() > 2 and op.type != "Requantize": + if t.element_size() > 2 and op.type not in ("Requantize") | self.binary_elem_wise_add_mul_sub: return False # check size if any(dim > 65536 for dim in t.shape): @@ -197,15 +196,13 @@ class SupportedOperators: # check kernel size if op.attrs["padding"] == b"SAME" and (op.attrs["filter_width"] > 8 or op.attrs["filter_height"] > 8): return False - if op.attrs["padding"] == b"VALID" and (op.attrs["filter_width"] > 256 or op.attrs["filter_height"] > 256): + if (op.attrs["padding"] == b"VALID" and + (op.attrs["filter_width"] * op.attrs["filter_height"] > 256 * 256 or op.attrs["filter_height"] > 256)): return False if op.type in self.max_pooling_ops: - # check data type - if not ifm_tensor.dtype == ofm_tensor.dtype: - return False - # check kernel size - if op.attrs["filter_width"] > 256 or op.attrs["filter_height"] > 256: # any padding + # check kernel size (any padding) + if op.attrs["filter_width"] * op.attrs["filter_height"] > 256 * 256 or op.attrs["filter_height"] > 256: return False return True @@ -220,8 +217,27 @@ class SupportedOperators: def check_element_wise_restrictions(self, op): # check data type ifm_tensor, ifm2_tensor, _, ofm_tensor = op.get_ifm_ifm2_weights_ofm() - if op.type in ("Minimum", "Maximum") and ifm_tensor.dtype != ofm_tensor.dtype: + # input and output datatype must match for these operators + if (op.type in self.binary_elem_wise_min_max_ops | self.unary_elem_wise_main_ops and + ifm_tensor.dtype != ofm_tensor.dtype): return False + if (op.type in self.binary_elem_wise_add_mul_sub): + # both inputs must have same type + if (ifm_tensor.dtype != ifm2_tensor.dtype): + return False + # signed input check + if (ifm_tensor.dtype.type & BaseType.Signed): + # output must be signed + if (ofm_tensor.dtype.type & BaseType.Unsigned): + return False + # and 8, 16 or 32-bit + if (ofm_tensor.element_size() not in (1, 2, 4)): + return False + # unsigned input check, output must be same type or int32 + if (ifm_tensor.dtype.type & BaseType.Unsigned and not + (ifm_tensor.dtype == ofm_tensor.dtype or + ofm_tensor.dtype == DataType.int32)): + return False # check batch size if len(ifm_tensor.shape) > 2 and ifm_tensor.shape[0] != 1: |