Dioptra Documentation
  • What is KatiML ?
  • Overview
    • 🏃Getting Started
    • 🌊KatiML
      • Quick start
      • Ingestion basics
      • Ingestion SDK
      • Query basics
      • Query SDK
      • Dataset basics
      • Dataset SDK
      • Supported fields
      • Matching local data with Kati ML IDs
      • Managing Datapoints with Tags
      • Configuring Object Stores (optional)
    • 🧠Active Learning
      • 📖Miners basics
      • ⛏️Miners SDK
      • 🚗[Experimental] Mining on the edge
    • 🤖PyTorch and Tensorflow integrations
      • Tensorflow
      • PyTorch
  • 😬Enough docs, show me some code !
  • 📑Case studies
  • Definitions
Powered by GitBook
On this page
  • Inference Runner
  • Object Store Dataset

Was this helpful?

  1. Overview
  2. PyTorch and Tensorflow integrations

PyTorch

Inference Runner

To easily capture model metadata, you can use our inference runners. These runners will run a dataset through the model and log the metadata katiML needs, without having to call upload_to_lake() or a REST API.

from dioptra.inference.torch.torch_runner import TorchInferenceRunner
class TorchInferenceRunner(          
    model: Model,
    model_type: str,
    model_name: str,
    embeddings_layers: Optional[List[str]],
    logits_layer: Optional[str],
    datapoint_ids: Optional[List[str]],
    datapoints_metadata: Optional[List[object]],
    dataset_metadata: Optional[object],
    data_transform: Optional[transforms],
    mc_dropout_samples: Optional[int],
    device: Optional[str],
    class_names: List[str]
)
Arguments
Description

a pytorch model

the type of model to be used CLASSIFICATION or SEGMENTATION

the name of the model. should incliude version number if running more than once

a list of names of embeddings layers. It can combine names and indexes. For example: [0].embeddings

the name of the logits layer

list of datapoints ids to be updated. The list should be in the same order as the dataset

list of metadata to be added to each datapoints. The list should be in the same order as the dataset

a metadata object to be added to all datapoints

an optional transform method to be applied to each batch of the dataset. Could be useful to remove the groundtruth from the dataset iterator.

how many monte carlo dropout samples to take. if > 0, then this is a monte carlo dropout experiment, the model will be put in train mode before running inference

a string to control where to run the inference. cpu or cuda

a list of class names corresponding to the logits indexes

def run(
    self,
    dataset: Dataset
)
Arguments
Description

an iterable dataset to run the model inference on. The iterator should only return the features, not the groundtruth. Groundtruth should be passed as a metadata. It should not be shuffled is used with metadata

def wait_for_uploads(self): -> [object]
# Waits on all metadata generated during inference to be uploaded to Dioptra.
# Returns the list of uploads generated by the runner during the inference.

Example usage

import os

os.environ['DIOPTRA_API_KEY'] = 'my_api_key'
os.environ['DIOPTRA_UPLOAD_BUCKET'] = 'my_upload_bucket'

my_model = ... 
my_dataset = ... # some dataset that we can iterate over
my_classes = ...

my_runner = ClassifierRunner(
    model=model, 
    embeddings_layers=['embeddings'],
    logits_layer='logits',
    class_names=my_classes,
    metadata=metadata
)

my_runner.run(my_dataset)

my_runner.wait_for_uploads()

print('Metadata generated and uploaded!')

You can also send metadata like groundtruth, image uris or tags. Here is a complete example.

import os

os.environ['DIOPTRA_API_KEY'] = 'my_api_key'
os.environ['DIOPTRA_UPLOAD_BUCKET'] = 'my_upload_bucket'

from datasets import load_dataset
from torchvision import transforms
import uuid
from dioptra.inference.torch.torch_runner import TorchInferenceRunner

preprocess = transforms.Compose([
    transforms.Lambda(lambda x:x.numpy()),
    transforms.ToPILImage(),
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def my_transform(batch):
    processed = preprocess(batch['image']).unsqueeze(0)
    return processed

dataset = load_dataset('cats_vs_dogs').with_format('torch')['train'].select(range(10))

metadata = [{
    'tags': {
        'datapoint_id': str(uuid.uuid4())
        'model_id': 'test_model_torch'
    }
} for i in range(10)]

my_runner = TorchInferenceRunner(
    model=model, 
    embeddings_layers=['layer4'],
    logits_layer='fc',
    class_names=categories,
    data_transform=my_transform,
    device='cuda',
    metadata=metadata
)

my_runner.run(dataset)

my_runner.wait_for_uploads()

print('Metadata generated and uploaded!')

Object Store Dataset

To easily use katiML in your training pipeline, we provide a wrapper around torch.utils.data.Dataset. This dataset can download object store images into a local cache to optimize loading time.

Other pre processing will also happen for SEGMENTATION to decompress and load the mask under groundtruths[0].encoded_segmentation_class_mask

The dataset can prefetch images using a multi processing thread pool using the prefetch_images method or can be wrapped in a torch.utils.data.DataLoader with several workers to stream the data to the notebook. and still feed the GPU.

from dioptra.lake.torch.object_store_datasets import ImageDataset
class ImageDataset(
    dataframe: DataFrame,
    transform: torchvision.transform
)
Arguments
Description

the data frame to be used as a data source. Can come from download_from_lake() or dataset.download()

a torchsivison.transform method to be applied to each item when iterating over the dataset

def prefetch_images(
    self,
    num_workers: int -> 1
)
Arguments
Description

the number of workers to be used to pre fetch the data. Default is 1

Example usage

from dioptra.lake.utils import select_datapoints
from dioptra.lake.torch.object_store_datasets import ImageDataset

unlabeled_df = select_datapoints(filters=[{
    'left': 'tags.data_split',
    'op': '=',
    'right': 'train'
}], fields=['image_metadata.uri', 'tags.datapoint_id', 'request_id'])

unlabeled_dataset = ImageDataset(unlabeled_df)

Example in a pipeline

import torch
from torchvision import transforms
from torch.utils.data import DataLoader

from dioptra.inference.torch.torch_runner import TorchInferenceRunner
from dioptra.lake.utils import select_datapoints
from dioptra.lake.torch.object_store_datasets import ImageDataset

unlabeled_df = select_datapoints(filters=[{
    'left': 'tags.data_split',
    'op': '=',
    'right': 'train'
}], fields=['image_metadata.uri', 'tags.datapoint_id', 'request_id'])

unlabeled_dataset = ImageDataset(unlabeled_df)

transform_pipe = transforms.Compose([
    transforms.Lambda(lambda x: x.convert('RGB')),
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Lambda(lambda x: x.repeat(3, 1, 1) if x.shape[0] == 1 else x),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

def transform(row):
    return transform_pipe(row['image'])
    
unlabeled_dataset.transform = transform
unlabeled_dataset.load_images = True
unlabeled_dataset.prefetch_images(20)

data_loader = DataLoader(
    unlabeled_dataset, batch_size=10, num_workers=4, shuffle=False)

torch_model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet18', pretrained=True)
torch_model.to('cuda')

my_runner = TorchInferenceRunner(
    model=torch_model, 
    embeddings_layers=['layer4'],
    device='cuda',
    metadata=first_run_metadata
)

my_runner.run(data_loader)
PreviousTensorflowNextEnough docs, show me some code !

Last updated 1 year ago

Was this helpful?

If self.use_caching is True, the dataset will use the caching dir defined by DIOPTRA_CACHE_DIR (default is ~.dioptra) to pull cache the image using the path stored DataFrame under metadata.uri . The images will be loaded using so proper credentials should be configured. the resulting image will be a PIL image stored in the image field of the returned row.

🤖
model
model_type
model_name
embeddings_layers
logits_layer
datapoints_ids
datapoints_metadata
dataset_metadata
data_transform
mc_dropout_samples
device
class_names
dataset
dataframe
transform
num_workers
smart_open