diff options
Diffstat (limited to 'ethosu/vela/operation.py')
-rw-r--r-- | ethosu/vela/operation.py | 514 |
1 files changed, 347 insertions, 167 deletions
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py index 14818870..a2b67dfb 100644 --- a/ethosu/vela/operation.py +++ b/ethosu/vela/operation.py @@ -15,10 +15,11 @@ # limitations under the License. # Description: # Internal representation of a Neural Network Operation. -import enum +from collections import namedtuple +from enum import Enum -class NpuBlockType(enum.Enum): +class NpuBlockType(Enum): Default = 0 ConvolutionMxN = 1 VectorProduct = 2 @@ -28,10 +29,266 @@ class NpuBlockType(enum.Enum): ReduceSum = 6 +# Classifies operators of type Custom +class CustomType(Enum): + ThirdPartyOp = 0 # Third party custom op + NpuOp = 1 # NPU op + ExistingNpuOp = 2 # NPU op that was part of the input network + + +TensorIndices = namedtuple("TensorIndices", ["ifms", "weights", "biases"]) + +NO_INDICES = TensorIndices([], [], []) +IFM_INDICES = TensorIndices([0], [], []) +IFM_WEIGHTS_INDICES = TensorIndices([0], [1], []) +IFM_WEIGHTS_BIAS_INDICES = TensorIndices([0], [1], [2]) +IFM_IFM2_INDICES = TensorIndices([0, 1], [], []) +CONV2D_BACKPROP_INDICES = TensorIndices([2], [1], [3]) +TRANSPOSE_CONV_INDICES = TensorIndices([0], [1], [3]) +CONCAT_INDICES = TensorIndices([1, 2], [], []) +SPLIT_IFM_INDICES = TensorIndices([1], [], []) +BLOCK_LSTM_INDICES = TensorIndices([3], [4], []) + + +# Static information related to operation codes +class OperatorInfo: + __slots__ = ("id", "block_type", "indices", "is_unary") + _id = 0 + + def __init__(self, block_type=NpuBlockType.Default, indices=NO_INDICES, is_unary=False): + OperatorInfo._id += 1 + self.id = OperatorInfo._id + self.block_type = block_type + self.indices = indices # Indices of the different tensor purposes + self.is_unary = is_unary # Classifies elementwise operators + + +# Internally used operation codes +class Op(Enum): + Abs = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True) + Add = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES) + AddN = OperatorInfo() + Any = OperatorInfo() + ArgMax = OperatorInfo() + ArgMin = OperatorInfo() + AvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES) + BatchMatMul = OperatorInfo() + BatchToSpaceND = OperatorInfo() + BidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES) + BidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES) + BlockLSTM = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=BLOCK_LSTM_INDICES) + + CLZ = OperatorInfo( + block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True + ) # NPU specific operation + Call = OperatorInfo() + Cast = OperatorInfo() + Ceil = OperatorInfo() + Concat = OperatorInfo(indices=CONCAT_INDICES) + ConcatEmbeddings = OperatorInfo() + ConcatSliceWrite = OperatorInfo(indices=IFM_INDICES) + ConcatTFLite = OperatorInfo() + Const = OperatorInfo() # Constant tensor, only used in CPU subgraphs + Conv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_INDICES) + Conv2DBackpropInput = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=CONV2D_BACKPROP_INDICES) + Conv2DBackpropInputSwitchedBias = OperatorInfo( + block_type=NpuBlockType.ConvolutionMxN, indices=TRANSPOSE_CONV_INDICES + ) + Conv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_BIAS_INDICES) + Cos = OperatorInfo() + Custom = OperatorInfo() # Custom 3rd party operator, only used in CPU subgraphs + CustomNpuOp = OperatorInfo() # NPU custom operator, only used in CPU subgraphs + DMA = OperatorInfo() + Delegate = OperatorInfo() + Densify = OperatorInfo() + DepthToSpace = OperatorInfo() + DepthwiseConv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionDepthWise, indices=IFM_WEIGHTS_BIAS_INDICES) + Dequantize = OperatorInfo() + Div = OperatorInfo() + Elu = OperatorInfo() + EmbeddingLookup = OperatorInfo() + EmbeddingLookupSparse = OperatorInfo() + Equal = OperatorInfo() + Exp = OperatorInfo() + ExpandDims = OperatorInfo(indices=IFM_INDICES) + FakeQuantWithMinMaxArgs = OperatorInfo() + Fill = OperatorInfo() + Floor = OperatorInfo() + FloorDiv = OperatorInfo() + FloorMod = OperatorInfo() + FullyConnected = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_BIAS_INDICES) + GatherNd = OperatorInfo() + GatherV2 = OperatorInfo() + Greater = OperatorInfo() + GreaterEqual = OperatorInfo() + HardSwish = OperatorInfo() + HashtableLookup = OperatorInfo() + Identity = OperatorInfo() + If = OperatorInfo() + L2Norm = OperatorInfo() + L2Pool2D = OperatorInfo() + LRN = OperatorInfo() + LSHProjection = OperatorInfo() + LeakyRelu = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True) + Less = OperatorInfo() + LessEqual = OperatorInfo() + Log = OperatorInfo() + LogSoftmax = OperatorInfo() + LogicalAnd = OperatorInfo() + LogicalNot = OperatorInfo() + LogicalOr = OperatorInfo() + Lstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES) + LUT = OperatorInfo() # NPU specific, operator has LUT, only used in fused activation functions + MatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES) + MatrixDiag = OperatorInfo() + MatrixSetDiag = OperatorInfo() + Max = OperatorInfo() + MaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES) + Maximum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES) + Mean = OperatorInfo() + Min = OperatorInfo() + Minimum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES) + MirrorPad = OperatorInfo() + Mul = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES) + Neg = OperatorInfo() + NonMaxSuppressionV4 = OperatorInfo() + NonMaxSuppressionV5 = OperatorInfo() + NotEqual = OperatorInfo() + OneHot = OperatorInfo() + Pack = OperatorInfo() + PackReshaped = OperatorInfo(indices=IFM_INDICES) + Pad = OperatorInfo() + PadV2 = OperatorInfo() + Placeholder = OperatorInfo() # Only used in CPU subgraphs + Pow = OperatorInfo() + Prelu = OperatorInfo() + Prod = OperatorInfo() + Quantize = OperatorInfo() + QuantizedAvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES) + QuantizedConv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_INDICES) + QuantizedMatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES) + QuantizedMaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES) + QuantizedReshape = OperatorInfo(indices=IFM_INDICES) + Range = OperatorInfo() + Rank = OperatorInfo() + ReduceSum = OperatorInfo(block_type=NpuBlockType.ReduceSum, indices=IFM_INDICES) + Relu = OperatorInfo(indices=IFM_INDICES) + Relu6 = OperatorInfo(indices=IFM_INDICES) + ReluN1To1 = OperatorInfo(indices=IFM_INDICES) + Reshape = OperatorInfo(indices=IFM_INDICES) + ResizeBilinear = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES) + ResizeNearestNeighbor = OperatorInfo() + ReverseSequence = OperatorInfo() + ReverseV2 = OperatorInfo() + Rnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES) + Round = OperatorInfo() + Rsqrt = OperatorInfo() + SHL = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES) # NPU specific operation + SHR = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES) # NPU specific operation + ScatterNd = OperatorInfo() + SegmentSum = OperatorInfo() + Select = OperatorInfo() + SelectV2 = OperatorInfo() + Shape = OperatorInfo() + Sigmoid = OperatorInfo(indices=IFM_INDICES) + SignBit = OperatorInfo() + Sin = OperatorInfo() + SkipGram = OperatorInfo() + Slice = OperatorInfo(indices=IFM_INDICES) + Softmax = OperatorInfo() + SpaceToBatchND = OperatorInfo() + SpaceToDepth = OperatorInfo() + SparseToDense = OperatorInfo() + Split = OperatorInfo(indices=SPLIT_IFM_INDICES) + SplitSliceRead = OperatorInfo(indices=IFM_INDICES) + SplitV = OperatorInfo(indices=IFM_INDICES) + Sqrt = OperatorInfo() + Square = OperatorInfo() + SquaredDifference = OperatorInfo() + Squeeze = OperatorInfo(indices=IFM_INDICES) + StridedSlice = OperatorInfo(indices=IFM_INDICES) + StridedSliceOptions = OperatorInfo() + Sub = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES) + SubgraphInput = OperatorInfo() # Only used in CPU subgraphs + Sum = OperatorInfo() + Svdf = OperatorInfo() + Tanh = OperatorInfo(indices=IFM_INDICES) + Tile = OperatorInfo() + TopKV2 = OperatorInfo() + Transpose = OperatorInfo() + UnidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES) + UnidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES) + Unique = OperatorInfo() + Unpack = OperatorInfo() + UnpackReshaped = OperatorInfo(indices=IFM_INDICES) + Where = OperatorInfo() + While = OperatorInfo() + ZerosLike = OperatorInfo() + + @property + def info(self): + return self.value + + @property + def npu_block_type(self): + return self.info.block_type + + def is_conv2d_op(self): + return self.info.block_type == NpuBlockType.ConvolutionMxN + + def is_depthwise_conv2d_op(self): + return self.info.block_type == NpuBlockType.ConvolutionDepthWise + + def is_pool_op(self): + return self.info.block_type == NpuBlockType.Pooling + + def is_maxpool_op(self): + return self in (Op.MaxPool, Op.QuantizedMaxPool) + + def is_avgpool_op(self): + return self in (Op.QuantizedAvgPool, Op.AvgPool) + + def is_elementwise_op(self): + return self.info.block_type == NpuBlockType.ElementWise + + def is_unary_elementwise_op(self): + return self.info.block_type == NpuBlockType.ElementWise and self.info.is_unary + + def is_binary_elementwise_op(self): + return self.info.block_type == NpuBlockType.ElementWise and not self.info.is_unary + + def is_relu_op(self): + return self in (Op.Relu, Op.Relu6, Op.ReluN1To1) + + def is_activation_op(self): + return self.is_relu_op() or self in (Op.Tanh, Op.Sigmoid, Op.Softmax, Op.LUT) + + def is_split_op(self): + return self in (Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped) + + def is_concat_op(self): + return self in (Op.Concat, Op.ConcatTFLite, Op.PackReshaped) + + def needs_bias(self): + return bool(self.info.indices.biases) + + @classmethod + def op_set(cls, predicate): + # Returns the set of all operator codes that fulfill the given predicate + return {op_type for op_type in Op if predicate(op_type)} + + def __str__(self): + return self.name + + __repr__ = __str__ + + def __lt__(self, other): + return self.value.id < other.value.id + + def create_avgpool_nop(name): - op = Operation("AvgPool", name) + op = Operation(Op.AvgPool, name) op.attrs["padding"] = b"VALID" - op.attrs["npu_block_type"] = NpuBlockType.Pooling op.attrs["stride_w"] = 1 op.attrs["stride_h"] = 1 op.attrs["filter_width"] = 1 @@ -70,6 +327,9 @@ input and output tensors, as well as an attribute dictionary.""" "flops", "scheduled_pass", "run_on_npu", + "activation", + "memory_function", + "forced_output_quantization", "activation_lut", ) @@ -81,6 +341,13 @@ input and output tensors, as well as an attribute dictionary.""" self.outputs = [] self.flops = 0 self.run_on_npu = True + # Fused activation function. If not none: operator code. + self.activation = None + # Fused memory function, if not None: operator code + self.memory_function = None + # If not none: contains QuantizationParameters to be used as output quantization + # (which overrides the ofm tensor's quantization), used in LUT + self.forced_output_quantization = None self.scheduled_pass = None self.op_index = None # input network operator index self.activation_lut = None @@ -92,173 +359,95 @@ input and output tensors, as well as an attribute dictionary.""" res.inputs = list(self.inputs) res.outputs = list(self.outputs) res.flops = self.flops + res.run_on_npu = self.run_on_npu + res.activation = self.activation + res.memory_function = self.memory_function + res.forced_output_quantization = self.forced_output_quantization res.scheduled_pass = self.scheduled_pass res.op_index = None # not relevant as not part of input network return res def __str__(self): - return "<nng.Operation '%s' type=%s>" % (self.name, self.type) + return "<nng.Operation '{}' type={}>".format(self.name, self.type) __repr__ = __str__ - def get_ifm_ifm2_weight_bias_ofm_indices(self): - ifm_idx = -1 - ifm2_idx = -1 - weight_idx = -1 - bias_idx = -1 - ofm_idx = -1 - npu_block_type = self.attrs.get("npu_block_type", NpuBlockType.Default) - if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise): - ifm_idx = 0 - weight_idx = 1 - ofm_idx = 0 - - if self.type in ("Conv2DBiasAct", "DepthwiseConv2dBiasAct", "TransposeConvAct"): - if len(self.inputs) >= 3: - bias_idx = 2 - - elif self.type == "Conv2DBackpropInputSwitchedBias": - bias_idx = 3 - - elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum): - ifm_idx = 0 - ofm_idx = 0 - elif npu_block_type == NpuBlockType.VectorProduct: - ifm_idx = 0 - weight_idx = 1 - ofm_idx = 0 - - if self.type == "FullyConnectedAct": - if len(self.inputs) >= 3: - bias_idx = 2 - - if self.type == "BlockLSTM": - ifm_idx = 3 - weight_idx = 4 - ofm_idx = 6 - - elif npu_block_type == NpuBlockType.ElementWise: - ifm_idx = 0 - ifm2_idx = 1 - ofm_idx = 0 - - # LeakyRelu, Abs and CLZ have a single IFM - if self.type in ("LeakyRelu", "Abs", "CLZ"): - ifm2_idx = -1 - - elif self.type == "Conv2DBackpropInput": - ifm_idx = 2 - weight_idx = 1 - ofm_idx = 0 - - elif self.type in ("Squeeze", "Reshape", "QuantizedReshape", "ExpandDims", "Sigmoid", "Tanh"): - ifm_idx = 0 - ofm_idx = 0 - - elif self.is_split_op(): - ifm_idx = 0 - ofm_idx = 0 - if self.type == "Split": - ifm_idx = 1 - - elif self.is_concat_op(): - ifms, _ = self.get_concat_inputs_axis() - ifm_idx = self.inputs.index(ifms[0]) - if len(ifms) > 1: - ifm2_idx = self.inputs.index(ifms[1]) - ofm_idx = 0 - - return ifm_idx, ifm2_idx, weight_idx, bias_idx, ofm_idx - def get_ifm_ifm2_weights_ofm(self): - ifm_tensor = None - ifm2_tensor = None - weight_tensor = None - ofm_tensor = None - - ifm_idx, ifm2_idx, weight_idx, _, ofm_idx = self.get_ifm_ifm2_weight_bias_ofm_indices() - if ifm_idx != -1: - ifm_tensor = self.inputs[ifm_idx] - if ifm2_idx != -1: - ifm2_tensor = self.inputs[ifm2_idx] - if weight_idx != -1: - weight_tensor = self.inputs[weight_idx] - if ofm_idx != -1: - ofm_tensor = self.outputs[ofm_idx] - - return ifm_tensor, ifm2_tensor, weight_tensor, ofm_tensor + return self.ifm, self.ifm2, self.weights, self.ofm def get_ifm_weights_biases_ofm(self): - ifm_tensor = None - weight_tensor = None - bias_tensor = None - ofm_tensor = None - - ifm_idx, _, weight_idx, bias_idx, ofm_idx = self.get_ifm_ifm2_weight_bias_ofm_indices() - if ifm_idx != -1: - ifm_tensor = self.inputs[ifm_idx] - if weight_idx != -1: - weight_tensor = self.inputs[weight_idx] - if bias_idx != -1: - bias_tensor = self.inputs[bias_idx] - if ofm_idx != -1: - ofm_tensor = self.outputs[ofm_idx] - - return ifm_tensor, weight_tensor, bias_tensor, ofm_tensor + return self.ifm, self.weights, self.bias, self.ofm def get_ifm_ifm2_weights_biases_ofm(self): - ifm_tensor = None - ifm2_tensor = None - weight_tensor = None - bias_tensor = None - ofm_tensor = None - - ifm_idx, ifm2_idx, weight_idx, bias_idx, ofm_idx = self.get_ifm_ifm2_weight_bias_ofm_indices() - if ifm_idx != -1: - ifm_tensor = self.inputs[ifm_idx] - if ifm2_idx != -1: - ifm2_tensor = self.inputs[ifm2_idx] - if weight_idx != -1: - weight_tensor = self.inputs[weight_idx] - if bias_idx != -1: - bias_tensor = self.inputs[bias_idx] - if ofm_idx != -1: - ofm_tensor = self.outputs[ofm_idx] - - return ifm_tensor, ifm2_tensor, weight_tensor, bias_tensor, ofm_tensor - - def get_ofm(self): - _, _, _, ofm = self.get_ifm_ifm2_weights_ofm() - return ofm - - def is_concat_op(self): - return self.type in ("Concat", "ConcatV2", "QuantizedConcat", "ConcatTFLite", "PackReshaped") + return self.ifm, self.ifm2, self.weights, self.bias, self.ofm + + def get_ifm_ofm(self): + return self.ifm, self.ofm + + @property + def ifm(self): + # Gets the IFM tensor, or None if not applicable + return self.get_input(self.type.info.indices.ifms, 0) + + @property + def ifm2(self): + # Gets the IFM2 tensor, or None if not applicable + return self.get_input(self.type.info.indices.ifms, 1) + + @property + def bias(self): + # Gets the bias tensor, or None if not applicable + return self.get_input(self.type.info.indices.biases, 0) + + @property + def weights(self): + # Gets the weight tensor, or None if not applicable + return self.get_input(self.type.info.indices.weights, 0) + + def get_ifm_tensors(self): + # Gets the IFM tensors, or empty list if not applicable + return self._index_list_to_tensors(self.type.info.indices.ifms) + + def get_weight_tensors(self): + # Gets the weight tensors, or empty list if not applicable + return self._index_list_to_tensors(self.type.info.indices.weights) + + def get_bias_tensors(self): + # Gets the bias tensors, or empty list if not applicable + return self._index_list_to_tensors(self.type.info.indices.biases) + + def _index_list_to_tensors(self, index_list): + return [self.inputs[ix] for ix in index_list if ix < len(self.inputs)] + + def get_input(self, index_list, ix): + if ix >= len(index_list): + return None + if index_list[ix] >= len(self.inputs): + return None + return self.inputs[index_list[ix]] + + @property + def ofm(self): + # Gets the OFM tensor, or None if not applicable + return self.outputs[0] if self.outputs else None def get_concat_inputs_axis(self): - assert self.is_concat_op() + assert self.type.is_concat_op() - if self.type == "ConcatV2": - axis_tensor = self.inputs[-1] - inputs = self.inputs[:-1] - elif self.type == "Concat": - axis_tensor = self.inputs[0] - inputs = self.inputs[1:] - elif self.type == "QuantizedConcat": + if self.type == Op.Concat: axis_tensor = self.inputs[0] inputs = self.inputs[1:] - inputs = inputs[: len(inputs) // 3] # Skip min/max - - if self.type == "ConcatTFLite": + elif self.type == Op.ConcatTFLite: inputs = self.inputs axis = self.attrs["axis"] - elif self.type == "PackReshaped": + elif self.type == Op.PackReshaped: # Requires fixup_pack_input to be called before this point inputs = self.inputs axis = self.attrs["axis"] assert len(self.inputs) == self.attrs["values_count"] else: - assert len(axis_tensor.ops) == 1 and axis_tensor.ops[0].type == "Const" + assert len(axis_tensor.ops) == 1 and axis_tensor.ops[0].type == Op.Const axis = int(axis_tensor.values) return inputs, axis @@ -267,33 +456,30 @@ input and output tensors, as well as an attribute dictionary.""" _, dilation_h, dilation_w, _ = self.attrs.get("dilation", (1, 1, 1, 1)) return dilation_h, dilation_w - def is_split_op(self): - return self.type in ("Split", "SplitV", "StridedSlice", "Slice", "UnpackReshaped") - def get_split_inputs_axis(self): - assert self.is_split_op() + assert self.type.is_split_op() offset_start = None offset_end = None axis = None - if self.type == "Split": + if self.type == Op.Split: num_splits = self.attrs.get("num_splits") axis_tens = self.inputs[0] - assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == "Const" + assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == Op.Const axis = int(axis_tens.values) input_tens = self.inputs[1] outputs = self.outputs assert num_splits == len(outputs) - elif self.type == "SplitV": + elif self.type == Op.SplitV: num_splits = self.attrs.get("num_splits") input_tens = self.inputs[0] size_tens = self.inputs[1] - assert len(size_tens.ops) == 1 and size_tens.ops[0].type == "Const" + assert len(size_tens.ops) == 1 and size_tens.ops[0].type == Op.Const sizes = size_tens.values axis_tens = self.inputs[2] - assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == "Const" + assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == Op.Const axis = int(axis_tens.values) for idx, size in enumerate(sizes): @@ -306,7 +492,7 @@ input and output tensors, as well as an attribute dictionary.""" assert num_splits == len(outputs) assert sum(sizes) == input_tens.shape[axis] - elif self.type == "Slice": + elif self.type == Op.Slice: input_tens, begin_tens, size_tens = self.inputs outputs = self.outputs offset_start = [0] * len(input_tens.shape) @@ -318,7 +504,7 @@ input and output tensors, as well as an attribute dictionary.""" offset_start[idx] = begin_tens.values[idx] offset_end[idx] = size_tens.values[idx] + offset_start[idx] - elif self.type == "StridedSlice": + elif self.type == Op.StridedSlice: input_tens, begin_tens, end_tens, strides_tens = self.inputs outputs = self.outputs out_tens = outputs[0] @@ -336,7 +522,7 @@ input and output tensors, as well as an attribute dictionary.""" assert len(input_tens.shape) == len(out_tens.shape) offset_start = get_slice_offsets(input_tens.shape, begin_tens, begin_mask, is_begin=True) offset_end = get_slice_offsets(input_tens.shape, end_tens, end_mask, is_begin=False) - elif self.type == "UnpackReshaped": + elif self.type == Op.UnpackReshaped: # Requires fixup_unpack_output to be called before this point input_tens = self.inputs[0] outputs = self.outputs @@ -350,7 +536,7 @@ input and output tensors, as well as an attribute dictionary.""" return input_tens, outputs, axis, offset_start, offset_end def set_activation_lut(self, lut_tensor): - self.attrs["fused_activation_function"] = "LUT" + self.activation = Op.LUT self.activation_lut = lut_tensor self.add_input_tensor(lut_tensor) @@ -372,13 +558,7 @@ input and output tensors, as well as an attribute dictionary.""" tens.ops = [self] self.outputs = [tens] - def needs_bias(self): - return self.type in ( - "Conv2DBiasAct", - "DepthwiseConv2dBiasAct", - "Conv2DBackpropInputSwitchedBias", - "FullyConnectedAct", - ) - def get_output_quantization(self): - return self.attrs.get("forced_output_quantization", self.get_ofm().quantization) + if self.forced_output_quantization is not None: + return self.forced_output_quantization + return self.ofm.quantization |