Description
🐛 Describe the bug
I am using benchmark_with_validation_stream
to split my benchmark into a version where the train_stream
is divided into train_stream
and valid_stream
. However, training fails as soon as I set eval_streams=[cl_val_stream]
in my strategy.
Error: AttributeError: 'DatasetExperience' object has no attribute 'benchmark'
When inspecting the streams it seems like the train_steam
, as well as the valid_stream
, no longer have the benchmark
attribute, whereas test_stream
still has it.
I think the issue here is that the generated streams from benchmark_with_validation_stream
, so the train_stream
and valid_stream
, belong to EagerCLStream while the test_stream
still belongs to NCStream.
Also the train_stream
and valid_stream
lose the benchmark
attribute after calling benchmark_with_validation_stream
, while test_stream
retains it.
🐜 To Reproduce
cl_mnist = SplitMNIST(
n_experiences=5,
return_task_id=False,
seed=42,
fixed_class_order=[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
)
cl_mnist_with_val = benchmark_with_validation_stream(
cl_mnist,
validation_size=0.1,
shuffle=True,
seed=42
)
cl_train_stream = cl_mnist_with_val.train_stream
cl_test_stream = cl_mnist_with_val.test_stream
cl_val_stream = cl_mnist_with_val.valid_stream
baseline_model = SimpleMLP(num_classes=10)
baseline_optimizer = SGD(baseline_model.parameters(), lr=0.001, momentum=0.9)
baseline_criterion = CrossEntropyLoss()
baseline_naive_strategy = Naive(
model=baseline_model,
optimizer=baseline_optimizer,
criterion=baseline_criterion,
train_mb_size=64,
train_epochs=5,
eval_mb_size=64,
eval_every=0,
evaluator=baseline_eval_plugin
)
baseline_results = []
for exp in cl_train_stream:
res = baseline_naive_strategy.train(exp, eval_streams=[cl_val_stream]) # <- Error is happening here
baseline_results.append(res)
🐝 Expected behavior
I should be able to train the model on the train set, validate it during training on the validation set and do inference on the test set afterwards. However, since cl_val_stream
laks the benchmark
attribute, training fails when eval_streams=[cl_val_stream]
is used.