From 0b8268a0dac80aa22133ca83ed6912d3b565439a Mon Sep 17 00:00:00 2001 From: Louis Verhaard Date: Wed, 5 Aug 2020 16:11:29 +0200 Subject: MLBEDSW-2688: Improved LUT support - Support for more than one 256-byte LUT in SHRAM - No DMA is performed for a LUT that is already located in SHRAM - Added MemArea.Shram, used for LUT, to avoid false address collision asserts during SRAM tensor allocation - Added read access to LUT in memory access calculation Change-Id: If4d1eded5ed029d253f4f5efb2d80495fc3eac99 Signed-off-by: Louis Verhaard --- ethosu/mlw_codec/test/test_mlw_codec.py | 1 + ethosu/vela/architecture_features.py | 11 ++ ethosu/vela/compiler_driver.py | 2 + ethosu/vela/greedy_allocation.py | 4 +- ethosu/vela/high_level_command_stream.py | 13 +- ethosu/vela/high_level_command_stream_generator.py | 2 +- ethosu/vela/insert_dma.py | 8 +- ethosu/vela/lut.py | 120 ++++++++++++++ ethosu/vela/numeric_util.py | 4 + ethosu/vela/pass_packing.py | 9 +- ethosu/vela/register_command_stream_generator.py | 15 +- ethosu/vela/tensor.py | 12 +- ethosu/vela/tensor_allocation.py | 6 + ethosu/vela/test/test_live_range.py | 1 + ethosu/vela/test/test_lut.py | 180 +++++++++++++++++++++ ethosu/vela/test/test_model_reader.py | 1 + ethosu/vela/test/test_tflite_reader.py | 1 + ethosu/vela/test/testutil.py | 70 ++++++++ 18 files changed, 436 insertions(+), 24 deletions(-) create mode 100644 ethosu/vela/lut.py create mode 100644 ethosu/vela/test/test_lut.py create mode 100644 ethosu/vela/test/testutil.py diff --git a/ethosu/mlw_codec/test/test_mlw_codec.py b/ethosu/mlw_codec/test/test_mlw_codec.py index 31a3bc08..d37462d1 100644 --- a/ethosu/mlw_codec/test/test_mlw_codec.py +++ b/ethosu/mlw_codec/test/test_mlw_codec.py @@ -16,6 +16,7 @@ # limitations under the License. # Simple example of the usage of mlw_codec. import pytest + from ethosu import mlw_codec diff --git a/ethosu/vela/architecture_features.py b/ethosu/vela/architecture_features.py index 021597e6..265af426 100644 --- a/ethosu/vela/architecture_features.py +++ b/ethosu/vela/architecture_features.py @@ -316,6 +316,9 @@ Note the difference between ArchitectureFeatures and CompilerOptions self.shram_reserved_unused_banks = 2 if accel_config.shram_banks > 16 else 0 self.shram_total_banks = accel_config.shram_banks - self.shram_reserved_unused_banks self.shram_bank_granules = np.array(accel_config.shram_granules, np.int32) + self.shram_lut_size = 2048 + # SHRAM base address of the activation lookup table + self.shram_lut_address = self.shram_bank_size * self.available_shram_banks(True) # Build a map of acceptable IFM/OFM block configurations up to the maximum # IFM/OFM block size. @@ -326,6 +329,14 @@ Note the difference between ArchitectureFeatures and CompilerOptions # Setup supported operators and restriction checkers class self.supported_operators = SupportedOperators(softmax_support) + # Returns available number of SHRAM banks depending on activation lookup table + # being used or not + def available_shram_banks(self, uses_activation_lut): + banks = self.shram_total_banks + if uses_activation_lut and self.shram_reserved_unused_banks == 0: + banks -= 2 + return banks + # Calculate block configuration for ALL known IFM operations and # accumulator sizes. Consumers will need to select their preferred # operation and bit-width at read-time. diff --git a/ethosu/vela/compiler_driver.py b/ethosu/vela/compiler_driver.py index f407fdc4..5e9e38fb 100644 --- a/ethosu/vela/compiler_driver.py +++ b/ethosu/vela/compiler_driver.py @@ -22,6 +22,7 @@ from . import graph_optimiser from . import high_level_command_stream_generator from . import insert_dma from . import live_range +from . import lut from . import mark_tensors from . import npu_performance from . import npu_serialisation @@ -198,6 +199,7 @@ def compiler_driver(nng, arch, options, scheduler_options): high_level_command_stream_generator.generate_high_level_command_stream( nng, sg, arch, options.verbose_high_level_command_stream ) + lut.optimize_high_level_cmd_stream(sg, arch) register_command_stream_generator.generate_register_command_stream( nng, sg, arch, options.verbose_register_command_stream ) diff --git a/ethosu/vela/greedy_allocation.py b/ethosu/vela/greedy_allocation.py index e0176875..1cbfce3f 100644 --- a/ethosu/vela/greedy_allocation.py +++ b/ethosu/vela/greedy_allocation.py @@ -77,9 +77,7 @@ class GreedyAllocator: for m in lrs: if n != m and n.overlaps_ranges(m): overlap, tens_n, tens_m = n.overlaps_address(m) - if overlap and not ( - tens_n.equivalence_id == tens_m.equivalence_id and tens_n.address == tens_m.address - ): + if overlap and not (tens_n.equivalent(tens_m) and tens_n.address == tens_m.address): print("Solution failed, overlapping buffer!") print(tens_n.address, tens_n.address + n.size, n.name) print(tens_m.address, tens_m.address + m.size, m.name) diff --git a/ethosu/vela/high_level_command_stream.py b/ethosu/vela/high_level_command_stream.py index c6698297..95af1ccb 100644 --- a/ethosu/vela/high_level_command_stream.py +++ b/ethosu/vela/high_level_command_stream.py @@ -23,6 +23,9 @@ from .numeric_util import round_up_divide from .operation import NpuBlockType from .range_set import AccessDirection from .range_set import MemoryAccessSet +from .range_set import MemoryRangeSet +from .tensor import MemArea +from .tensor import TensorPurpose class Box: @@ -233,6 +236,13 @@ class NpuStripe(Command): ), AccessDirection.Read, ) + # Add read access to SHRAM by any LUT-s + for tens in self.ps.intermediates: + if tens.purpose == TensorPurpose.LUT and tens.mem_area == MemArea.Shram: + res.add( + MemoryRangeSet(tens.mem_area, tens.address, tens.address + tens.storage_size()), + AccessDirection.Read, + ) return res def is_npu_pass_command(self): @@ -359,8 +369,9 @@ class NpuStripe(Command): class DMA(Command): - def __init__(self, in_tensor, out_tensor, box): + def __init__(self, ps, in_tensor, out_tensor, box): self.cmdtype = CommandType.DMA + self.ps = ps self.in_tensor = in_tensor self.out_tensor = out_tensor self.box = box diff --git a/ethosu/vela/high_level_command_stream_generator.py b/ethosu/vela/high_level_command_stream_generator.py index d34fb75b..d5a6341b 100644 --- a/ethosu/vela/high_level_command_stream_generator.py +++ b/ethosu/vela/high_level_command_stream_generator.py @@ -32,7 +32,7 @@ def dma_if_necessary(ps, box, tensor): if tensor.needs_dma(): dma_op = tensor.ops[0] in_tensor = dma_op.inputs[0] - yield DMA(in_tensor, tensor, box) + yield DMA(ps, in_tensor, tensor, box) def match_tensor(source, derived): diff --git a/ethosu/vela/insert_dma.py b/ethosu/vela/insert_dma.py index 6c5c8031..6cd2202c 100644 --- a/ethosu/vela/insert_dma.py +++ b/ethosu/vela/insert_dma.py @@ -61,13 +61,7 @@ def insert_dma_cmd(op, arch): dma_cmd.attrs["destination"] = new_tens.mem_area dma_cmd.run_on_npu = True if tens.purpose == TensorPurpose.LUT: - # TODO: Add support more than one LUT at a time - # Reserve last 2 blocks for LUT - if arch.shram_reserved_unused_banks == 0: - arch.shram_reserved_unused_banks = 2 - arch.shram_total_banks -= arch.shram_reserved_unused_banks - # Place the LUT in the last 2 blocks of SHRAM - new_tens.address = arch.shram_bank_size * arch.shram_total_banks + new_tens.mem_area = MemArea.Shram op.inputs[idx] = new_tens return op diff --git a/ethosu/vela/lut.py b/ethosu/vela/lut.py new file mode 100644 index 00000000..39101fac --- /dev/null +++ b/ethosu/vela/lut.py @@ -0,0 +1,120 @@ +# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Description: +# Functionality for lookup table support. +import uuid +from functools import lru_cache + +from . import numeric_util +from .high_level_command_stream import CommandType +from .tensor import TensorPurpose + + +@lru_cache(maxsize=None) +def create_equivalence_id(key): + # Generates equivalence_id based on key. + # The DMA optimization of LUT-s assumes that 2 LUT tensors are identical + # if they have the same equivalence_id. + # So for example all created 256-byte tanh LUT tensors should have + # the same equivalence id. + return uuid.uuid4() + + +class LUTState: + # Tracks which LUT-s are located in SHRAM. + def __init__(self): + self.tensors = [] + + def get_equivalent(self, lut_tens): + # Returns existing lut with same equivalence id, None if not found + for t in self.tensors: + if t.equivalent(lut_tens): + return t + return None + + def put(self, lut_tens): + # Returns new LUT state containing given tensor + all tensors in this state + # that do not overlap with the given tensor + new_state = LUTState() + new_state.tensors.append(lut_tens) + start = lut_tens.address + end = start + lut_tens.storage_size() + for tens in self.tensors: + start2 = tens.address + end2 = start2 + tens.storage_size() + if not numeric_util.overlaps(start, end, start2, end2): + new_state.tensors.append(tens) + return new_state + + def find_best_address(self, start, stop, step): + # Finds the address in the given range that overlaps with the minimum number of + # currently present LUT-s. + # An improvement would be to also take future LUT usage into account + best_addr = start + best_nr_overlaps = stop + for addr in range(start, stop, step): + nr_overlaps = 0 + for tens in self.tensors: + start2 = tens.address + end2 = start2 + tens.storage_size() + if numeric_util.overlaps(addr, addr + step, start2, end2): + nr_overlaps += 1 + if nr_overlaps < best_nr_overlaps: + best_nr_overlaps = nr_overlaps + best_addr = addr + return best_addr + + +def get_lut_index(arch, lut_tensor): + # Returns the index in SHRAM where the given LUT is stored, a value between 0 and 8 + slot = (lut_tensor.address - arch.shram_lut_address) // lut_tensor.storage_size() + assert 0 <= slot < 8 + return slot + + +def optimize_high_level_cmd_stream(sg, arch): + # - Allocates SHRAM address/lut index to LUT tensors + # - Removes unnecessary DMA operations of LUT-s that are already present in SHRAM from sg's command stream + cmd_stream = [] # will contain existing command stream minus unneeded DMA operations + lut_state = LUTState() + slot_size = 256 + lut_start = arch.shram_lut_address + lut_end = lut_start + arch.shram_lut_size + for cmd in sg.high_level_command_stream: + if cmd.cmdtype == CommandType.NpuStripe and cmd.ps.lut_tensor is None and arch.shram_reserved_unused_banks == 0: + # The command overwrites the last 2 banks containing the LUT; next LUT operation will require DMA + # TODO: check the command's SHRAM usage in more detail to determine if the LUT is overwritten or not + lut_state = LUTState() + if cmd.cmdtype != CommandType.DMA or cmd.out_tensor.purpose != TensorPurpose.LUT: + # Non-LUT operation; leave untouched + cmd_stream.append(cmd) + continue + # LUT DMA operation + lut_tens = cmd.out_tensor + existing_tens = lut_state.get_equivalent(lut_tens) + if existing_tens is not None: + # LUT is already in SHRAM, no need to perform DMA + lut_tens.address = existing_tens.address + cmd.ps.primary_op.attrs["lut_index"] = get_lut_index(arch, existing_tens) + continue + # Place the LUT in the last 2 blocks of SHRAM + # Alignment is always on the size of the LUT, 256 for 256-byte LUT, 1K for 1K LUT, etc + address = lut_state.find_best_address(lut_start, lut_end, lut_tens.storage_size()) + lut_tens.address = address + cmd.ps.primary_op.attrs["lut_index"] = (address - lut_start) // slot_size + lut_state = lut_state.put(lut_tens) + cmd_stream.append(cmd) + sg.high_level_command_stream = cmd_stream diff --git a/ethosu/vela/numeric_util.py b/ethosu/vela/numeric_util.py index 70209fba..4ebef8e5 100644 --- a/ethosu/vela/numeric_util.py +++ b/ethosu/vela/numeric_util.py @@ -89,3 +89,7 @@ def clamp_sigmoid(x): def full_shape(dim, shape, fill): return ([fill] * (dim - len(shape))) + shape + + +def overlaps(start1, end1, start2, end2): + return start1 < end2 and start2 < end1 diff --git a/ethosu/vela/pass_packing.py b/ethosu/vela/pass_packing.py index 8e108dbf..7b69e35d 100644 --- a/ethosu/vela/pass_packing.py +++ b/ethosu/vela/pass_packing.py @@ -381,12 +381,18 @@ def pack_into_passes(nng, arch, verbose_packing=False): input_set.add(input_tens) ordered_input_list = [] + # Keep LUT-s in a separate list and add as inputs at the end + # to avoid that they would accidentally be assigned as ifm or ifm2 + lut_list = [] input_refcounts = collections.defaultdict(int) for op in ops_list: for inp in op.inputs: if inp in input_set: if input_refcounts[inp] == 0: - ordered_input_list.append(inp) + if inp.purpose == TensorPurpose.LUT: + lut_list.append(inp) + else: + ordered_input_list.append(inp) input_refcounts[inp] += 1 name = ops_list[0].name @@ -416,6 +422,7 @@ def pack_into_passes(nng, arch, verbose_packing=False): ps.weight_tensor = ps.get_primary_op_ifm_weights()[1] ps.scale_tensor = ps.get_primary_op_ifm_weights_biases_ofm()[2] ps.lut_tensor = ps.get_primary_op_lut() + ps.inputs.extend(lut_list) for op in ps.ops: op.scheduled_pass = ps diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py index 09348811..4a9b0719 100644 --- a/ethosu/vela/register_command_stream_generator.py +++ b/ethosu/vela/register_command_stream_generator.py @@ -277,10 +277,10 @@ def has_prev_op_dependency(prev_cmd, cmd): if prev_cmd is None: return False if (prev_cmd.cmdtype == cmd.cmdtype == CommandType.NpuStripe) and (prev_cmd.ps != cmd.ps): - if prev_cmd.ofm_tensor.equivalence_id == cmd.ifm_tensor.equivalence_id: + if prev_cmd.ofm_tensor.equivalent(cmd.ifm_tensor): return True elif cmd.ifm2_tensor is not None: - return prev_cmd.ofm_tensor.equivalence_id == cmd.ifm2_tensor.equivalence_id + return prev_cmd.ofm_tensor.equivalent(cmd.ifm2_tensor) return False @@ -560,12 +560,13 @@ def generate_register_command_stream(nng, sg, arch, verbose=False): else: emit.cmd1_with_offset(cmd1.NPU_SET_OFM_SCALE, 1, 0) - # For elementwise set the required SHRAM to be equal to the total size of SHRAM - shram_required = arch.shram_total_banks + # For elementwise set the required SHRAM to be equal to the total size of available SHRAM + uses_lut = primary_op.activation_lut is not None + shram_required = arch.available_shram_banks(uses_lut) emit.cmd0_with_param(cmd0.NPU_SET_IFM_IB_END, shram_required) # Acc buffers not needed so set AB_START to size of SHRAM - emit.cmd0_with_param(cmd0.NPU_SET_AB_START, arch.shram_total_banks) + emit.cmd0_with_param(cmd0.NPU_SET_AB_START, shram_required) # Is not a unary operator if cmd.ifm2_tensor is not None: @@ -852,8 +853,8 @@ def generate_register_command_stream(nng, sg, arch, verbose=False): faf_min = quantise_float32(clamp_sigmoid(ifm_min), ofm_quant.scale_f32, ofm_quant.zero_point) faf_max = quantise_float32(clamp_sigmoid(ifm_max), ofm_quant.scale_f32, ofm_quant.zero_point) elif faf == "LUT": - lut_index = int(activation.LUT_START.value) + primary_op.attrs.get("lut_index", 0) - assert lut_index <= activation.LUT_END.value, "LUT index out of range." + lut_index = int(activation.LUT_START.value) + primary_op.attrs.get("lut_index", -1) + assert activation.LUT_START.value <= lut_index <= activation.LUT_END.value, "LUT index out of range." emit.cmd0_with_param(cmd0.NPU_SET_ACTIVATION, lut_index) faf_min = ofm_quant_qmin faf_max = ofm_quant_qmax diff --git a/ethosu/vela/tensor.py b/ethosu/vela/tensor.py index ecca0e0e..312e8f35 100644 --- a/ethosu/vela/tensor.py +++ b/ethosu/vela/tensor.py @@ -54,16 +54,17 @@ class MemArea(enum.IntFlag): Dram = 2 OnChipFlash = 3 OffChipFlash = 4 - Size = OffChipFlash + 1 + Shram = 5 # for LUT + Size = Shram + 1 def display_name(self): - return ("Unknown", "SRAM", "DRAM", "On-chip Flash", "Off-chip Flash", "Size")[self.value] + return ("Unknown", "SRAM", "DRAM", "On-chip Flash", "Off-chip Flash", "SHRAM", "Size")[self.value] def identifier_name(self): - return ("unknown", "sram", "dram", "on_chip_flash", "off_chip_flash", "size")[self.value] + return ("unknown", "sram", "dram", "on_chip_flash", "off_chip_flash", "shram", "size")[self.value] def all(): - return (MemArea.Sram, MemArea.Dram, MemArea.OnChipFlash, MemArea.OffChipFlash) + return (MemArea.Sram, MemArea.Dram, MemArea.OnChipFlash, MemArea.OffChipFlash, MemArea.Shram) def __str__(self): return self.name @@ -728,6 +729,9 @@ class Tensor: return True return False + def equivalent(self, tens): + return self.equivalence_id == tens.equivalence_id + def set_all_shapes(self, shape): self.shape = shape self.storage_shape = shape diff --git a/ethosu/vela/tensor_allocation.py b/ethosu/vela/tensor_allocation.py index f29296d1..bb91145e 100644 --- a/ethosu/vela/tensor_allocation.py +++ b/ethosu/vela/tensor_allocation.py @@ -26,6 +26,7 @@ from .greedy_allocation import allocate_live_ranges as greedy_allocate_live_rang from .nn_graph import TensorAllocator from .tensor import MemArea from .tensor import MemType +from .tensor import TensorPurpose def linear_allocate_live_ranges(live_ranges, alloc_granularity=16): @@ -44,6 +45,11 @@ def linear_allocate_live_ranges(live_ranges, alloc_granularity=16): if allocated_tens.weight_compression_config == tens.weight_compression_config: address = allocated_tens.address break + if tens.purpose == TensorPurpose.LUT: + for allocated_tens in allocated_tensors: + if allocated_tens.equivalent(tens): + address = allocated_tens.address + break lr.set_address(address) allocated_tensors += lr.tensors if address == total_sz: diff --git a/ethosu/vela/test/test_live_range.py b/ethosu/vela/test/test_live_range.py index 395d0f3d..d087dd99 100644 --- a/ethosu/vela/test/test_live_range.py +++ b/ethosu/vela/test/test_live_range.py @@ -18,6 +18,7 @@ from unittest.mock import MagicMock import pytest + from ethosu.vela.live_range import LiveRange diff --git a/ethosu/vela/test/test_lut.py b/ethosu/vela/test/test_lut.py new file mode 100644 index 00000000..3b7f57be --- /dev/null +++ b/ethosu/vela/test/test_lut.py @@ -0,0 +1,180 @@ +# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Description: +# Unit tests for LUT support +import numpy as np + +from ethosu.vela import insert_dma +from ethosu.vela import lut +from ethosu.vela import mark_tensors +from ethosu.vela import pass_packing +from ethosu.vela.data_type import DataType +from ethosu.vela.high_level_command_stream import DMA +from ethosu.vela.nn_graph import Graph +from ethosu.vela.rewrite_graph import verify_graph_health +from ethosu.vela.tensor import create_const_tensor +from ethosu.vela.tensor import TensorPurpose +from ethosu.vela.test import testutil + + +def set_256_lut(op, key): + values = list(range(256)) + lut_tensor = create_const_tensor( + op.name + "_lut", [1, 1, 1, 256], DataType.int8, values, np.uint8, TensorPurpose.LUT + ) + lut_tensor.equivalence_id = lut.create_equivalence_id(key) + op.set_activation_lut(lut_tensor) + + +def set_1K_lut(op, key): + values = list(range(256)) + lut_tensor = create_const_tensor( + op.name + "_lut", [1, 1, 1, 256], DataType.int32, values, np.uint32, TensorPurpose.LUT + ) + lut_tensor.equivalence_id = lut.create_equivalence_id(key) + op.set_activation_lut(lut_tensor) + + +def set_2K_lut(op, key): + values = list(range(512)) + lut_tensor = create_const_tensor( + op.name + "_lut", [1, 1, 1, 512], DataType.int32, values, np.uint32, TensorPurpose.LUT + ) + lut_tensor.equivalence_id = lut.create_equivalence_id(key) + op.set_activation_lut(lut_tensor) + + +def process(arch, op_list): + # Returns subgraph with given operations + nng = Graph() + sg = testutil.create_subgraph(op_list) + nng.subgraphs.append(sg) + assert verify_graph_health(nng) + nng = mark_tensors.mark_tensor_purpose(nng, arch, False) + assert verify_graph_health(nng) + nng = insert_dma.insert_dma_commands(nng, arch, False) + assert verify_graph_health(nng) + pass_packing.pack_into_passes(nng, arch, False) + assert verify_graph_health(nng) + # Create a DMA instruction for every op + cmd_list = [] + for ps in sg.passes: + for intermediate in ps.intermediates: + if intermediate.needs_dma(): + cmd_list.append(DMA(ps, intermediate.get_dma_src_tensor(), intermediate, None)) + sg.high_level_command_stream = cmd_list + return sg + + +def test_optimize_high_level_cmd_stream_2K(): + # Tests lut.optimize_high_level_cmd_stream, blending 256 byte and 2K luts + arch = testutil.create_arch() + shape = [1, 1, 1, 1] + # u8 LUT op, should lead to DMA + op0 = testutil.create_elemwise_op("AddAct", "op0", shape, shape, shape) + set_256_lut(op0, "lut0") + # u8 LUT op, should lead to DMA + op1 = testutil.create_elemwise_op("AddAct", "op1", shape, shape, shape) + set_256_lut(op1, "lut1") + # u8 LUT op with different LUT, should lead to DMA + op2 = testutil.create_elemwise_op("AddAct", "op2", shape, shape, shape) + set_256_lut(op2, "lut2") + # u8 LUT op with same LUT as in op1, should not lead to DMA + op3 = testutil.create_elemwise_op("AddAct", "op3", shape, shape, shape) + set_256_lut(op3, "lut1") + # u8 LUT op with same LUT as in op2, should not lead to DMA + op4 = testutil.create_elemwise_op("AddAct", "op4", shape, shape, shape) + set_256_lut(op4, "lut2") + # 2K LUT op, should lead to DMA, and will overwrite all previous LUTs in SHRAM + op5_2K = testutil.create_elemwise_op("AddAct", "op5", shape, shape, shape) + set_2K_lut(op5_2K, "lut5") + # Another 2K LUT op, should lead to DMA, and will overwrite the previous LUT in SHRAM + op6_2K = testutil.create_elemwise_op("AddAct", "op6", shape, shape, shape) + set_2K_lut(op6_2K, "lut6") + # u8 LUT op with same LUT as in op1, should lead to DMA + op7 = testutil.create_elemwise_op("AddAct", "op7", shape, shape, shape) + set_256_lut(op7, "lut1") + + op_list = [op0, op1, op2, op3, op4, op5_2K, op6_2K, op7] + sg = process(arch, op_list) + orig_cmd_list = sg.high_level_command_stream + sg.high_level_command_stream = orig_cmd_list + lut.optimize_high_level_cmd_stream(sg, arch) + cmd_list = sg.high_level_command_stream + # Check that only the needed DMA commands are left + expected_dma_ops = [op0, op1, op2, op5_2K, op6_2K, op7] + for (cmd, op) in zip(cmd_list, expected_dma_ops): + assert cmd.in_tensor == op.activation_lut + # Check that lut0, lut1 and lut2 in op0, op1, op2 are stored on different addresses + assert orig_cmd_list[0].out_tensor.address != orig_cmd_list[1].out_tensor.address + assert orig_cmd_list[0].out_tensor.address != orig_cmd_list[2].out_tensor.address + assert orig_cmd_list[1].out_tensor.address != orig_cmd_list[2].out_tensor.address + # Check that lut1 in op1 and op3 have same address + assert orig_cmd_list[1].out_tensor.address == orig_cmd_list[3].out_tensor.address + # Check that lut2 in op2 and op4 have same address + assert orig_cmd_list[2].out_tensor.address == orig_cmd_list[4].out_tensor.address + # Check that lut-s for 16 bit (op5 and op6) are stored on same address + assert orig_cmd_list[5].out_tensor.address == orig_cmd_list[6].out_tensor.address + + +def test_optimize_high_level_cmd_stream_1K(): + # Tests lut.optimize_high_level_cmd_stream, blending 256 and 1K luts + arch = testutil.create_arch() + shape = [1, 1, 1, 1] + # u8 LUT op, should lead to DMA + op0 = testutil.create_elemwise_op("AddAct", "op0", shape, shape, shape) + set_256_lut(op0, "lut0") + # u8 LUT op, should lead to DMA + op1 = testutil.create_elemwise_op("AddAct", "op1", shape, shape, shape) + set_256_lut(op1, "lut1") + # 1K LUT op with different LUT, should lead to DMA + op2_1K = testutil.create_elemwise_op("AddAct", "op2", shape, shape, shape) + set_1K_lut(op2_1K, "lut2") + # u8 LUT op with same LUT as in op1, should not lead to DMA + op3 = testutil.create_elemwise_op("AddAct", "op3", shape, shape, shape) + set_256_lut(op3, "lut1") + # 1K LUT op with same LUT as in op2, should not lead to DMA + op4_1K = testutil.create_elemwise_op("AddAct", "op4", shape, shape, shape) + set_1K_lut(op4_1K, "lut2") + # 1K LUT op, should lead to DMA, and will overwrite lut2 + op5_2K = testutil.create_elemwise_op("AddAct", "op5", shape, shape, shape) + set_1K_lut(op5_2K, "lut5") + # u8 LUT op, lut0 should still be present, should not lead to DMA + op6 = testutil.create_elemwise_op("AddAct", "op6", shape, shape, shape) + set_256_lut(op6, "lut0") + # 1K LUT op with same LUT as in op2, should lead to DMA + op7 = testutil.create_elemwise_op("AddAct", "op7", shape, shape, shape) + set_1K_lut(op7, "lut2") + + op_list = [op0, op1, op2_1K, op3, op4_1K, op5_2K, op6, op7] + sg = process(arch, op_list) + orig_cmd_list = sg.high_level_command_stream + sg.high_level_command_stream = orig_cmd_list + lut.optimize_high_level_cmd_stream(sg, arch) + cmd_list = sg.high_level_command_stream + # Check that only the needed DMA commands are left + expected_dma_ops = [op0, op1, op2_1K, op5_2K, op7] + for (cmd, op) in zip(cmd_list, expected_dma_ops): + assert cmd.in_tensor == op.activation_lut + # Check that lut0, lut1 and lut2 in op0, op1, op2 are stored on different addresses + assert orig_cmd_list[0].out_tensor.address != orig_cmd_list[1].out_tensor.address + assert orig_cmd_list[0].out_tensor.address != orig_cmd_list[2].out_tensor.address + assert orig_cmd_list[1].out_tensor.address != orig_cmd_list[2].out_tensor.address + # Check that lut1 in op1 and op3 have same address + assert orig_cmd_list[1].out_tensor.address == orig_cmd_list[3].out_tensor.address + # Check that lut2 in op2 and op4 and op7 have same address + assert orig_cmd_list[2].out_tensor.address == orig_cmd_list[4].out_tensor.address + assert orig_cmd_list[2].out_tensor.address == orig_cmd_list[7].out_tensor.address diff --git a/ethosu/vela/test/test_model_reader.py b/ethosu/vela/test/test_model_reader.py index 23e7e90b..bd7ca377 100644 --- a/ethosu/vela/test/test_model_reader.py +++ b/ethosu/vela/test/test_model_reader.py @@ -16,6 +16,7 @@ # Description: # Unit tests for model_reader. import pytest + from ethosu.vela import model_reader from ethosu.vela.errors import InputFileError diff --git a/ethosu/vela/test/test_tflite_reader.py b/ethosu/vela/test/test_tflite_reader.py index 898e3840..1ba07423 100644 --- a/ethosu/vela/test/test_tflite_reader.py +++ b/ethosu/vela/test/test_tflite_reader.py @@ -16,6 +16,7 @@ # Description: # Contains unit tests for tflite_reader import pytest + from ethosu.vela.tflite_reader import TFLiteSubgraph diff --git a/ethosu/vela/test/testutil.py b/ethosu/vela/test/testutil.py new file mode 100644 index 00000000..116afa40 --- /dev/null +++ b/ethosu/vela/test/testutil.py @@ -0,0 +1,70 @@ +# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved. +# +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the License); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an AS IS BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# Description: +# Utilities used in vela unit tests +import numpy as np + +from ethosu.vela import architecture_features +from ethosu.vela.data_type import DataType +from ethosu.vela.nn_graph import Subgraph +from ethosu.vela.operation import NpuBlockType +from ethosu.vela.operation import Operation +from ethosu.vela.tensor import create_const_tensor +from ethosu.vela.tensor import MemArea +from ethosu.vela.tensor import Tensor + + +def create_arch(): + return architecture_features.ArchitectureFeatures( + vela_config=None, + system_config=None, + accelerator_config=architecture_features.Accelerator.Ethos_U55_128.value, + permanent_storage=MemArea.OnChipFlash, + override_block_config=None, + block_config_limit=None, + global_memory_clock_scale=1.0, + max_blockdep=0, + softmax_support=True, + ) + + +def create_elemwise_op(type, name, ifm_shape, ifm2_shape, ofm_shape, datatype=DataType.uint8): + # Creates elementwise operation with constant IFM/IFM2 + if datatype.size_in_bytes() == 1: + np_type = np.uint8 + elif datatype.size_in_bytes() == 2: + np_type = np.int16 + else: + np_type = np.int32 + op = Operation(type, name) + op.add_input_tensor(create_const_tensor(name + "_ifm", ifm_shape, datatype, np.zeros(ifm_shape), np_type)) + op.add_input_tensor(create_const_tensor(name + "_ifm2", ifm2_shape, datatype, np.zeros(ifm2_shape), np_type)) + ofm = Tensor(ofm_shape, datatype, name + "_ofm") + op.set_output_tensor(ofm) + op.attrs["npu_block_type"] = NpuBlockType.ElementWise + return op + + +def create_subgraph(op_list): + # Creates subgraph using the given list of operations + sg = Subgraph() + all_inputs = set(tens for op in op_list for tens in op.inputs) + # Reversing, so that the resulting subgraph has same order as op_list + for op in op_list[::-1]: + for tens in op.outputs: + if tens not in all_inputs and tens not in sg.output_tensors: + sg.output_tensors.append(tens) + return sg -- cgit v1.2.1