Adding Class Objects To Pytorch Dataloader: Batch Must Contain Tensors
I have a custom Pytorch dataset that returns a dictionary containing a class object 'queries'. class QueryDataset(torch.utils.data.Dataset): def __init__(self, queries, values
Solution 1:
You need to define your own colate_fn in order to do this. A sloppy approach just to show you how stuff works here, would be something like this:
import torch
classDeviceDict:
def__init__(self, data):
self.data = data
defprint_data(self):
print(self.data)
classQueryDataset(torch.utils.data.Dataset):
def__init__(self, queries, values, targets):
super(QueryDataset).__init__()
self.queries = queries
self.values = values
self.targets = targets
def__len__(self):
return5def__getitem__(self, idx):
sample = {'query': self.queries[idx],
"values": self.values[idx],
"targets": self.targets[idx]}
return sample
defcustom_collate(dict):
return DeviceDict(dict)
dt = QueryDataset("q","v","t")
dl = torch.utils.data.DataLoader(dtt,batch_size=1,collate_fn=custom_collate)
t = next(iter(dl))
t.print_data()
Basically colate_fn
allows you to achieve custom batching or adding support for custom data types as explained in the link I previously provided.
As you see it just shows the concept, you need to change it based on your own needs.
Solution 2:
For those curious, this is the DeviceDict and custom collate function that I used to get things to work.
classDeviceDict(dict):
def__init__(self, *args):
super(DeviceDict, self).__init__(*args)
defto(self, device):
dd = DeviceDict()
for k, v in self.items():
if torch.is_tensor(v):
dd[k] = v.to(device)
else:
dd[k] = v
return dd
defcollate_helper(elems, key):
if key == "query":
return elems
else:
return torch.utils.data.dataloader.default_collate(elems)
defcustom_collate(batch):
elem = batch[0]
return DeviceDict({key: collate_helper([d[key] for d in batch], key) for key in elem})
Post a Comment for "Adding Class Objects To Pytorch Dataloader: Batch Must Contain Tensors"