Description
🐛 Describe the bug
I am working with a multimodal signal dataset. It is composed of time-series data (signals) and structured data (process parameters). The samples are returned as a dict with two keys (x and proc_data), which are then fed to the model. I am getting an error during training, as Avalanche attempts to move the minibatches to the GPU device. The error I am getting is the following:
File "/workspace/example/continual_training.py", line 325, in main
strategy.train(scenario.train_stream)
File "/opt/conda/lib/python3.11/site-packages/avalanche/training/supervised/joint_training.py", line 152, in train
self._before_training(**kwargs)
File "/opt/conda/lib/python3.11/site-packages/avalanche/training/templates/base.py", line 326, in _before_training
trigger_plugins(self, "before_training", **kwargs)
File "/opt/conda/lib/python3.11/site-packages/avalanche/training/utils.py", line 75, in trigger_plugins
getattr(p, event)(strategy, **kwargs)
File "/opt/conda/lib/python3.11/site-packages/avalanche/training/templates/base_sgd.py", line 616, in before_training
self._peval(strategy, **kwargs)
File "/opt/conda/lib/python3.11/site-packages/avalanche/training/templates/base_sgd.py", line 637, in _peval
strategy.eval(el, **kwargs)
File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/avalanche/training/templates/base_sgd.py", line 228, in eval
super().eval(exp_list, **kwargs)
File "/opt/conda/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/opt/conda/lib/python3.11/site-packages/avalanche/training/templates/base.py", line 212, in eval
self._eval_exp(**kwargs)
File "/opt/conda/lib/python3.11/site-packages/avalanche/training/templates/base_sgd.py", line 232, in _eval_exp
self.eval_epoch(**kwargs)
File "/opt/conda/lib/python3.11/site-packages/avalanche/training/templates/base_sgd.py", line 270, in eval_epoch
self._unpack_minibatch()
File "/opt/conda/lib/python3.11/site-packages/avalanche/training/templates/problem_type/supervised_problem.py", line 63, in _unpack_minibatch
mbatch[i] = mbatch[i].to(self.device, non_blocking=True) # type: ignore
^^^^^^^^^^^^
AttributeError: 'dict' object has no attribute 'to'
🐜 To Reproduce
The following repository has a minimal example to reproduce the error: https://github.com/spartanjoax/continual_example. It can be run with the following command:
python example.py
🐝 Expected behavior
I would expect to be able to work with multimodal data. I would appreciate it if you could provide me with some guidance on how to edit the code so that it can do something like the following:
if isinstance(mbatch[i], dict):
# Move all tensor values in the dictionary to the specified device
mbatch[i] = {key: value.to(self.device, non_blocking=True) for key, value in mbatch[i].items()}
elif isinstance(mbatch[i], torch.Tensor):
# Directly move the tensor to the specified device
mbatch[i] = mbatch[i].to(self.device, non_blocking=True)
else:
# Optionally handle other cases (e.g., raise an error if unsupported types are encountered)
raise TypeError(f"Unsupported type in batch: {type(mbatch[i])}")