Created by: stephenroller
Patch Description Some of our consolidation logic is extremely slow: this is due to the concat's happening on CPU, and the default in pytorch is to use OMP threads to parallelize this. However, we already have spawned multiple processes (or slurm did), so all the threads end up fighting and thrashing for the CPUs.
We could do this individually in every script if we wanted to be careful about optimizing for CPU cases, but given that we basically only ever do GPU logic, we can just bake this assumption right into our distributed code.
Testing steps Before change, loading checkpoints pegged all CPUs to 100%. After change, they do not and it runs faster.