Commit a0b3ed3a authored by szr712's avatar szr712

Merge branch 'master' of http://gitlab.uiiai.com/szr/pinyin2hanzi

parents e33198ea 49413df3
......@@ -4,6 +4,11 @@ from Layers import EncoderLayer, DecoderLayer
from Embed import Embedder, PositionalEncoder
from Sublayers import Norm
import copy
# from kaldi.matrix import _matrix_ext, Vector, DoubleMatrix
# import kaldi.fstext as fst
# from pika_WFSTdecoder.sorted_matcher import SortedMatcher
# from pika_WFSTdecoder.beam_transducer import BeamMergeTransducer
import k2
def get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
......@@ -64,6 +69,42 @@ class TransformerForTokenClassification(nn.Module):
output = self.out(outputs)
return output
class ClassificationWithCLAS(TransformerForTokenClassification):
def __init__(self, src_vocab, trg_vocab, d_model, N, heads, dropout, contextual_bias_num):
super(ClassificationWithCLAS, self).__init__(src_vocab, trg_vocab, d_model, N, heads, dropout)
self.hanzi_embed = Embedder(trg_vocab, d_model)
self.phoneme_lstm = nn.LSTM(d_model, d_model // 2, bidirectional=True, batch_fisrt=True, dropout=dropout)
self.hanzi_lstm = nn.LSTM(d_model, d_model // 2, bidirectional=True, batch_fisrt=True, dropout=dropout)
self.glyph_lstm = nn.LSTM(d_model, d_model // 2, bidirectional=True, batch_fisrt=True, dropout=dropout)
self.cond_proj = nn.Linear(d_model * contextual_bias_num, d_model)
self.out = nn.Linear(d_model * 2, trg_vocab)
# self.phoneme_embed
# self.glyph_embed
def forward(self, src, src_mask, contextual_bias_phrases):
'''
src: [B, T],
src_mask: [B, T],
contextual_bias_phrases: dict {"hanzi":, "phoneme": Optional, "glyph": Optional,}
'''
encoder_outputs = self.encoder(src, src_mask)
biased_cond = []
biased_hanzi_phrases = self.hanzi_lstm(self.hanzi_embed(contextual_bias_phrases['hanzi']))
biased_cond.append(biased_hanzi_phrases)
if 'phoneme' in contextual_bias_phrases.keys():
biased_phoneme_phrases = self.phoneme_lstm(self.phoneme_embed(contextual_bias_phrases['phoneme']))
biased_cond.append(biased_phoneme_phrases)
if 'glyph' in contextual_bias_phrases.keys():
biased_glyph_phrases = self.glyph_lstm(self.glyph_embed(contextual_bias_phrases['glyph']))
biased_cond.append(biased_glyph_phrases)
biased_cond = self.cond_proj(torch.cat(biased_cond, dim=-1))
biased_cond = self.cross_attend(biased_cond, encoder_outputs)
outputs = torch.cat(encoder_outputs, biased_cond, dim=-1)
output = self.out(outputs)
return output
def get_model(opt, src_vocab, trg_vocab):
assert opt.d_model % opt.heads == 0
......@@ -84,6 +125,89 @@ def get_model(opt, src_vocab, trg_vocab):
return model
class HanziTransducerRescorer_k2:
def __init__(self, vocab, lm_dir='k2_WFSTdecoder/lm', clas_rescorer_model=None, beam_size=5,
lm=None, lm_scale=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]):
from k2_WFSTdecoder.graph_compiler import compile_GG
self.G_general_fst = compile_GG(lm_dir)
assert self.G_general_fst.requires_grad is False
# if not hasattr(self.G_general_fst, "lm_scores"):
# self.G_general_fst.lm_scores = self.G_general_fst.scores.clone()
# self.G_special_fst = compile_G(lm_dir, 'special')
self.search_beam = beam_size * 4
self.output_beam = beam_size
self.min_active_states = 1
self.max_active_states = 50
self.num_paths = 1000
self.lm_scale_list = lm_scale
self.nbest_scale = 0.5
self.clas_rescorer = clas_rescorer_model
self.vocab = vocab
def prepare_batch_logits(self, batch_logits, topk_filter=False):
batch_size = batch_logits.shape[0]
# batch_logits = torch.log10(batch_logits.detach())
batch_logits = torch.log10(torch.nn.functional.softmax(batch_logits.detach() / 8, dim=-1))
batch_logits[:, :, 0] = -999
eps_logits = torch.zeros(batch_size, batch_logits.shape[1], 1).to(batch_logits.device)
eps_logits[:, :, :] = -999
if topk_filter:
values, indexes = batch_logits.topk(k=10, dim=-1, largest=True, sorted=True)
batch_logits[:, :, :] = -999
for i in range(batch_logits.shape[1]):
not_min_v_indexes = (values[0][i] != values.min()).nonzero().view(-1)
batch_logits[0, i, indexes[0, i, not_min_v_indexes]] = values[0, i, not_min_v_indexes]
batch_logits = torch.cat([batch_logits, eps_logits], dim=-1)
#0
eps_logits = torch.zeros_like(batch_logits)
eps_logits[:, :, :] = -999
eps_logits[:, :, len(self.vocab.itos)] = 0
batch_logits = torch.cat([batch_logits, eps_logits], dim=0).transpose(0, 1).reshape(1, batch_logits.shape[1] * 2, -1)
batch_logits = batch_logits * 9
#这里的参数用于扩大batch_logits的数值差距,差距越大,对最终影响越大
return batch_logits
def rescore_batch(self, batch_logits):
'''
assert B == 1
batch_logits: [B, seq_length, hanzi_classes], device:cuda
# batch_logits: [B, seq_length, hidden_size], device:cuda
'''
from k2_WFSTdecoder.icefall.decode import get_lattice, rescore_with_n_best_list
from k2_WFSTdecoder.icefall.utils import get_texts
import ipdb
ipdb.set_trace()
batch_size = batch_logits.shape[0]
batch_text = ''.join([self.vocab.itos[x] for x in torch.argmax(batch_logits, dim=-1)[0][:]])
batch_logits = self.prepare_batch_logits(batch_logits, topk_filter=True)
supervision_segments = torch.tensor(
[[i, 0, batch_logits.shape[1]] for i in range(batch_size)],
dtype=torch.int32,
)
batch_lattice = get_lattice(
batch_logits,
self.G_general_fst,
supervision_segments,
search_beam=10000,
output_beam=1000,
min_active_states=self.min_active_states,
max_active_states=self.max_active_states,
)
return get_texts(k2.shortest_path(batch_lattice, use_double_scores=True))
def get_model_token_classification(opt, src_vocab, trg_vocab):
assert opt.d_model % opt.heads == 0
......
......@@ -218,9 +218,11 @@ if __name__ == "__main__":
ori_dir="./data/Chinese/train/ori"
hanzi_dir="./data/Chinese/train/hanzi"
pinyin_dir="./data/Chinese/train/pinyin"
os.makedirs(hanzi_dir, exist_ok=True)
os.makedirs(pinyin_dir, exist_ok=True)
for file in os.listdir(ori_dir):
build_corpus(os.path.join(ori_dir,file),
os.path.join(pinyin_dir,file), os.path.join(hanzi_dir,file))
print("Done")
print(f"{file} Done")
# build_corpus("./data/dev/dev_hanzi.txt",
# "./data/dev/dev_pinyin_split.txt", "./data/dev/dev_hanzi_split.txt")
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment