Visual Language Model

Large Language Model
Multi-Modality
In this project, I will implement a Visual Language Model (VLM) using ‘pure’ PyTorch, which is a model that can understand and generate text based on visual inputs.
Author

Yuyang Zhang

1 Vision Transformer

The vision encoder is the classic Vision Transformer (ViT) architecture as proposed in (Dosovitskiy et al. 2021).

Figure 1: The illustration of the Vision Transformer.

1.1 Patch Embedding

The first component of the ViT is the Patch Embedding layer, which converts the input image into a sequence of patches. Each patch is treated as a token, similar to how words are treated in NLP models. The patches are flattened and linearly projected into a higher-dimensional space. We can combine those two steps into a single convolutional layer. After patching the image into smaller patches, we can add a learnable class token and positional embeddings to the sequence of patches. The class token is used for classification tasks, while positional embeddings help the model understand the spatial relationships between patches.

class ViTPatchEmbedding(nn.Module):
    def __init__(self, config: VLMConfig):
        super().__init__()

        self.img_size = config.vit_img_size
        self.patch_size = config.vit_patch_size

        assert (
            self.img_size % self.patch_size == 0
        ), "Image size must be divisible by patch size."
        self.num_patches = (self.img_size // self.patch_size) ** 2
        self.cls_flag = config.vit_cls_flag
        self.embd_dim = config.vit_hidden_dim

        self.conv = nn.Conv2d(
            in_channels=3,
            out_channels=self.embd_dim,
            kernel_size=self.patch_size,
            stride=self.patch_size,
            padding="valid",
        )

        if self.cls_flag:
            self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embd_dim))  # (B, 1, D)
            self.position_embeddings = nn.Parameter(
                torch.zeros(1, self.num_patches + 1, self.embd_dim)
            )  # (B, P+1, D)
        else:
            self.position_embeddings = nn.Parameter(
                torch.zeros(1, self.num_patches, self.embd_dim)
            )

    def forward(self, imgs: torch.Tensor):
        # (B, C, H, W) -> (B, D, H // P, W // P)
        x = self.conv(imgs)
        # (B, D, H // P, W // P) -> (B, H // P * W // P, D)
        x = x.flatten(2).transpose(1, 2)

        if self.cls_flag:
            cls_tokens = self.cls_token.expand(x.shape[0], -1, -1)
            x = torch.cat((cls_tokens, x), dim=1)  # (B, P+1, D)

        x = x + self.position_embeddings  # (B, P+1, D) or (B, P, D)
        return x

1.2 Multi Head Self-Attention

After we get the tokens from the Patch Embedding layer, we can feed those tokens into the transformer block. The transformer block consists of a Multi-Head Self-Attention (MHSA) layer and a Feed Forward Network (FFN). The MHSA layer allows the model to attend to different parts of the input sequence simultaneously, capturing complex relationships between patches. The FFN is a simple feed-forward neural network that processes the output of the MHSA layer.

def scale_dot_product_attention(
    q: torch.Tensor,
    k: torch.Tensor,
    v: torch.Tensor,
    mask: torch.Tensor | None = None,
    dropout: float = 0.0,
):

    d_k = q.shape[-1]

    # Compute the dot product attention scores
    attn_scores = torch.matmul(q, k.transpose(-2, -1)) / (
        d_k**0.5
    )  # Scale by the square root of the dimension
    if mask is not None:
        attn_scores = attn_scores.masked_fill(mask == 0, float("-inf"))
    attn_weights = torch.softmax(attn_scores, dim=-1)
    if dropout > 0.0:
        attn_weights = torch.nn.functional.dropout(
            attn_weights, p=dropout, training=True
        )
    # Compute the attention output
    attn_output = torch.matmul(attn_weights, v)  # Shape: (B, S_q, D)

    return attn_output, attn_weights



class ViTMultiHeadAttention(nn.Module):
    def __init__(self, config: VLMConfig):
        super().__init__()

        self.n_heads = config.vit_n_heads
        self.embd_dim = config.vit_hidden_dim

        assert (
            self.embd_dim % self.n_heads == 0
        ), "embd_dim must be divisible by num_heads"
        self.head_dim = self.embd_dim // self.n_heads

        self.dropout = config.vit_dropout

        self.qkv_proj = nn.Linear(self.embd_dim, 3 * self.embd_dim)
        self.out_proj = nn.Linear(self.embd_dim, self.embd_dim)

        # Dropout layer
        self.attn_dropout = nn.Dropout(self.dropout)
        self.resid_dropout = nn.Dropout(self.dropout)

        # Use scaled dot product attention
        self.sdpa = hasattr(F, "scaled_dot_product_attention")
        if not self.sdpa:
            print(
                "Warning: Scaled Dot Product Attention not available. Using custom implementation."
            )

    def forward(self, x: torch.Tensor):
        B, T, C = x.size()

        q, k, v = map(
            lambda t: t.view(B, T, self.n_heads, self.head_dim).transpose(1, 2),
            self.qkv_proj(x).chunk(3, dim=-1),
        )

        if self.sdpa:
            y = F.scaled_dot_product_attention(
                q, k, v, dropout_p=self.dropout if self.training else 0.0
            )
        else:
            y, _ = scale_dot_product_attention(
                q=q, k=k, v=v, dropout=self.dropout if self.training else 0.0
            )

        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.out_proj(y)

        return self.resid_dropout(y)

1.3 Feed Forward Network

The Feed Forward Network (FFN) is a simple two-layer fully connected network with a GeLU activation function in between. It processes the output of the MHSA layer and applies a residual connection to the input.

