aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/operation.py
diff options
context:
space:
mode:
authorLouis Verhaard <louis.verhaard@arm.com>2020-09-30 09:01:52 +0200
committerLouis Verhaard <louis.verhaard@arm.com>2020-10-08 16:29:29 +0200
commitaee5d7537ff81ffda5ba222721b72f914ce50fb8 (patch)
tree495b9dfff2a188c6916f8ca2e390ee88f7da8ccc /ethosu/vela/operation.py
parent36ad73a0fb46d3f844845c97c56d92de2a7a9b3d (diff)
downloadethos-u-vela-aee5d7537ff81ffda5ba222721b72f914ce50fb8.tar.gz
MLBEDSW-3148: Refactor Operation
- op.type is now an enum instead of a string - Removed unused operator codes - Refactored some attributes like npu_block_type, fused_activation_function - Refactored operator index calculation - Refactored a number of operator sets Change-Id: I641f65ee375794b7aec42abc0664251ae37d78e8 Signed-off-by: Louis Verhaard <louis.verhaard@arm.com>
Diffstat (limited to 'ethosu/vela/operation.py')
-rw-r--r--ethosu/vela/operation.py514
1 files changed, 347 insertions, 167 deletions
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 14818870..a2b67dfb 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -15,10 +15,11 @@
# limitations under the License.
# Description:
# Internal representation of a Neural Network Operation.
-import enum
+from collections import namedtuple
+from enum import Enum
-class NpuBlockType(enum.Enum):
+class NpuBlockType(Enum):
Default = 0
ConvolutionMxN = 1
VectorProduct = 2
@@ -28,10 +29,266 @@ class NpuBlockType(enum.Enum):
ReduceSum = 6
+# Classifies operators of type Custom
+class CustomType(Enum):
+ ThirdPartyOp = 0 # Third party custom op
+ NpuOp = 1 # NPU op
+ ExistingNpuOp = 2 # NPU op that was part of the input network
+
+
+TensorIndices = namedtuple("TensorIndices", ["ifms", "weights", "biases"])
+
+NO_INDICES = TensorIndices([], [], [])
+IFM_INDICES = TensorIndices([0], [], [])
+IFM_WEIGHTS_INDICES = TensorIndices([0], [1], [])
+IFM_WEIGHTS_BIAS_INDICES = TensorIndices([0], [1], [2])
+IFM_IFM2_INDICES = TensorIndices([0, 1], [], [])
+CONV2D_BACKPROP_INDICES = TensorIndices([2], [1], [3])
+TRANSPOSE_CONV_INDICES = TensorIndices([0], [1], [3])
+CONCAT_INDICES = TensorIndices([1, 2], [], [])
+SPLIT_IFM_INDICES = TensorIndices([1], [], [])
+BLOCK_LSTM_INDICES = TensorIndices([3], [4], [])
+
+
+# Static information related to operation codes
+class OperatorInfo:
+ __slots__ = ("id", "block_type", "indices", "is_unary")
+ _id = 0
+
+ def __init__(self, block_type=NpuBlockType.Default, indices=NO_INDICES, is_unary=False):
+ OperatorInfo._id += 1
+ self.id = OperatorInfo._id
+ self.block_type = block_type
+ self.indices = indices # Indices of the different tensor purposes
+ self.is_unary = is_unary # Classifies elementwise operators
+
+
+# Internally used operation codes
+class Op(Enum):
+ Abs = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True)
+ Add = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
+ AddN = OperatorInfo()
+ Any = OperatorInfo()
+ ArgMax = OperatorInfo()
+ ArgMin = OperatorInfo()
+ AvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
+ BatchMatMul = OperatorInfo()
+ BatchToSpaceND = OperatorInfo()
+ BidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
+ BidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
+ BlockLSTM = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=BLOCK_LSTM_INDICES)
+
+ CLZ = OperatorInfo(
+ block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True
+ ) # NPU specific operation
+ Call = OperatorInfo()
+ Cast = OperatorInfo()
+ Ceil = OperatorInfo()
+ Concat = OperatorInfo(indices=CONCAT_INDICES)
+ ConcatEmbeddings = OperatorInfo()
+ ConcatSliceWrite = OperatorInfo(indices=IFM_INDICES)
+ ConcatTFLite = OperatorInfo()
+ Const = OperatorInfo() # Constant tensor, only used in CPU subgraphs
+ Conv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_INDICES)
+ Conv2DBackpropInput = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=CONV2D_BACKPROP_INDICES)
+ Conv2DBackpropInputSwitchedBias = OperatorInfo(
+ block_type=NpuBlockType.ConvolutionMxN, indices=TRANSPOSE_CONV_INDICES
+ )
+ Conv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_BIAS_INDICES)
+ Cos = OperatorInfo()
+ Custom = OperatorInfo() # Custom 3rd party operator, only used in CPU subgraphs
+ CustomNpuOp = OperatorInfo() # NPU custom operator, only used in CPU subgraphs
+ DMA = OperatorInfo()
+ Delegate = OperatorInfo()
+ Densify = OperatorInfo()
+ DepthToSpace = OperatorInfo()
+ DepthwiseConv2DBias = OperatorInfo(block_type=NpuBlockType.ConvolutionDepthWise, indices=IFM_WEIGHTS_BIAS_INDICES)
+ Dequantize = OperatorInfo()
+ Div = OperatorInfo()
+ Elu = OperatorInfo()
+ EmbeddingLookup = OperatorInfo()
+ EmbeddingLookupSparse = OperatorInfo()
+ Equal = OperatorInfo()
+ Exp = OperatorInfo()
+ ExpandDims = OperatorInfo(indices=IFM_INDICES)
+ FakeQuantWithMinMaxArgs = OperatorInfo()
+ Fill = OperatorInfo()
+ Floor = OperatorInfo()
+ FloorDiv = OperatorInfo()
+ FloorMod = OperatorInfo()
+ FullyConnected = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_BIAS_INDICES)
+ GatherNd = OperatorInfo()
+ GatherV2 = OperatorInfo()
+ Greater = OperatorInfo()
+ GreaterEqual = OperatorInfo()
+ HardSwish = OperatorInfo()
+ HashtableLookup = OperatorInfo()
+ Identity = OperatorInfo()
+ If = OperatorInfo()
+ L2Norm = OperatorInfo()
+ L2Pool2D = OperatorInfo()
+ LRN = OperatorInfo()
+ LSHProjection = OperatorInfo()
+ LeakyRelu = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_INDICES, is_unary=True)
+ Less = OperatorInfo()
+ LessEqual = OperatorInfo()
+ Log = OperatorInfo()
+ LogSoftmax = OperatorInfo()
+ LogicalAnd = OperatorInfo()
+ LogicalNot = OperatorInfo()
+ LogicalOr = OperatorInfo()
+ Lstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
+ LUT = OperatorInfo() # NPU specific, operator has LUT, only used in fused activation functions
+ MatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
+ MatrixDiag = OperatorInfo()
+ MatrixSetDiag = OperatorInfo()
+ Max = OperatorInfo()
+ MaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
+ Maximum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
+ Mean = OperatorInfo()
+ Min = OperatorInfo()
+ Minimum = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
+ MirrorPad = OperatorInfo()
+ Mul = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
+ Neg = OperatorInfo()
+ NonMaxSuppressionV4 = OperatorInfo()
+ NonMaxSuppressionV5 = OperatorInfo()
+ NotEqual = OperatorInfo()
+ OneHot = OperatorInfo()
+ Pack = OperatorInfo()
+ PackReshaped = OperatorInfo(indices=IFM_INDICES)
+ Pad = OperatorInfo()
+ PadV2 = OperatorInfo()
+ Placeholder = OperatorInfo() # Only used in CPU subgraphs
+ Pow = OperatorInfo()
+ Prelu = OperatorInfo()
+ Prod = OperatorInfo()
+ Quantize = OperatorInfo()
+ QuantizedAvgPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
+ QuantizedConv2D = OperatorInfo(block_type=NpuBlockType.ConvolutionMxN, indices=IFM_WEIGHTS_INDICES)
+ QuantizedMatMul = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
+ QuantizedMaxPool = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
+ QuantizedReshape = OperatorInfo(indices=IFM_INDICES)
+ Range = OperatorInfo()
+ Rank = OperatorInfo()
+ ReduceSum = OperatorInfo(block_type=NpuBlockType.ReduceSum, indices=IFM_INDICES)
+ Relu = OperatorInfo(indices=IFM_INDICES)
+ Relu6 = OperatorInfo(indices=IFM_INDICES)
+ ReluN1To1 = OperatorInfo(indices=IFM_INDICES)
+ Reshape = OperatorInfo(indices=IFM_INDICES)
+ ResizeBilinear = OperatorInfo(block_type=NpuBlockType.Pooling, indices=IFM_INDICES)
+ ResizeNearestNeighbor = OperatorInfo()
+ ReverseSequence = OperatorInfo()
+ ReverseV2 = OperatorInfo()
+ Rnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
+ Round = OperatorInfo()
+ Rsqrt = OperatorInfo()
+ SHL = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES) # NPU specific operation
+ SHR = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES) # NPU specific operation
+ ScatterNd = OperatorInfo()
+ SegmentSum = OperatorInfo()
+ Select = OperatorInfo()
+ SelectV2 = OperatorInfo()
+ Shape = OperatorInfo()
+ Sigmoid = OperatorInfo(indices=IFM_INDICES)
+ SignBit = OperatorInfo()
+ Sin = OperatorInfo()
+ SkipGram = OperatorInfo()
+ Slice = OperatorInfo(indices=IFM_INDICES)
+ Softmax = OperatorInfo()
+ SpaceToBatchND = OperatorInfo()
+ SpaceToDepth = OperatorInfo()
+ SparseToDense = OperatorInfo()
+ Split = OperatorInfo(indices=SPLIT_IFM_INDICES)
+ SplitSliceRead = OperatorInfo(indices=IFM_INDICES)
+ SplitV = OperatorInfo(indices=IFM_INDICES)
+ Sqrt = OperatorInfo()
+ Square = OperatorInfo()
+ SquaredDifference = OperatorInfo()
+ Squeeze = OperatorInfo(indices=IFM_INDICES)
+ StridedSlice = OperatorInfo(indices=IFM_INDICES)
+ StridedSliceOptions = OperatorInfo()
+ Sub = OperatorInfo(block_type=NpuBlockType.ElementWise, indices=IFM_IFM2_INDICES)
+ SubgraphInput = OperatorInfo() # Only used in CPU subgraphs
+ Sum = OperatorInfo()
+ Svdf = OperatorInfo()
+ Tanh = OperatorInfo(indices=IFM_INDICES)
+ Tile = OperatorInfo()
+ TopKV2 = OperatorInfo()
+ Transpose = OperatorInfo()
+ UnidirectionalSequenceLstm = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
+ UnidirectionalSequenceRnn = OperatorInfo(block_type=NpuBlockType.VectorProduct, indices=IFM_WEIGHTS_INDICES)
+ Unique = OperatorInfo()
+ Unpack = OperatorInfo()
+ UnpackReshaped = OperatorInfo(indices=IFM_INDICES)
+ Where = OperatorInfo()
+ While = OperatorInfo()
+ ZerosLike = OperatorInfo()
+
+ @property
+ def info(self):
+ return self.value
+
+ @property
+ def npu_block_type(self):
+ return self.info.block_type
+
+ def is_conv2d_op(self):
+ return self.info.block_type == NpuBlockType.ConvolutionMxN
+
+ def is_depthwise_conv2d_op(self):
+ return self.info.block_type == NpuBlockType.ConvolutionDepthWise
+
+ def is_pool_op(self):
+ return self.info.block_type == NpuBlockType.Pooling
+
+ def is_maxpool_op(self):
+ return self in (Op.MaxPool, Op.QuantizedMaxPool)
+
+ def is_avgpool_op(self):
+ return self in (Op.QuantizedAvgPool, Op.AvgPool)
+
+ def is_elementwise_op(self):
+ return self.info.block_type == NpuBlockType.ElementWise
+
+ def is_unary_elementwise_op(self):
+ return self.info.block_type == NpuBlockType.ElementWise and self.info.is_unary
+
+ def is_binary_elementwise_op(self):
+ return self.info.block_type == NpuBlockType.ElementWise and not self.info.is_unary
+
+ def is_relu_op(self):
+ return self in (Op.Relu, Op.Relu6, Op.ReluN1To1)
+
+ def is_activation_op(self):
+ return self.is_relu_op() or self in (Op.Tanh, Op.Sigmoid, Op.Softmax, Op.LUT)
+
+ def is_split_op(self):
+ return self in (Op.Split, Op.SplitV, Op.StridedSlice, Op.Slice, Op.UnpackReshaped)
+
+ def is_concat_op(self):
+ return self in (Op.Concat, Op.ConcatTFLite, Op.PackReshaped)
+
+ def needs_bias(self):
+ return bool(self.info.indices.biases)
+
+ @classmethod
+ def op_set(cls, predicate):
+ # Returns the set of all operator codes that fulfill the given predicate
+ return {op_type for op_type in Op if predicate(op_type)}
+
+ def __str__(self):
+ return self.name
+
+ __repr__ = __str__
+
+ def __lt__(self, other):
+ return self.value.id < other.value.id
+
+
def create_avgpool_nop(name):
- op = Operation("AvgPool", name)
+ op = Operation(Op.AvgPool, name)
op.attrs["padding"] = b"VALID"
- op.attrs["npu_block_type"] = NpuBlockType.Pooling
op.attrs["stride_w"] = 1
op.attrs["stride_h"] = 1
op.attrs["filter_width"] = 1
@@ -70,6 +327,9 @@ input and output tensors, as well as an attribute dictionary."""
"flops",
"scheduled_pass",
"run_on_npu",
+ "activation",
+ "memory_function",
+ "forced_output_quantization",
"activation_lut",
)
@@ -81,6 +341,13 @@ input and output tensors, as well as an attribute dictionary."""
self.outputs = []
self.flops = 0
self.run_on_npu = True
+ # Fused activation function. If not none: operator code.
+ self.activation = None
+ # Fused memory function, if not None: operator code
+ self.memory_function = None
+ # If not none: contains QuantizationParameters to be used as output quantization
+ # (which overrides the ofm tensor's quantization), used in LUT
+ self.forced_output_quantization = None
self.scheduled_pass = None
self.op_index = None # input network operator index
self.activation_lut = None
@@ -92,173 +359,95 @@ input and output tensors, as well as an attribute dictionary."""
res.inputs = list(self.inputs)
res.outputs = list(self.outputs)
res.flops = self.flops
+ res.run_on_npu = self.run_on_npu
+ res.activation = self.activation
+ res.memory_function = self.memory_function
+ res.forced_output_quantization = self.forced_output_quantization
res.scheduled_pass = self.scheduled_pass
res.op_index = None # not relevant as not part of input network
return res
def __str__(self):
- return "<nng.Operation '%s' type=%s>" % (self.name, self.type)
+ return "<nng.Operation '{}' type={}>".format(self.name, self.type)
__repr__ = __str__
- def get_ifm_ifm2_weight_bias_ofm_indices(self):
- ifm_idx = -1
- ifm2_idx = -1
- weight_idx = -1
- bias_idx = -1
- ofm_idx = -1
- npu_block_type = self.attrs.get("npu_block_type", NpuBlockType.Default)
- if npu_block_type in (NpuBlockType.ConvolutionMxN, NpuBlockType.ConvolutionDepthWise):
- ifm_idx = 0
- weight_idx = 1
- ofm_idx = 0
-
- if self.type in ("Conv2DBiasAct", "DepthwiseConv2dBiasAct", "TransposeConvAct"):
- if len(self.inputs) >= 3:
- bias_idx = 2
-
- elif self.type == "Conv2DBackpropInputSwitchedBias":
- bias_idx = 3
-
- elif npu_block_type in (NpuBlockType.Pooling, NpuBlockType.ReduceSum):
- ifm_idx = 0
- ofm_idx = 0
- elif npu_block_type == NpuBlockType.VectorProduct:
- ifm_idx = 0
- weight_idx = 1
- ofm_idx = 0
-
- if self.type == "FullyConnectedAct":
- if len(self.inputs) >= 3:
- bias_idx = 2
-
- if self.type == "BlockLSTM":
- ifm_idx = 3
- weight_idx = 4
- ofm_idx = 6
-
- elif npu_block_type == NpuBlockType.ElementWise:
- ifm_idx = 0
- ifm2_idx = 1
- ofm_idx = 0
-
- # LeakyRelu, Abs and CLZ have a single IFM
- if self.type in ("LeakyRelu", "Abs", "CLZ"):
- ifm2_idx = -1
-
- elif self.type == "Conv2DBackpropInput":
- ifm_idx = 2
- weight_idx = 1
- ofm_idx = 0
-
- elif self.type in ("Squeeze", "Reshape", "QuantizedReshape", "ExpandDims", "Sigmoid", "Tanh"):
- ifm_idx = 0
- ofm_idx = 0
-
- elif self.is_split_op():
- ifm_idx = 0
- ofm_idx = 0
- if self.type == "Split":
- ifm_idx = 1
-
- elif self.is_concat_op():
- ifms, _ = self.get_concat_inputs_axis()
- ifm_idx = self.inputs.index(ifms[0])
- if len(ifms) > 1:
- ifm2_idx = self.inputs.index(ifms[1])
- ofm_idx = 0
-
- return ifm_idx, ifm2_idx, weight_idx, bias_idx, ofm_idx
-
def get_ifm_ifm2_weights_ofm(self):
- ifm_tensor = None
- ifm2_tensor = None
- weight_tensor = None
- ofm_tensor = None
-
- ifm_idx, ifm2_idx, weight_idx, _, ofm_idx = self.get_ifm_ifm2_weight_bias_ofm_indices()
- if ifm_idx != -1:
- ifm_tensor = self.inputs[ifm_idx]
- if ifm2_idx != -1:
- ifm2_tensor = self.inputs[ifm2_idx]
- if weight_idx != -1:
- weight_tensor = self.inputs[weight_idx]
- if ofm_idx != -1:
- ofm_tensor = self.outputs[ofm_idx]
-
- return ifm_tensor, ifm2_tensor, weight_tensor, ofm_tensor
+ return self.ifm, self.ifm2, self.weights, self.ofm
def get_ifm_weights_biases_ofm(self):
- ifm_tensor = None
- weight_tensor = None
- bias_tensor = None
- ofm_tensor = None
-
- ifm_idx, _, weight_idx, bias_idx, ofm_idx = self.get_ifm_ifm2_weight_bias_ofm_indices()
- if ifm_idx != -1:
- ifm_tensor = self.inputs[ifm_idx]
- if weight_idx != -1:
- weight_tensor = self.inputs[weight_idx]
- if bias_idx != -1:
- bias_tensor = self.inputs[bias_idx]
- if ofm_idx != -1:
- ofm_tensor = self.outputs[ofm_idx]
-
- return ifm_tensor, weight_tensor, bias_tensor, ofm_tensor
+ return self.ifm, self.weights, self.bias, self.ofm
def get_ifm_ifm2_weights_biases_ofm(self):
- ifm_tensor = None
- ifm2_tensor = None
- weight_tensor = None
- bias_tensor = None
- ofm_tensor = None
-
- ifm_idx, ifm2_idx, weight_idx, bias_idx, ofm_idx = self.get_ifm_ifm2_weight_bias_ofm_indices()
- if ifm_idx != -1:
- ifm_tensor = self.inputs[ifm_idx]
- if ifm2_idx != -1:
- ifm2_tensor = self.inputs[ifm2_idx]
- if weight_idx != -1:
- weight_tensor = self.inputs[weight_idx]
- if bias_idx != -1:
- bias_tensor = self.inputs[bias_idx]
- if ofm_idx != -1:
- ofm_tensor = self.outputs[ofm_idx]
-
- return ifm_tensor, ifm2_tensor, weight_tensor, bias_tensor, ofm_tensor
-
- def get_ofm(self):
- _, _, _, ofm = self.get_ifm_ifm2_weights_ofm()
- return ofm
-
- def is_concat_op(self):
- return self.type in ("Concat", "ConcatV2", "QuantizedConcat", "ConcatTFLite", "PackReshaped")
+ return self.ifm, self.ifm2, self.weights, self.bias, self.ofm
+
+ def get_ifm_ofm(self):
+ return self.ifm, self.ofm
+
+ @property
+ def ifm(self):
+ # Gets the IFM tensor, or None if not applicable
+ return self.get_input(self.type.info.indices.ifms, 0)
+
+ @property
+ def ifm2(self):
+ # Gets the IFM2 tensor, or None if not applicable
+ return self.get_input(self.type.info.indices.ifms, 1)
+
+ @property
+ def bias(self):
+ # Gets the bias tensor, or None if not applicable
+ return self.get_input(self.type.info.indices.biases, 0)
+
+ @property
+ def weights(self):
+ # Gets the weight tensor, or None if not applicable
+ return self.get_input(self.type.info.indices.weights, 0)
+
+ def get_ifm_tensors(self):
+ # Gets the IFM tensors, or empty list if not applicable
+ return self._index_list_to_tensors(self.type.info.indices.ifms)
+
+ def get_weight_tensors(self):
+ # Gets the weight tensors, or empty list if not applicable
+ return self._index_list_to_tensors(self.type.info.indices.weights)
+
+ def get_bias_tensors(self):
+ # Gets the bias tensors, or empty list if not applicable
+ return self._index_list_to_tensors(self.type.info.indices.biases)
+
+ def _index_list_to_tensors(self, index_list):
+ return [self.inputs[ix] for ix in index_list if ix < len(self.inputs)]
+
+ def get_input(self, index_list, ix):
+ if ix >= len(index_list):
+ return None
+ if index_list[ix] >= len(self.inputs):
+ return None
+ return self.inputs[index_list[ix]]
+
+ @property
+ def ofm(self):
+ # Gets the OFM tensor, or None if not applicable
+ return self.outputs[0] if self.outputs else None
def get_concat_inputs_axis(self):
- assert self.is_concat_op()
+ assert self.type.is_concat_op()
- if self.type == "ConcatV2":
- axis_tensor = self.inputs[-1]
- inputs = self.inputs[:-1]
- elif self.type == "Concat":
- axis_tensor = self.inputs[0]
- inputs = self.inputs[1:]
- elif self.type == "QuantizedConcat":
+ if self.type == Op.Concat:
axis_tensor = self.inputs[0]
inputs = self.inputs[1:]
- inputs = inputs[: len(inputs) // 3] # Skip min/max
-
- if self.type == "ConcatTFLite":
+ elif self.type == Op.ConcatTFLite:
inputs = self.inputs
axis = self.attrs["axis"]
- elif self.type == "PackReshaped":
+ elif self.type == Op.PackReshaped:
# Requires fixup_pack_input to be called before this point
inputs = self.inputs
axis = self.attrs["axis"]
assert len(self.inputs) == self.attrs["values_count"]
else:
- assert len(axis_tensor.ops) == 1 and axis_tensor.ops[0].type == "Const"
+ assert len(axis_tensor.ops) == 1 and axis_tensor.ops[0].type == Op.Const
axis = int(axis_tensor.values)
return inputs, axis
@@ -267,33 +456,30 @@ input and output tensors, as well as an attribute dictionary."""
_, dilation_h, dilation_w, _ = self.attrs.get("dilation", (1, 1, 1, 1))
return dilation_h, dilation_w
- def is_split_op(self):
- return self.type in ("Split", "SplitV", "StridedSlice", "Slice", "UnpackReshaped")
-
def get_split_inputs_axis(self):
- assert self.is_split_op()
+ assert self.type.is_split_op()
offset_start = None
offset_end = None
axis = None
- if self.type == "Split":
+ if self.type == Op.Split:
num_splits = self.attrs.get("num_splits")
axis_tens = self.inputs[0]
- assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == "Const"
+ assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == Op.Const
axis = int(axis_tens.values)
input_tens = self.inputs[1]
outputs = self.outputs
assert num_splits == len(outputs)
- elif self.type == "SplitV":
+ elif self.type == Op.SplitV:
num_splits = self.attrs.get("num_splits")
input_tens = self.inputs[0]
size_tens = self.inputs[1]
- assert len(size_tens.ops) == 1 and size_tens.ops[0].type == "Const"
+ assert len(size_tens.ops) == 1 and size_tens.ops[0].type == Op.Const
sizes = size_tens.values
axis_tens = self.inputs[2]
- assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == "Const"
+ assert len(axis_tens.ops) == 1 and axis_tens.ops[0].type == Op.Const
axis = int(axis_tens.values)
for idx, size in enumerate(sizes):
@@ -306,7 +492,7 @@ input and output tensors, as well as an attribute dictionary."""
assert num_splits == len(outputs)
assert sum(sizes) == input_tens.shape[axis]
- elif self.type == "Slice":
+ elif self.type == Op.Slice:
input_tens, begin_tens, size_tens = self.inputs
outputs = self.outputs
offset_start = [0] * len(input_tens.shape)
@@ -318,7 +504,7 @@ input and output tensors, as well as an attribute dictionary."""
offset_start[idx] = begin_tens.values[idx]
offset_end[idx] = size_tens.values[idx] + offset_start[idx]
- elif self.type == "StridedSlice":
+ elif self.type == Op.StridedSlice:
input_tens, begin_tens, end_tens, strides_tens = self.inputs
outputs = self.outputs
out_tens = outputs[0]
@@ -336,7 +522,7 @@ input and output tensors, as well as an attribute dictionary."""
assert len(input_tens.shape) == len(out_tens.shape)
offset_start = get_slice_offsets(input_tens.shape, begin_tens, begin_mask, is_begin=True)
offset_end = get_slice_offsets(input_tens.shape, end_tens, end_mask, is_begin=False)
- elif self.type == "UnpackReshaped":
+ elif self.type == Op.UnpackReshaped:
# Requires fixup_unpack_output to be called before this point
input_tens = self.inputs[0]
outputs = self.outputs
@@ -350,7 +536,7 @@ input and output tensors, as well as an attribute dictionary."""
return input_tens, outputs, axis, offset_start, offset_end
def set_activation_lut(self, lut_tensor):
- self.attrs["fused_activation_function"] = "LUT"
+ self.activation = Op.LUT
self.activation_lut = lut_tensor
self.add_input_tensor(lut_tensor)
@@ -372,13 +558,7 @@ input and output tensors, as well as an attribute dictionary."""
tens.ops = [self]
self.outputs = [tens]
- def needs_bias(self):
- return self.type in (
- "Conv2DBiasAct",
- "DepthwiseConv2dBiasAct",
- "Conv2DBackpropInputSwitchedBias",
- "FullyConnectedAct",
- )
-
def get_output_quantization(self):
- return self.attrs.get("forced_output_quantization", self.get_ofm().quantization)
+ if self.forced_output_quantization is not None:
+ return self.forced_output_quantization
+ return self.ofm.quantization