# Copyright (C) 2021 Arm Limited or its affiliates. All rights reserved. # # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the License); you may # not use this file except in compliance with the License. # You may obtain a copy of the License at # # www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an AS IS BASIS, WITHOUT # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # Description: # The TosaSupportedOperators class which is a collection of all supported operators and parameter checks. from collections import defaultdict from .data_type import DataType from .operation import Op from .supported_operators_util import docstring_format_args from .supported_operators_util import list_formatter from .tosa_mapping import optype_to_tosa_op_type class TosaSupportedOperators: # TODO currently sparsely populated # Categorised lists of supported operators convolution_ops = set((Op.Conv2DBias,)) convolution_like_ops = convolution_ops mac_main_ops = convolution_like_ops type_conversion_ops = set((Op.Rescale,)) relu_ops = set((Op.Clamp, Op.ReluN,)) activation_ops = relu_ops npu_post_ops = activation_ops supported_operators = mac_main_ops | type_conversion_ops | npu_post_ops # Supported data types # TODO will differ compared to TensorFlow Lite, currently set to the same supported_op_dtypes = set((DataType.uint8, DataType.int8, DataType.int16, DataType.int32)) def __init__(self): # Setup the generic constraints. Note: the order matters self.generic_constraints = [] self.generic_constraints.append(TosaSupportedOperators.constraint_tens_dtype) # Setup specific constraints. Note: the order matters self.specific_constraints = defaultdict(list) def is_operator_supported(self, op): ext_type = optype_to_tosa_op_type(op.type) if op.type not in TosaSupportedOperators.supported_operators: if op.type not in (Op.Placeholder, Op.SubgraphInput, Op.Const): print(f"Info: {ext_type} '{op.name}' is not a NPU op") return False for constraint in self.generic_constraints + self.specific_constraints[op.type]: valid, extra = constraint(op) if not valid: print(f"Warning: {ext_type} '{op.name}' is not supported on the NPU") print(f" - {constraint.__doc__}") if extra: print(f" {extra}") return False return True # TODO this function is the same for TensorFlow Lite, but input might differ @classmethod @docstring_format_args([list_formatter(supported_op_dtypes)]) def constraint_tens_dtype(cls, op): "Tensors must be of type: {}" valid = True extra = [] tensors = [tens for tens in op.get_ifm_ifm2_weights_ofm() if tens] if not tensors: tensors = [tens for tens in op.inputs if tens] for tens in tensors: if tens.dtype not in cls.supported_op_dtypes: valid = False extra.append(f"Tensor '{tens.name}' has data type: {tens.dtype}") return valid, ", ".join(extra)