aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--driver_library/src/ethosu.cpp13
1 files changed, 13 insertions, 0 deletions
diff --git a/driver_library/src/ethosu.cpp b/driver_library/src/ethosu.cpp
index 6b30827..2c498a8 100644
--- a/driver_library/src/ethosu.cpp
+++ b/driver_library/src/ethosu.cpp
@@ -57,12 +57,25 @@ size_t getShapeSize(const flatbuffers::Vector<int32_t> *shape) {
return size;
}
+size_t getTensorTypeSize(const enum tflite::TensorType type) {
+ switch (type) {
+ case tflite::TensorType::TensorType_UINT8:
+ case tflite::TensorType::TensorType_INT8:
+ return 1;
+ case tflite::TensorType::TensorType_INT16:
+ return 2;
+ default:
+ throw EthosU::Exception("Unsupported tensor type");
+ }
+}
+
vector<size_t> getSubGraphDims(const tflite::SubGraph *subgraph, const flatbuffers::Vector<int32_t> *tensorMap) {
vector<size_t> dims;
for (auto index = tensorMap->begin(); index != tensorMap->end(); ++index) {
auto tensor = subgraph->tensors()->Get(*index);
size_t size = getShapeSize(tensor->shape());
+ size *= getTensorTypeSize(tensor->type());
if (size > 0) {
dims.push_back(size);