class MLP(nn.Module):
    def __init__(self, config: VLMConfig):
        super().__init__()

        self.activation_fn = nn.GELU(approximate="tanh")
        self.fc1 = nn.Linear(config.vit_hidden_dim, config.vit_inter_dim)
        self.fc2 = nn.Linear(config.vit_inter_dim, config.vit_hidden_dim)
        self.dropout = nn.Dropout(config.vit_dropout)

    def forward(self, x: torch.Tensor):
        x = self.fc1(x)
        x = self.activation_fn(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

1.4 Transformer Block

After define the MHSA and FFN layers, we can combine them into a single transformer block. The transformer block applies layer normalization before the MHSA and FFN layers(pre-norm), and it also includes residual connections to help with training stability.

class ViTBlock(nn.Module):
    def __init__(self, config: VLMConfig):
        super().__init__()

        self.attn = ViTMultiHeadAttention(config)
        self.mlp = MLP(config)
        self.ln1 = nn.LayerNorm(config.vit_hidden_dim, eps=config.vit_ln_eps)
        self.ln2 = nn.LayerNorm(config.vit_hidden_dim, eps=config.vit_ln_eps)

    def forward(self, x: torch.Tensor):
        # Layer normalization and multi-head attention
        x = x + self.attn(self.ln1(x))
        # Layer normalization and MLP
        x = x + self.mlp(self.ln2(x))
        return x

1.5 Vision Transformer

Finally, we can combine the Patch Embedding layer and the transformer blocks to create the Vision Transformer. The Vision Transformer consists of a series of transformer blocks stacked on top of each other, with the output of the last block being used for classification or further processing.

class ViT(nn.Module):
    def __init__(self, config: VLMConfig):
        super().__init__()

        self.config = config

        self.patch_embedding = ViTPatchEmbedding(config)

        self.cls_flag = config.vit_cls_flag
        self.dropout = nn.Dropout(config.vit_dropout)

        self.blocks = nn.ModuleList(
            [ViTBlock(config) for _ in range(config.vit_n_blocks)]
        )

        self.layer_norm = nn.LayerNorm(config.vit_hidden_dim, eps=config.vit_ln_eps)

        self.apply(self._init_weights)
    
    def forward(self, imgs: torch.Tensor):
        x = self.patch_embedding(imgs)
        x = self.dropout(x)

        for block in self.blocks:
            x = block(x)

        if self.cls_flag:
            x = x[:, 0]
        else:
            x = self.layer_norm(x)

        return x

So, that all we need for the Vision Encoder. After we get the output of the

1.6 Modality Projection

Vision Encoder, we need to project the output into the semantic embedding space to match the text embedding space. This is done using a linear projection layer. One small trick used here is pixel shuffle(Shi et al. 2016), which is used to reduce the number of tokens.

Figure 2: Illustration of the pixel shuffle operation and un-shuffling process. The pixel shuffle operation rearranges the elements of a tensor to increase the spatial resolution, while the un-shuffle process reverses this operation.
class ModalityProjector(nn.Module):
    def __init__(self, config: VLMConfig):
        super().__init__()
        
        self.config = config 
        
        self.input_dim = config.vit_hidden_dim * (config.mp_pixel_shuffle_factor**2)    
        self.output_dim = config.lm_hidden_dim
        self.scale_factor = config.mp_pixel_shuffle_factor
        
        
        self.proj = nn.Linear(self.input_dim, self.output_dim, bias=False)
        self._init_weight()

    def _init_weight(self):
        nn.init.normal_(self.proj.weight, mean=0.0, std=0.02)
    
    def pixel_shuffle(self, x: torch.Tensor) -> torch.Tensor:
        B, S, D = x.size()
        assert S % 2 == 0, "Input sequence length must be even for pixel shuffle."
        assert S ** 0.5 % self.scale_factor == 0, "Input sequence length must be a perfect square for pixel shuffle."
        
        
        H, W = S, S 
        x = x.view(B, H, W, D) # Convert the flattened sequence into a 2D grid
        h_out = H // self.scale_factor
        w_out = W // self.scale_factor
        
        x = einops.rearrange(x, 
                             "b (h sf1) (w sf2) d -> b (h w) (d sf1 sf2)",
                             sf1=self.scale_factor, 
                             sf2=self.scale_factor
                             )
        
        return x

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pixel_shuffle(x)
        assert x.size(-1) == self.input_dim, f"Input dimension mismatch: expected {self.input_dim}, got {x.size(-1)}"
        
        x = self.proj(x)
        return x

After the pixel shuffle, the number of tokens is reduced by a factor of mp_pixel_shuffle_factor**2, which helps to reduce the computational cost while maintaining the semantic information of the visual input. The output of the Modality Projector is then ready to be fed into the Language Model (LM) for further processing.

2 Language Model

Dosovitskiy, Alexey, Lucas Beyer, Alexander Kolesnikov, Dirk Weissenborn, Xiaohua Zhai, Thomas Unterthiner, Mostafa Dehghani, et al. 2021. “An Image Is Worth 16x16 Words: Transformers for Image Recognition at Scale.” June 3, 2021. https://doi.org/10.48550/arXiv.2010.11929.
Shi, Wenzhe, Jose Caballero, Ferenc Huszár, Johannes Totz, Andrew P. Aitken, Rob Bishop, Daniel Rueckert, and Zehan Wang. 2016. “Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional Neural Network.” September 23, 2016. https://doi.org/10.48550/arXiv.1609.05158.