diff options
-rw-r--r-- | python/serializer/__init__.py | 3 | ||||
-rw-r--r-- | python/serializer/tosa_serializer.py (renamed from python/tosa_serializer.py) | 70 |
2 files changed, 32 insertions, 41 deletions
diff --git a/python/serializer/__init__.py b/python/serializer/__init__.py new file mode 100644 index 0000000..39e9ecc --- /dev/null +++ b/python/serializer/__init__.py @@ -0,0 +1,3 @@ +"""Namespace.""" +# Copyright (c) 2021-2022 Arm Limited. +# SPDX-License-Identifier: Apache-2.0 diff --git a/python/tosa_serializer.py b/python/serializer/tosa_serializer.py index f294ba3..b29f963 100644 --- a/python/tosa_serializer.py +++ b/python/serializer/tosa_serializer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2021, ARM Limited. +# Copyright (c) 2020-2022, ARM Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -12,42 +12,37 @@ # See the License for the specific language governing permissions and # limitations under the License. -#!/usr/bin/env python3 - import os -import sys import json import flatbuffers import numpy as np import struct -from enum import Enum, IntEnum, unique +from enum import IntEnum, unique from tosa import ( TosaGraph, TosaBasicBlock, TosaTensor, TosaOperator, - DType, - Op, - ResizeMode, Version, ) -from tosa_ref_run import TosaReturnCode - -import tosa +import tosa.DType as TosaDType +import tosa.Op as TosaOp # Keep version number in sync with the version default value with schema/tosa.fbs TOSA_VERSION_MAJOR = 0 TOSA_VERSION_MINOR = 24 TOSA_VERSION_PATCH = 0 TOSA_VERSION_DRAFT = True -TOSA_VERSION = [TOSA_VERSION_MAJOR, - TOSA_VERSION_MINOR, - TOSA_VERSION_PATCH, - TOSA_VERSION_DRAFT] +TOSA_VERSION = [ + TOSA_VERSION_MAJOR, + TOSA_VERSION_MINOR, + TOSA_VERSION_PATCH, + TOSA_VERSION_DRAFT, +] # With the way flatc generates its python types, there is no programatic way # to get string names for the integer types. Manually maintain a string table # here. -DType = tosa.DType.DType() +DType = TosaDType.DType() DTypeNames = [ "UNKNOWN", "BOOL", @@ -76,10 +71,12 @@ class TosaSerializerUnion: def __init__(self): - # A tuple of the start and end functions. Set by the options constructors below + # A tuple of the start and end functions. + # Set by the options constructors below self.optFcns = None - # The type from the tosa.Options enumeration. Set by the options constructors below. + # The type from the tosa.Options enumeration. + # Set by the options constructors below. self.utype = None # Each of these lists is a tuple of the add function and the @@ -310,8 +307,9 @@ class TosaSerializerAttribute(TosaSerializerUnion): self.intvecs.append((a.AddTable, table)) + class TosaSerializerQuantInfo(TosaSerializerUnion): - """This class handles encapsulating all of the enumerated types for quantinfo types""" + """This class handles encapsulating all of the enumerated types for quantinfo""" def __init__(self): super().__init__() @@ -377,9 +375,10 @@ class TosaSerializerTensor: self.data = None # Filename for placeholder tensors. These get generated by the test generation - # process and are written to disk, but are considered input tensors by the network - # so they do not appear in the TOSA serialiazation. However, if we want to form a unit - # test around these input tensors, we can get the filename from here. + # process and are written to disk, but are considered input tensors by the + # network so they do not appear in the TOSA serialiazation. However, if we + # want to form a unit test around these input tensors, we can get the filename + # from here. self.placeholderFilename = placeholderFilename def __str__(self): @@ -528,10 +527,7 @@ class TosaSerializerBasicBlock: data=None, placeholderFilename=None, ): - try: - # Someone already added this tensor. - tens = self.tensors[name] - except KeyError: + if name not in self.tensors: self.tensors[name] = TosaSerializerTensor( name, shape, dtype, data, placeholderFilename ) @@ -601,7 +597,7 @@ class TosaSerializer: self.currResultIdx = 0 # Is this an illegal test that is expected to fail? - self.expectedReturnCode = TosaReturnCode.VALID + self.expectedReturnCode = 0 self.expectedFailure = False self.expectedFailureDesc = "" @@ -633,12 +629,11 @@ class TosaSerializer: raise Exception("addTensor called without valid basic block") name = "const-{}".format(self.currInputIdx) - filename = "{}.npy".format(name) self.currInputIdx = self.currInputIdx + 1 tens = self.currBasicBlock.addTensor(name, shape, dtype, vals) # Add the operator now - self.currBasicBlock.addOperator(tosa.Op.Op().CONST, [], name) + self.currBasicBlock.addOperator(TosaOp.Op().CONST, [], name) return tens @@ -674,24 +669,18 @@ class TosaSerializer: def addOperator(self, op, inputs, outputs, attributes=None, quant_info=None): - if op == tosa.Op.Op().CONST: + if op == TosaOp.Op().CONST: raise Exception("Use addConstTensor() to add CONST ops") return self.currBasicBlock.addOperator( op, inputs, outputs, attributes, quant_info ) - def setExpectedReturnCode(self, val, desc=""): + def setExpectedReturnCode(self, val, fail, desc=""): self.expectedReturnCode = val self.expectedFailureDesc = desc - - if val == TosaReturnCode.VALID: - self.expectedFailure = False - else: - # Unpredictable or error results are considered expected failures - # for conformance - self.expectedFailure = True + self.expectedFailure = fail def serialize(self): @@ -734,8 +723,8 @@ class TosaSerializer: ifm_file.append(b.tensors[i].placeholderFilename) for o in b.outputs: ofm_name.append(o) - # Make up an OFM filename here. One isn't generated until the reference tool is - # run, so any name is a good name + # Make up an OFM filename here. One isn't generated until the + # reference tool is run, so any name is a good name ofm_file.append("ref-{}.npy".format(o)) test_desc["ifm_name"] = ifm_name @@ -811,4 +800,3 @@ class TosaSerializer: return val else: return [val] - |