Spaces:
Runtime error
Runtime error
| # ------------------------------------------ | |
| # TextDiffuser: Diffusion Models as Text Painters | |
| # Paper Link: https://arxiv.org/abs/2305.10855 | |
| # Code Link: https://github.com/microsoft/unilm/tree/master/textdiffuser | |
| # Copyright (c) Microsoft Corporation. | |
| # This file define the Layout Transformer for predicting the layout of keywords. | |
| # ------------------------------------------ | |
| import torch | |
| import torch.nn as nn | |
| from transformers import CLIPTokenizer, CLIPTextModel | |
| class TextConditioner(nn.Module): | |
| def __init__(self): | |
| super(TextConditioner, self).__init__() | |
| self.transformer = CLIPTextModel.from_pretrained('openai/clip-vit-large-patch14') | |
| self.tokenizer = CLIPTokenizer.from_pretrained('openai/clip-vit-large-patch14') | |
| # fix | |
| self.transformer.eval() | |
| for param in self.transformer.parameters(): | |
| param.requires_grad = False | |
| def forward(self, prompt_list): | |
| batch_encoding = self.tokenizer(prompt_list, truncation=True, max_length=77, return_length=True, return_overflowing_tokens=False, padding="max_length", return_tensors="pt") | |
| text_embedding = self.transformer(batch_encoding["input_ids"].cuda()) | |
| return text_embedding.last_hidden_state.cuda(), batch_encoding["attention_mask"].cuda() # 1, 77, 768 / 1, 768 | |
| class LayoutTransformer(nn.Module): | |
| def __init__(self, layer_number=2): | |
| super(LayoutTransformer, self).__init__() | |
| self.encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8) | |
| self.transformer = torch.nn.TransformerEncoder( | |
| self.encoder_layer, num_layers=layer_number | |
| ) | |
| self.decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8) | |
| self.decoder_transformer = torch.nn.TransformerDecoder( | |
| self.decoder_layer, num_layers=layer_number | |
| ) | |
| self.mask_embedding = nn.Embedding(2,512) | |
| self.length_embedding = nn.Embedding(256,512) | |
| self.width_embedding = nn.Embedding(256,512) | |
| self.position_embedding = nn.Embedding(256,512) | |
| self.state_embedding = nn.Embedding(256,512) | |
| self.match_embedding = nn.Embedding(256,512) | |
| self.x_embedding = nn.Embedding(512,512) | |
| self.y_embedding = nn.Embedding(512,512) | |
| self.w_embedding = nn.Embedding(512,512) | |
| self.h_embedding = nn.Embedding(512,512) | |
| self.encoder_target_embedding = nn.Embedding(256,512) | |
| self.input_layer = nn.Sequential( | |
| nn.Linear(768, 512), | |
| nn.ReLU(), | |
| nn.Linear(512, 512), | |
| ) | |
| self.output_layer = nn.Sequential( | |
| nn.Linear(512, 128), | |
| nn.ReLU(), | |
| nn.Linear(128, 4), | |
| ) | |
| def forward(self, x, length, width, mask, state, match, target, right_shifted_boxes, train=False, encoder_embedding=None): | |
| # detect whether the encoder_embedding is cached | |
| if encoder_embedding is None: | |
| # augmentation | |
| if train: | |
| width = width + torch.randint(-3, 3, (width.shape[0], width.shape[1])).cuda() | |
| x = self.input_layer(x) # (1, 77, 512) | |
| width_embedding = self.width_embedding(torch.clamp(width, 0, 255).long()) # (1, 77, 512) | |
| encoder_target_embedding = self.encoder_target_embedding(target[:,:,0].long()) # (1, 77, 512) | |
| pe_embedding = self.position_embedding(torch.arange(77).cuda()).unsqueeze(0) # (1, 77, 512) | |
| total_embedding = x + width_embedding + pe_embedding + encoder_target_embedding # combine all the embeddings (1, 77, 512) | |
| total_embedding = total_embedding.permute(1,0,2) # (77, 1, 512) | |
| encoder_embedding = self.transformer(total_embedding) # (77, 1, 512) | |
| right_shifted_boxes_resize = (right_shifted_boxes * 512).long() # (1, 8, 4) | |
| right_shifted_boxes_resize = torch.clamp(right_shifted_boxes_resize, 0, 511) # (1, 8, 4) | |
| # decoder pe | |
| pe_decoder = torch.arange(8).cuda() # (8, ) | |
| pe_embedding_decoder = self.position_embedding(pe_decoder).unsqueeze(0) # (1, 8, 512) | |
| decoder_input = pe_embedding_decoder + self.x_embedding(right_shifted_boxes_resize[:,:,0]) + self.y_embedding(right_shifted_boxes_resize[:,:,1]) + self.w_embedding(right_shifted_boxes_resize[:,:,2]) + self.h_embedding(right_shifted_boxes_resize[:,:,3]) # (1, 8, 512) | |
| decoder_input = decoder_input.permute(1,0,2) # (8, 1, 512) | |
| # generate triangular mask | |
| mask = nn.Transformer.generate_square_subsequent_mask(8) # (8, 8) | |
| mask = mask.cuda() # (8, 8) | |
| decoder_result = self.decoder_transformer(decoder_input, encoder_embedding, tgt_mask=mask) # (8, 1, 512) | |
| decoder_result = decoder_result.permute(1,0,2) # (1, 8, 512) | |
| box_prediction = self.output_layer(decoder_result) # (1, 8, 4) | |
| return box_prediction, encoder_embedding | |