Skip to content Skip to sidebar Skip to footer

Adding Class Objects To Pytorch Dataloader: Batch Must Contain Tensors

I have a custom Pytorch dataset that returns a dictionary containing a class object 'queries'. class QueryDataset(torch.utils.data.Dataset): def __init__(self, queries, values

Solution 1:

You need to define your own colate_fn in order to do this. A sloppy approach just to show you how stuff works here, would be something like this:

import torch
classDeviceDict:
    def__init__(self, data):
        self.data = data 

    defprint_data(self):
        print(self.data)

classQueryDataset(torch.utils.data.Dataset):

    def__init__(self, queries, values, targets):
        super(QueryDataset).__init__()
        self.queries = queries
        self.values = values
        self.targets = targets

    def__len__(self):
        return5def__getitem__(self, idx):
        sample = {'query': self.queries[idx],
                 "values": self.values[idx],
                 "targets": self.targets[idx]}
        return sample

defcustom_collate(dict):
    return DeviceDict(dict)

dt = QueryDataset("q","v","t")
dl = torch.utils.data.DataLoader(dtt,batch_size=1,collate_fn=custom_collate)
t = next(iter(dl))
t.print_data()

Basically colate_fn allows you to achieve custom batching or adding support for custom data types as explained in the link I previously provided. As you see it just shows the concept, you need to change it based on your own needs.

Solution 2:

For those curious, this is the DeviceDict and custom collate function that I used to get things to work.

classDeviceDict(dict):

    def__init__(self, *args):
        super(DeviceDict, self).__init__(*args)

    defto(self, device):
        dd = DeviceDict()
        for k, v in self.items():
            if torch.is_tensor(v):
                dd[k] = v.to(device)
            else:
                dd[k] = v
        return dd


defcollate_helper(elems, key):
    if key == "query":
        return elems
    else:
        return torch.utils.data.dataloader.default_collate(elems)


defcustom_collate(batch):
    elem = batch[0]
    return DeviceDict({key: collate_helper([d[key] for d in batch], key) for key in elem})

Post a Comment for "Adding Class Objects To Pytorch Dataloader: Batch Must Contain Tensors"