From 1eb8504e43d987cb9584149b8fdb0d37eb82964e Mon Sep 17 00:00:00 2001 From: Jerry Ge Date: Fri, 6 Jan 2023 14:19:14 -0800 Subject: Add TosaSerializerRegion to python version of serialization_lib Signed-off-by: Jerry Ge Change-Id: Ibd15f21aa24168730c904224f08fd55e27aae41f --- python/serializer/tosa_serializer.py | 165 ++++++++++++++++++++++------------- 1 file changed, 106 insertions(+), 59 deletions(-) diff --git a/python/serializer/tosa_serializer.py b/python/serializer/tosa_serializer.py index 85955aa..2d03d49 100644 --- a/python/serializer/tosa_serializer.py +++ b/python/serializer/tosa_serializer.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, ARM Limited. +# Copyright (c) 2020-2023, ARM Limited. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ import struct from enum import IntEnum, unique from tosa import ( TosaGraph, + TosaRegion, TosaBasicBlock, TosaTensor, TosaOperator, @@ -404,12 +405,12 @@ class TosaSerializerTensor: self.placeholderFilename = placeholderFilename def __str__(self): - str = "TosaSerializerTensor name: {} shape: {} dtype: {}".format( + concatString = "TosaSerializerTensor name: {} shape: {} dtype: {}".format( self.name, self.shape, DTypeNames[self.dtype], ) - return str + return concatString def setDtype(self, dtype): self.dtype = dtype @@ -500,14 +501,14 @@ class TosaSerializerOperator: self.outputs = TosaSerializer.toList(outputs) def __str__(self): - str = "Op {}\n----\n".format(self.op) + concatString = "Op {}\n----\n".format(self.op) for i in self.inputs: - str = str + " Input: {}\n".format(i) + concatString = concatString + " Input: {}\n".format(i) for o in self.outputs: - str = str + " Output: {}\n".format(o) + concatString = concatString + " Output: {}\n".format(o) - return str + return concatString def serialize(self, builder): fb_inputs = TosaSerializer.serializeStrVec( @@ -532,9 +533,11 @@ class TosaSerializerOperator: class TosaSerializerBasicBlock: - def __init__(self, name): + def __init__(self, name, pathPrefix, saveConstsToFile=False): self.name = name + self.pathPrefix = pathPrefix self.operators = [] + self.saveConstsToFile = saveConstsToFile # Dict assures uniqueness, but allows us to look up by name self.tensors = dict() @@ -592,44 +595,33 @@ class TosaSerializerBasicBlock: return TosaBasicBlock.End(builder) -@unique -class TensorDir(IntEnum): - PLACEHOLDER = 0 - CONST = 1 - INTERMEDIATE = 2 - RESULT = 3 - - -class TosaSerializer: - def __init__(self, pathPrefix, saveConstsToFile=False): - self.add_compat_methods() - # Get the global TOSA version if not already defined - - self.builder = flatbuffers.Builder(0) - +class TosaSerializerRegion: + def __init__(self, name, pathPrefix, saveConstsToFile=False): + self.name = name self.basicBlocks = [] - self.startBasicBlock("main") - self.pathPrefix = pathPrefix - - # Enables inspection of constant data outside of graph - self.saveConstsToFile = saveConstsToFile - - # Indicies used for adding/naming tensors self.currInputIdx = 0 self.currConstIdx = 0 self.currLayerIdx = 1 self.currResultIdx = 0 + self.pathPrefix = pathPrefix + self.saveConstsToFile = saveConstsToFile - # Is this an illegal test that is expected to fail? - self.expectedReturnCode = 0 - self.expectedFailure = False - self.expectedFailureDesc = "" + def addBasicBlock(self, name, pathPrefix, saveConstsToFile): + self.currBasicBlock = TosaSerializerBasicBlock( + name, pathPrefix, saveConstsToFile + ) + self.basicBlocks.append(self.currBasicBlock) - def __str__(self): - str = "" - for bb in self.basicBlocks: - str = str + bb.__str__() - return str + def serialize(self, builder): + fb_name = builder.CreateString(self.name) + fbv_basicBlocks = TosaSerializer.serializeObjVec( + builder, self.basicBlocks, TosaRegion.StartBlocksVector + ) + + TosaRegion.Start(builder) + TosaRegion.AddName(builder, fb_name) + TosaRegion.AddBlocks(builder, fbv_basicBlocks) + return TosaRegion.End(builder) def addPlaceholder(self, shape, dtype, vals): if not self.currBasicBlock: @@ -666,7 +658,6 @@ class TosaSerializer: return tens def addIntermediate(self, shape, dtype): - if not self.currBasicBlock: raise Exception("addTensor called without valid basic block") @@ -696,7 +687,6 @@ class TosaSerializer: return tens def addOperator(self, op, inputs, outputs, attributes=None): - if op == TosaOp.Op().CONST: raise Exception("Use addConstTensor() to add CONST ops") @@ -707,6 +697,62 @@ class TosaSerializer: attributes, ) + +@unique +class TensorDir(IntEnum): + PLACEHOLDER = 0 + CONST = 1 + INTERMEDIATE = 2 + RESULT = 3 + + +class TosaSerializer: + def __init__(self, pathPrefix, saveConstsToFile=False): + self.add_compat_methods() + # Get the global TOSA version if not already defined + + self.builder = flatbuffers.Builder(0) + + self.regions = [] + self.startRegion("main", pathPrefix, saveConstsToFile) + + # Enables inspection of constant data outside of graph + self.saveConstsToFile = saveConstsToFile + + self.currRegion.addBasicBlock("main", pathPrefix, self.saveConstsToFile) + + # Is this an illegal test that is expected to fail? + self.expectedReturnCode = 0 + self.expectedFailure = False + self.expectedFailureDesc = "" + + def __str__(self): + concatString = "" + for region in self.regions: + concatString = concatString + str(region) + return concatString + + def addPlaceholder(self, shape, dtype, vals): + return self.currRegion.addPlaceholder(shape, dtype, vals) + + def addConst(self, shape, dtype, vals): + return self.currRegion.addConst(shape, dtype, vals) + + def addIntermediate(self, shape, dtype): + return self.currRegion.addIntermediate(shape, dtype) + + def addInputTensor(self, tensor): + self.currRegion.addInputTensor(tensor) + + def addOutputTensor(self, tensor): + self.currRegion.addOutputTensor(tensor) + + def addOutput(self, shape, dtype): + return self.currRegion.addOutput(shape, dtype) + + def addOperator(self, op, inputs, outputs, attributes=None): + return self.currRegion.addOperator(op, inputs, outputs, attributes) + def setExpectedReturnCode(self, val, fail, desc=""): self.expectedReturnCode = val @@ -724,13 +770,13 @@ class TosaSerializer: Version.Add_draft(builder, TOSA_VERSION[3]) version = Version.End(builder) - fbv_bb = TosaSerializer.serializeObjVec( - builder, self.basicBlocks, TosaGraph.StartBlocksVector + fbv_region = TosaSerializer.serializeObjVec( + builder, self.regions, TosaGraph.StartRegionsVector ) TosaGraph.Start(builder) TosaGraph.AddVersion(builder, version) - TosaGraph.AddBlocks(builder, fbv_bb) + TosaGraph.AddRegions(builder, fbv_region) graph = TosaGraph.End(builder) self.builder.Finish(graph, TOSA_GRAPH_IDENTIFIER) @@ -747,16 +793,17 @@ class TosaSerializer: ofm_name = [] ofm_file = [] - for b in self.basicBlocks: - if b.name == "main": - for i in b.inputs: - ifm_name.append(i) - ifm_file.append(b.tensors[i].placeholderFilename) - for o in b.outputs: - ofm_name.append(o) - # Make up an OFM filename here. One isn't generated until the - # reference tool is run, so any name is a good name - ofm_file.append("ref-{}.npy".format(o)) + for region in self.regions: + for block in region.basicBlocks: + if block: + for i in block.inputs: + ifm_name.append(i) + ifm_file.append(block.tensors[i].placeholderFilename) + for o in block.outputs: + ofm_name.append(o) + # Make up an OFM filename here. One isn't generated until the + # reference tool is run, so any name is a good name + ofm_file.append("ref-{}.npy".format(o)) test_desc["ifm_name"] = ifm_name test_desc["ifm_file"] = ifm_file @@ -769,9 +816,9 @@ class TosaSerializer: return json.dumps(test_desc, indent=" ") - def startBasicBlock(self, name): - self.currBasicBlock = TosaSerializerBasicBlock(name) - self.basicBlocks.append(self.currBasicBlock) + def startRegion(self, name, pathPrefix, saveConstsToFile): + self.currRegion = TosaSerializerRegion(name, pathPrefix, saveConstsToFile) + self.regions.append(self.currRegion) @staticmethod def serializeStrVec(builder, vec, start_fcn): @@ -1090,8 +1137,8 @@ class TosaSerializer: if not hasattr(TosaGraph, "Start"): TosaGraph.Start = TosaGraph.TosaGraphStart TosaGraph.AddVersion = TosaGraph.TosaGraphAddVersion - TosaGraph.AddBlocks = TosaGraph.TosaGraphAddBlocks - TosaGraph.StartBlocksVector = TosaGraph.TosaGraphStartBlocksVector + TosaGraph.AddRegions = TosaGraph.TosaGraphAddRegions + TosaGraph.StartRegionsVector = TosaGraph.TosaGraphStartRegionsVector TosaGraph.End = TosaGraph.TosaGraphEnd from tosa import TosaOperator -- cgit v1.2.1