Created by: klshuster
Patch Description
Pytorch would not broadcast the src_tokens
correctly to tokens
when utilizing beam search with beam_size > 1.
Testing steps Tested with sequence generation and beam search. I included a test but there are several issues going on with testing, and figured i'd just include this PR to help anyone else if they've experienced this
BEFORE, with a beam size of 5
2022-06-16 19:46:05 | INFO | metaseq.hub_utils | Executing generation on input tensor size torch.Size([2, 121])
.
.
.
File "metaseq_public/metaseq/sequence_generator.py", line 93, in generate
return self._generate(sample, **kwargs)
File "metaseq_public/metaseq/sequence_generator.py", line 218, in _generate
tokens[:, :start_step] = src_tokens
RuntimeError: The expanded size of the tensor (10) must match the existing size (2) at non-singleton dimension 0. Target sizes: [10, 121]. Tensor sizes: [2, 121]
AFTER It works