class MeanPooling(nn.Module):
def __init__(self):
super(MeanPooling, self).__init__()
def forward(self, last_hidden_state, attention_mask):
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
sum_mask = input_mask_expanded.sum(1)
sum_mask = torch.clamp(sum_mask, min = 1e-9)
mean_embeddings = sum_embeddings/sum_mask
return mean_embeddings
class MaxPooling(nn.Module):
def __init__(self):
super(MaxPooling, self).__init__()
def forward(self, last_hidden_state, attention_mask):
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
embeddings = last_hidden_state.clone()
embeddings[input_mask_expanded == 0] = -1e4
max_embeddings, _ = torch.max(embeddings, dim = 1)
return max_embeddings
class MinPooling(nn.Module):
def __init__(self):
super(MinPooling, self).__init__()
def forward(self, last_hidden_state, attention_mask):
input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
embeddings = last_hidden_state.clone()
embeddings[input_mask_expanded == 0] = 1e-4
min_embeddings, _ = torch.min(embeddings, dim = 1)
return min_embeddings
class AttentionPooling(nn.Module):
def __init__(self, in_dim):
super().__init__()
self.attention = nn.Sequential(
nn.Linear(in_dim, in_dim),
nn.LayerNorm(in_dim),
nn.GELU(),
nn.Linear(in_dim, 1),
)
def forward(self, last_hidden_state, attention_mask):
w = self.attention(last_hidden_state).float()
w[attention_mask==0]=float('-inf')
w = torch.softmax(w,1)
attention_embeddings = torch.sum(w * last_hidden_state, dim=1)
return attention_embeddings
class WeightedLayerPooling(nn.Module):
def __init__(self, num_hidden_layers, layer_start: int = 4, layer_weights = None):
super(WeightedLayerPooling, self).__init__()
self.layer_start = layer_start
self.num_hidden_layers = num_hidden_layers
self.layer_weights = layer_weights if layer_weights is not None \
else nn.Parameter(
torch.tensor([1] * (num_hidden_layers+1 - layer_start), dtype=torch.float)
)
def forward(self, ft_all_layers):
all_layer_embedding = torch.stack(ft_all_layers)
all_layer_embedding = all_layer_embedding[self.layer_start:, :, :, :]
weight_factor = self.layer_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand(all_layer_embedding.size())
weighted_average = (weight_factor*all_layer_embedding).sum(dim=0) / self.layer_weights.sum()
return weighted_average