aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFredrik Svedberg <fredrik.svedberg@arm.com>2020-05-25 16:32:00 +0200
committerTim Hall <tim.hall@arm.com>2020-06-18 17:53:52 +0100
commit388e9c230898385df59e6175aa45012e5864c09a (patch)
treeed7091ec07ba57275b09b306ad0e7dd6c6e7bef3
parent86841e7dfb2df70c7959f0fccdd2fe1b878a98e2 (diff)
downloadethos-u-vela-388e9c230898385df59e6175aa45012e5864c09a.tar.gz
[MLBEDSW-1996] Update supported operator checks
Updated supported operator checks according to latest requirements. Change-Id: I79708d8039e464e39818d3c09e61f3f533e96f3d Signed-off-by: Fredrik Svedberg <fredrik.svedberg@arm.com>
-rw-r--r--ethosu/vela/supported_operators.py42
1 files changed, 29 insertions, 13 deletions
diff --git a/ethosu/vela/supported_operators.py b/ethosu/vela/supported_operators.py
index 574b3a49..ce3fa609 100644
--- a/ethosu/vela/supported_operators.py
+++ b/ethosu/vela/supported_operators.py
@@ -15,7 +15,7 @@
# limitations under the License.
# Description:
# The SupportedOperators class which is a collection of all supported operators and parameter checks.
-from .data_type import BaseType
+from .data_type import BaseType, DataType
class SupportedOperators:
@@ -45,9 +45,9 @@ class SupportedOperators:
| set(("ResizeBilinear",))
)
self.unary_elem_wise_main_ops = set(("LeakyRelu", "Abs"))
- self.binary_elem_wise_main_ops = set(
+ self.binary_elem_wise_min_max_ops = set(("Minimum", "Maximum"))
+ self.binary_elem_wise_add_mul_sub = set(
(
- # binary element-wise
"AddAct",
"MulAct",
"SubAct",
@@ -57,10 +57,9 @@ class SupportedOperators:
"Mul",
"Add",
"Sub",
- "Minimum",
- "Maximum",
)
)
+ self.binary_elem_wise_main_ops = self.binary_elem_wise_min_max_ops | self.binary_elem_wise_add_mul_sub
self.elem_wise_main_ops = self.binary_elem_wise_main_ops | self.unary_elem_wise_main_ops
self.activation_ops = set(
("QuantizedRelu", "QuantizedRelu1", "QuantizedRelu6", "Relu", "Relu6", "ReluN1To1", "Sigmoid", "Tanh")
@@ -124,7 +123,7 @@ class SupportedOperators:
for t in tensors:
if not (t.dtype.type & BaseType.Int):
return False
- if t.element_size() > 2 and op.type != "Requantize":
+ if t.element_size() > 2 and op.type not in ("Requantize") | self.binary_elem_wise_add_mul_sub:
return False
# check size
if any(dim > 65536 for dim in t.shape):
@@ -197,15 +196,13 @@ class SupportedOperators:
# check kernel size
if op.attrs["padding"] == b"SAME" and (op.attrs["filter_width"] > 8 or op.attrs["filter_height"] > 8):
return False
- if op.attrs["padding"] == b"VALID" and (op.attrs["filter_width"] > 256 or op.attrs["filter_height"] > 256):
+ if (op.attrs["padding"] == b"VALID" and
+ (op.attrs["filter_width"] * op.attrs["filter_height"] > 256 * 256 or op.attrs["filter_height"] > 256)):
return False
if op.type in self.max_pooling_ops:
- # check data type
- if not ifm_tensor.dtype == ofm_tensor.dtype:
- return False
- # check kernel size
- if op.attrs["filter_width"] > 256 or op.attrs["filter_height"] > 256: # any padding
+ # check kernel size (any padding)
+ if op.attrs["filter_width"] * op.attrs["filter_height"] > 256 * 256 or op.attrs["filter_height"] > 256:
return False
return True
@@ -220,8 +217,27 @@ class SupportedOperators:
def check_element_wise_restrictions(self, op):
# check data type
ifm_tensor, ifm2_tensor, _, ofm_tensor = op.get_ifm_ifm2_weights_ofm()
- if op.type in ("Minimum", "Maximum") and ifm_tensor.dtype != ofm_tensor.dtype:
+ # input and output datatype must match for these operators
+ if (op.type in self.binary_elem_wise_min_max_ops | self.unary_elem_wise_main_ops and
+ ifm_tensor.dtype != ofm_tensor.dtype):
return False
+ if (op.type in self.binary_elem_wise_add_mul_sub):
+ # both inputs must have same type
+ if (ifm_tensor.dtype != ifm2_tensor.dtype):
+ return False
+ # signed input check
+ if (ifm_tensor.dtype.type & BaseType.Signed):
+ # output must be signed
+ if (ofm_tensor.dtype.type & BaseType.Unsigned):
+ return False
+ # and 8, 16 or 32-bit
+ if (ofm_tensor.element_size() not in (1, 2, 4)):
+ return False
+ # unsigned input check, output must be same type or int32
+ if (ifm_tensor.dtype.type & BaseType.Unsigned and not
+ (ifm_tensor.dtype == ofm_tensor.dtype or
+ ofm_tensor.dtype == DataType.int32)):
+ return False
# check batch size
if len(ifm_tensor.shape) > 2 and ifm_tensor.shape[0] != 1: