Tensorflow

To easily capture model metadata, you can use our inference runners. These runners will run a dataset through the model and log the metadata katiML needs, without having to call upload_to_lake() or a REST API.

from dioptra.inference.tf.classifier_runner import ClassifierRunner
class ClassifierRunner(
    model: Model,
    embeddings_layers: List[str],
    logits_layer: str,
    class_names: List[str],
    metadata: Optional[List[object]]
)
Arguments
Description
model

a tensorflow model

embeddings_layers

a list of names of embeddings layers. It can combine names and indexes. For example: [0].embeddings

logits_layer

the name of the logits layer

class_names

a list of class names corresponding to the logits indexes

metadata

a list of metadata to be added to each datapoint. The index in this list should match the index from the dataset. metadata can include any of the metadata accepted by Dioptra and will override model inferenced values

def run(
    self,
    dataset: Dataset
)
Arguments
Description
dataset

an iterable dataset to run the model inference on. The iterator should only return the features, not the groundtruth. Groundtruth should be passed as a metadata. It should not be shuffled is used with metadata

def wait_for_uploads(self): -> [object]
# Waits on all metadata generated during inference to be uploaded to Dioptra.
# Returns the list of uploads generated by the runner during the inference.

Example usage

import os

os.environ['DIOPTRA_API_KEY'] = 'my_api_key'
os.environ['DIOPTRA_UPLOAD_BUCKET'] = 'my_upload_bucket'

my_model = ... 
my_dataset = ... # some dataset that we can iterate over
my_classes = ...

my_runner = ClassifierRunner(
    model=model, 
    embeddings_layers=['embeddings'],
    logits_layer='logits',
    class_names=my_classes,
    metadata=metadata
)

my_runner.run(my_dataset)

my_runner.wait_for_uploads()

print('Metadata generated and uploaded!')

You can also send metadata like groundtruth, image uris or tags. Here is a complete example.

import os

os.environ['DIOPTRA_API_KEY'] = 'my_api_key'
os.environ['DIOPTRA_UPLOAD_BUCKET'] = 'my_upload_bucket'

import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow.keras.applications import EfficientNetB0

from dioptra.inference.tf.classifier_runner import ClassifierRunner

def pre_process_data(datapoint):
    img = tf.image.resize(datapoint['image'], (224, 224))
    return img

class DLBinary(Model):
    def __init__(self, imgs_size):
        super(DLBinary, self).__init__()
        input_tensor = layers.Input(shape=(imgs_size, imgs_size, 3), dtype=tf.float32)
        self.model = tf.keras.Sequential([
            input_tensor,
            EfficientNetB0(weights='imagenet', include_top=False, input_tensor=input_tensor),
            layers.Layer(name='embeddings'), # name your embeddings layer to reference it easily
            layers.Flatten(),
            layers.Dense(4096, activation='relu'),
            layers.Dropout(0.1),
            layers.Dense(4096, activation='relu'), # you can have multiple embeddings layers
            layers.Dense(1, activation=None)
        ])
     
    def call(self, x):
        return self.model(x)

dataset = tfds.load(name='cats_vs_dogs', split=['train'], as_supervised=False)[0].take(10)
my_dataset = dataset.map(lambda datapoint: pre_process_data(datapoint)).batch(5)

model = DLBinary(224)

# You can add metadata to each datapoint to add a uri to view the data,
# or tags to filter the data, or groundtruth 
metadata = []
for i in range(len(dataset)):
    metadata.append({
        'image_metadata': {
            'uri': 's3://....'
        },
        'tags': {
            'model_id': 'efficientnet',
            'dataset_id': 'cats_vs_dogs'
        },
        'groundtruth': {
            'class_name': '...'
        })

my_runner = ClassifierRunner(
    model=model, 
    embeddings_layers=['[0].embeddings', '[0].high_embeddings'], # log multiple embeddings layers
    logits_layer='[0].[-1]', # reference your layers by name or indexes
    class_names=['cat', 'dogs'],
    metadata=metadata
)

my_runner.run(my_dataset)

my_runner.wait_for_uploads()

print('Metadata generated and uploaded!')

Last updated

Was this helpful?