Commit a0b3ed3a authored by szr712's avatar szr712

Merge branch 'master' of http://gitlab.uiiai.com/szr/pinyin2hanzi

parents e33198ea 49413df3
...@@ -4,6 +4,11 @@ from Layers import EncoderLayer, DecoderLayer ...@@ -4,6 +4,11 @@ from Layers import EncoderLayer, DecoderLayer
from Embed import Embedder, PositionalEncoder from Embed import Embedder, PositionalEncoder
from Sublayers import Norm from Sublayers import Norm
import copy import copy
# from kaldi.matrix import _matrix_ext, Vector, DoubleMatrix
# import kaldi.fstext as fst
# from pika_WFSTdecoder.sorted_matcher import SortedMatcher
# from pika_WFSTdecoder.beam_transducer import BeamMergeTransducer
import k2
def get_clones(module, N): def get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
...@@ -64,6 +69,42 @@ class TransformerForTokenClassification(nn.Module): ...@@ -64,6 +69,42 @@ class TransformerForTokenClassification(nn.Module):
output = self.out(outputs) output = self.out(outputs)
return output return output
class ClassificationWithCLAS(TransformerForTokenClassification):
def __init__(self, src_vocab, trg_vocab, d_model, N, heads, dropout, contextual_bias_num):
super(ClassificationWithCLAS, self).__init__(src_vocab, trg_vocab, d_model, N, heads, dropout)
self.hanzi_embed = Embedder(trg_vocab, d_model)
self.phoneme_lstm = nn.LSTM(d_model, d_model // 2, bidirectional=True, batch_fisrt=True, dropout=dropout)
self.hanzi_lstm = nn.LSTM(d_model, d_model // 2, bidirectional=True, batch_fisrt=True, dropout=dropout)
self.glyph_lstm = nn.LSTM(d_model, d_model // 2, bidirectional=True, batch_fisrt=True, dropout=dropout)
self.cond_proj = nn.Linear(d_model * contextual_bias_num, d_model)
self.out = nn.Linear(d_model * 2, trg_vocab)
# self.phoneme_embed
# self.glyph_embed
def forward(self, src, src_mask, contextual_bias_phrases):
'''
src: [B, T],
src_mask: [B, T],
contextual_bias_phrases: dict {"hanzi":, "phoneme": Optional, "glyph": Optional,}
'''
encoder_outputs = self.encoder(src, src_mask)
biased_cond = []
biased_hanzi_phrases = self.hanzi_lstm(self.hanzi_embed(contextual_bias_phrases['hanzi']))
biased_cond.append(biased_hanzi_phrases)
if 'phoneme' in contextual_bias_phrases.keys():
biased_phoneme_phrases = self.phoneme_lstm(self.phoneme_embed(contextual_bias_phrases['phoneme']))
biased_cond.append(biased_phoneme_phrases)
if 'glyph' in contextual_bias_phrases.keys():
biased_glyph_phrases = self.glyph_lstm(self.glyph_embed(contextual_bias_phrases['glyph']))
biased_cond.append(biased_glyph_phrases)
biased_cond = self.cond_proj(torch.cat(biased_cond, dim=-1))
biased_cond = self.cross_attend(biased_cond, encoder_outputs)
outputs = torch.cat(encoder_outputs, biased_cond, dim=-1)
output = self.out(outputs)
return output
def get_model(opt, src_vocab, trg_vocab): def get_model(opt, src_vocab, trg_vocab):
assert opt.d_model % opt.heads == 0 assert opt.d_model % opt.heads == 0
...@@ -84,6 +125,89 @@ def get_model(opt, src_vocab, trg_vocab): ...@@ -84,6 +125,89 @@ def get_model(opt, src_vocab, trg_vocab):
return model return model
class HanziTransducerRescorer_k2:
def __init__(self, vocab, lm_dir='k2_WFSTdecoder/lm', clas_rescorer_model=None, beam_size=5,
lm=None, lm_scale=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]):
from k2_WFSTdecoder.graph_compiler import compile_GG
self.G_general_fst = compile_GG(lm_dir)
assert self.G_general_fst.requires_grad is False
# if not hasattr(self.G_general_fst, "lm_scores"):
# self.G_general_fst.lm_scores = self.G_general_fst.scores.clone()
# self.G_special_fst = compile_G(lm_dir, 'special')
self.search_beam = beam_size * 4
self.output_beam = beam_size
self.min_active_states = 1
self.max_active_states = 50
self.num_paths = 1000
self.lm_scale_list = lm_scale
self.nbest_scale = 0.5
self.clas_rescorer = clas_rescorer_model
self.vocab = vocab
def prepare_batch_logits(self, batch_logits, topk_filter=False):
batch_size = batch_logits.shape[0]
# batch_logits = torch.log10(batch_logits.detach())
batch_logits = torch.log10(torch.nn.functional.softmax(batch_logits.detach() / 8, dim=-1))
batch_logits[:, :, 0] = -999
eps_logits = torch.zeros(batch_size, batch_logits.shape[1], 1).to(batch_logits.device)
eps_logits[:, :, :] = -999
if topk_filter:
values, indexes = batch_logits.topk(k=10, dim=-1, largest=True, sorted=True)
batch_logits[:, :, :] = -999
for i in range(batch_logits.shape[1]):
not_min_v_indexes = (values[0][i] != values.min()).nonzero().view(-1)
batch_logits[0, i, indexes[0, i, not_min_v_indexes]] = values[0, i, not_min_v_indexes]
batch_logits = torch.cat([batch_logits, eps_logits], dim=-1)
#0
eps_logits = torch.zeros_like(batch_logits)
eps_logits[:, :, :] = -999
eps_logits[:, :, len(self.vocab.itos)] = 0
batch_logits = torch.cat([batch_logits, eps_logits], dim=0).transpose(0, 1).reshape(1, batch_logits.shape[1] * 2, -1)
batch_logits = batch_logits * 9
#这里的参数用于扩大batch_logits的数值差距,差距越大,对最终影响越大
return batch_logits
def rescore_batch(self, batch_logits):
'''
assert B == 1
batch_logits: [B, seq_length, hanzi_classes], device:cuda
# batch_logits: [B, seq_length, hidden_size], device:cuda
'''
from k2_WFSTdecoder.icefall.decode import get_lattice, rescore_with_n_best_list
from k2_WFSTdecoder.icefall.utils import get_texts
import ipdb
ipdb.set_trace()
batch_size = batch_logits.shape[0]
batch_text = ''.join([self.vocab.itos[x] for x in torch.argmax(batch_logits, dim=-1)[0][:]])
batch_logits = self.prepare_batch_logits(batch_logits, topk_filter=True)
supervision_segments = torch.tensor(
[[i, 0, batch_logits.shape[1]] for i in range(batch_size)],
dtype=torch.int32,
)
batch_lattice = get_lattice(
batch_logits,
self.G_general_fst,
supervision_segments,
search_beam=10000,
output_beam=1000,
min_active_states=self.min_active_states,
max_active_states=self.max_active_states,
)
return get_texts(k2.shortest_path(batch_lattice, use_double_scores=True))
def get_model_token_classification(opt, src_vocab, trg_vocab): def get_model_token_classification(opt, src_vocab, trg_vocab):
assert opt.d_model % opt.heads == 0 assert opt.d_model % opt.heads == 0
......
...@@ -218,9 +218,11 @@ if __name__ == "__main__": ...@@ -218,9 +218,11 @@ if __name__ == "__main__":
ori_dir="./data/Chinese/train/ori" ori_dir="./data/Chinese/train/ori"
hanzi_dir="./data/Chinese/train/hanzi" hanzi_dir="./data/Chinese/train/hanzi"
pinyin_dir="./data/Chinese/train/pinyin" pinyin_dir="./data/Chinese/train/pinyin"
os.makedirs(hanzi_dir, exist_ok=True)
os.makedirs(pinyin_dir, exist_ok=True)
for file in os.listdir(ori_dir): for file in os.listdir(ori_dir):
build_corpus(os.path.join(ori_dir,file), build_corpus(os.path.join(ori_dir,file),
os.path.join(pinyin_dir,file), os.path.join(hanzi_dir,file)) os.path.join(pinyin_dir,file), os.path.join(hanzi_dir,file))
print("Done") print(f"{file} Done")
# build_corpus("./data/dev/dev_hanzi.txt", # build_corpus("./data/dev/dev_hanzi.txt",
# "./data/dev/dev_pinyin_split.txt", "./data/dev/dev_hanzi_split.txt") # "./data/dev/dev_pinyin_split.txt", "./data/dev/dev_hanzi_split.txt")
This source diff could not be displayed because it is too large. You can view the blob instead.
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -14,7 +14,6 @@ from Beam import beam_search ...@@ -14,7 +14,6 @@ from Beam import beam_search
# from nltk.corpus import wordnet # from nltk.corpus import wordnet
from torch.autograd import Variable from torch.autograd import Variable
import re import re
import time
import random import random
import distance import distance
......
...@@ -17,14 +17,15 @@ def compile_GG(lm_dir: str) -> k2.Fsa: ...@@ -17,14 +17,15 @@ def compile_GG(lm_dir: str) -> k2.Fsa:
logging.info(f"Loading G_general.fst.txt") logging.info(f"Loading G_general.fst.txt")
with open(f"{lm_dir}/G_general.fst.txt") as f: with open(f"{lm_dir}/G_general.fst.txt") as f:
G = k2.Fsa.from_openfst(f.read(), acceptor=False) G = k2.Fsa.from_openfst(f.read(), acceptor=False)
G = k2.arc_sort(G)
logging.info(f"Loading G_special.fst.txt") logging.info(f"Loading G_special.fst.txt")
with open(f"{lm_dir}/G_special.fst.txt") as f: with open(f"{lm_dir}/G_special.fst.txt") as f:
G_ = k2.Fsa.from_openfst(f.read(), acceptor=False) G_ = k2.Fsa.from_openfst(f.read(), acceptor=False)
G_ = k2.arc_sort(G) # G_ = k2.arc_sort(G_)
G = k2.invert(G) G = k2.invert(G)
G = k2.arc_sort(G)
G_ = k2.remove_epsilon_and_add_self_loops(G_) G_ = k2.remove_epsilon_and_add_self_loops(G_)
del G.aux_labels del G.aux_labels
GG = k2.intersect(G_, G) GG = k2.intersect(G_, G)
...@@ -35,6 +36,7 @@ def compile_GG(lm_dir: str) -> k2.Fsa: ...@@ -35,6 +36,7 @@ def compile_GG(lm_dir: str) -> k2.Fsa:
if i == 0: if i == 0:
continue continue
backward_GG_arcs[i, 0] = i + 1 backward_GG_arcs[i, 0] = i + 1
backward_GG_arcs[-1, 0] = 0 backward_GG_arcs[-1, 0] = 0
backward_GG_arcs[-1, 1] = max_V + 1 backward_GG_arcs[-1, 1] = max_V + 1
backward_GG_arcs[-1, 2] = -1 backward_GG_arcs[-1, 2] = -1
...@@ -47,9 +49,10 @@ def compile_GG(lm_dir: str) -> k2.Fsa: ...@@ -47,9 +49,10 @@ def compile_GG(lm_dir: str) -> k2.Fsa:
temp_index = arcs[:, 0].sort()[1] temp_index = arcs[:, 0].sort()[1]
arcs = arcs[temp_index] arcs = arcs[temp_index]
temp_aux = torch.cat([GG.aux_labels, torch.zeros(backward_GG_arcs.shape[0])], dim=-0)[temp_index] temp_aux = torch.cat([GG.aux_labels, torch.zeros(backward_GG_arcs.shape[0])], dim=-0)[temp_index]
# import ipdb
# ipdb.set_trace() GG = k2.determinize(GG, k2.DeterminizeWeightPushingType.kTropicalWeightPushing)
GG = k2.Fsa(arcs=arcs, aux_labels=temp_aux).to(device) GG = k2.Fsa(arcs=arcs, aux_labels=temp_aux).to(device)
return GG return GG
def compile_G(lm_dir: str, key: str) -> k2.Fsa: def compile_G(lm_dir: str, key: str) -> k2.Fsa:
......
No preview for this file type
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment