From e6ccd87a2f40877cacdd9721a5116a6853dfe573 Mon Sep 17 00:00:00 2001 From: Tim Hall Date: Mon, 9 Nov 2020 16:46:37 +0000 Subject: MLBEDSW-3019: Add profiling debug database - Added mechanism to track input to output graph transforms for debugging the resultant command stream. - Provides base implementation for MLBEDSW-2661 Signed-off-by: Tim Hall Change-Id: I2dfe8a409fbde7ad0282bfab5acb11ba1c8b82d8 --- OPTIONS.md | 13 +++ ethosu/vela/compiler_driver.py | 13 +++ ethosu/vela/debug_database.py | 121 +++++++++++++++++++++++ ethosu/vela/graph_optimiser.py | 31 ++++++ ethosu/vela/pass_packing.py | 3 +- ethosu/vela/register_command_stream_generator.py | 14 ++- ethosu/vela/softmax.py | 40 ++++++++ ethosu/vela/vela.py | 35 +++++-- setup.py | 2 +- 9 files changed, 259 insertions(+), 13 deletions(-) create mode 100644 ethosu/vela/debug_database.py diff --git a/OPTIONS.md b/OPTIONS.md index 92201518..9aaf67b7 100644 --- a/OPTIONS.md +++ b/OPTIONS.md @@ -204,6 +204,19 @@ a RecursionError exception. vela network.tflite --recursion-limit 50000 ``` +### Enable Debug DB + +The neural network debug database allows tracking of optimisations from the +input network graph to the output command stream. Set this option to enable the +calculation and writing of an XML file that contains the network debug database +tables to the output directory. +**Type: Boolean** +**Default: Disabled** + +```bash +vela network.tflite --enable-debug-db +``` + ### Max Block Dependency Set the maximum value that can be used for the block dependency delay between diff --git a/ethosu/vela/compiler_driver.py b/ethosu/vela/compiler_driver.py index 9263305a..e089b708 100644 --- a/ethosu/vela/compiler_driver.py +++ b/ethosu/vela/compiler_driver.py @@ -31,10 +31,13 @@ from . import register_command_stream_generator from . import scheduler from . import tensor_allocation from . import weight_compressor +from .debug_database import DebugDatabase from .errors import VelaError from .nn_graph import PassPlacement from .nn_graph import TensorAllocator +from .operation import Op from .rewrite_graph import verify_graph_health +from .rewrite_graph import visit_graph_post_order from .tensor import MemType from .tensor import Tensor @@ -127,8 +130,18 @@ def next_sram_factor(alloc_results): return ((lower + upper) / 2, True) +def _record_operator(op, arch): + if op.type != Op.Const: + DebugDatabase.add_source(op) + + def compiler_driver(nng, arch, options, scheduler_options): assert verify_graph_health(nng) + + # Pre-optimisation operator tracking + for sg in nng.subgraphs: + visit_graph_post_order(sg.output_tensors, arch, [], [_record_operator]) + nng = graph_optimiser.optimise_graph_a(nng, arch, options.verbose_graph) assert verify_graph_health(nng) diff --git a/ethosu/vela/debug_database.py b/ethosu/vela/debug_database.py new file mode 100644 index 00000000..b5852cdc --- /dev/null +++ b/ethosu/vela/debug_database.py @@ -0,0 +1,121 @@ +# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import csv +import io + +import lxml.etree as xml + +from . import numeric_util +from .operation import Operation + + +class DebugDatabase: + NULLREF = -1 + show_warnings = False + + SOURCE_TABLE = "source" + _sourceUID = {} + _sourceHeaders = ["id", "operator", "kernel_w", "kernel_h", "ofm_w", "ofm_h", "ofm_d"] + _sourceTable = [] + + OPTIMISED_TABLE = "optimised" + _optimisedUID = {} + _optimisedHeaders = ["id", "source_id", "operator", "kernel_w", "kernel_h", "ofm_w", "ofm_h", "ofm_d"] + _optimisedTable = [] + + QUEUE_TABLE = "queue" + _queueHeaders = ["offset", "cmdstream_id", "optimised_id"] + _queueTable = [] + + STREAM_TABLE = "cmdstream" + _streamUID = {} + _streamHeaders = ["id", "file_offset"] + _streamTable = [] + + @classmethod + def add_source(cls, op: Operation): + assert isinstance(op, Operation) + uid = len(cls._sourceUID) + cls._sourceUID[op] = uid + ofm_shape = numeric_util.full_shape(3, op.outputs[0].shape, 1) + cls._sourceTable.append( + [uid, op.type, op.kernel.width, op.kernel.height, ofm_shape[-2], ofm_shape[-3], ofm_shape[-1]] + ) + + @classmethod + def add_optimised(cls, parent: Operation, op: Operation): + assert isinstance(parent, Operation) and isinstance(op, Operation) + if op not in cls._optimisedUID: + if parent not in cls._sourceUID: + # The the parent wasn't in the source network try to look it + # up in the optimised network and use that op's source parent. + if parent in cls._optimisedUID: + src_uid = cls._optimisedUID[parent][1] + else: + if DebugDatabase.show_warnings: + print("Debug Database: Associated parent '{0}' not in network".format(parent.type)) + src_uid = DebugDatabase.NULLREF + else: + src_uid = cls._sourceUID[parent] + uid = len(cls._optimisedUID) + cls._optimisedUID[op] = (uid, src_uid) + ofm_shape = numeric_util.full_shape(3, op.outputs[0].shape, 1) + cls._optimisedTable.append( + [uid, src_uid, op.type, op.kernel.width, op.kernel.height, ofm_shape[-2], ofm_shape[-3], ofm_shape[-1]] + ) + + @classmethod + def add_stream(cls, key): + if key not in cls._streamUID: + uid = len(cls._streamUID) + cls._streamUID[key] = uid + return uid + + @classmethod + def set_stream_offset(cls, key, file_offset): + assert key in cls._streamUID + uid = cls._streamUID[key] + cls._streamTable.append([uid, file_offset]) + + @classmethod + def add_command(cls, stream_id, offset, op: Operation): + assert stream_id < len(cls._streamUID) + assert op in cls._optimisedUID, "Optimised operator must exist before code generation" + optimised_id = cls._optimisedUID[op][0] + cls._queueTable.append([offset, stream_id, optimised_id]) + + @classmethod + def _write_table(cls, root, name, headers, table): + # Convert table to CSV + out = io.StringIO() + writer = csv.writer(out, quoting=csv.QUOTE_NONNUMERIC) + writer.writerow(headers) + writer.writerows(table) + + # Package table into XML output + table = xml.SubElement(root, "table", {"name": name}) + table.text = xml.CDATA(out.getvalue()) + + @classmethod + def write(cls, file_path, input_file, output_file): + root = xml.Element("debug", {"source": input_file, "optimised": output_file}) + + cls._write_table(root, cls.SOURCE_TABLE, cls._sourceHeaders, cls._sourceTable) + cls._write_table(root, cls.OPTIMISED_TABLE, cls._optimisedHeaders, cls._optimisedTable) + cls._write_table(root, cls.QUEUE_TABLE, cls._queueHeaders, cls._queueTable) + cls._write_table(root, cls.STREAM_TABLE, cls._streamHeaders, cls._streamTable) + + xml.ElementTree(root).write(file_path, encoding="utf-8", xml_declaration=True, pretty_print=True) diff --git a/ethosu/vela/graph_optimiser.py b/ethosu/vela/graph_optimiser.py index e31348b5..73046302 100644 --- a/ethosu/vela/graph_optimiser.py +++ b/ethosu/vela/graph_optimiser.py @@ -25,6 +25,7 @@ from . import lut from . import rewrite_graph from . import scaling from .data_type import DataType +from .debug_database import DebugDatabase from .errors import UnsupportedFeatureError from .ethos_u55_regs.ethos_u55_regs import resampling_mode from .numeric_util import clamp_sigmoid @@ -77,6 +78,7 @@ def rewrite_concat(tens, arch, nng): new_op.attrs["concat_end"] = offset new_op.run_on_npu = True tens.ops.append(new_op) + DebugDatabase.add_optimised(concat_op, new_op) assert tens.shape[axis] == offset # If axis corresponds to C-dimension, NHCWB16 can only be used in the output if all the concat_start's are a @@ -128,6 +130,7 @@ def rewrite_split(tens, arch, nng): new_op.attrs["split_end"] = offset_end new_op.run_on_npu = True new_op.set_output_tensor(tens) + DebugDatabase.add_optimised(split_op, new_op) return tens @@ -399,6 +402,7 @@ def fixup_pack_input(op, arch, nng): reshape_op.attrs["new_shape"] = desired_shape reshape_op.inputs = [inp, new_shape_tens] reshape_op.set_output_tensor(reshape_out) + DebugDatabase.add_optimised(op, reshape_op) op.inputs[idx] = reshape_out @@ -492,6 +496,7 @@ def fixup_unpack_output(tens, arch, nng): reshape_op.attrs["new_shape"] = reshape_input_shape reshape_op.inputs = [reshape_in, new_shape_tens] reshape_op.set_output_tensor(out_tens) + DebugDatabase.add_optimised(op, reshape_op) op.outputs[idx] = reshape_in @@ -568,6 +573,7 @@ def convert_depthwise_to_conv(op, arch, nng): op.attrs["depth_multiplier"], ifm_tensor.shape[3], ofm_tensor.shape[3] ) ) + DebugDatabase.add_optimised(op, op) return op @@ -616,6 +622,9 @@ def convert_conv_to_fc(op, arch, nng): reshape_op.set_output_tensor(orig_ofm_tensor) # Replace this ops OFM to point to the 2D tensor op.outputs[0] = fc_ofm_tensor + # Record optimisation in debug database + DebugDatabase.add_optimised(op, reshape_op) + DebugDatabase.add_optimised(op, op) return op @@ -670,6 +679,10 @@ def fixup_act_reorder(op, arch, nng): # Mark the op so that it will be removed as passthrough later on op.type = Op.Identity + + # Record optimisation in debug database + DebugDatabase.add_optimised(op, act_op) + DebugDatabase.add_optimised(op, op) return op @@ -788,6 +801,10 @@ def convert_mul_max_to_abs_or_lrelu(op, arch, nng): op.name = op.name.replace("Maximum", new_op.name) op.outputs[0].name = op.outputs[0].name.replace("Maximum", new_op.name) op.inputs = [shared_in] + + # Record optimisation in debug database + DebugDatabase.add_optimised(op, op) + return op @@ -812,6 +829,7 @@ def convert_lrelu_to_mul_max(op, arch): mul_alpha.add_input_tensor(alpha_tens) fm_alpha = ofm.clone(op.name + "_alpha") mul_alpha.set_output_tensor(fm_alpha) + DebugDatabase.add_optimised(op, mul_alpha) if check_quantized_tens_scaling_equal(ifm, ofm): # No identity multiplication is needed @@ -832,6 +850,7 @@ def convert_lrelu_to_mul_max(op, arch): mul_identity.add_input_tensor(identity_tens) fm_id = ofm.clone(op.name + "_id") mul_identity.set_output_tensor(fm_id) + DebugDatabase.add_optimised(op, mul_alpha) # Convert LeakyRelu to Max, add the results of the multiplication(s) as inputs op.type = Op.Maximum @@ -840,6 +859,8 @@ def convert_lrelu_to_mul_max(op, arch): ifm.consumer_list.remove(op) op.add_input_tensor(fm_alpha) op.add_input_tensor(fm_id) + + DebugDatabase.add_optimised(op, op) return op @@ -1012,6 +1033,7 @@ def fuse_activation_function_with_prev(op, arch, nng): prev_op.set_activation_lut(op.activation_lut) # Bypass op prev_op.set_output_tensor(ofm) + DebugDatabase.add_optimised(op, prev_op) return op @@ -1052,6 +1074,11 @@ def supported_operator_check(op, arch, nng): return op +def _record_optimised(op, arch): + if op.type != Op.Const: + DebugDatabase.add_optimised(op, op) + + def optimise_graph_a(nng, arch, verbose_graph=False): if verbose_graph: nng.print_graph() @@ -1093,6 +1120,10 @@ def optimise_graph_a(nng, arch, verbose_graph=False): nng, sg, arch, [remove_passthrough_tensor], [fuse_activation_function_with_prev, add_padding_fields] ) + # Post-optimisation operator debug tracing + for sg in nng.subgraphs: + rewrite_graph.visit_graph_post_order(sg.output_tensors, arch, [], [_record_optimised]) + if verbose_graph: nng.print_graph() return nng diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py index 5673c2df..59376a85 100644 --- a/ethosu/vela/pass_packing.py +++ b/ethosu/vela/pass_packing.py @@ -18,6 +18,7 @@ import collections import enum +from .debug_database import DebugDatabase from .nn_graph import Pass from .nn_graph import PassPlacement from .operation import create_avgpool_nop @@ -430,7 +431,6 @@ def pack_into_passes(nng, arch, verbose_packing=False): # Configure a 1x1 AvgPool and attach the op onto it op = op_list[0] inp = op.inputs[0] - avgpool_op = create_avgpool_nop(op.name + "_avgpool") avgpool_op.add_input_tensor(inp) avgpool_out = inp.clone("_avgpooled") @@ -440,6 +440,7 @@ def pack_into_passes(nng, arch, verbose_packing=False): op.inputs[0] = avgpool_out op_list.insert(0, avgpool_op) + DebugDatabase.add_optimised(op, avgpool_op) return avgpool_op return None diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py index e5e4fb14..e3fedfcc 100644 --- a/ethosu/vela/register_command_stream_generator.py +++ b/ethosu/vela/register_command_stream_generator.py @@ -32,6 +32,7 @@ from .architecture_features import SharedBufferArea from .architecture_features import SHRAMElements from .data_type import BaseType from .data_type import DataType +from .debug_database import DebugDatabase from .ethos_u55_regs.ethos_u55_regs import acc_format from .ethos_u55_regs.ethos_u55_regs import activation from .ethos_u55_regs.ethos_u55_regs import cmd0 @@ -96,10 +97,13 @@ class IFM2Broadcast(IntEnum): class CommandStreamEmitter: + WORD_SIZE = 4 + def __init__(self): self.cmd_stream = [] self.reg_machine = [RegisterMachine(), RegisterMachine()] self.last_absolute_wait = defaultdict(int) + self.offset = 0 def get_reg_machine(self, cmd): if "DMA" in cmd.name: @@ -110,7 +114,7 @@ class CommandStreamEmitter: def size_in_bytes(self): sz = 0 for cmd in self.cmd_stream: - sz += len(cmd) * 4 + sz += len(cmd) * CommandStreamEmitter.WORD_SIZE return sz def to_list(self): @@ -154,6 +158,7 @@ class CommandStreamEmitter: # This is not a redundant command, actually write it self.cmd_stream.append((command,)) + self.offset += CommandStreamEmitter.WORD_SIZE def cmd1_with_offset(self, cmd, offset, param=0x0): offset = int(offset) & 0xFFFFFFFFF @@ -164,17 +169,20 @@ class CommandStreamEmitter: # This is not a redundant command, actually write it self.cmd_stream.append((command, offset)) + self.offset += CommandStreamEmitter.WORD_SIZE * 2 def cmd_wait(self, cmd, channel, outstanding_count): param = (16 * channel) + outstanding_count command = ((param & 0xFFFF) << 16) | cmd.value self.cmd_stream.append((command,)) + self.offset += CommandStreamEmitter.WORD_SIZE def cmd_do_operation(self, cmd, param=0): param = int(param) command = ((param & 0xFFFF) << 16) | cmd.value self.cmd_stream.append((command,)) + self.offset += CommandStreamEmitter.WORD_SIZE self.get_reg_machine(cmd).switch_bank() @@ -378,6 +386,9 @@ def generate_register_command_stream(nng, sg, arch, verbose=False): dep_watermark = Watermark(0, 0) + stream_id = DebugDatabase.add_stream(sg) + DebugDatabase.set_stream_offset(sg, 0) # Default to zero, can only set during file writing + for cmd_index, cmd in enumerate(cmd_stream): dep_watermark, cmd_waits = get_cmd_wait_dependency(arch, cmd_stream, memory_accesses, cmd_index, dep_watermark) @@ -1077,6 +1088,7 @@ def generate_register_command_stream(nng, sg, arch, verbose=False): prev_cmd = cmd emit_cmd_waits(cmd_waits) + DebugDatabase.add_command(stream_id, emit.offset, primary_op) if npu_block_type == NpuBlockType.ConvolutionMxN: emit.cmd_do_operation(cmd0.NPU_OP_CONV) diff --git a/ethosu/vela/softmax.py b/ethosu/vela/softmax.py index 12c20164..efd91a35 100644 --- a/ethosu/vela/softmax.py +++ b/ethosu/vela/softmax.py @@ -25,6 +25,7 @@ import numpy as np from . import fp_math from . import scaling from .data_type import DataType +from .debug_database import DebugDatabase from .operation import Op from .operation import Operation from .tensor import create_const_tensor @@ -220,6 +221,9 @@ class SoftMax: def get_graph_8bit(self, ifm, ofm): exp_lut = self.generate_exp_table(self.op.attrs.get("beta", 1.0), ifm.quantization.scale_f32) + ifm = create_reshape_tensor(ifm, ifm.get_full_shape()) + DebugDatabase.add_optimised(self.op, ifm.ops[0]) + ofm = create_reshape_tensor(ofm, ofm.get_full_shape(), False) no_scale_quant = ifm.quantization.clone() no_scale_quant.scale_f32 = None no_scale_quant.zero_point = 0 @@ -245,6 +249,7 @@ class SoftMax: ifm_max = Tensor([1, maxpool_h, 1, 1], ifm.dtype, maxpool_op.name + "_0") ifm_max.quantization = no_scale_quant maxpool_op.set_output_tensor(ifm_max) + DebugDatabase.add_optimised(self.op, maxpool_op) # PASS 1 - Sub+LUT(exp) sub_op = Operation(Op.Sub, self.op.name + "_sub1") @@ -261,6 +266,7 @@ class SoftMax: ifm_exp.quantization.quant_min = -128 ifm_exp.quantization.quant_max = 127 sub_op.set_output_tensor(ifm_exp) + DebugDatabase.add_optimised(self.op, sub_op) # PASS 2 - SHR shr2_op = Operation(Op.SHR, self.op.name + "_shr2") @@ -274,6 +280,7 @@ class SoftMax: rescaled_exp = Tensor(ifm.shape, ifm_exp.dtype, shr2_op.name + "_0") rescaled_exp.quantization = no_scale_quant shr2_op.set_output_tensor(rescaled_exp) + DebugDatabase.add_optimised(self.op, shr2_op) # PASS 3 - Reduce sum reduce_sum_op = Operation(Op.ReduceSum, self.op.name + "_reduce_sum3") @@ -290,6 +297,7 @@ class SoftMax: sum_of_exp = Tensor(reduce_sum_shape, DataType.int32, reduce_sum_op.name + "_0") sum_of_exp.quantization = no_scale_quant reduce_sum_op.set_output_tensor(sum_of_exp) + DebugDatabase.add_optimised(self.op, reduce_sum_op) # PASS 4 - CLZ clz_op = Operation(Op.CLZ, self.op.name + "_clz4") @@ -297,6 +305,7 @@ class SoftMax: headroom_plus_one = Tensor(reduce_sum_shape, DataType.int32, clz_op.name + "_0") headroom_plus_one.quantization = no_scale_quant clz_op.set_output_tensor(headroom_plus_one) + DebugDatabase.add_optimised(self.op, clz_op) # PASS 5 - Sub sub5_op = Operation(Op.Sub, self.op.name + "_sub5") @@ -314,6 +323,7 @@ class SoftMax: right_shift = Tensor(reduce_sum_shape, DataType.int32, sub5_op.name + "_0") right_shift.quantization = no_scale_quant sub5_op.set_output_tensor(right_shift) + DebugDatabase.add_optimised(self.op, sub5_op) # PASS 6 - Sub one = create_const_tensor("one_const", [1, 1, 1, 1], DataType.int32, [1], np.int32, quantization=no_scale_quant) @@ -323,6 +333,7 @@ class SoftMax: headroom = Tensor(reduce_sum_shape, DataType.int32, sub6_op.name + "_0") headroom.quantization = no_scale_quant sub6_op.set_output_tensor(headroom) + DebugDatabase.add_optimised(self.op, sub6_op) # PASS 7 - SHL shl7_op = Operation(Op.SHL, self.op.name + "_shl7") @@ -331,6 +342,7 @@ class SoftMax: shifted_sum = Tensor(reduce_sum_shape, DataType.int32, shl7_op.name + "_0") shifted_sum.quantization = no_scale_quant shl7_op.set_output_tensor(shifted_sum) + DebugDatabase.add_optimised(self.op, shl7_op) # PASS 8 - Sub sub8_op = Operation(Op.Sub, self.op.name + "_sub8") @@ -343,6 +355,7 @@ class SoftMax: shifted_sum_minus_one = Tensor(reduce_sum_shape, DataType.int32, sub8_op.name + "_0") shifted_sum_minus_one.quantization = no_scale_quant sub8_op.set_output_tensor(shifted_sum_minus_one) + DebugDatabase.add_optimised(self.op, sub8_op) # PASS 9 - SHL shl9_op = Operation(Op.SHL, self.op.name + "_shl9") @@ -351,6 +364,7 @@ class SoftMax: shifted_sum_minus_one = Tensor(reduce_sum_shape, DataType.int32, shl9_op.name + "_0") shifted_sum_minus_one.quantization = no_scale_quant shl9_op.set_output_tensor(shifted_sum_minus_one) + DebugDatabase.add_optimised(self.op, shl9_op) # PASS 10 - Add add10_op = Operation(Op.Add, self.op.name + "_add10") @@ -364,6 +378,7 @@ class SoftMax: half_denominator = Tensor(reduce_sum_shape, DataType.int32, add10_op.name + "_0") half_denominator.quantization = one_scale_quant add10_op.set_output_tensor(half_denominator) + DebugDatabase.add_optimised(self.op, add10_op) # PASS 11 - Multiply mul11_op = Operation(Op.Mul, self.op.name + "_mul11") @@ -382,6 +397,7 @@ class SoftMax: rescaled.quantization = one_scale_quant.clone() rescaled.quantization.scale_f32 = 2.0 mul11_op.set_output_tensor(rescaled) + DebugDatabase.add_optimised(self.op, mul11_op) # PASS 12 - Add add12_op = Operation(Op.Add, self.op.name + "_add12") @@ -394,6 +410,7 @@ class SoftMax: rescale_w_offset = Tensor(reduce_sum_shape, DataType.int32, add12_op.name + "_0") rescale_w_offset.quantization = one_scale_quant add12_op.set_output_tensor(rescale_w_offset) + DebugDatabase.add_optimised(self.op, add12_op) nr_x = rescale_w_offset F2_one = create_const_tensor( @@ -411,6 +428,7 @@ class SoftMax: half_denominator_times_x.quantization = one_scale_quant.clone() half_denominator_times_x.quantization.scale_f32 = 2.0 mul_op.set_output_tensor(half_denominator_times_x) + DebugDatabase.add_optimised(self.op, mul_op) # PASS 14, 19, 24 - SUB sub_op = Operation(Op.Sub, self.op.name + "_sub%d" % (14 + i * 5)) sub_op.add_input_tensor(F2_one) @@ -418,6 +436,7 @@ class SoftMax: one_minus_half_denominator_times_x = Tensor(reduce_sum_shape, DataType.int32, sub_op.name + "_0") one_minus_half_denominator_times_x.quantization = one_scale_quant sub_op.set_output_tensor(one_minus_half_denominator_times_x) + DebugDatabase.add_optimised(self.op, sub_op) # PASS 15, 20, 25 - MUL mul_op = Operation(Op.Mul, self.op.name + "_mul%d" % (15 + i * 5)) mul_op.add_input_tensor(nr_x) @@ -426,6 +445,7 @@ class SoftMax: to_rescale.quantization = one_scale_quant.clone() to_rescale.quantization.scale_f32 = 2.0 mul_op.set_output_tensor(to_rescale) + DebugDatabase.add_optimised(self.op, mul_op) # PASS 16, 21, 26 - MUL shl_op = Operation(Op.Mul, self.op.name + "_mul%d" % (16 + i * 5)) shl_op.add_input_tensor(to_rescale) @@ -433,6 +453,7 @@ class SoftMax: to_add = Tensor(reduce_sum_shape, DataType.int32, shl_op.name + "_0") to_add.quantization = no_scale_quant shl_op.set_output_tensor(to_add) + DebugDatabase.add_optimised(self.op, shl_op) # PASS 17, 22, 27 - ADD add_op = Operation(Op.Add, self.op.name + "_add%d" % (17 + i * 5)) add_op.add_input_tensor(nr_x) @@ -440,6 +461,7 @@ class SoftMax: nr_x = Tensor(reduce_sum_shape, DataType.int32, add_op.name + "_0") nr_x.quantization = one_scale_quant add_op.set_output_tensor(nr_x) + DebugDatabase.add_optimised(self.op, add_op) # PASS 28 - Multiply mul28_op = Operation(Op.Mul, self.op.name + "_mul28") @@ -450,6 +472,7 @@ class SoftMax: scale_factor = Tensor(reduce_sum_shape, DataType.int32, mul28_op.name + "_0") scale_factor.quantization = one_scale_quant mul28_op.set_output_tensor(scale_factor) + DebugDatabase.add_optimised(self.op, mul28_op) # PASS 29 - Multiply mul_op = Operation(Op.Mul, self.op.name + "_mul29") @@ -459,6 +482,7 @@ class SoftMax: scaled_exp.quantization = one_scale_quant.clone() scaled_exp.quantization.scale_f32 = 2.0 mul_op.set_output_tensor(scaled_exp) + DebugDatabase.add_optimised(self.op, mul_op) # PASS 30 - SHR shr30_op = Operation(Op.SHR, self.op.name + "_shr30") @@ -466,6 +490,7 @@ class SoftMax: shr30_op.add_input_tensor(scaled_exp) shr30_op.add_input_tensor(right_shift) shr30_op.set_output_tensor(ofm) + DebugDatabase.add_optimised(self.op, shr30_op) return shr30_op @@ -476,6 +501,7 @@ class SoftMax: # PASS 0 - Depthwise Maxpool maxpool_op = self.op.clone("_maxpool0") maxpool_op.type = Op.MaxPool + DebugDatabase.add_optimised(self.op, maxpool_op) maxpool_h = ifm.shape[1] * ifm.shape[2] maxpool_w = ifm.shape[3] maxpool_ifm_shape = [1, maxpool_h, maxpool_w, 1] @@ -490,6 +516,7 @@ class SoftMax: maxpool_ofm = Tensor([1, maxpool_h, 1, 1], ifm.dtype, maxpool_op.name + "_0") maxpool_ofm.quantization = no_scale_quant maxpool_op.set_output_tensor(maxpool_ofm) + DebugDatabase.add_optimised(self.op, maxpool_op) # PASS 1 - Sub sub1_op = Operation(Op.Sub, self.op.name + "_sub1") @@ -498,6 +525,7 @@ class SoftMax: sub1_ofm = Tensor(ifm.shape, DataType.int32, sub1_op.name + "_0") sub1_ofm.quantization = ifm.quantization.clone() sub1_op.set_output_tensor(sub1_ofm) + DebugDatabase.add_optimised(self.op, sub1_op) # PASS 2 - Mul beta = self.op.attrs.get("beta", 1.0) @@ -516,6 +544,7 @@ class SoftMax: mul2_ofm.quantization = ofm.quantization.clone() mul2_ofm.quantization.scale_f32 = mul2_out_range mul2_op.set_output_tensor(mul2_ofm) + DebugDatabase.add_optimised(self.op, mul2_op) # PASS 3 - Add+LUT(exp) add_op = Operation(Op.Add, self.op.name + "_add3") @@ -533,6 +562,7 @@ class SoftMax: exp_ofm = Tensor(mul2_ofm.shape, DataType.int16, add_op.name + "_0") exp_ofm.quantization = mul2_ofm.quantization.clone() add_op.set_output_tensor(exp_ofm) + DebugDatabase.add_optimised(self.op, add_op) # PASS 4 - Reduce sum reduce_sum_op = Operation(Op.ReduceSum, self.op.name + "_reduce_sum4") @@ -549,6 +579,7 @@ class SoftMax: sum_of_exp = Tensor(reduce_sum_shape, DataType.int32, reduce_sum_op.name + "_0") sum_of_exp.quantization = no_scale_quant reduce_sum_op.set_output_tensor(sum_of_exp) + DebugDatabase.add_optimised(self.op, reduce_sum_op) # PASS 5 - CLZ clz_op = Operation(Op.CLZ, self.op.name + "_clz5") @@ -556,6 +587,7 @@ class SoftMax: headroom_plus_one = Tensor(reduce_sum_shape, DataType.int32, clz_op.name + "_0") headroom_plus_one.quantization = no_scale_quant clz_op.set_output_tensor(headroom_plus_one) + DebugDatabase.add_optimised(self.op, clz_op) # PASS 6 - Sub sub6_op = Operation(Op.Sub, self.op.name + "_sub6") @@ -568,6 +600,7 @@ class SoftMax: reciprocal_right_shift = Tensor(reduce_sum_shape, DataType.int32, sub6_op.name + "_0") reciprocal_right_shift.quantization = no_scale_quant sub6_op.set_output_tensor(reciprocal_right_shift) + DebugDatabase.add_optimised(self.op, sub6_op) # PASS 7 - SHL shl7_op = Operation(Op.SHL, self.op.name + "_shl7") @@ -580,6 +613,7 @@ class SoftMax: constant_one = Tensor(reduce_sum_shape, DataType.int32, shl7_op.name + "_0") constant_one.quantization = no_scale_quant shl7_op.set_output_tensor(constant_one) + DebugDatabase.add_optimised(self.op, shl7_op) # PASS 8 - Sub sub8_op = Operation(Op.Sub, self.op.name + "_sub8") @@ -588,6 +622,7 @@ class SoftMax: sum_of_exps_minus_one = Tensor(reduce_sum_shape, DataType.int32, sub8_op.name + "_0") sum_of_exps_minus_one.quantization = no_scale_quant sub8_op.set_output_tensor(sum_of_exps_minus_one) + DebugDatabase.add_optimised(self.op, sub8_op) # PASS 9 - SHL shl9_op = Operation(Op.SHL, self.op.name + "_shl9") @@ -596,6 +631,7 @@ class SoftMax: shifted_sum_minus_one = Tensor(reduce_sum_shape, DataType.int32, shl9_op.name + "_0") shifted_sum_minus_one.quantization = no_scale_quant shl9_op.set_output_tensor(shifted_sum_minus_one) + DebugDatabase.add_optimised(self.op, shl9_op) # PASS 10 - SHR shr10_op = Operation(Op.SHR, self.op.name + "_shr10") @@ -608,6 +644,7 @@ class SoftMax: shifted_sum_minus_one_16 = Tensor(reduce_sum_shape, DataType.int32, shr10_op.name + "_0") shifted_sum_minus_one_16.quantization = shifted_sum_minus_one.quantization.clone() shr10_op.set_output_tensor(shifted_sum_minus_one_16) + DebugDatabase.add_optimised(self.op, shr10_op) # PASS 11 - Sub+LUT(one over one plus x) sub11_op = Operation(Op.Sub, self.op.name + "_sub11") @@ -630,6 +667,7 @@ class SoftMax: reciprocal_scale = Tensor(reduce_sum_shape, DataType.int16, sub11_op.name + "_0") reciprocal_scale.quantization = no_scale_quant sub11_op.set_output_tensor(reciprocal_scale) + DebugDatabase.add_optimised(self.op, sub11_op) # PASS 12 - Multiply mul_op = Operation(Op.Mul, self.op.name + "_mul12") @@ -638,11 +676,13 @@ class SoftMax: mul_ofm = Tensor(exp_ofm.shape, DataType.int32, mul_op.name + "_0") mul_ofm.quantization = no_scale_quant mul_op.set_output_tensor(mul_ofm) + DebugDatabase.add_optimised(self.op, mul_op) # PASS 13 - SHR shr13_op = Operation(Op.SHR, self.op.name + "_shr13") shr13_op.add_input_tensor(mul_ofm) shr13_op.add_input_tensor(reciprocal_right_shift) shr13_op.set_output_tensor(ofm) + DebugDatabase.add_optimised(self.op, shr13_op) return shr13_op diff --git a/ethosu/vela/vela.py b/ethosu/vela/vela.py index 4b43751a..5df20d22 100644 --- a/ethosu/vela/vela.py +++ b/ethosu/vela/vela.py @@ -31,6 +31,7 @@ from . import scheduler from . import stats_writer from . import tflite_writer from ._version import __version__ +from .debug_database import DebugDatabase from .errors import InputFileError from .nn_graph import PassPlacement from .nn_graph import TensorAllocator @@ -39,14 +40,18 @@ from .tensor import MemArea from .tensor import Tensor -def process(fname, arch, model_reader_options, compiler_options, scheduler_options): +def process(input_name, enable_debug_db, arch, model_reader_options, compiler_options, scheduler_options): if compiler_options.timing: start = time.time() - nng = model_reader.read_model(fname, model_reader_options) + os.makedirs(compiler_options.output_dir, exist_ok=True) + output_basename = os.path.join(compiler_options.output_dir, os.path.splitext(os.path.basename(input_name))[0]) + DebugDatabase.show_warnings = enable_debug_db + + nng = model_reader.read_model(input_name, model_reader_options) if not nng: - raise InputFileError(fname, "input file could not be read") + raise InputFileError(input_name, "input file could not be read") if compiler_options.verbose_operators: nng.print_operators() @@ -58,16 +63,21 @@ def process(fname, arch, model_reader_options, compiler_options, scheduler_optio compiler_driver.compiler_driver(nng, arch, compiler_options, scheduler_options) - passes_csv_file = "%s/%s_pass-breakdown_%s.csv" % (compiler_options.output_dir, nng.name, arch.system_config) + passes_csv_file = "{0}_pass-breakdown_{1}.csv".format(output_basename, arch.system_config) stats_writer.write_pass_metrics_csv(nng, passes_csv_file) - summary_csv_file = "%s/%s_summary_%s.csv" % (compiler_options.output_dir, nng.name, arch.system_config) + summary_csv_file = "{0}_summary_{1}.csv".format(output_basename, arch.system_config) stats_writer.write_summary_metrics_csv(nng, summary_csv_file, arch) stats_writer.print_performance_metrics(nng, show_cpu_operations=compiler_options.show_cpu_operations, arch=arch) - if fname.endswith(".tflite"): - tflite_writer.write_tflite(nng, "%s/%s_vela.tflite" % (compiler_options.output_dir, nng.name)) + output_filename = output_basename + "_vela.tflite" + if input_name.endswith(".tflite"): + tflite_writer.write_tflite(nng, output_filename) + + if enable_debug_db: + debug_filename = output_basename + "_debug.xml" + DebugDatabase.write(debug_filename, input_name, output_filename) if compiler_options.timing: stop = time.time() @@ -123,6 +133,13 @@ def main(args=None): parser.add_argument( "--output-dir", type=str, default="output", help="Output directory to write files to (default: %(default)s)" ) + parser.add_argument( + "--enable-debug-db", + action="store_true", + default=None, + help="Enables the calculation and writing of a network debug database to output directory", + ) + parser.add_argument("--config", type=str, help="Location of vela configuration file") parser.add_argument("--verbose-graph", action="store_true", help="Verbose graph rewriter") @@ -319,9 +336,7 @@ def main(args=None): model_reader_options = model_reader.ModelReaderOptions() - os.makedirs(args.output_dir, exist_ok=True) - - nng = process(args.network, arch, model_reader_options, compiler_options, scheduler_options) + nng = process(args.network, args.enable_debug_db, arch, model_reader_options, compiler_options, scheduler_options) if args.show_subgraph_io_summary: print_subgraph_io_summary(nng) diff --git a/setup.py b/setup.py index 07ab2d13..cc306360 100644 --- a/setup.py +++ b/setup.py @@ -56,7 +56,7 @@ setup( keywords=["ethos-u", "vela compiler", "tflite", "npu"], packages=find_namespace_packages(include=["ethosu.*"]), python_requires="~=3.6", # We support only 3.6+ - install_requires=["flatbuffers==1.11.0", "numpy>=1.16.6"], + install_requires=["flatbuffers==1.11.0", "numpy>=1.16.6", "lxml>=4.6.1"], entry_points={"console_scripts": ["vela = ethosu.vela.vela:main"]}, ext_modules=[mlw_module], setup_requires=["setuptools_scm"], -- cgit v1.2.1