PyTorch
Inference Runner
To easily capture model metadata, you can use our inference runners.
These runners will run a dataset through the model and log the metadata katiML needs, without having to call upload_to_lake()
or a REST API.
a pytorch model
the type of model to be used CLASSIFICATION
or SEGMENTATION
the name of the model. should incliude version number if running more than once
a list of names of embeddings layers. It can combine names and indexes. For example:
[0].embeddings
the name of the logits layer
list of datapoints ids to be updated. The list should be in the same order as the dataset
list of metadata to be added to each datapoints. The list should be in the same order as the dataset
a metadata object to be added to all datapoints
an optional transform method to be applied to each batch of the dataset. Could be useful to remove the groundtruth from the dataset iterator.
how many monte carlo dropout samples to take. if > 0, then this is a monte carlo dropout experiment, the model will be put in train
mode before running inference
a string to control where to run the inference. cpu
or cuda
a list of class names corresponding to the logits indexes
an iterable dataset to run the model inference on.
The iterator should only return the features, not the groundtruth.
Groundtruth should be passed as a metadata.
It should not be shuffled is used with metadata
Example usage
You can also send metadata like groundtruth, image uris or tags. Here is a complete example.
Object Store Dataset
To easily use katiML in your training pipeline, we provide a wrapper around torch.utils.data.Dataset
. This dataset can download object store images into a local cache to optimize loading time.
Other pre processing will also happen for SEGMENTATION
to decompress and load the mask under groundtruths[0].encoded_segmentation_class_mask
The dataset can prefetch images using a multi processing thread pool using the prefetch_images
method or can be wrapped in a torch.utils.data.DataLoader
with several workers to stream the data to the notebook. and still feed the GPU.
the data frame to be used as a data source. Can come from download_from_lake()
or dataset.download()
a torchsivison.transform
method to be applied to each item when iterating over the dataset
the number of workers to be used to pre fetch the data. Default is 1
Example usage
Example in a pipeline
Last updated
Was this helpful?