The [VisionEncoderDecoderModel
] can be used to initialize an image-to-text model with any
pretrained Transformer-based vision model as the encoder (e.g. ViT, BEiT, DeiT, Swin)
and any pretrained language model as the decoder (e.g. RoBERTa, GPT2, BERT, DistilBERT).
The effectiveness of initializing image-to-text-sequence models with pretrained checkpoints has been shown in (for example) TrOCR: Transformer-based Optical Character Recognition with Pre-trained Models by Minghao Li, Tengchao Lv, Lei Cui, Yijuan Lu, Dinei Florencio, Cha Zhang, Zhoujun Li, Furu Wei.
After such a [VisionEncoderDecoderModel
] has been trained/fine-tuned, it can be saved/loaded just like any other models (see the examples below
for more information).
An example application is image captioning, in which the encoder is used to encode the image, after which an autoregressive language model generates
the caption. Another example is optical character recognition. Refer to TrOCR, which is an instance of [VisionEncoderDecoderModel
].
[VisionEncoderDecoderModel
] can be randomly initialized from an encoder and a decoder config. In the following example, we show how to do this using the default [ViTModel
] configuration for the encoder
and the default [BertForCausalLM
] configuration for the decoder.
>>> from transformers import BertConfig, ViTConfig, VisionEncoderDecoderConfig, VisionEncoderDecoderModel
>>> config_encoder = ViTConfig()
>>> config_decoder = BertConfig()
>>> config = VisionEncoderDecoderConfig.from_encoder_decoder_configs(config_encoder, config_decoder)
>>> model = VisionEncoderDecoderModel(config=config)
[VisionEncoderDecoderModel
] can be initialized from a pretrained encoder checkpoint and a pretrained decoder checkpoint. Note that any pretrained Transformer-based vision model, e.g. Swin, can serve as the encoder and both pretrained auto-encoding models, e.g. BERT, pretrained causal language models, e.g. GPT2, as well as the pretrained decoder part of sequence-to-sequence models, e.g. decoder of BART, can be used as the decoder.
Depending on which architecture you choose as the decoder, the cross-attention layers might be randomly initialized.
Initializing [VisionEncoderDecoderModel
] from a pretrained encoder and decoder checkpoint requires the model to be fine-tuned on a downstream task, as has been shown in the Warm-starting-encoder-decoder blog post.
To do so, the VisionEncoderDecoderModel
class provides a [VisionEncoderDecoderModel.from_encoder_decoder_pretrained
] method.
>>> from transformers import VisionEncoderDecoderModel
>>> model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
... "microsoft/swin-base-patch4-window7-224-in22k", "google-bert/bert-base-uncased"
... )
To load fine-tuned checkpoints of the VisionEncoderDecoderModel
class, [VisionEncoderDecoderModel
] provides the from_pretrained(...)
method just like any other model architecture in Transformers.
To perform inference, one uses the [generate
] method, which allows to autoregressively generate text. This method supports various forms of decoding, such as greedy, beam search and multinomial sampling.
>>> import requests
>>> from PIL import Image
>>> from transformers import GPT2TokenizerFast, ViTImageProcessor, VisionEncoderDecoderModel
>>> # load a fine-tuned image captioning model and corresponding tokenizer and image processor
>>> model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
>>> tokenizer = GPT2TokenizerFast.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
>>> image_processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
>>> # let's perform inference on an image
>>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
>>> image = Image.open(requests.get(url, stream=True).raw)
>>> pixel_values = image_processor(image, return_tensors="pt").pixel_values
>>> # autoregressively generate caption (uses greedy decoding by default)
>>> generated_ids = model.generate(pixel_values)
>>> generated_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
>>> print(generated_text)
a cat laying on a blanket next to a cat laying on a bed
[TFVisionEncoderDecoderModel.from_pretrained
] currently doesn't support initializing the model from a
PyTorch checkpoint. Passing from_pt=True
to this method will throw an exception. If there are only PyTorch
checkpoints for a particular vision encoder-decoder model, a workaround is:
>>> from transformers import VisionEncoderDecoderModel, TFVisionEncoderDecoderModel
>>> _model = VisionEncoderDecoderModel.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
>>> _model.encoder.save_pretrained("./encoder")
>>> _model.decoder.save_pretrained("./decoder")
>>> model = TFVisionEncoderDecoderModel.from_encoder_decoder_pretrained(
... "./encoder", "./decoder", encoder_from_pt=True, decoder_from_pt=True
... )
>>> # This is only for copying some specific attributes of this particular model.
>>> model.config = _model.config
Once the model is created, it can be fine-tuned similar to BART, T5 or any other encoder-decoder model on a dataset of (image, text) pairs.
As you can see, only 2 inputs are required for the model in order to compute a loss: pixel_values
(which are the
images) and labels
(which are the input_ids
of the encoded target sequence).
>>> from transformers import ViTImageProcessor, BertTokenizer, VisionEncoderDecoderModel
>>> from datasets import load_dataset
>>> image_processor = ViTImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
>>> tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
>>> model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained(
... "google/vit-base-patch16-224-in21k", "google-bert/bert-base-uncased"
... )
>>> model.config.decoder_start_token_id = tokenizer.cls_token_id
>>> model.config.pad_token_id = tokenizer.pad_token_id
>>> dataset = load_dataset("huggingface/cats-image")
>>> image = dataset["test"]["image"][0]
>>> pixel_values = image_processor(image, return_tensors="pt").pixel_values
>>> labels = tokenizer(
... "an image of two cats chilling on a couch",
... return_tensors="pt",
... ).input_ids
>>> # the forward function automatically creates the correct decoder_input_ids
>>> loss = model(pixel_values=pixel_values, labels=labels).loss
This model was contributed by nielsr. This model's TensorFlow and Flax versions were contributed by ydshieh.
[[autodoc]] VisionEncoderDecoderConfig
[[autodoc]] VisionEncoderDecoderModel - forward - from_encoder_decoder_pretrained
[[autodoc]] TFVisionEncoderDecoderModel - call - from_encoder_decoder_pretrained
[[autodoc]] FlaxVisionEncoderDecoderModel - call - from_encoder_decoder_pretrained