Source code for onmt.modules.structured_attention

import torch.nn as nn
import torch
import torch.cuda


[docs]class MatrixTree(nn.Module): """Implementation of the matrix-tree theorem for computing marginals of non-projective dependency parsing. This attention layer is used in the paper "Learning Structured Text Representations" :cite:`DBLP:journals/corr/LiuL17d`. """ def __init__(self, eps=1e-5): self.eps = eps super(MatrixTree, self).__init__()
[docs] def forward(self, input): laplacian = input.exp() + self.eps output = input.clone() for b in range(input.size(0)): lap = laplacian[b].masked_fill( torch.eye(input.size(1), device=input.device).ne(0), 0 ) lap = -lap + torch.diag(lap.sum(0)) # store roots on diagonal lap[0] = input[b].diag().exp() inv_laplacian = lap.inverse() factor = ( inv_laplacian.diag().unsqueeze(1).expand_as(input[b]).transpose(0, 1) ) term1 = input[b].exp().mul(factor).clone() term2 = input[b].exp().mul(inv_laplacian.transpose(0, 1)).clone() term1[:, 0] = 0 term2[0] = 0 output[b] = term1 - term2 roots_output = input[b].diag().exp().mul(inv_laplacian.transpose(0, 1)[0]) output[b] = output[b] + torch.diag(roots_output) return output