aboutsummaryrefslogtreecommitdiff
path: root/ethosu/vela/test/extapi/test_extapi_encode_weights.py
diff options
context:
space:
mode:
authorLouis Verhaard <louis.verhaard@arm.com>2020-11-02 18:04:27 +0100
committerLouis Verhaard <louis.verhaard@arm.com>2020-11-23 13:39:35 +0100
commitaeae56770f3c19182d32cc63fd32396e061a7648 (patch)
tree95ca2e6c90d81ba8910c8ca9ced68ffa132b7dab /ethosu/vela/test/extapi/test_extapi_encode_weights.py
parent27d36f003d35413beb51c1de8f33259ddeca7543 (diff)
downloadethos-u-vela-aeae56770f3c19182d32cc63fd32396e061a7648.tar.gz
MLBEDSW-3424: Expose API through separate file
All external APIs are now exposed by api.py. Signed-off-by: Louis Verhaard <louis.verhaard@arm.com> Change-Id: I33f480e424692ac30e9c7d791f583199f31164a7
Diffstat (limited to 'ethosu/vela/test/extapi/test_extapi_encode_weights.py')
-rw-r--r--ethosu/vela/test/extapi/test_extapi_encode_weights.py18
1 files changed, 5 insertions, 13 deletions
diff --git a/ethosu/vela/test/extapi/test_extapi_encode_weights.py b/ethosu/vela/test/extapi/test_extapi_encode_weights.py
index 854d14c0..6367cb30 100644
--- a/ethosu/vela/test/extapi/test_extapi_encode_weights.py
+++ b/ethosu/vela/test/extapi/test_extapi_encode_weights.py
@@ -14,25 +14,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Description:
-# Contains unit tests for encode_weights API for an external consumer
+# Contains unit tests for npu_encode_weights API for an external consumer
import numpy as np
import pytest
-from ethosu.vela import weight_compressor
+from ethosu.vela.api import npu_encode_weights
+from ethosu.vela.api import NpuAccelerator
from ethosu.vela.api import NpuBlockTraversal
-from ethosu.vela.architecture_features import Accelerator
@pytest.mark.parametrize(
- "arch",
- [
- Accelerator.Ethos_U55_32,
- Accelerator.Ethos_U55_64,
- Accelerator.Ethos_U55_128,
- Accelerator.Ethos_U55_256,
- Accelerator.Ethos_U65_256,
- Accelerator.Ethos_U65_512,
- ],
+ "arch", list(NpuAccelerator),
)
@pytest.mark.parametrize("dilation_x", [1, 2])
@pytest.mark.parametrize("dilation_y", [1, 2])
@@ -56,7 +48,7 @@ def test_encode_weights(
block_traversal = NpuBlockTraversal.PART_KERNEL_FIRST if depth_control == 3 else NpuBlockTraversal.DEPTH_FIRST
dilation_xy = (dilation_x, dilation_y)
- encoded_stream = weight_compressor.encode_weights(
+ encoded_stream = npu_encode_weights(
accelerator=arch,
weights_volume=weights_ohwi,
dilation_xy=dilation_xy,