PyTorch

Inference Runner

To easily capture model metadata, you can use our inference runners. These runners will run a dataset through the model and log the metadata katiML needs, without having to call upload_to_lake() or a REST API.

from dioptra.inference.torch.torch_runner import TorchInferenceRunner
class TorchInferenceRunner(          
    model: Model,
    model_type: str,
    model_name: str,
    embeddings_layers: Optional[List[str]],
    logits_layer: Optional[str],
    datapoint_ids: Optional[List[str]],
    datapoints_metadata: Optional[List[object]],
    dataset_metadata: Optional[object],
    data_transform: Optional[transforms],
    mc_dropout_samples: Optional[int],
    device: Optional[str],
    class_names: List[str]
)
Arguments
Description

a pytorch model

the type of model to be used CLASSIFICATION or SEGMENTATION

the name of the model. should incliude version number if running more than once

a list of names of embeddings layers. It can combine names and indexes. For example: [0].embeddings

the name of the logits layer

list of datapoints ids to be updated. The list should be in the same order as the dataset

list of metadata to be added to each datapoints. The list should be in the same order as the dataset

a metadata object to be added to all datapoints

an optional transform method to be applied to each batch of the dataset. Could be useful to remove the groundtruth from the dataset iterator.

how many monte carlo dropout samples to take. if > 0, then this is a monte carlo dropout experiment, the model will be put in train mode before running inference

a string to control where to run the inference. cpu or cuda

a list of class names corresponding to the logits indexes

Arguments
Description

an iterable dataset to run the model inference on. The iterator should only return the features, not the groundtruth. Groundtruth should be passed as a metadata. It should not be shuffled is used with metadata

Example usage

You can also send metadata like groundtruth, image uris or tags. Here is a complete example.

Object Store Dataset

To easily use katiML in your training pipeline, we provide a wrapper around torch.utils.data.Dataset. This dataset can download object store images into a local cache to optimize loading time.

If self.use_caching is True, the dataset will use the caching dir defined by DIOPTRA_CACHE_DIR (default is ~.dioptra) to pull cache the image using the path stored DataFrame under metadata.uri . The images will be loaded using smart_open so proper credentials should be configured. the resulting image will be a PIL image stored in the image field of the returned row.

Other pre processing will also happen for SEGMENTATION to decompress and load the mask under groundtruths[0].encoded_segmentation_class_mask

The dataset can prefetch images using a multi processing thread pool using the prefetch_images method or can be wrapped in a torch.utils.data.DataLoader with several workers to stream the data to the notebook. and still feed the GPU.

Arguments
Description

the data frame to be used as a data source. Can come from download_from_lake() or dataset.download()

a torchsivison.transform method to be applied to each item when iterating over the dataset

Arguments
Description

the number of workers to be used to pre fetch the data. Default is 1

Example usage

Example in a pipeline

Last updated

Was this helpful?