aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--ethosu/vela/architecture_features.py46
-rw-r--r--ethosu/vela/errors.py18
-rw-r--r--ethosu/vela/test/extapi/test_extapi_encode_weights.py73
-rw-r--r--ethosu/vela/vela.py2
-rw-r--r--ethosu/vela/weight_compressor.py85
5 files changed, 197 insertions, 27 deletions
diff --git a/ethosu/vela/architecture_features.py b/ethosu/vela/architecture_features.py
index 6460c527..43b32109 100644
--- a/ethosu/vela/architecture_features.py
+++ b/ethosu/vela/architecture_features.py
@@ -120,6 +120,19 @@ class SharedBufferArea(enum.IntEnum):
Size = Accumulators + 1
+class Accelerator(enum.Enum):
+ Ethos_U55_32 = "ethos-u55-32"
+ Ethos_U55_64 = "ethos-u55-64"
+ Ethos_U55_128 = "ethos-u55-128"
+ Ethos_U55_256 = "ethos-u55-256"
+ Yoda_256 = "yoda-256"
+ Yoda_512 = "yoda-512"
+
+ @classmethod
+ def member_list(cls):
+ return [e.value for e in cls]
+
+
class ArchitectureFeatures:
"""This class is a container for various parameters of the Ethos-U55 core
and system configuration that can be tuned, either by command line
@@ -136,15 +149,28 @@ Note the difference between ArchitectureFeatures and CompilerOptions
"ArchitectureConfig", "macs cores ofm_ublock ifm_ublock shram_banks shram_granules elem_units"
)
accelerator_configs = {
- "yoda-512": ArchitectureConfig(256, 2, Block(2, 2, 8), Block(2, 2, 8), 48, [8, 8, 8, 8, 8, 16, 20], 8),
- "yoda-256": ArchitectureConfig(256, 1, Block(2, 2, 8), Block(2, 2, 8), 48, [8, 8, 8, 8, 8, 16, 20], 8),
- "ethos-u55-256": ArchitectureConfig(256, 1, Block(2, 2, 8), Block(2, 2, 8), 48, [8, 8, 8, 8, 8, 16, 20], 8),
- "ethos-u55-128": ArchitectureConfig(128, 1, Block(2, 1, 8), Block(2, 2, 8), 24, [4, 4, 4, 4, 4, 8, 12], 4),
- "ethos-u55-64": ArchitectureConfig(64, 1, Block(1, 1, 8), Block(1, 1, 8), 16, [2, 2, 2, 2, 4, 4, 8], 2),
- "ethos-u55-32": ArchitectureConfig(32, 1, Block(1, 1, 4), Block(1, 1, 8), 16, [2, 2, 2, 2, 4, 4, 4], 1),
+ Accelerator.Yoda_512: ArchitectureConfig(
+ 256, 2, Block(2, 2, 8), Block(2, 2, 8), 48, [8, 8, 8, 8, 8, 16, 20], 8
+ ),
+ Accelerator.Yoda_256: ArchitectureConfig(
+ 256, 1, Block(2, 2, 8), Block(2, 2, 8), 48, [8, 8, 8, 8, 8, 16, 20], 8
+ ),
+ Accelerator.Ethos_U55_256: ArchitectureConfig(
+ 256, 1, Block(2, 2, 8), Block(2, 2, 8), 48, [8, 8, 8, 8, 8, 16, 20], 8
+ ),
+ Accelerator.Ethos_U55_128: ArchitectureConfig(
+ 128, 1, Block(2, 1, 8), Block(2, 2, 8), 24, [4, 4, 4, 4, 4, 8, 12], 4
+ ),
+ Accelerator.Ethos_U55_64: ArchitectureConfig(
+ 64, 1, Block(1, 1, 8), Block(1, 1, 8), 16, [2, 2, 2, 2, 4, 4, 8], 2
+ ),
+ Accelerator.Ethos_U55_32: ArchitectureConfig(
+ 32, 1, Block(1, 1, 4), Block(1, 1, 8), 16, [2, 2, 2, 2, 4, 4, 4], 1
+ ),
}
OFMSplitDepth = 16
+ SubKernelMax = Block(8, 8, 65536)
def __init__(
self,
@@ -159,20 +185,18 @@ Note the difference between ArchitectureFeatures and CompilerOptions
):
accelerator_config = accelerator_config.lower()
self.vela_config = vela_config
- self.accelerator_config = accelerator_config
- if self.accelerator_config not in ArchitectureFeatures.accelerator_configs:
+ if accelerator_config not in Accelerator.member_list():
raise OptionError("--accelerator-config", self.accelerator_config, "Unknown accelerator configuration")
+ self.accelerator_config = Accelerator(accelerator_config)
accel_config = ArchitectureFeatures.accelerator_configs[self.accelerator_config]
self.config = accel_config
self.system_config = system_config
-
- self.is_yoda_system = "yoda-" in self.accelerator_config
+ self.is_yoda_system = self.accelerator_config in (Accelerator.Yoda_256, Accelerator.Yoda_512)
self.ncores = accel_config.cores
self.ofm_ublock = accel_config.ofm_ublock
self.ifm_ublock = accel_config.ifm_ublock
- self.subkernel_max = Block(8, 8, 65536)
self.ofm_block_max = Block(64, 32, 128)
self.override_block_config = override_block_config
self.block_config_limit = block_config_limit
diff --git a/ethosu/vela/errors.py b/ethosu/vela/errors.py
index 2c93fbc6..59740aac 100644
--- a/ethosu/vela/errors.py
+++ b/ethosu/vela/errors.py
@@ -15,6 +15,7 @@
# limitations under the License.
# Description:
# Defines custom exceptions.
+import inspect
import sys
from .operation import Operation
@@ -121,3 +122,20 @@ def TensorError(tens, msg):
print("Error: {}".format(data))
sys.exit(1)
+
+
+def typecheck(func):
+ def wrapper(*args, **kwargs):
+ fsig = inspect.signature(func)
+ args_zipped = zip(kwargs.values(), fsig.parameters.keys())
+ for actual, expected in args_zipped:
+ expected_type = fsig.parameters[expected].annotation
+ actual_type = type(actual)
+ if expected_type is inspect.Parameter.empty:
+ raise TypeError("Please provide type info for {}, hint = {}".format(expected, actual_type))
+ if expected_type is not actual_type:
+ raise TypeError("expected : {}, but got {}".format(expected_type, actual_type))
+ # Actual execution
+ return func(*args, **kwargs)
+
+ return wrapper
diff --git a/ethosu/vela/test/extapi/test_extapi_encode_weights.py b/ethosu/vela/test/extapi/test_extapi_encode_weights.py
new file mode 100644
index 00000000..47ca02b8
--- /dev/null
+++ b/ethosu/vela/test/extapi/test_extapi_encode_weights.py
@@ -0,0 +1,73 @@
+# Copyright (C) 2020 Arm Limited or its affiliates. All rights reserved.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+# Licensed under the Apache License, Version 2.0 (the License); you may
+# not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an AS IS BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+# Description:
+# Contains unit tests for encode_weights API for an external consumer
+import numpy as np
+import pytest
+
+from ethosu.vela import weight_compressor
+from ethosu.vela.architecture_features import Accelerator
+
+
+@pytest.mark.parametrize(
+ "arch",
+ [
+ Accelerator.Ethos_U55_32,
+ Accelerator.Ethos_U55_64,
+ Accelerator.Ethos_U55_128,
+ Accelerator.Ethos_U55_256,
+ Accelerator.Yoda_256,
+ Accelerator.Yoda_512,
+ ],
+)
+@pytest.mark.parametrize("dilation_x", [1, 2])
+@pytest.mark.parametrize("dilation_y", [1, 2])
+@pytest.mark.parametrize("ifm_bitdepth", [8, 16])
+@pytest.mark.parametrize("depth_control", [1, 2, 3])
+@pytest.mark.parametrize("weights_shape_and_block_depth", [((16, 16, 16, 16), 8), ((3, 3, 25, 16), 8)])
+def test_encode_weights(
+ arch, weights_shape_and_block_depth, dilation_x, dilation_y, ifm_bitdepth, depth_control,
+):
+ """
+ This unit test checks the interface of the API function but not the functionality.
+ Functional correctness is tested at a system level.
+ """
+
+ weights_shape = weights_shape_and_block_depth[0]
+ ofm_block_depth = weights_shape_and_block_depth[1]
+ val_max = np.iinfo(np.uint8).max
+ weights_hwio = np.random.randint(val_max, size=weights_shape, dtype=np.uint8)
+ weights_ohwi = np.transpose(weights_hwio, (3, 0, 1, 2))
+ is_depthwise = True if depth_control == 2 else False
+ is_partkernel = True if depth_control == 3 else False
+ dilation_xy = (dilation_x, dilation_y)
+
+ encoded_stream = weight_compressor.encode_weights(
+ accelerator=arch,
+ weights_volume=weights_ohwi,
+ dilation_xy=dilation_xy,
+ ifm_bitdepth=ifm_bitdepth,
+ ofm_block_depth=ofm_block_depth,
+ is_depthwise=is_depthwise,
+ is_partkernel=is_partkernel,
+ )
+ assert type(encoded_stream) == bytearray
+
+
+if __name__ == "__main__":
+ # two test candidates for debugging purposes
+ test_encode_weights(Accelerator.Ethos_U55_256, ((3, 3, 25, 16), 8), 1, 1, 8, 0)
+ test_encode_weights(Accelerator.Ethos_U55_256, ((16, 16, 16, 16), 8), 1, 1, 8, 0)
diff --git a/ethosu/vela/vela.py b/ethosu/vela/vela.py
index 20bc525b..1766750e 100644
--- a/ethosu/vela/vela.py
+++ b/ethosu/vela/vela.py
@@ -170,7 +170,7 @@ def main(args=None):
"--accelerator-config",
type=str,
default="ethos-u55-256",
- choices=list(architecture_features.ArchitectureFeatures.accelerator_configs.keys()),
+ choices=list(architecture_features.Accelerator.member_list()),
help="Accelerator configuration to use (default: %(default)s)",
)
parser.add_argument(
diff --git a/ethosu/vela/weight_compressor.py b/ethosu/vela/weight_compressor.py
index 8ebd7511..687a0805 100644
--- a/ethosu/vela/weight_compressor.py
+++ b/ethosu/vela/weight_compressor.py
@@ -20,7 +20,10 @@ from collections import namedtuple
import numpy as np
+from .architecture_features import Accelerator
+from .architecture_features import ArchitectureFeatures
from .data_type import DataType
+from .errors import typecheck
from .errors import UnsupportedFeatureError
from .nn_graph import SchedulingStrategy
from .numeric_util import round_up
@@ -42,6 +45,55 @@ WeightCompressionConfig = namedtuple(
)
+@typecheck
+def encode_weights(
+ accelerator: Accelerator,
+ weights_volume: np.ndarray,
+ dilation_xy: tuple,
+ ifm_bitdepth: int,
+ ofm_block_depth: int,
+ is_depthwise: bool,
+ is_partkernel: bool,
+):
+ """
+ Public facing API to use the ethosu weight encoding.
+
+ :param accelerator: architecture_features.Accelerator enum to pick the correct ethosu accelerator
+ :param weights_volume: numpy.ndarray in OHWI layout with a shape of four
+ :param dilation_xy: a two element tuple of dilation attributes in x,y dimension
+ :param ifm_bitdepth: the bitdepth of input feature map
+ :param ofm_block_depth: the depth of blocks for ethosu processing
+ :param is_depthwise: a boolean indicating these weights are used for a depthwise traversal
+ :param is_partkernel: a boolean indicating these weights are traversed on sub-kernal basis
+ :return: a bytearray of compressed weights
+ """
+
+ # Checks for weight layout
+ assert len(weights_volume.shape) == 4, "weights ndarray should have a shape of 4"
+
+ # It cannot be both partkernel and depthwise
+ assert not (is_depthwise and is_partkernel), "encode_weights :: partkernel and depthwise are mutually exclusive"
+
+ # Check valid values for dilation
+ assert dilation_xy[0] in (1, 2), "encode_weights :: dilation x should be 1 or 2 not {}".format(dilation_xy[0])
+ assert dilation_xy[1] in (1, 2), "encode_weights :: dilation y should be 1 or 2 not {}".format(dilation_xy[1])
+
+ ifm_ublock = ArchitectureFeatures.accelerator_configs[accelerator].ifm_ublock
+ ofm_ublock = ArchitectureFeatures.accelerator_configs[accelerator].ofm_ublock
+ raw_stream = generate_brick(
+ ifm_ublock=ifm_ublock,
+ ofm_ublock=ofm_ublock,
+ brick_weights=weights_volume,
+ ofm_block_depth=ofm_block_depth,
+ is_depthwise=is_depthwise,
+ is_partkernel=is_partkernel,
+ ifm_bitdepth=ifm_bitdepth,
+ dilation=dilation_xy,
+ )
+ encoded_stream = encode(raw_stream)
+ return encoded_stream
+
+
def create_weight_compression_config(tens, npu_block_type, ofm_block_depth, ofm_depth_step, dilation):
# Note: for an ofm block only its depth is used in weight compression.
# And block depth > ofm depth gives same result as block depth == ofm depth
@@ -93,13 +145,12 @@ def encode(weight_stream):
return compressed
-def generate_brick(arch, brick_weights, ofm_block_depth, block_traversal, ifm_bitdepth, dilation):
- is_depthwise = block_traversal == TensorBlockTraversal.DepthWise
- is_partkernel = block_traversal == TensorBlockTraversal.PartKernelFirst
- decomp_h = arch.subkernel_max.height // dilation[0]
- decomp_w = arch.subkernel_max.width // dilation[1]
- ofm_ublock = arch.ofm_ublock
- ifm_ublock = arch.ifm_ublock
+def generate_brick(
+ ifm_ublock, ofm_ublock, brick_weights, ofm_block_depth, is_depthwise, is_partkernel, ifm_bitdepth, dilation
+):
+
+ decomp_h = ArchitectureFeatures.SubKernelMax.height // dilation[0]
+ decomp_w = ArchitectureFeatures.SubKernelMax.width // dilation[1]
# Expect weights formatted OHWI
ofm_depth = brick_weights.shape[-4]
ifm_depth = brick_weights.shape[-1]
@@ -245,6 +296,9 @@ def compress_weights(arch, nng, tens, npu_block_type, ofm_block_depth, ofm_depth
else:
tens.block_traversal = TensorBlockTraversal.DepthFirst
+ is_depthwise = tens.block_traversal == TensorBlockTraversal.DepthWise
+ is_partkernel = tens.block_traversal == TensorBlockTraversal.PartKernelFirst
+
if tens.consumer_list[0].type == "Conv2DBackpropInputSwitchedBias":
# Transpose Convoluion, reverse weights in H and W axes
weights = np.flip(weights, axis=(0, 1))
@@ -262,7 +316,6 @@ def compress_weights(arch, nng, tens, npu_block_type, ofm_block_depth, ofm_depth
substream_offsets = [0]
encoded_stream = []
- raw_size = 0
# For each core, deinterleave weights from the larger volume
# and generate separate compressed streams.
@@ -270,15 +323,17 @@ def compress_weights(arch, nng, tens, npu_block_type, ofm_block_depth, ofm_depth
core_weights = core_deinterleave(brick_weights, core, arch.ncores)
block_depth = (ofm_block_depth + arch.ncores - 1 - core) // arch.ncores
+ encoded_substream = []
if block_depth != 0:
- raw_stream = generate_brick(
- arch, core_weights, block_depth, tens.block_traversal, ifm_bitdepth, dilation
+ encoded_substream = encode_weights(
+ accelerator=arch.accelerator_config,
+ weights_volume=core_weights,
+ dilation_xy=dilation,
+ ifm_bitdepth=ifm_bitdepth,
+ ofm_block_depth=block_depth,
+ is_depthwise=is_depthwise,
+ is_partkernel=is_partkernel,
)
- else:
- raw_stream = []
-
- raw_size += len(raw_stream)
- encoded_substream = encode(raw_stream)
encoded_stream.extend(encoded_substream)
substream_offsets.append(len(encoded_stream))