This is a 2 part tutorial for the Fairseq model BART. In the first part I have walked through the details how a Transformer model is built. Please refer to part 1.

A BART class is, in essence, a FairseqTransformer class. The difference only lies in the arguments that were used to construct the model. Since this part is relatively straightforward, I will postpone diving into its details till the end of this article.

For now, I will discuss how a BART model is actually loaded if you follow BART Official Doc.

Pytorch Hub

Bart model is loaded via Pytorch Hub. Pytorch hub allows researchers to host their models in their own github repository. Researchers should define a hubconf.py at the root of their github repository to define how torch hub can retrieve their model definition and pretrained weights.

When the user calls load method, he needs to pass in github and model arguments. The former defines which github repository to look, the latter defines which model it should retrieve. The model argument specifies the function name defined in hubconf.py to call from.

Fairseq’s hubconf.py defines the following routines that registers some functions to for user’s to retrieve model from:

for _model_type, _cls in MODEL_REGISTRY.items():
    for model_name in _cls.hub_models().keys():
        globals()[model_name] = functools.partial(
            _cls.from_pretrained,
            model_name,
        )

It loop over all MODEL_REGISTRY to retrieve all name and model class. For each model class, it retrieve the URLs to available pretrained models using the hub_models method. What hub_models return is a dictionary, the keys are the name to the model and the entries are the URLs to the pretrained model. Fairseq will add those keys as the names of the global function list, mapping them to a partial function: with model path fixed to the retrieved pretrained model.

BARTModel::from_pretrained

BartModel.from_pretrained actually calls hub_utils.from_pretrained() to return a dictionary with three key-value items: ‘args’ ‘task’ ‘models’. It uses the checkpoint_utils::load_model_ememble_and_task() method. This is the function that builds the model and load state dict to the model. Remember that BartModel is a nn.module in nature, calling load_state_dict will load pretrained weights to the model. The loaded model is then returned to help construct a BARTHubInterface instance.

A closer examination sees that the parameters for building the model is also saved to the checkpoint file, after loading with checkpoin_utils::load_check_point_to_cpu() method, it will return a state dictionary, in which the args key corresponds to the arguments that used to rebuild the model, the model key corresponds to the dictionary that contains the trained weights.

An important side note is that the task that bart utilizes is defined in those args, which is a Task::denoising task, we will look at this class later.

BARTHubInterface

A BARTHubinterface is eventually what a torch.hub.load() returns, the interface provides a few useful methods that helps user to use the model.

class BARTHubInterface(nn.Module):
    # Save user defined arguments, task and model are setup in from_pretrained
    def __init__(self, args, task, model):
    @property
    def device(self):...
    # encode a sentence/sentence pair to bpe encoding, returns a long tensor
    def encode(self, sentence: str, *addl_sentences, no_separator=True) -> torch.LongTensor:...
    # decode bpe encodings back to a normal sentence, returns a string
    def decode(self, tokens: torch.LongTensor):...
    # convert input tokens to proper encodings
    def _build_sample(self, src_tokens: List[torch.LongTensor]):...
    # The function that performs summarization task
    # See below discussion
    def sample(self, sentences: List[str], beam: int = 1, verbose: bool = False, **kwargs) -> str:...
    # utilized by sample
    def generate(self, tokens: List[torch.LongTensor], beam: int = 5, verbose: bool = False, **kwargs) -> torch.LongTensor:...
    # Bart has a special head used for sentence classification, the default is
    # defined in model.py::BARTClassificationHead, can add user defined head
    def register_classification_head(
        self, name: str, num_classes: int = None, embedding_size: int = None, **kwargs
    ):...
    # Interface used for sentence level classification. It uses the output from the
    # classification head (which are the logits) and pass through a log_softmax.
    def predict(self, head: str, tokens: torch.LongTensor, return_logits: bool = False):...

A sample method is devided in three steps: encode, generate, decode. The genreate methods calls the build_generator method that’s defined in the DenosingTask, which inherits from FairseqTask. build_generator method returns a SequenceGenerator class, which takes a source token sequence and perform “translation”. A translation involves many special token handling (BOS, EOS), paddings, masks, model feed forward and word search.

BARTModel

The BARTModel is a Transformer class. Besides the aforementioned hub_models(), from_pretrained() and register_classification_head() method and other support methods. It defines a (quite trivial) forward method: It simply passes the input tokens through the encoder and decoders. Bart may support sentence classification task, thus user may define whether to use sentence classification head or not, or even pass in a custom defined head.

Some important BART parameters includes: it has 12 TransformerEncoderLayer and TransformerDeocderLayer respectively. Each of them has 16 attention heads. The input dimensions are 1024 and hidden layers are 4096. They both does not adopt LayerDrop, aka prob is 0. Dropout used at attension layer is 0.1. Activation function is GELU. Optimization is ADAM, with parameter adam_betas='(0.9, 0.999)', adam_eps=1e-06. I have host the complete parameter set and model definition here.