mirror of
https://github.com/labring/FastGPT.git
synced 2025-07-28 00:56:26 +00:00
add sensevoice & cosevoice (#2562)
Signed-off-by: EthanD <EthanD4869@gmail.com> Co-authored-by: EthanD <EthanD4869@gmail.com>
This commit is contained in:
895
python/sensevoice/app/model.py
Normal file
895
python/sensevoice/app/model.py
Normal file
@@ -0,0 +1,895 @@
|
||||
|
||||
import time
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from typing import Iterable, Optional
|
||||
|
||||
from funasr.register import tables
|
||||
from funasr.models.ctc.ctc import CTC
|
||||
from funasr.utils.datadir_writer import DatadirWriter
|
||||
from funasr.models.paraformer.search import Hypothesis
|
||||
from funasr.train_utils.device_funcs import force_gatherable
|
||||
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
|
||||
from funasr.metrics.compute_acc import compute_accuracy, th_accuracy
|
||||
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
|
||||
|
||||
|
||||
class SinusoidalPositionEncoder(torch.nn.Module):
|
||||
""" """
|
||||
|
||||
def __int__(self, d_model=80, dropout_rate=0.1):
|
||||
pass
|
||||
|
||||
def encode(
|
||||
self, positions: torch.Tensor = None, depth: int = None, dtype: torch.dtype = torch.float32
|
||||
):
|
||||
batch_size = positions.size(0)
|
||||
positions = positions.type(dtype)
|
||||
device = positions.device
|
||||
log_timescale_increment = torch.log(torch.tensor([10000], dtype=dtype, device=device)) / (
|
||||
depth / 2 - 1
|
||||
)
|
||||
inv_timescales = torch.exp(
|
||||
torch.arange(depth / 2, device=device).type(dtype) * (-log_timescale_increment)
|
||||
)
|
||||
inv_timescales = torch.reshape(inv_timescales, [batch_size, -1])
|
||||
scaled_time = torch.reshape(positions, [1, -1, 1]) * torch.reshape(
|
||||
inv_timescales, [1, 1, -1]
|
||||
)
|
||||
encoding = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=2)
|
||||
return encoding.type(dtype)
|
||||
|
||||
def forward(self, x):
|
||||
batch_size, timesteps, input_dim = x.size()
|
||||
positions = torch.arange(1, timesteps + 1, device=x.device)[None, :]
|
||||
position_encoding = self.encode(positions, input_dim, x.dtype).to(x.device)
|
||||
|
||||
return x + position_encoding
|
||||
|
||||
|
||||
class PositionwiseFeedForward(torch.nn.Module):
|
||||
"""Positionwise feed forward layer.
|
||||
|
||||
Args:
|
||||
idim (int): Input dimenstion.
|
||||
hidden_units (int): The number of hidden units.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim, hidden_units, dropout_rate, activation=torch.nn.ReLU()):
|
||||
"""Construct an PositionwiseFeedForward object."""
|
||||
super(PositionwiseFeedForward, self).__init__()
|
||||
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
||||
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
self.activation = activation
|
||||
|
||||
def forward(self, x):
|
||||
"""Forward function."""
|
||||
return self.w_2(self.dropout(self.activation(self.w_1(x))))
|
||||
|
||||
|
||||
class MultiHeadedAttentionSANM(nn.Module):
|
||||
"""Multi-Head Attention layer.
|
||||
|
||||
Args:
|
||||
n_head (int): The number of heads.
|
||||
n_feat (int): The number of features.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_head,
|
||||
in_feat,
|
||||
n_feat,
|
||||
dropout_rate,
|
||||
kernel_size,
|
||||
sanm_shfit=0,
|
||||
lora_list=None,
|
||||
lora_rank=8,
|
||||
lora_alpha=16,
|
||||
lora_dropout=0.1,
|
||||
):
|
||||
"""Construct an MultiHeadedAttention object."""
|
||||
super().__init__()
|
||||
assert n_feat % n_head == 0
|
||||
# We assume d_v always equals d_k
|
||||
self.d_k = n_feat // n_head
|
||||
self.h = n_head
|
||||
# self.linear_q = nn.Linear(n_feat, n_feat)
|
||||
# self.linear_k = nn.Linear(n_feat, n_feat)
|
||||
# self.linear_v = nn.Linear(n_feat, n_feat)
|
||||
|
||||
self.linear_out = nn.Linear(n_feat, n_feat)
|
||||
self.linear_q_k_v = nn.Linear(in_feat, n_feat * 3)
|
||||
self.attn = None
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
|
||||
self.fsmn_block = nn.Conv1d(
|
||||
n_feat, n_feat, kernel_size, stride=1, padding=0, groups=n_feat, bias=False
|
||||
)
|
||||
# padding
|
||||
left_padding = (kernel_size - 1) // 2
|
||||
if sanm_shfit > 0:
|
||||
left_padding = left_padding + sanm_shfit
|
||||
right_padding = kernel_size - 1 - left_padding
|
||||
self.pad_fn = nn.ConstantPad1d((left_padding, right_padding), 0.0)
|
||||
|
||||
def forward_fsmn(self, inputs, mask, mask_shfit_chunk=None):
|
||||
b, t, d = inputs.size()
|
||||
if mask is not None:
|
||||
mask = torch.reshape(mask, (b, -1, 1))
|
||||
if mask_shfit_chunk is not None:
|
||||
mask = mask * mask_shfit_chunk
|
||||
inputs = inputs * mask
|
||||
|
||||
x = inputs.transpose(1, 2)
|
||||
x = self.pad_fn(x)
|
||||
x = self.fsmn_block(x)
|
||||
x = x.transpose(1, 2)
|
||||
x += inputs
|
||||
x = self.dropout(x)
|
||||
if mask is not None:
|
||||
x = x * mask
|
||||
return x
|
||||
|
||||
def forward_qkv(self, x):
|
||||
"""Transform query, key and value.
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
|
||||
torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
|
||||
torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
|
||||
|
||||
"""
|
||||
b, t, d = x.size()
|
||||
q_k_v = self.linear_q_k_v(x)
|
||||
q, k, v = torch.split(q_k_v, int(self.h * self.d_k), dim=-1)
|
||||
q_h = torch.reshape(q, (b, t, self.h, self.d_k)).transpose(
|
||||
1, 2
|
||||
) # (batch, head, time1, d_k)
|
||||
k_h = torch.reshape(k, (b, t, self.h, self.d_k)).transpose(
|
||||
1, 2
|
||||
) # (batch, head, time2, d_k)
|
||||
v_h = torch.reshape(v, (b, t, self.h, self.d_k)).transpose(
|
||||
1, 2
|
||||
) # (batch, head, time2, d_k)
|
||||
|
||||
return q_h, k_h, v_h, v
|
||||
|
||||
def forward_attention(self, value, scores, mask, mask_att_chunk_encoder=None):
|
||||
"""Compute attention context vector.
|
||||
|
||||
Args:
|
||||
value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
|
||||
scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
|
||||
mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed value (#batch, time1, d_model)
|
||||
weighted by the attention score (#batch, time1, time2).
|
||||
|
||||
"""
|
||||
n_batch = value.size(0)
|
||||
if mask is not None:
|
||||
if mask_att_chunk_encoder is not None:
|
||||
mask = mask * mask_att_chunk_encoder
|
||||
|
||||
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
||||
|
||||
min_value = -float(
|
||||
"inf"
|
||||
) # float(numpy.finfo(torch.tensor(0, dtype=scores.dtype).numpy().dtype).min)
|
||||
scores = scores.masked_fill(mask, min_value)
|
||||
self.attn = torch.softmax(scores, dim=-1).masked_fill(
|
||||
mask, 0.0
|
||||
) # (batch, head, time1, time2)
|
||||
else:
|
||||
self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
||||
|
||||
p_attn = self.dropout(self.attn)
|
||||
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
||||
x = (
|
||||
x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
|
||||
) # (batch, time1, d_model)
|
||||
|
||||
return self.linear_out(x) # (batch, time1, d_model)
|
||||
|
||||
def forward(self, x, mask, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
|
||||
"""Compute scaled dot product attention.
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
||||
(#batch, time1, time2).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time1, d_model).
|
||||
|
||||
"""
|
||||
q_h, k_h, v_h, v = self.forward_qkv(x)
|
||||
fsmn_memory = self.forward_fsmn(v, mask, mask_shfit_chunk)
|
||||
q_h = q_h * self.d_k ** (-0.5)
|
||||
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
|
||||
att_outs = self.forward_attention(v_h, scores, mask, mask_att_chunk_encoder)
|
||||
return att_outs + fsmn_memory
|
||||
|
||||
def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
|
||||
"""Compute scaled dot product attention.
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
||||
(#batch, time1, time2).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time1, d_model).
|
||||
|
||||
"""
|
||||
q_h, k_h, v_h, v = self.forward_qkv(x)
|
||||
if chunk_size is not None and look_back > 0 or look_back == -1:
|
||||
if cache is not None:
|
||||
k_h_stride = k_h[:, :, : -(chunk_size[2]), :]
|
||||
v_h_stride = v_h[:, :, : -(chunk_size[2]), :]
|
||||
k_h = torch.cat((cache["k"], k_h), dim=2)
|
||||
v_h = torch.cat((cache["v"], v_h), dim=2)
|
||||
|
||||
cache["k"] = torch.cat((cache["k"], k_h_stride), dim=2)
|
||||
cache["v"] = torch.cat((cache["v"], v_h_stride), dim=2)
|
||||
if look_back != -1:
|
||||
cache["k"] = cache["k"][:, :, -(look_back * chunk_size[1]) :, :]
|
||||
cache["v"] = cache["v"][:, :, -(look_back * chunk_size[1]) :, :]
|
||||
else:
|
||||
cache_tmp = {
|
||||
"k": k_h[:, :, : -(chunk_size[2]), :],
|
||||
"v": v_h[:, :, : -(chunk_size[2]), :],
|
||||
}
|
||||
cache = cache_tmp
|
||||
fsmn_memory = self.forward_fsmn(v, None)
|
||||
q_h = q_h * self.d_k ** (-0.5)
|
||||
scores = torch.matmul(q_h, k_h.transpose(-2, -1))
|
||||
att_outs = self.forward_attention(v_h, scores, None)
|
||||
return att_outs + fsmn_memory, cache
|
||||
|
||||
|
||||
class LayerNorm(nn.LayerNorm):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def forward(self, input):
|
||||
output = F.layer_norm(
|
||||
input.float(),
|
||||
self.normalized_shape,
|
||||
self.weight.float() if self.weight is not None else None,
|
||||
self.bias.float() if self.bias is not None else None,
|
||||
self.eps,
|
||||
)
|
||||
return output.type_as(input)
|
||||
|
||||
|
||||
def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None):
|
||||
if maxlen is None:
|
||||
maxlen = lengths.max()
|
||||
row_vector = torch.arange(0, maxlen, 1).to(lengths.device)
|
||||
matrix = torch.unsqueeze(lengths, dim=-1)
|
||||
mask = row_vector < matrix
|
||||
mask = mask.detach()
|
||||
|
||||
return mask.type(dtype).to(device) if device is not None else mask.type(dtype)
|
||||
|
||||
|
||||
class EncoderLayerSANM(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_size,
|
||||
size,
|
||||
self_attn,
|
||||
feed_forward,
|
||||
dropout_rate,
|
||||
normalize_before=True,
|
||||
concat_after=False,
|
||||
stochastic_depth_rate=0.0,
|
||||
):
|
||||
"""Construct an EncoderLayer object."""
|
||||
super(EncoderLayerSANM, self).__init__()
|
||||
self.self_attn = self_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.norm1 = LayerNorm(in_size)
|
||||
self.norm2 = LayerNorm(size)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.in_size = in_size
|
||||
self.size = size
|
||||
self.normalize_before = normalize_before
|
||||
self.concat_after = concat_after
|
||||
if self.concat_after:
|
||||
self.concat_linear = nn.Linear(size + size, size)
|
||||
self.stochastic_depth_rate = stochastic_depth_rate
|
||||
self.dropout_rate = dropout_rate
|
||||
|
||||
def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
|
||||
"""Compute encoded features.
|
||||
|
||||
Args:
|
||||
x_input (torch.Tensor): Input tensor (#batch, time, size).
|
||||
mask (torch.Tensor): Mask tensor for the input (#batch, time).
|
||||
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time, size).
|
||||
torch.Tensor: Mask tensor (#batch, time).
|
||||
|
||||
"""
|
||||
skip_layer = False
|
||||
# with stochastic depth, residual connection `x + f(x)` becomes
|
||||
# `x <- x + 1 / (1 - p) * f(x)` at training time.
|
||||
stoch_layer_coeff = 1.0
|
||||
if self.training and self.stochastic_depth_rate > 0:
|
||||
skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
|
||||
stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
|
||||
|
||||
if skip_layer:
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
return x, mask
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
if self.concat_after:
|
||||
x_concat = torch.cat(
|
||||
(
|
||||
x,
|
||||
self.self_attn(
|
||||
x,
|
||||
mask,
|
||||
mask_shfit_chunk=mask_shfit_chunk,
|
||||
mask_att_chunk_encoder=mask_att_chunk_encoder,
|
||||
),
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
if self.in_size == self.size:
|
||||
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
|
||||
else:
|
||||
x = stoch_layer_coeff * self.concat_linear(x_concat)
|
||||
else:
|
||||
if self.in_size == self.size:
|
||||
x = residual + stoch_layer_coeff * self.dropout(
|
||||
self.self_attn(
|
||||
x,
|
||||
mask,
|
||||
mask_shfit_chunk=mask_shfit_chunk,
|
||||
mask_att_chunk_encoder=mask_att_chunk_encoder,
|
||||
)
|
||||
)
|
||||
else:
|
||||
x = stoch_layer_coeff * self.dropout(
|
||||
self.self_attn(
|
||||
x,
|
||||
mask,
|
||||
mask_shfit_chunk=mask_shfit_chunk,
|
||||
mask_att_chunk_encoder=mask_att_chunk_encoder,
|
||||
)
|
||||
)
|
||||
if not self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
|
||||
return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
|
||||
|
||||
def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
|
||||
"""Compute encoded features.
|
||||
|
||||
Args:
|
||||
x_input (torch.Tensor): Input tensor (#batch, time, size).
|
||||
mask (torch.Tensor): Mask tensor for the input (#batch, time).
|
||||
cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time, size).
|
||||
torch.Tensor: Mask tensor (#batch, time).
|
||||
|
||||
"""
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
if self.in_size == self.size:
|
||||
attn, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
|
||||
x = residual + attn
|
||||
else:
|
||||
x, cache = self.self_attn.forward_chunk(x, cache, chunk_size, look_back)
|
||||
|
||||
if not self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
x = residual + self.feed_forward(x)
|
||||
if not self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
|
||||
return x, cache
|
||||
|
||||
|
||||
@tables.register("encoder_classes", "SenseVoiceEncoderSmall")
|
||||
class SenseVoiceEncoderSmall(nn.Module):
|
||||
"""
|
||||
Author: Speech Lab of DAMO Academy, Alibaba Group
|
||||
SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
|
||||
https://arxiv.org/abs/2006.01713
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
tp_blocks: int = 0,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
stochastic_depth_rate: float = 0.0,
|
||||
input_layer: Optional[str] = "conv2d",
|
||||
pos_enc_class=SinusoidalPositionEncoder,
|
||||
normalize_before: bool = True,
|
||||
concat_after: bool = False,
|
||||
positionwise_layer_type: str = "linear",
|
||||
positionwise_conv_kernel_size: int = 1,
|
||||
padding_idx: int = -1,
|
||||
kernel_size: int = 11,
|
||||
sanm_shfit: int = 0,
|
||||
selfattention_layer_type: str = "sanm",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
|
||||
self.embed = SinusoidalPositionEncoder()
|
||||
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
positionwise_layer = PositionwiseFeedForward
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
)
|
||||
|
||||
encoder_selfattn_layer = MultiHeadedAttentionSANM
|
||||
encoder_selfattn_layer_args0 = (
|
||||
attention_heads,
|
||||
input_size,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
kernel_size,
|
||||
sanm_shfit,
|
||||
)
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
kernel_size,
|
||||
sanm_shfit,
|
||||
)
|
||||
|
||||
self.encoders0 = nn.ModuleList(
|
||||
[
|
||||
EncoderLayerSANM(
|
||||
input_size,
|
||||
output_size,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args0),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
)
|
||||
for i in range(1)
|
||||
]
|
||||
)
|
||||
self.encoders = nn.ModuleList(
|
||||
[
|
||||
EncoderLayerSANM(
|
||||
output_size,
|
||||
output_size,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
)
|
||||
for i in range(num_blocks - 1)
|
||||
]
|
||||
)
|
||||
|
||||
self.tp_encoders = nn.ModuleList(
|
||||
[
|
||||
EncoderLayerSANM(
|
||||
output_size,
|
||||
output_size,
|
||||
encoder_selfattn_layer(*encoder_selfattn_layer_args),
|
||||
positionwise_layer(*positionwise_layer_args),
|
||||
dropout_rate,
|
||||
)
|
||||
for i in range(tp_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.after_norm = LayerNorm(output_size)
|
||||
|
||||
self.tp_norm = LayerNorm(output_size)
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs_pad: torch.Tensor,
|
||||
ilens: torch.Tensor,
|
||||
):
|
||||
"""Embed positions in tensor."""
|
||||
masks = sequence_mask(ilens, device=ilens.device)[:, None, :]
|
||||
|
||||
xs_pad *= self.output_size() ** 0.5
|
||||
|
||||
xs_pad = self.embed(xs_pad)
|
||||
|
||||
# forward encoder1
|
||||
for layer_idx, encoder_layer in enumerate(self.encoders0):
|
||||
encoder_outs = encoder_layer(xs_pad, masks)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
|
||||
for layer_idx, encoder_layer in enumerate(self.encoders):
|
||||
encoder_outs = encoder_layer(xs_pad, masks)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
|
||||
xs_pad = self.after_norm(xs_pad)
|
||||
|
||||
# forward encoder2
|
||||
olens = masks.squeeze(1).sum(1).int()
|
||||
|
||||
for layer_idx, encoder_layer in enumerate(self.tp_encoders):
|
||||
encoder_outs = encoder_layer(xs_pad, masks)
|
||||
xs_pad, masks = encoder_outs[0], encoder_outs[1]
|
||||
|
||||
xs_pad = self.tp_norm(xs_pad)
|
||||
return xs_pad, olens
|
||||
|
||||
|
||||
@tables.register("model_classes", "SenseVoiceSmall")
|
||||
class SenseVoiceSmall(nn.Module):
|
||||
"""CTC-attention hybrid Encoder-Decoder model"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
specaug: str = None,
|
||||
specaug_conf: dict = None,
|
||||
normalize: str = None,
|
||||
normalize_conf: dict = None,
|
||||
encoder: str = None,
|
||||
encoder_conf: dict = None,
|
||||
ctc_conf: dict = None,
|
||||
input_size: int = 80,
|
||||
vocab_size: int = -1,
|
||||
ignore_id: int = -1,
|
||||
blank_id: int = 0,
|
||||
sos: int = 1,
|
||||
eos: int = 2,
|
||||
length_normalized_loss: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
|
||||
if specaug is not None:
|
||||
specaug_class = tables.specaug_classes.get(specaug)
|
||||
specaug = specaug_class(**specaug_conf)
|
||||
if normalize is not None:
|
||||
normalize_class = tables.normalize_classes.get(normalize)
|
||||
normalize = normalize_class(**normalize_conf)
|
||||
encoder_class = tables.encoder_classes.get(encoder)
|
||||
encoder = encoder_class(input_size=input_size, **encoder_conf)
|
||||
encoder_output_size = encoder.output_size()
|
||||
|
||||
if ctc_conf is None:
|
||||
ctc_conf = {}
|
||||
ctc = CTC(odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf)
|
||||
|
||||
self.blank_id = blank_id
|
||||
self.sos = sos if sos is not None else vocab_size - 1
|
||||
self.eos = eos if eos is not None else vocab_size - 1
|
||||
self.vocab_size = vocab_size
|
||||
self.ignore_id = ignore_id
|
||||
self.specaug = specaug
|
||||
self.normalize = normalize
|
||||
self.encoder = encoder
|
||||
self.error_calculator = None
|
||||
|
||||
self.ctc = ctc
|
||||
|
||||
self.length_normalized_loss = length_normalized_loss
|
||||
self.encoder_output_size = encoder_output_size
|
||||
|
||||
self.lid_dict = {"auto": 0, "zh": 3, "en": 4, "yue": 7, "ja": 11, "ko": 12, "nospeech": 13}
|
||||
self.lid_int_dict = {24884: 3, 24885: 4, 24888: 7, 24892: 11, 24896: 12, 24992: 13}
|
||||
self.textnorm_dict = {"withitn": 14, "woitn": 15}
|
||||
self.textnorm_int_dict = {25016: 14, 25017: 15}
|
||||
self.embed = torch.nn.Embedding(7 + len(self.lid_dict) + len(self.textnorm_dict), input_size)
|
||||
self.emo_dict = {"unk": 25009, "happy": 25001, "sad": 25002, "angry": 25003, "neutral": 25004}
|
||||
|
||||
self.criterion_att = LabelSmoothingLoss(
|
||||
size=self.vocab_size,
|
||||
padding_idx=self.ignore_id,
|
||||
smoothing=kwargs.get("lsm_weight", 0.0),
|
||||
normalize_length=self.length_normalized_loss,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def from_pretrained(model:str=None, **kwargs):
|
||||
from funasr import AutoModel
|
||||
model, kwargs = AutoModel.build_model(model=model, trust_remote_code=True, **kwargs)
|
||||
|
||||
return model, kwargs
|
||||
|
||||
def forward(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
**kwargs,
|
||||
):
|
||||
"""Encoder + Decoder + Calc loss
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
text: (Batch, Length)
|
||||
text_lengths: (Batch,)
|
||||
"""
|
||||
# import pdb;
|
||||
# pdb.set_trace()
|
||||
if len(text_lengths.size()) > 1:
|
||||
text_lengths = text_lengths[:, 0]
|
||||
if len(speech_lengths.size()) > 1:
|
||||
speech_lengths = speech_lengths[:, 0]
|
||||
|
||||
batch_size = speech.shape[0]
|
||||
|
||||
# 1. Encoder
|
||||
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, text)
|
||||
|
||||
loss_ctc, cer_ctc = None, None
|
||||
loss_rich, acc_rich = None, None
|
||||
stats = dict()
|
||||
|
||||
loss_ctc, cer_ctc = self._calc_ctc_loss(
|
||||
encoder_out[:, 4:, :], encoder_out_lens - 4, text[:, 4:], text_lengths - 4
|
||||
)
|
||||
|
||||
loss_rich, acc_rich = self._calc_rich_ce_loss(
|
||||
encoder_out[:, :4, :], text[:, :4]
|
||||
)
|
||||
|
||||
loss = loss_ctc
|
||||
# Collect total loss stats
|
||||
stats["loss"] = torch.clone(loss.detach()) if loss_ctc is not None else None
|
||||
stats["loss_rich"] = torch.clone(loss_rich.detach()) if loss_rich is not None else None
|
||||
stats["acc_rich"] = acc_rich
|
||||
|
||||
# force_gatherable: to-device and to-tensor if scalar for DataParallel
|
||||
if self.length_normalized_loss:
|
||||
batch_size = int((text_lengths + 1).sum())
|
||||
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
|
||||
return loss, stats, weight
|
||||
|
||||
def encode(
|
||||
self,
|
||||
speech: torch.Tensor,
|
||||
speech_lengths: torch.Tensor,
|
||||
text: torch.Tensor,
|
||||
**kwargs,
|
||||
):
|
||||
"""Frontend + Encoder. Note that this method is used by asr_inference.py
|
||||
Args:
|
||||
speech: (Batch, Length, ...)
|
||||
speech_lengths: (Batch, )
|
||||
ind: int
|
||||
"""
|
||||
|
||||
# Data augmentation
|
||||
if self.specaug is not None and self.training:
|
||||
speech, speech_lengths = self.specaug(speech, speech_lengths)
|
||||
|
||||
# Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
|
||||
if self.normalize is not None:
|
||||
speech, speech_lengths = self.normalize(speech, speech_lengths)
|
||||
|
||||
|
||||
lids = torch.LongTensor([[self.lid_int_dict[int(lid)] if torch.rand(1) > 0.2 and int(lid) in self.lid_int_dict else 0 ] for lid in text[:, 0]]).to(speech.device)
|
||||
language_query = self.embed(lids)
|
||||
|
||||
styles = torch.LongTensor([[self.textnorm_int_dict[int(style)]] for style in text[:, 3]]).to(speech.device)
|
||||
style_query = self.embed(styles)
|
||||
speech = torch.cat((style_query, speech), dim=1)
|
||||
speech_lengths += 1
|
||||
|
||||
event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(speech.size(0), 1, 1)
|
||||
input_query = torch.cat((language_query, event_emo_query), dim=1)
|
||||
speech = torch.cat((input_query, speech), dim=1)
|
||||
speech_lengths += 3
|
||||
|
||||
encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
|
||||
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def _calc_ctc_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
encoder_out_lens: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
ys_pad_lens: torch.Tensor,
|
||||
):
|
||||
# Calc CTC loss
|
||||
loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
|
||||
|
||||
# Calc CER using CTC
|
||||
cer_ctc = None
|
||||
if not self.training and self.error_calculator is not None:
|
||||
ys_hat = self.ctc.argmax(encoder_out).data
|
||||
cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
|
||||
return loss_ctc, cer_ctc
|
||||
|
||||
def _calc_rich_ce_loss(
|
||||
self,
|
||||
encoder_out: torch.Tensor,
|
||||
ys_pad: torch.Tensor,
|
||||
):
|
||||
decoder_out = self.ctc.ctc_lo(encoder_out)
|
||||
# 2. Compute attention loss
|
||||
loss_rich = self.criterion_att(decoder_out, ys_pad.contiguous())
|
||||
acc_rich = th_accuracy(
|
||||
decoder_out.view(-1, self.vocab_size),
|
||||
ys_pad.contiguous(),
|
||||
ignore_label=self.ignore_id,
|
||||
)
|
||||
|
||||
return loss_rich, acc_rich
|
||||
|
||||
|
||||
def inference(
|
||||
self,
|
||||
data_in,
|
||||
data_lengths=None,
|
||||
key: list = ["wav_file_tmp_name"],
|
||||
tokenizer=None,
|
||||
frontend=None,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
|
||||
meta_data = {}
|
||||
if (
|
||||
isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
|
||||
): # fbank
|
||||
speech, speech_lengths = data_in, data_lengths
|
||||
if len(speech.shape) < 3:
|
||||
speech = speech[None, :, :]
|
||||
if speech_lengths is None:
|
||||
speech_lengths = speech.shape[1]
|
||||
else:
|
||||
# extract fbank feats
|
||||
time1 = time.perf_counter()
|
||||
audio_sample_list = load_audio_text_image_video(
|
||||
data_in,
|
||||
fs=frontend.fs,
|
||||
audio_fs=kwargs.get("fs", 16000),
|
||||
data_type=kwargs.get("data_type", "sound"),
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
time2 = time.perf_counter()
|
||||
meta_data["load_data"] = f"{time2 - time1:0.3f}"
|
||||
speech, speech_lengths = extract_fbank(
|
||||
audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
|
||||
)
|
||||
time3 = time.perf_counter()
|
||||
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
|
||||
meta_data["batch_data_time"] = (
|
||||
speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
|
||||
)
|
||||
|
||||
speech = speech.to(device=kwargs["device"])
|
||||
speech_lengths = speech_lengths.to(device=kwargs["device"])
|
||||
|
||||
language = kwargs.get("language", "auto")
|
||||
language_query = self.embed(
|
||||
torch.LongTensor(
|
||||
[[self.lid_dict[language] if language in self.lid_dict else 0]]
|
||||
).to(speech.device)
|
||||
).repeat(speech.size(0), 1, 1)
|
||||
|
||||
use_itn = kwargs.get("use_itn", False)
|
||||
textnorm = kwargs.get("text_norm", None)
|
||||
if textnorm is None:
|
||||
textnorm = "withitn" if use_itn else "woitn"
|
||||
textnorm_query = self.embed(
|
||||
torch.LongTensor([[self.textnorm_dict[textnorm]]]).to(speech.device)
|
||||
).repeat(speech.size(0), 1, 1)
|
||||
speech = torch.cat((textnorm_query, speech), dim=1)
|
||||
speech_lengths += 1
|
||||
|
||||
event_emo_query = self.embed(torch.LongTensor([[1, 2]]).to(speech.device)).repeat(
|
||||
speech.size(0), 1, 1
|
||||
)
|
||||
input_query = torch.cat((language_query, event_emo_query), dim=1)
|
||||
speech = torch.cat((input_query, speech), dim=1)
|
||||
speech_lengths += 3
|
||||
|
||||
# Encoder
|
||||
encoder_out, encoder_out_lens = self.encoder(speech, speech_lengths)
|
||||
if isinstance(encoder_out, tuple):
|
||||
encoder_out = encoder_out[0]
|
||||
|
||||
# c. Passed the encoder result and the beam search
|
||||
ctc_logits = self.ctc.log_softmax(encoder_out)
|
||||
if kwargs.get("ban_emo_unk", False):
|
||||
ctc_logits[:, :, self.emo_dict["unk"]] = -float("inf")
|
||||
|
||||
results = []
|
||||
b, n, d = encoder_out.size()
|
||||
if isinstance(key[0], (list, tuple)):
|
||||
key = key[0]
|
||||
if len(key) < b:
|
||||
key = key * b
|
||||
for i in range(b):
|
||||
x = ctc_logits[i, : encoder_out_lens[i].item(), :]
|
||||
yseq = x.argmax(dim=-1)
|
||||
yseq = torch.unique_consecutive(yseq, dim=-1)
|
||||
|
||||
ibest_writer = None
|
||||
if kwargs.get("output_dir") is not None:
|
||||
if not hasattr(self, "writer"):
|
||||
self.writer = DatadirWriter(kwargs.get("output_dir"))
|
||||
ibest_writer = self.writer[f"1best_recog"]
|
||||
|
||||
mask = yseq != self.blank_id
|
||||
token_int = yseq[mask].tolist()
|
||||
|
||||
# Change integer-ids to tokens
|
||||
text = tokenizer.decode(token_int)
|
||||
|
||||
result_i = {"key": key[i], "text": text}
|
||||
results.append(result_i)
|
||||
|
||||
if ibest_writer is not None:
|
||||
ibest_writer["text"][key[i]] = text
|
||||
|
||||
return results, meta_data
|
||||
|
||||
def export(self, **kwargs):
|
||||
from export_meta import export_rebuild_model
|
||||
|
||||
if "max_seq_len" not in kwargs:
|
||||
kwargs["max_seq_len"] = 512
|
||||
models = export_rebuild_model(model=self, **kwargs)
|
||||
return models
|
Reference in New Issue
Block a user