blob: 1612447ff3122f31e58a4b56307818e4c9fb7cf5 (
plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
|
# SPDX-FileCopyrightText: Copyright 2022-2023, Arm Limited and/or its affiliates.
# SPDX-FileCopyrightText: Copyright The TensorFlow Authors. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
"""Collection of useful functions for optimizations."""
from __future__ import annotations
from pathlib import Path
from typing import Any
import tensorflow as tf
def get_tf_tensor_shape(model: str) -> list:
"""Get input shape for the TensorFlow tensor model."""
loaded = tf.saved_model.load(model)
try:
default_signature_key = tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY
default_signature = loaded.signatures[default_signature_key]
inputs_tensor_info = default_signature.inputs
except KeyError as err:
raise KeyError(f"Signature '{default_signature_key}' not found.") from err
return [
dim
for input_key in inputs_tensor_info
if (shape := input_key.get_shape())
for dim in shape
]
def save_keras_model(
model: tf.keras.Model, save_path: str | Path, include_optimizer: bool = True
) -> None:
"""Save Keras model at provided path."""
model.save(save_path, include_optimizer=include_optimizer)
def save_tflite_model(tflite_model: bytes, save_path: str | Path) -> None:
"""Save TensorFlow Lite model at provided path."""
with open(save_path, "wb") as file:
file.write(tflite_model)
def is_tflite_model(model: str | Path) -> bool:
"""Check if path contains TensorFlow Lite model."""
model_path = Path(model)
return model_path.suffix == ".tflite"
def is_keras_model(model: str | Path) -> bool:
"""Check if path contains a Keras model."""
model_path = Path(model)
if model_path.is_dir():
return model_path.joinpath("keras_metadata.pb").exists()
return model_path.suffix in (".h5", ".hdf5")
def is_saved_model(model: str | Path) -> bool:
"""Check if path contains SavedModel model."""
model_path = Path(model)
return model_path.is_dir() and not is_keras_model(model)
def get_tflite_model_type_map(model_filename: str | Path) -> dict[str, type]:
"""Get type map from tflite model."""
model_type_map: dict[str, Any] = {}
interpreter = tf.lite.Interpreter(str(model_filename))
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
model_type_map = {
input_detail["name"]: input_detail["dtype"] for input_detail in input_details
}
return model_type_map
def check_tflite_datatypes(model_filename: str | Path, *allowed_types: type) -> None:
"""Check if the model only has the given allowed datatypes."""
type_map = get_tflite_model_type_map(model_filename)
types = set(type_map.values())
allowed = set(allowed_types)
unexpected = types - allowed
def cls_to_str(types: set[type]) -> list[str]:
return [t.__name__ for t in types]
if len(unexpected) > 0:
raise TypeError(
f"Model {model_filename} has "
f"unexpected data types: {cls_to_str(unexpected)}. "
f"Only {cls_to_str(allowed)} are allowed."
)
|