aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJeremy Johnson <jeremy.johnson@arm.com>2021-12-14 16:34:47 +0000
committerJeremy Johnson <jeremy.johnson@arm.com>2022-01-06 11:27:52 +0000
commit9b22517ba0cd6f767123583ce56e864f50e9d758 (patch)
treec97160837c8f7d2229c236f81e4b059c366d064a
parent8b3903a3bce3b1082cd8bebb7fde8eef3ae203e9 (diff)
downloadserialization_lib-9b22517ba0cd6f767123583ce56e864f50e9d758.tar.gz
Add python package support
Move tosa_serializer into its own namespace Fix up for pre-commit black/flake8 Remove import dependency on reference model Signed-off-by: Jeremy Johnson <jeremy.johnson@arm.com> Change-Id: I8693fb7c00d224142a66dcb19eac74ac77c6224b
-rw-r--r--python/serializer/__init__.py3
-rw-r--r--python/serializer/tosa_serializer.py (renamed from python/tosa_serializer.py)70
2 files changed, 32 insertions, 41 deletions
diff --git a/python/serializer/__init__.py b/python/serializer/__init__.py
new file mode 100644
index 0000000..39e9ecc
--- /dev/null
+++ b/python/serializer/__init__.py
@@ -0,0 +1,3 @@
+"""Namespace."""
+# Copyright (c) 2021-2022 Arm Limited.
+# SPDX-License-Identifier: Apache-2.0
diff --git a/python/tosa_serializer.py b/python/serializer/tosa_serializer.py
index f294ba3..b29f963 100644
--- a/python/tosa_serializer.py
+++ b/python/serializer/tosa_serializer.py
@@ -1,4 +1,4 @@
-# Copyright (c) 2020-2021, ARM Limited.
+# Copyright (c) 2020-2022, ARM Limited.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@@ -12,42 +12,37 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-#!/usr/bin/env python3
-
import os
-import sys
import json
import flatbuffers
import numpy as np
import struct
-from enum import Enum, IntEnum, unique
+from enum import IntEnum, unique
from tosa import (
TosaGraph,
TosaBasicBlock,
TosaTensor,
TosaOperator,
- DType,
- Op,
- ResizeMode,
Version,
)
-from tosa_ref_run import TosaReturnCode
-
-import tosa
+import tosa.DType as TosaDType
+import tosa.Op as TosaOp
# Keep version number in sync with the version default value with schema/tosa.fbs
TOSA_VERSION_MAJOR = 0
TOSA_VERSION_MINOR = 24
TOSA_VERSION_PATCH = 0
TOSA_VERSION_DRAFT = True
-TOSA_VERSION = [TOSA_VERSION_MAJOR,
- TOSA_VERSION_MINOR,
- TOSA_VERSION_PATCH,
- TOSA_VERSION_DRAFT]
+TOSA_VERSION = [
+ TOSA_VERSION_MAJOR,
+ TOSA_VERSION_MINOR,
+ TOSA_VERSION_PATCH,
+ TOSA_VERSION_DRAFT,
+]
# With the way flatc generates its python types, there is no programatic way
# to get string names for the integer types. Manually maintain a string table
# here.
-DType = tosa.DType.DType()
+DType = TosaDType.DType()
DTypeNames = [
"UNKNOWN",
"BOOL",
@@ -76,10 +71,12 @@ class TosaSerializerUnion:
def __init__(self):
- # A tuple of the start and end functions. Set by the options constructors below
+ # A tuple of the start and end functions.
+ # Set by the options constructors below
self.optFcns = None
- # The type from the tosa.Options enumeration. Set by the options constructors below.
+ # The type from the tosa.Options enumeration.
+ # Set by the options constructors below.
self.utype = None
# Each of these lists is a tuple of the add function and the
@@ -310,8 +307,9 @@ class TosaSerializerAttribute(TosaSerializerUnion):
self.intvecs.append((a.AddTable, table))
+
class TosaSerializerQuantInfo(TosaSerializerUnion):
- """This class handles encapsulating all of the enumerated types for quantinfo types"""
+ """This class handles encapsulating all of the enumerated types for quantinfo"""
def __init__(self):
super().__init__()
@@ -377,9 +375,10 @@ class TosaSerializerTensor:
self.data = None
# Filename for placeholder tensors. These get generated by the test generation
- # process and are written to disk, but are considered input tensors by the network
- # so they do not appear in the TOSA serialiazation. However, if we want to form a unit
- # test around these input tensors, we can get the filename from here.
+ # process and are written to disk, but are considered input tensors by the
+ # network so they do not appear in the TOSA serialiazation. However, if we
+ # want to form a unit test around these input tensors, we can get the filename
+ # from here.
self.placeholderFilename = placeholderFilename
def __str__(self):
@@ -528,10 +527,7 @@ class TosaSerializerBasicBlock:
data=None,
placeholderFilename=None,
):
- try:
- # Someone already added this tensor.
- tens = self.tensors[name]
- except KeyError:
+ if name not in self.tensors:
self.tensors[name] = TosaSerializerTensor(
name, shape, dtype, data, placeholderFilename
)
@@ -601,7 +597,7 @@ class TosaSerializer:
self.currResultIdx = 0
# Is this an illegal test that is expected to fail?
- self.expectedReturnCode = TosaReturnCode.VALID
+ self.expectedReturnCode = 0
self.expectedFailure = False
self.expectedFailureDesc = ""
@@ -633,12 +629,11 @@ class TosaSerializer:
raise Exception("addTensor called without valid basic block")
name = "const-{}".format(self.currInputIdx)
- filename = "{}.npy".format(name)
self.currInputIdx = self.currInputIdx + 1
tens = self.currBasicBlock.addTensor(name, shape, dtype, vals)
# Add the operator now
- self.currBasicBlock.addOperator(tosa.Op.Op().CONST, [], name)
+ self.currBasicBlock.addOperator(TosaOp.Op().CONST, [], name)
return tens
@@ -674,24 +669,18 @@ class TosaSerializer:
def addOperator(self, op, inputs, outputs, attributes=None, quant_info=None):
- if op == tosa.Op.Op().CONST:
+ if op == TosaOp.Op().CONST:
raise Exception("Use addConstTensor() to add CONST ops")
return self.currBasicBlock.addOperator(
op, inputs, outputs, attributes, quant_info
)
- def setExpectedReturnCode(self, val, desc=""):
+ def setExpectedReturnCode(self, val, fail, desc=""):
self.expectedReturnCode = val
self.expectedFailureDesc = desc
-
- if val == TosaReturnCode.VALID:
- self.expectedFailure = False
- else:
- # Unpredictable or error results are considered expected failures
- # for conformance
- self.expectedFailure = True
+ self.expectedFailure = fail
def serialize(self):
@@ -734,8 +723,8 @@ class TosaSerializer:
ifm_file.append(b.tensors[i].placeholderFilename)
for o in b.outputs:
ofm_name.append(o)
- # Make up an OFM filename here. One isn't generated until the reference tool is
- # run, so any name is a good name
+ # Make up an OFM filename here. One isn't generated until the
+ # reference tool is run, so any name is a good name
ofm_file.append("ref-{}.npy".format(o))
test_desc["ifm_name"] = ifm_name
@@ -811,4 +800,3 @@ class TosaSerializer:
return val
else:
return [val]
-