aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorTim Hall <tim.hall@arm.com>2020-10-20 18:54:20 +0100
committerTim Hall <tim.hall@arm.com>2020-10-21 15:23:33 +0100
commit4ed38bce498e1b9a5ae917316323de444792521a (patch)
treef3721d7131eeafa14c33cf0339d579de99a3c66a
parent9358296a51b9186335304a53bd7ea5dfbe5322d8 (diff)
downloadethos-u-vela-4ed38bce498e1b9a5ae917316323de444792521a.tar.gz
vela: Refactor operators to use Kernel objects
- Normalise kernel availability by requiring all operators offer a kernel describing how much data they consume from the source, per OFM element, regardless of whether kernels are relevant to the operation. Signed-off-by: Tim Hall <tim.hall@arm.com> Change-Id: Idbcff64879fc2eccf292b6208a7d2038eb388017
-rw-r--r--ethosu/vela/architecture_features.py15
-rw-r--r--ethosu/vela/npu_performance.py3
-rw-r--r--ethosu/vela/operation.py42
-rw-r--r--ethosu/vela/register_command_stream_generator.py26
-rw-r--r--ethosu/vela/shared_buffer_allocation.py24
5 files changed, 52 insertions, 58 deletions
diff --git a/ethosu/vela/architecture_features.py b/ethosu/vela/architecture_features.py
index 04c1c62..b77205b 100644
--- a/ethosu/vela/architecture_features.py
+++ b/ethosu/vela/architecture_features.py
@@ -25,16 +25,15 @@ from .errors import OptionError
from .ethos_u55_regs.ethos_u55_regs import resampling_mode
from .numeric_util import round_up
from .numeric_util import round_up_divide
+from .operation import Kernel
from .operation import NpuBlockType
+from .operation import PointXYZ
from .supported_operators import SupportedOperators
from .tensor import MemArea
from .tensor import MemType
from .tensor import TensorFormat
from .tensor import TensorPurpose
-PointXY = namedtuple("PointXY", "x y")
-PointXYZ = namedtuple("PointXYZ", "x y z")
-
class Block:
def __init__(self, w, h, d):
@@ -79,16 +78,6 @@ class Rect:
return "<Rect: ({0},{1},{2}) ({3},{4},{5})>".format(self.x, self.y, self.z, self.x2, self.y2, self.z2)
-class Kernel:
- def __init__(self, w, h, sx=1, sy=1, dx=1, dy=1):
- assert sx > 0 and sy > 0
- assert dx > 0 and dy > 0
- self.width = w
- self.height = h
- self.stride = PointXY(sx, sy)
- self.dilation = PointXY(dx, dy)
-
-
class SHRAMElements:
IFM8 = 0
IFM16 = 1
diff --git a/ethosu/vela/npu_performance.py b/ethosu/vela/npu_performance.py
index e71e95b..24b4c68 100644
--- a/ethosu/vela/npu_performance.py
+++ b/ethosu/vela/npu_performance.py
@@ -31,7 +31,6 @@ from .nn_graph import PassPlacement
from .nn_graph import SchedulerRewrite
from .operation import NpuBlockType
from .operation import Op
-from .register_command_stream_generator import get_op_kernel
from .tensor import MemArea
from .tensor import shape_num_elements
from .tensor import TensorBlockTraversal
@@ -40,7 +39,7 @@ from .tensor import TensorPurpose
def rolling_buffer_dims_from_passes(arch, ps1, block_config_ps1, ps2, block_config_ps2):
ofm_block = Block(block_config_ps2[-3], block_config_ps2[-4], block_config_ps2[-1])
- kernel = get_op_kernel(ps2)
+ kernel = ps2.primary_op.kernel
if ps2.npu_block_type in set((NpuBlockType.ConvolutionMxN, NpuBlockType.VectorProduct)):
op = ps2.primary_op
diff --git a/ethosu/vela/operation.py b/ethosu/vela/operation.py
index 6e5b482..cc52ff4 100644
--- a/ethosu/vela/operation.py
+++ b/ethosu/vela/operation.py
@@ -18,6 +18,11 @@
from collections import namedtuple
from enum import Enum
+from .numeric_util import full_shape
+
+PointXY = namedtuple("PointXY", "x y")
+PointXYZ = namedtuple("PointXYZ", "x y z")
+
class NpuBlockType(Enum):
Default = 0
@@ -29,6 +34,26 @@ class NpuBlockType(Enum):
ReduceSum = 6
+class Kernel:
+ def __init__(self, w, h, sx=1, sy=1, dx=1, dy=1):
+ assert sx > 0 and sy > 0
+ assert dx > 0 and dy > 0
+ self.width = w
+ self.height = h
+ self.stride = PointXY(sx, sy)
+ self.dilation = PointXY(dx, dy)
+ self.upscale = 1
+
+ def elements_wh(self):
+ return self.width * self.height
+
+ def area_width(self):
+ return (self.width - 1) * self.dilation.x + 1
+
+ def area_height(self):
+ return (self.height - 1) * self.dilation.y + 1
+
+
# Classifies operators of type Custom
class CustomType(Enum):
ThirdPartyOp = 0 # Third party custom op
@@ -330,6 +355,7 @@ class Operation:
"memory_function",
"forced_output_quantization",
"activation_lut",
+ "_kernel",
)
def __init__(self, op_type, name):
@@ -350,6 +376,7 @@ class Operation:
self.scheduled_pass = None
self.op_index = None # input network operator index
self.activation_lut = None
+ self._kernel = None
def clone(self, suffix="_clone"):
res = Operation(self.type, self.name + suffix)
@@ -372,6 +399,21 @@ class Operation:
__repr__ = __str__
+ @property
+ def kernel(self):
+ strides = self.attrs.get("strides", (1, 1, 1, 1))
+ dilation = self.attrs.get("dilation", (1, 1, 1, 1))
+ weights = self.weights
+ if weights and self.type.npu_block_type in (NpuBlockType.ConvolutionDepthWise, NpuBlockType.ConvolutionMxN):
+ weight_shape = full_shape(4, weights.shape, 1)
+ k_h = weight_shape[-4]
+ k_w = weight_shape[-3]
+ else:
+ k_h = self.attrs.get("filter_height", 1)
+ k_w = self.attrs.get("filter_width", 1)
+ self._kernel = Kernel(k_w, k_h, strides[2], strides[1], dilation[2], dilation[1])
+ return self._kernel
+
def get_ifm_ifm2_weights_ofm(self):
return self.ifm, self.ifm2, self.weights, self.ofm
diff --git a/ethosu/vela/register_command_stream_generator.py b/ethosu/vela/register_command_stream_generator.py
index 4f3fe7d..0abd882 100644
--- a/ethosu/vela/register_command_stream_generator.py
+++ b/ethosu/vela/register_command_stream_generator.py
@@ -27,7 +27,6 @@ import numpy as np
from . import scaling
from .architecture_features import ArchitectureFeatures
from .architecture_features import Block
-from .architecture_features import Kernel
from .architecture_features import Rect
from .architecture_features import SharedBufferArea
from .architecture_features import SHRAMElements
@@ -239,26 +238,6 @@ def get_cmd_wait_dependency(arch, cmd_stream, memory_accesses, cmd_index, waterm
return watermark, outstanding
-def get_op_kernel(ps):
- if ps.primary_op is None:
- return None
-
- strides = ps.primary_op.attrs.get("strides", (1, 1, 1, 1))
- dilation = ps.primary_op.attrs.get("dilation", (1, 1, 1, 1))
- if ps.weight_tensor:
- if ps.npu_block_type in set((NpuBlockType.VectorProduct, NpuBlockType.ElementWise)):
- k_h = 1
- k_w = 1
- else:
- k_h = ps.weight_tensor.shape[0]
- k_w = ps.weight_tensor.shape[1]
- else:
- k_h = ps.primary_op.attrs.get("filter_height", 1)
- k_w = ps.primary_op.attrs.get("filter_width", 1)
-
- return Kernel(k_w, k_h, strides[2], strides[1], dilation[2], dilation[1])
-
-
def has_prev_op_dependency(prev_cmd, cmd):
if prev_cmd is None:
return False
@@ -462,7 +441,7 @@ def generate_register_command_stream(nng, sg, arch, verbose=False):
prev_ofm_rect = cur_ofm_rect
prev_ofm_block = cur_ofm_block
prev_kernel = cur_kernel
- cur_kernel = get_op_kernel(ps)
+ cur_kernel = ps.primary_op.kernel if ps.primary_op else None
block_config = ps.block_config
emit.cmd0_with_param(cmd0.NPU_SET_OFM_BLK_HEIGHT_M1, block_config[0] - 1)
@@ -585,7 +564,8 @@ def generate_register_command_stream(nng, sg, arch, verbose=False):
# Set IFM2_IB_START to the latter half of the IB space
ifm_ib_start = shared_buffer.bank_locations[SharedBufferArea.IFM]
emit.cmd0_with_param(
- cmd0.NPU_SET_IFM2_IB_START, (shram_required - ifm_ib_start) // shared_buffer.ifm_count + ifm_ib_start
+ cmd0.NPU_SET_IFM2_IB_START,
+ (shram_required - ifm_ib_start) // shared_buffer.ifm_count + ifm_ib_start,
)
emit.cmd0_with_param(cmd0.NPU_SET_IFM2_BROADCAST, ifm2_broadcast)
diff --git a/ethosu/vela/shared_buffer_allocation.py b/ethosu/vela/shared_buffer_allocation.py
index f52d3a9..484c34b 100644
--- a/ethosu/vela/shared_buffer_allocation.py
+++ b/ethosu/vela/shared_buffer_allocation.py
@@ -19,13 +19,12 @@ import numpy as np
from .architecture_features import ArchitectureFeatures
from .architecture_features import Block
-from .architecture_features import Kernel
from .architecture_features import SharedBufferArea
from .architecture_features import SHRAMElements
from .errors import VelaError
from .ethos_u55_regs.ethos_u55_regs import resampling_mode
+from .operation import Kernel
from .operation import NpuBlockType
-from .operation import Op
from .range_set import MemoryRangeSet
from .tensor import MemArea
@@ -42,34 +41,19 @@ class SharedBufferAllocation:
scales = [t.quantization.scale_f32 for t in tensors if t.quantization is not None]
has_scale = len(tensors) == len(scales) and None not in scales
- strides = (1, 1, 1, 1)
- dilation = (1, 1, 1, 1)
self.kernel = Kernel(1, 1)
self.is_elementwise = ps.npu_block_type == NpuBlockType.ElementWise
self.uses_lut = False
self.ifm_count = 1
if ps.primary_op:
- strides = ps.primary_op.attrs.get("strides", strides)
- dilation = ps.primary_op.attrs.get("dilation", dilation)
- k_h = 1
- k_w = 1
- if weight_tensor:
- if ps.primary_op.type != Op.FullyConnected:
- k_h = weight_tensor.shape[0]
- k_w = weight_tensor.shape[1]
- else:
- k_h = ps.primary_op.attrs.get("filter_height", 1)
- k_w = ps.primary_op.attrs.get("filter_width", 1)
-
- self.kernel = Kernel(k_w, k_h, strides[2], strides[1], dilation[2], dilation[1])
+ self.kernel = ps.primary_op.kernel
self.uses_lut = ps.primary_op.activation_lut is not None
self.is_equal_depth_op = self.is_elementwise or ps.npu_block_type in (
NpuBlockType.ConvolutionDepthWise,
NpuBlockType.Pooling,
)
- self.strides = strides
self.use_accumulator_element = SHRAMElements.Acc32
if self.is_elementwise:
@@ -89,11 +73,11 @@ class SharedBufferAllocation:
if self.is_elementwise:
self.ifm_count = 2
- if ifm_tensor.shape == []: # Scalar in ifm1
+ if ifm_tensor.shape == []: # Scalar in ifm1
assert ifm2_tensor
self.ifm_depth = ifm2_tensor.shape[-1]
self.ifm_count = 1
- elif not ifm2_tensor or ifm2_tensor.shape == []: # Scalar in ifm2
+ elif not ifm2_tensor or ifm2_tensor.shape == []: # Scalar in ifm2
self.ifm_count = 1
if self.ifm_bits == 16: