aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/tosa_graph_optimiser.py
diff options
context:
space:
mode:
Diffstat (limited to 'ethosu/vela/tosa_graph_optimiser.py')
-rw-r--r--ethosu/vela/tosa_graph_optimiser.py93
1 files changed, 92 insertions, 1 deletions
diff --git a/ethosu/vela/tosa_graph_optimiser.py b/ethosu/vela/tosa_graph_optimiser.py
index 2d1245b0..49fc997d 100644
--- a/ethosu/vela/tosa_graph_optimiser.py
+++ b/ethosu/vela/tosa_graph_optimiser.py
@@ -35,6 +35,7 @@ from .operation_util import create_add_nop
from .operation_util import create_avgpool_nop
from .shape4d import Shape4D
from .tensor import create_const_tensor
+from .tensor import create_equivalence_id
def replace_rescale_with_avg_pool(rescale_op):
@@ -417,6 +418,96 @@ def rewrite_rescale(op, arch, nng):
return op
+# TODO modified copy of TFLite, solution for TOSA PAD will change so reuse has not been considered
+def convert_pad(op, arch, nng):
+ """
+ Rewrites PAD operator to an add that copies the IFM to the OFM
+ + up to 4 add operators that fill the OFM with zeros at the borders.
+ """
+
+ if op.type != Op.Pad:
+ return op
+
+ # TODO assuming rank <= 4 and N = 1 for rank ==4
+ # This is checked in tosa_supported_operators
+ ifm = op.ifm
+ assert ifm is not None
+ ifm_shape = Shape4D(ifm.shape)
+ ofm = op.ofm
+ assert ofm is not None
+ ofm.ops = []
+ ofm_shape = op.ofm_shapes[0]
+
+ rank = len(ifm.shape)
+ padding = op.inputs[1].values
+ pad_depth = padding[-1]
+ if not (pad_depth == 0).all():
+ print("Warning: For PAD, padding in depth not supported yet")
+ assert False
+
+ top, bottom = 0, 0
+ left, right = 0, 0
+ if rank > 1:
+ left, right = padding[-2][0], padding[-2][1]
+ if rank > 2:
+ top, bottom = padding[-3][0], padding[-3][1]
+ if rank == 4 and not (padding[-4] == 0).all():
+ print("Warning: For PAD, padding not supported in first dimension when rank == 4 yet")
+ assert False
+
+ # Add op that copies IFM to the right place inside the OFM
+ shp0 = Shape4D(0, 0, 0, 0)
+ shp_top = shp0.with_height(top)
+ add_op = create_add_for_concat(op, op.name + "_main", ifm, ifm_shape, shp_top.with_width(left))
+ add_op.activation = op.activation
+
+ quant = ofm.quantization
+ pad_value = ifm.quantization.zero_point
+ # Add operations that fill the borders of the OFM
+ if top > 0:
+ shape = Shape4D(1, top, ofm_shape.width, ofm_shape.depth)
+ zero_tens = create_const_tensor(
+ op.name + "_top",
+ shape.as_list(),
+ ofm.dtype,
+ shape.elements() * [pad_value],
+ np.uint8,
+ quantization=quant, # TODO
+ )
+ # If top/bottom or left/right are equal, the const tensors can be allocated to the same address
+ zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
+ create_add_for_concat(op, op.name + "_top", zero_tens, shape, shp0)
+ if bottom > 0:
+ shape = Shape4D(1, bottom, ofm_shape.width, ofm_shape.depth)
+ zero_tens = create_const_tensor(
+ op.name + "_bottom",
+ shape.as_list(),
+ ofm.dtype,
+ shape.elements() * [pad_value],
+ np.uint8,
+ quantization=quant,
+ )
+ zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
+ create_add_for_concat(op, op.name + "_bottom", zero_tens, shape, shp0.with_height(ofm_shape.height - bottom))
+ if left > 0:
+ shape = Shape4D(1, ifm_shape.height, left, ofm_shape.depth)
+ zero_tens = create_const_tensor(
+ op.name + "_left", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
+ )
+ zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
+ create_add_for_concat(op, op.name + "_left", zero_tens, shape, shp_top)
+ if right > 0:
+ shape = Shape4D(1, ifm_shape.height, right, ofm_shape.depth)
+ zero_tens = create_const_tensor(
+ op.name + "_right", shape.as_list(), ofm.dtype, shape.elements() * [pad_value], np.uint8, quantization=quant
+ )
+ zero_tens.equivalence_id = create_equivalence_id(tuple(zero_tens.values))
+ create_add_for_concat(op, op.name + "_right", zero_tens, shape, shp_top.with_width(ofm_shape.width - right))
+
+ op.type = Op.ConcatTFLite
+ return add_op
+
+
def fixup_quantization(op, arch, nng):
if op.ifm and op.ifm.quantization.zero_point is None:
op.ifm.quantization.zero_point = 0
@@ -484,7 +575,7 @@ def tosa_optimise_graph(nng, arch):
# Post-processing step 1
for idx, sg in enumerate(nng.subgraphs):
nng.subgraphs[idx] = rewrite_graph.rewrite_graph_pre_order(
- nng, sg, arch, [], [rewrite_activation, add_padding_fields],
+ nng, sg, arch, [], [rewrite_activation, convert_pad, add_padding_fields],
)
# Removal of Slice, need to be done after optimisation has been performed,