Commit 909da653 authored by 李承曦(20硕)'s avatar 李承曦(20硕)

update WFST

parent f6f3e5dc
...@@ -4,6 +4,11 @@ from Layers import EncoderLayer, DecoderLayer ...@@ -4,6 +4,11 @@ from Layers import EncoderLayer, DecoderLayer
from Embed import Embedder, PositionalEncoder from Embed import Embedder, PositionalEncoder
from Sublayers import Norm from Sublayers import Norm
import copy import copy
# from kaldi.matrix import _matrix_ext, Vector, DoubleMatrix
# import kaldi.fstext as fst
# from pika_WFSTdecoder.sorted_matcher import SortedMatcher
# from pika_WFSTdecoder.beam_transducer import BeamMergeTransducer
import k2
def get_clones(module, N): def get_clones(module, N):
return nn.ModuleList([copy.deepcopy(module) for i in range(N)]) return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
...@@ -64,6 +69,42 @@ class TransformerForTokenClassification(nn.Module): ...@@ -64,6 +69,42 @@ class TransformerForTokenClassification(nn.Module):
output = self.out(outputs) output = self.out(outputs)
return output return output
class ClassificationWithCLAS(TransformerForTokenClassification):
def __init__(self, src_vocab, trg_vocab, d_model, N, heads, dropout, contextual_bias_num):
super(ClassificationWithCLAS, self).__init__(src_vocab, trg_vocab, d_model, N, heads, dropout)
self.hanzi_embed = Embedder(trg_vocab, d_model)
self.phoneme_lstm = nn.LSTM(d_model, d_model // 2, bidirectional=True, batch_fisrt=True, dropout=dropout)
self.hanzi_lstm = nn.LSTM(d_model, d_model // 2, bidirectional=True, batch_fisrt=True, dropout=dropout)
self.glyph_lstm = nn.LSTM(d_model, d_model // 2, bidirectional=True, batch_fisrt=True, dropout=dropout)
self.cond_proj = nn.Linear(d_model * contextual_bias_num, d_model)
self.out = nn.Linear(d_model * 2, trg_vocab)
# self.phoneme_embed
# self.glyph_embed
def forward(self, src, src_mask, contextual_bias_phrases):
'''
src: [B, T],
src_mask: [B, T],
contextual_bias_phrases: dict {"hanzi":, "phoneme": Optional, "glyph": Optional,}
'''
encoder_outputs = self.encoder(src, src_mask)
biased_cond = []
biased_hanzi_phrases = self.hanzi_lstm(self.hanzi_embed(contextual_bias_phrases['hanzi']))
biased_cond.append(biased_hanzi_phrases)
if 'phoneme' in contextual_bias_phrases.keys():
biased_phoneme_phrases = self.phoneme_lstm(self.phoneme_embed(contextual_bias_phrases['phoneme']))
biased_cond.append(biased_phoneme_phrases)
if 'glyph' in contextual_bias_phrases.keys():
biased_glyph_phrases = self.glyph_lstm(self.glyph_embed(contextual_bias_phrases['glyph']))
biased_cond.append(biased_glyph_phrases)
biased_cond = self.cond_proj(torch.cat(biased_cond, dim=-1))
biased_cond = self.cross_attend(biased_cond, encoder_outputs)
outputs = torch.cat(encoder_outputs, biased_cond, dim=-1)
output = self.out(outputs)
return output
def get_model(opt, src_vocab, trg_vocab): def get_model(opt, src_vocab, trg_vocab):
assert opt.d_model % opt.heads == 0 assert opt.d_model % opt.heads == 0
...@@ -84,6 +125,89 @@ def get_model(opt, src_vocab, trg_vocab): ...@@ -84,6 +125,89 @@ def get_model(opt, src_vocab, trg_vocab):
return model return model
class HanziTransducerRescorer_k2:
def __init__(self, vocab, lm_dir='k2_WFSTdecoder/lm', clas_rescorer_model=None, beam_size=5,
lm=None, lm_scale=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0]):
from k2_WFSTdecoder.graph_compiler import compile_GG
self.G_general_fst = compile_GG(lm_dir)
assert self.G_general_fst.requires_grad is False
# if not hasattr(self.G_general_fst, "lm_scores"):
# self.G_general_fst.lm_scores = self.G_general_fst.scores.clone()
# self.G_special_fst = compile_G(lm_dir, 'special')
self.search_beam = beam_size * 4
self.output_beam = beam_size
self.min_active_states = 1
self.max_active_states = 50
self.num_paths = 1000
self.lm_scale_list = lm_scale
self.nbest_scale = 0.5
self.clas_rescorer = clas_rescorer_model
self.vocab = vocab
def prepare_batch_logits(self, batch_logits, topk_filter=False):
batch_size = batch_logits.shape[0]
# batch_logits = torch.log10(batch_logits.detach())
batch_logits = torch.log10(torch.nn.functional.softmax(batch_logits.detach() / 8, dim=-1))
batch_logits[:, :, 0] = -999
eps_logits = torch.zeros(batch_size, batch_logits.shape[1], 1).to(batch_logits.device)
eps_logits[:, :, :] = -999
if topk_filter:
values, indexes = batch_logits.topk(k=10, dim=-1, largest=True, sorted=True)
batch_logits[:, :, :] = -999
for i in range(batch_logits.shape[1]):
not_min_v_indexes = (values[0][i] != values.min()).nonzero().view(-1)
batch_logits[0, i, indexes[0, i, not_min_v_indexes]] = values[0, i, not_min_v_indexes]
batch_logits = torch.cat([batch_logits, eps_logits], dim=-1)
#0
eps_logits = torch.zeros_like(batch_logits)
eps_logits[:, :, :] = -999
eps_logits[:, :, len(self.vocab.itos)] = 0
batch_logits = torch.cat([batch_logits, eps_logits], dim=0).transpose(0, 1).reshape(1, batch_logits.shape[1] * 2, -1)
batch_logits = batch_logits * 9
#这里的参数用于扩大batch_logits的数值差距,差距越大,对最终影响越大
return batch_logits
def rescore_batch(self, batch_logits):
'''
assert B == 1
batch_logits: [B, seq_length, hanzi_classes], device:cuda
# batch_logits: [B, seq_length, hidden_size], device:cuda
'''
from k2_WFSTdecoder.icefall.decode import get_lattice, rescore_with_n_best_list
from k2_WFSTdecoder.icefall.utils import get_texts
import ipdb
ipdb.set_trace()
batch_size = batch_logits.shape[0]
batch_text = ''.join([self.vocab.itos[x] for x in torch.argmax(batch_logits, dim=-1)[0][:]])
batch_logits = self.prepare_batch_logits(batch_logits, topk_filter=True)
supervision_segments = torch.tensor(
[[i, 0, batch_logits.shape[1]] for i in range(batch_size)],
dtype=torch.int32,
)
batch_lattice = get_lattice(
batch_logits,
self.G_general_fst,
supervision_segments,
search_beam=10000,
output_beam=1000,
min_active_states=self.min_active_states,
max_active_states=self.max_active_states,
)
return get_texts(k2.shortest_path(batch_lattice, use_double_scores=True))
def get_model_token_classification(opt, src_vocab, trg_vocab): def get_model_token_classification(opt, src_vocab, trg_vocab):
assert opt.d_model % opt.heads == 0 assert opt.d_model % opt.heads == 0
......
...@@ -218,9 +218,11 @@ if __name__ == "__main__": ...@@ -218,9 +218,11 @@ if __name__ == "__main__":
ori_dir="./data/Chinese/train/ori" ori_dir="./data/Chinese/train/ori"
hanzi_dir="./data/Chinese/train/hanzi" hanzi_dir="./data/Chinese/train/hanzi"
pinyin_dir="./data/Chinese/train/pinyin" pinyin_dir="./data/Chinese/train/pinyin"
os.makedirs(hanzi_dir, exist_ok=True)
os.makedirs(pinyin_dir, exist_ok=True)
for file in os.listdir(ori_dir): for file in os.listdir(ori_dir):
build_corpus(os.path.join(ori_dir,file), build_corpus(os.path.join(ori_dir,file),
os.path.join(pinyin_dir,file), os.path.join(hanzi_dir,file)) os.path.join(pinyin_dir,file), os.path.join(hanzi_dir,file))
print("Done") print(f"{file} Done")
# build_corpus("./data/dev/dev_hanzi.txt", # build_corpus("./data/dev/dev_hanzi.txt",
# "./data/dev/dev_pinyin_split.txt", "./data/dev/dev_hanzi_split.txt") # "./data/dev/dev_pinyin_split.txt", "./data/dev/dev_hanzi_split.txt")
This source diff could not be displayed because it is too large. You can view the blob instead.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment