GPU 0: holds part of
every param (first 1/8th)
minibatch 0
input tokens
reconstruct first param
and do forward w/it
activations

repeat for next params
reconstruct last param
and do forward w/it
minibatch 0
output
···
···
···
GPU 1: holds part of
every param (second 1/8th)
minibatch 1
input tokens
reconstruct first param
and do forward w/it

repeat for next params
reconstruct last param
and do forward w/it
minibatch 1
output
···
···
···
GPU 7: holds part of
every param (last 1/8th)
minibatch 7
input tokens
reconstruct first param
and do forward w/it

repeat for next params
reconstruct last param
and do forward w/it
minibatch 7
output