From 5e0ae5598ab1d7debd603bdd32c7e8f9cad9d581 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Johan=20Alfv=C3=A9n?= Date: Wed, 9 Feb 2022 21:20:10 +0100 Subject: MLBEDSW-6148: Reduce SRAM usage for elementwise op Reduce memory footprint when using optimization strategy Size for elementwise operations. Signed-off-by: Johan Alfven Change-Id: I30380aed587c31adbf7615f74179b4c5da686773 --- ethosu/vela/scheduler.py | 51 ++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 49 insertions(+), 2 deletions(-) diff --git a/ethosu/vela/scheduler.py b/ethosu/vela/scheduler.py index d1607779..8f2426c1 100644 --- a/ethosu/vela/scheduler.py +++ b/ethosu/vela/scheduler.py @@ -18,6 +18,7 @@ # The scheduler creates and searches for an optimal plan for the network, selecting block configurations and # subdivisions for the Operators import copy +from collections import namedtuple from enum import auto from enum import IntEnum from typing import Dict @@ -342,6 +343,45 @@ class Scheduler: self.max_schedule = None self.scheduler_options = options + def avoid_nhcwb16_for_ofm(self, tens, ps, arch): + # Only run this check for opt strategy Size + if self.scheduler_options.optimization_strategy == OptimizationStrategy.Performance: + return False + + op = ps.primary_op + if not op.type.is_elementwise_op(): + return False + + depth = op.ofm_shapes[0][-1] + if (depth % 16) == 0: + return False + + # Check if overwriting the inputs can be allowed + OpShapeTens = namedtuple("OpShapeTens", ["op_shape", "tens"]) + outp = OpShapeTens(op.ofm_shapes[0], op.ofm) + inps = [] + if op.ifm is not None: + inps.append(OpShapeTens(op.ifm_shapes[0], op.ifm)) + if op.ifm2 is not None: + inps.append(OpShapeTens(op.ifm_shapes[1], op.ifm2)) + + # Find an input tensor that can be overwritten by the output + for inp in inps: + if ( + # check op input and output shapes allow overlapping + inp.op_shape == outp.op_shape + # check input tensor is valid + and inp.tens is not None + and inp.tens.shape != [] + # check input and output tensors are compatible + and inp.tens.format == outp.tens.format + and inp.tens.dtype == outp.tens.dtype + ): + if inp.tens.format == TensorFormat.NHWC: + return True + + return False + def create_scheduler_representation(self, arch: ArchitectureFeatures): """Creates a Scheduler Graph representation""" # Temporary dict for creating connections between the Operations @@ -354,8 +394,15 @@ class Scheduler: for output in ps.outputs: if output in self.sg.output_tensors or output.purpose != TensorPurpose.FeatureMap: continue - if not output.needs_linear_format: - output.set_format(TensorFormat.NHCWB16, arch) + + if output.needs_linear_format: + continue + + if self.avoid_nhcwb16_for_ofm(output, ps, arch): + output.needs_linear_format = True + continue + + output.set_format(TensorFormat.NHCWB16, arch) # Create SchedulerOperations op = SchedulerOperation(ps, arch, self.nng) -- cgit v1.2.1