
    hH~                     (   d dl Z d dlmZ d dlmZmZmZ d dlZd dlmZ ddl	m
Z
 ddlmZ ddlmZmZ dd	lmZmZ dd
lmZ ddlmZmZmZmZ ddlmZ ddlmZ ddlmZm Z   G d dejB                        Z" G d dejB                        Z# G d dejB                        Z$dejJ                  de&dejJ                  fdZ'	 d5dejB                  dejJ                  dejJ                  dejJ                  deejJ                     de(d e(d!ee   fd"Z) G d# d$ejB                        Z* G d% d&ejB                        Z+ G d' d(e      Z,e G d) d*e             Z- ed+,       G d- d.e-             Z.e G d/ d0e             Z/ ed1,       G d2 d3e-             Z0g d4Z1y)6    N)	dataclass)CallableOptionalUnion)nn   )ACT2FN)GradientCheckpointingLayer)BaseModelOutputCausalLMOutput)ALL_ATTENTION_FUNCTIONSPreTrainedModel)Unpack)ModelOutputTransformersKwargsauto_docstringcan_return_tuple)deprecate_kwarg)check_model_inputs   )ParakeetCTCConfigParakeetEncoderConfigc                        e Zd ZU dZej
                  ed<   ddef fdZ ej                         dej
                  fd       Z
 xZS )$ParakeetEncoderRelPositionalEncodingz*Relative positional encoding for Parakeet.inv_freqconfigc                 6   t         |           |j                  | _        d}d|t        j                  d|j
                  dt        j                        j                  |t        j                        |j
                  z  z  z  }| j                  d|d	       y )
Ng     @      ?r      dtype)devicer!   r   F)
persistent)
super__init__max_position_embeddingstorcharangehidden_sizeint64tofloatregister_buffer)selfr   r"   baser   	__class__s        m/var/www/html/aiagenthome/venv/lib/python3.12/site-packages/transformers/models/parakeet/modeling_parakeet.pyr%   z-ParakeetEncoderRelPositionalEncoding.__init__-   s    '-'E'E$Q 2 2AU[[ILLTZbgbmbmLn$$%
 	ZeD    hidden_statesc                    |j                   d   }|| j                  kD  rt        d| d| j                   d      t        j                  |dz
  | d|j
                        }| j                  d d d d f   j                         j                  |j                   d   dd      j                  |j
                        }|d d d d f   j                         }t        |j
                  j                  t              r/|j
                  j                  dk7  r|j
                  j                  nd	}t        j                  |d
      5  |j                         |j                         z  j                  dd      }|j                         }|j!                         }	t        j"                  ||	gd      }
 |
j$                  g |
j                   d d d }
d d d        
j                  |j&                        S # 1 sw Y   %xY w)Nr   zSequence Length: z= has to be less or equal than config.max_position_embeddings .r"   r   mpscpuF)device_typeenabledr   dimr    )shaper&   
ValueErrorr'   r(   r"   r   r,   expandr+   
isinstancetypestrautocast	transposesincosstackreshaper!   )r.   r3   
seq_lengthposition_idsinv_freq_expandedposition_ids_expandedr:   freqsrG   rH   	pos_embeds              r1   forwardz,ParakeetEncoderRelPositionalEncoding.forward;   s   "((+
444#J< 02262N2N1OqR 
 ||JNZKML`L`aMM$4-(..0778K8KA8NPRTUVYYZgZnZno 	 !-T4] ; A A C -..33S9m>R>R>W>W[`>`   %% 	
 ^^UC&,,.1F1L1L1NNYYZ[]^_E))+C))+CS#JB7I)	))D9??3B+?DDI D ||-"5"5|66 DCs   ?BG//G8N)__name__
__module____qualname____doc__r'   Tensor__annotations__r   r%   no_gradrQ   __classcell__r0   s   @r1   r   r   (   sF    4llE4 E U]]_7U\\ 7 7r2   r   c                   *     e Zd Zdef fdZd Z xZS )ParakeetEncoderFeedForwardr   c                 `   t         |           t        j                  |j                  |j
                  |j                        | _        t        |j                     | _
        t        j                  |j
                  |j                  |j                        | _        |j                  | _        y )Nbias)r$   r%   r   Linearr)   intermediate_sizeattention_biaslinear1r	   
hidden_act
activationlinear2activation_dropoutr.   r   r0   s     r1   r%   z#ParakeetEncoderFeedForward.__init__[   s|    yy!3!3V5M5MTZTiTij !2!23yy!9!96;M;MTZTiTij"(";";r2   c                     | j                  | j                  |            }t        j                  j	                  || j
                  | j                        }| j                  |      }|S )Nptraining)rf   rd   r   
functionaldropoutrh   rm   rg   )r.   r3   s     r1   rQ   z"ParakeetEncoderFeedForward.forwardb   sU    ](CD--mt?V?Vaeanan-o]3r2   rS   rT   rU   r   r%   rQ   rZ   r[   s   @r1   r]   r]   Z   s    <4 <r2   r]   c                   .     e Zd Zddef fdZddZ xZS ) ParakeetEncoderConvolutionModuler   c           	         t         |           |j                  }|&|j                  }t        t        |dd         | _        n#|d   }t        |j                  dd         | _        |dz
  dz  | _        t        j                  |d|z  dddd	
      | _        t        j                  |||d| j                  |d	      | _        t        j                  |      | _        t        j                  ||dddd	
      | _        y)z
        Args:
            config (ParakeetEncoderConfig): Configuration for the model.
            module_config (dict): Configuration for the module (e.g., encoder or decoder).
        Nre   silukernel_sizerf   r   r   r   T)ru   stridepaddingr`   )rv   rw   groupsr`   )r$   r%   r)   conv_kernel_sizer	   getattrrf   getrw   r   Conv1dpointwise_conv1depthwise_convBatchNorm1dnormpointwise_conv2)r.   r   module_configchannelsru   r0   s        r1   r%   z)ParakeetEncoderConvolutionModule.__init__j   s     	%%  11K$WV\6%JKDO'6K$]%6%6|V%LMDO#aA-!yy1x<QWXbcjno iihAt||T\cg
 NN8,	!yy8ST^_fjkr2   c                    |j                  dd      }| j                  |      }t        j                  j	                  |d      }|*t        j                  | d      }|j                  |d      }| j                  |      }| j                  |      }| j                  |      }| j                  |      }|j                  dd      S )aS  
        Compute convolution module.

        Args:
            hidden_states (`torch.Tensor` of shape `(batch, time, channels)`): Input tensor.
            attention_mask (`torch.Tensor` of shape `(batch, 1, time)`): Attention mask.

        Returns:
            `torch.Tensor`: Output tensor of shape `(batch, time, channels)`.

        r   r   r<   r6           )rF   r}   r   rn   glur'   allmasked_fillr~   r   rf   r   )r.   r3   attention_maskall_masked_rowss       r1   rQ   z(ParakeetEncoderConvolutionModule.forward   s     &//15 ,,];))-Q)? %#iiR@O)55osKM ++M:		-06,,];&&q!,,r2   rR   rp   r[   s   @r1   rr   rr   i   s    l4 l0-r2   rr   r3   n_repreturnc                     | j                   \  }}}}|dk(  r| S | dddddddddf   j                  |||||      } | j                  |||z  ||      S )z
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    r   N)r?   rA   rJ   )r3   r   batchnum_key_value_headsslenhead_dims         r1   	repeat_kvr      so    
 2?1D1D.Ehz!!Qa"23::5BUW\^bdlmM  (;e(CT8TTr2   modulequerykeyvaluer   scalingro   kwargsc                 T   t        || j                        }t        || j                        }	t        j                  ||j	                  dd            |z  }
|#|d d d d d d d |j
                  d   f   }|
|z   }
t        j                  j                  |
dt        j                        j                  |j                        }
t        j                  j                  |
|| j                        }
t        j                  |
|	      }|j	                  dd      j                         }||
fS )Nr   r   r>   r6   r=   r!   rk   r   )r   num_key_value_groupsr'   matmulrF   r?   r   rn   softmaxfloat32r+   r!   ro   rm   
contiguous)r   r   r   r   r   r   ro   r   
key_statesvalue_statesattn_weightscausal_maskattn_outputs                r1   eager_attention_forwardr      s    3 ; ;<JUF$?$?@L<<z';';Aq'ABWLL!$Q1.D
0@0@0D.D%DE#k1==((2U]](SVVW\WbWbcL==((6??([L,,|\:K''1-88:K$$r2   c                        e Zd ZdZdedef fdZ eddd      	 dd	ej                  d
e
ej                     de
ej                     dee   deej                  ej                  f   f
d       Zd Z xZS )ParakeetEncoderAttentionztMulti-head attention with relative positional encoding. See section 3.3 of https://huggingface.co/papers/1901.02860.r   	layer_idxc                    t         |           || _        || _        t	        |d|j
                  |j                  z        | _        |j                  |j                  z  | _	        | j                  dz  | _
        |j                  | _        d| _        t        j                  |j
                  |j                  | j                  z  |j                        | _        t        j                  |j
                  |j                  | j                  z  |j                        | _        t        j                  |j
                  |j                  | j                  z  |j                        | _        t        j                  |j                  | j                  z  |j
                  |j                        | _        t        j                  |j
                  |j                  | j                  z  d      | _        t        j*                  t-        j.                  |j                  | j                              | _        t        j*                  t-        j.                  |j                  | j                              | _        y )Nr   g      Fr_   )r$   r%   r   r   rz   r)   num_attention_headsr   r   r   r   attention_dropout	is_causalr   ra   rc   q_projk_projv_projo_projrelative_k_proj	Parameterr'   zerosbias_ubias_vr.   r   r   r0   s      r1   r%   z!ParakeetEncoderAttention.__init__   s   "
F4F4F&JdJd4de$*$>$>&B\B\$\!}}d*!'!9!9ii : :T]] JQWQfQf
 ii : :T]] JQWQfQf
 ii : :T]] JQWQfQf
 ii&&68J8JQWQfQf
  "yy););V=W=WZ^ZgZg=gnstll5;;v/I/I4==#YZll5;;v/I/I4==#YZr2   past_key_valuepast_key_valuesz4.58)new_nameversionr3   position_embeddingsr   r   r   c           
         |j                   d d }|\  }}||d| j                  f}| j                  |      j                  |      j	                  dd      }	| j                  |      j                  |      j	                  dd      }
| j                  |      j                  |      j	                  dd      }t        }| j                  j                  dk7  rt        | j                  j                     }|	| j                  j                  d| j                  j                  d| j                        z   }|	| j                  j                  d| j                  j                  d| j                        z   }| j                  |      }|j                  |d| j                  j                  | j                        }||j                  dddd      z  }| j!                  |      }|dd |f   }|| j"                  z  }|)|j%                  |j'                         t)        d            } || f||
||| j*                  sd	n| j,                  | j"                  d
|\  }} |j.                  g |d j1                         }| j3                  |      }||fS )Nr6   r   r   eagerr   r   .z-infr   )r   r   r   r   ro   r   )r?   r   r   viewrF   r   r   r   r   _attn_implementationr   r   r   r   r   permute
_rel_shiftr   masked_fill_logical_notr,   rm   r   rJ   r   r   )r.   r3   r   r   r   input_shape
batch_sizerK   hidden_shapequery_statesr   r   attention_interfacequery_states_with_bias_uquery_states_with_bias_vrelative_key_states	matrix_bdr   r   s                      r1   rQ   z ParakeetEncoderAttention.forward   st    $))#2.!,
J"JDMMB{{=166|DNNqRST[[/44\BLLQPQR
{{=166|DNNqRST(?;;++w6"9$++:Z:Z"[#/$++2B2Bt{{..4==3
 $
  $0$++2B2Bt{{..4==3
 $
  #223FG166z2t{{GfGfhlhuhuv -/B/J/J1aQRTU/VV	OOI.	c;J;./	,	% "..~/I/I/KUSY][I %8	%
*$#}}C$2H2HLL	%
 	%
!\ *k));;;;FFHkk+.L((r2   c                     |j                   \  }}}}t        j                  j                  |d      }|j	                  ||d|      }|ddddddf   j	                  ||||      }|S )ztRelative position shift for Shaw et al. style attention. See appendix B of https://huggingface.co/papers/1901.02860.)r   r   )padr6   Nr   )r?   r   rn   r   r   )r.   attention_scoresr   	num_headsquery_lengthposition_lengths         r1   r   z#ParakeetEncoderAttention._rel_shift$  st    ?O?U?U<
I|_==,,-=6,J+00YLY+Aq!"H5:::yR^`opr2   rR   )rS   rT   rU   rV   r   intr%   r   r'   rW   r   r   r   tuplerQ   r   rZ   r[   s   @r1   r   r      s    ~[4 [ [: %0A6R
 26	7)||7) &ell37) !.	7)
 +,7) 
u||U\\)	*7) S7)r r2   r   c                        e Zd Zdef fdZdej                  dej                  fdZ	d	dej                  dej                  fdZ
 xZS )
 ParakeetEncoderSubsamplingConv2Dr   c                    t         |           |j                  | _        |j                  | _        |j                  | _        | j                  dz
  dz  | _        t        t        j                  |j                              | _        t        j                         | _        | j                   j#                  t        j$                  d| j                  | j                  | j
                  | j                               | j                   j#                  t        j&                                t)        | j                  dz
        D ]  }| j                   j#                  t        j$                  | j                  | j                  | j                  | j
                  | j                  | j                               | j                   j#                  t        j$                  | j                  | j                  d             | j                   j#                  t        j&                                 |j*                  | j
                  | j                  z  z  }t        j,                  |j                  |z  |j.                  d      | _        y )Nr   r   )ru   rv   rw   )ru   rv   rw   rx   ru   Tr_   )r$   r%   subsampling_conv_kernel_sizeru   subsampling_conv_striderv   subsampling_conv_channelsr   rw   r   mathlog2subsampling_factor
num_layersr   
ModuleListlayersappendConv2dReLUrangenum_mel_binsra   r)   linear)r.   r   i
out_lengthr0   s       r1   r%   z)ParakeetEncoderSubsamplingConv2D.__init__.  s   !>>4488((1,2dii(A(ABC mmoIIaD4D4DT[[bfbnbno	
 	2779%t*+AKK		MMMM $ 0 0;; LL==	 KKryySTUVKKrwwy) ," ((T[[$//-IJ
ii @ @: MvOaOahlmr2   input_lengths
conv_layerc                     t        |d      rR|j                  dk7  rC|j                  }|j                  d   }|j                  d   }||d   z   |d   z   |z
  |z  dz   }|S |S )Nrv   )r   r   r   r   )hasattrrv   rw   ru   )r.   r   r   rw   ru   rv   output_lengthss          r1   _get_output_lengthz3ParakeetEncoderSubsamplingConv2D._get_output_lengthQ  sx    :x(Z->->&-H ((G$003K&&q)F+gaj871:ESX^^abbN!!r2   input_featuresr   c                    |j                  d      }||j                  d      nd }| j                  D ]  } ||      }t        |t        j
                        s&|)| j                  ||      }|j                  d   }t        j                  ||j                        |d d d f   k  }||d d d d d d f   z  } |j                  dd      j                  |j                  d   |j                  d   d      }| j                  |      }|S )Nr   r6   r   r7   r   )	unsqueezesumr   rB   r   r   r   r?   r'   r(   r"   rF   rJ   r   )r.   r   r   r3   current_lengthslayercurrent_seq_lengthchannel_masks           r1   rQ   z(ParakeetEncoderSubsamplingConv2D.forward\  s   &0034B4N.,,R0TX[[E!-0M %+0J"&"9"9/5"Q%2%8%8%;"LL!3N<Q<QRUdefhlelUmm  aq$.>!?? ! &//15==m>Q>QRS>TVcViVijkVlnpqM2r2   rR   )rS   rT   rU   r   r%   r'   rW   r   r   r   rQ   rZ   r[   s   @r1   r   r   -  sI    !n4 !nF	 	")) 	ell ELL r2   r   c                        e Zd Zd
dedee   f fdZ	 	 ddej                  deej                     deej                     de	e
   dej                  f
d	Z xZS )ParakeetEncoderBlockr   r   c                    t         |           d| _        t        |      | _        t        ||      | _        t        |      | _        t        |      | _	        t        j                  |j                        | _        t        j                  |j                        | _        t        j                  |j                        | _        t        j                  |j                        | _        t        j                  |j                        | _        y NF)r$   r%   gradient_checkpointingr]   feed_forward1r   	self_attnrr   convfeed_forward2r   	LayerNormr)   norm_feed_forward1norm_self_att	norm_convnorm_feed_forward2norm_outr   s      r1   r%   zParakeetEncoderBlock.__init__s  s    &+#7?1&)D4V<	7?"$,,v/A/A"B\\&*<*<=f&8&89"$,,v/A/A"BV%7%78r2   r3   r   r   r   r   c                 x   |}| j                  | j                  |            }|d|z  z   }| j                  |      } | j                  d|||d|\  }}||z   }| j	                  | j                  |      |      }	||	z   }| j                  | j                  |            }
|d|
z  z   }| j                  |      }|S )Ng      ?)r3   r   r   )r    )	r   r  r  r   r   r  r   r  r  )r.   r3   r   r   r   residualnormalized_hidden_statesr   _conv_output
ff2_outputs              r1   rQ   zParakeetEncoderBlock.forward  s     !**4+B+B=+QR 3#66#'#5#5m#D ' 
2) 3
 	
Q &3ii} =ni]%3''(?(?(NO
%j(88m4r2   rR   NN)rS   rT   rU   r   r   r   r%   r'   rW   r   r   rQ   rZ   r[   s   @r1   r   r   r  sx    94 9# 9$ 266:	|| !. &ell3	
 +, 
r2   r   c                        e Zd ZU eed<   dZdZdZdgZdZ	dZ
dZdZdZdZeedZ fdZd	ej(                  fd
Zddej(                  dee   fdZ xZS )ParakeetPreTrainedModelr   modelr   Tr   F)r3   
attentionsc                    t         |   |       t        | j                  d      r| j                  j                  }n%t        | j                  j                         dd      }t        |t              rO|j                  j                  j                  d|       |j                  j                  j                  d|       y y )Ninitializer_rangeg{Gz?r   )meanstd)r$   _init_weightsr   r   r  rz   get_text_configrB   r   r   datanormal_r   )r.   r   r  r0   s      r1   r  z%ParakeetPreTrainedModel._init_weights  s    f%4;; 34++//C $++5579LdSCf67MM&&CS&9MM&&CS&9 8r2   r   c                    t        | j                  t              r| j                  j                  n| j                  }|j                  }|j
                  }t        t        j                  |j                              }|dz
  dz  dz  }||z
  }|}t        |      D ]Q  }	t        j                  |j                  t        j                        |z   |      dz   }t        j                  |      }S |j                  t        j                        S )Nr   r   r    r   )rB   r   r   encoder_configr   r   r   r   r   r   r   r'   divr+   r,   floor)
r.   r   r  ru   rv   r   all_paddingsadd_padlengthsr  s
             r1   _get_subsampling_output_lengthz6ParakeetPreTrainedModel._get_subsampling_output_length  s    7A$++O`7a33gkgrgr$AA77>#D#DEF
#aA-1,z"Aii


 = GPSVVGkk'*G # zz		z**r2   r   target_lengthc                     | j                  |j                  d            }||n|j                         }t        j                  ||j
                        |dddf   k  }|S )z
        Convert the input attention mask to its subsampled form. `target_length` sets the desired output length, useful
        when the attention mask length differs from `sum(-1).max()` (i.e., when the longest sequence in the batch is padded)
        r6   Nr7   )r"  r   maxr'   r(   r"   )r.   r   r#  r   
max_lengths        r1   _get_output_attention_maskz2ParakeetPreTrainedModel._get_output_attention_mask  sc    
 <<^=O=OPR=ST&3&?]^EWEWEY
j9N9NOR`abdhahRiir2   rR   )rS   rT   rU   r   rX   base_model_prefixmain_input_namesupports_gradient_checkpointing_no_split_modules_supports_flat_attention_mask_supports_sdpa_supports_flex_attn_supports_flash_attn_can_compile_fullgraph_supports_attention_backendr   r   _can_record_outputsr  r'   rW   r"  r   r   r'  rZ   r[   s   @r1   r  r    s    &O&*#/0$(!N !!"&-.
:+ELL +"	 	V^_bVc 	r2   r  z{
    The Parakeet Encoder model, based on the [Fast Conformer architecture](https://huggingface.co/papers/2305.05084).
    )custom_introc                        e Zd ZU eed<   dZdef fdZeee		 d	de
j                  dee
j                     dee   defd                     Z xZS )
ParakeetEncoderr   encoderc           	         t         |   |       || _        d| _        |j                  | _        |j
                  | _        |j                  | _        |j                  rt        j                  |j                        nd| _        t        |      | _        t        |      | _        t!        j"                  t%        |j&                        D cg c]  }t)        ||       c}      | _        | j-                          y c c}w )NFr   )r$   r%   r   r   ro   dropout_positions	layerdropscale_inputr   sqrtr)   input_scaler   subsamplingr   encode_positionsr   r   r   num_hidden_layersr   r   	post_initr   s      r1   r%   zParakeetEncoder.__init__  s     &+#~~!'!9!9))<B<N<N499V%7%78TW;FC DV LmmFKFLdLdFefFe!&)4Fef
 	 gs   
C:r   r   r   r   c                    | j                  ||      }|| j                  z  }| j                  |      }t        j                  j                  || j
                  | j                        }t        j                  j                  || j                  | j                        }|u| j                  ||j                  d         }|j                  d      j                  d|j                  d   d      }||j                  dd      z  }|j                  d      }| j                  D ]E  }d}| j                  r&t        j                  g       }|| j                   k  rd}|r: ||f||d|}G t#        |	      S )
a  
        Example:

        ```python
        >>> from transformers import AutoProcessor, ParakeetEncoder
        >>> from datasets import load_dataset, Audio

        >>> model_id = "nvidia/parakeet-ctc-1.1b"
        >>> processor = AutoProcessor.from_pretrained(model_id)
        >>> encoder = ParakeetEncoder.from_pretrained(model_id)

        >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
        >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))

        >>> inputs = processor(ds[0]["audio"]["array"])
        >>> encoder_outputs = encoder(**inputs)

        >>> print(encoder_outputs.last_hidden_state.shape)
        ```
        rk   r   r#  r6   r   FT)r   r   )last_hidden_state)r=  r<  r>  r   rn   ro   rm   r8  r'  r?   r   rA   rF   r   r'   randr9  r   )	r.   r   r   r   r3   r   encoder_layerto_dropdropout_probabilitys	            r1   rQ   zParakeetEncoder.forward  sp   < ((H%(8(88"33MB--mt||VZVcVc-d mm334#9#9DMM 4 
 %!<<^[h[n[nop[q<rN+55a8??MDWDWXYDZ\^_N+n.F.Fq!.LLN+55a8N![[MG}}&+jjn#&7"G -!!#1(;! 	! )  ??r2   rR   )rS   rT   rU   r   rX   r(  r%   r   r   r   r'   rW   r   r   r   r   rQ   rZ   r[   s   @r1   r5  r5    s     "!!4 &  26:@:@ !.:@ +,	:@
 
:@   :@r2   r5  c                       e Zd ZU dZej
                  ed<   dZee	ej                        ed<   dZee	e	ej                           ed<   dZee	e	ej                           ed<   y)ParakeetGenerateOutputal  
    Outputs of Parakeet models.

    Args:
        sequences (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            The generated sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter
            if all batches finished early due to the `eos_token_id`.
        logits (`tuple(torch.FloatTensor)` *optional*, returned when `output_logits=True`):
            Unprocessed prediction scores of the language modeling head (scores for each vocabulary token before SoftMax)
            at each generation step. Tuple of `torch.FloatTensor` with up to `max_new_tokens` elements (one element for
            each generated token), with each tensor of shape `(batch_size, config.vocab_size)`.
        attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True`):
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
            `torch.FloatTensor` of shape `(batch_size, num_heads, generated_length, sequence_length)`.
        hidden_states (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_hidden_states=True`):
            Tuple (one element for each generated token) of tuples (one element for each layer of the decoder) of
            `torch.FloatTensor` of shape `(batch_size, generated_length, hidden_size)`.
    	sequencesNlogitsr  r3   )rS   rT   rU   rV   r'   
LongTensorrX   rK  r   r   FloatTensorr  r3   r  r2   r1   rI  rI  =  sm    & 15FHU5,,-.5<@JuU%6%6789@?CM8E%(9(9":;<Cr2   rI  zS
    Parakeet Encoder with a Connectionist Temporal Classification (CTC) head.
    c                   P    e Zd ZU eed<   def fdZee	 	 ddej                  de
ej                     de
ej                     dee   def
d              Z ej                         	 	 ddej                  de
ej                     d	edee   deeej&                  f   f
d
       Z xZS )ParakeetForCTCr   c                     t         |   |       t        |j                        | _        t        j                  |j                  j                  |j                  d      | _	        | j                          y )Nr   r   )r$   r%   r5  r  r6  r   r|   r)   
vocab_sizectc_headr@  ri   s     r1   r%   zParakeetForCTC.__init__`  sS     &v'<'<=		&"7"7"C"CVEVEVdefr2   r   r   labelsr   r   c           
          | j                   d||d|}|j                  }| j                  |j                  dd            j                  dd      }d}|Y||n$t	        j
                  |t        j                        }| j                  |j                  d            }	|| j                  j                  k7  }
|
j                  d      }|j                  |
      }t        j                  j                  |dt        j                        j                  dd      }t        j                   j"                  j%                  d	
      5  t        j                  j'                  |||	|| j                  j                  | j                  j(                  | j                  j*                        }ddd       t-        |||j.                  |j0                        S # 1 sw Y   ,xY w)a  
        Example:

        ```python
        >>> from transformers import AutoProcessor, ParakeetForCTC
        >>> from datasets import load_dataset, Audio

        >>> model_id = "nvidia/parakeet-ctc-1.1b"
        >>> processor = AutoProcessor.from_pretrained(model_id)
        >>> model = ParakeetForCTC.from_pretrained(model_id)

        >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
        >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))

        >>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
        >>> outputs = model(**inputs)

        >>> print(outputs.loss)
        ```r   r   r   r   Nr    r6   r   r   F)r;   )blank	reductionzero_infinity)lossrK  r3   r  r  )r6  rC  rR  rF   r'   	ones_likelongr"  r   r   pad_token_idmasked_selectr   rn   log_softmaxr   backendscudnnflagsctc_lossctc_loss_reductionctc_zero_infinityr   r3   r  )r.   r   r   rS  r   encoder_outputsr3   rK  rY  r   labels_masktarget_lengthsflattened_targets	log_probss                 r1   rQ   zParakeetForCTC.forwardh  s   : '$,, 
))
 
 (99}66q!<=GG1M #1"<%//R`hmhrhrBs  !??@R@RSU@VWM !DKK$<$<<K(__R0N & 4 4[ A 11&b1V``abdefI%%++E+:}}--%!"++22"kk<<"&++"?"? .  ; )77&11	
 	
 ;:s   A#GGreturn_dict_in_generatec                 H   d|d<    | j                   d	||d|}|j                  j                  d      }|:| j                  ||j                  d         }| j
                  j                  || <   |r-t        ||j                  |j                  |j                        S |S )
a3  
        Example:

        ```python
        >>> from transformers import AutoProcessor, ParakeetForCTC
        >>> from datasets import load_dataset, Audio

        >>> model_id = "nvidia/parakeet-ctc-1.1b"
        >>> processor = AutoProcessor.from_pretrained(model_id)
        >>> model = ParakeetForCTC.from_pretrained(model_id)

        >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
        >>> ds = ds.cast_column("audio", Audio(sampling_rate=processor.feature_extractor.sampling_rate))

        >>> inputs = processor(ds[0]["audio"]["array"], text=ds[0]["text"])
        >>> predicted_ids = model.generate(**inputs)
        >>> transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)

        >>> print(transcription)
        ```
        Treturn_dictrU  r6   r<   r   rB  )rJ  rK  r  r3   r  )
rQ   rK  argmaxr'  r?   r   r\  rI  r  r3   )r.   r   r   rj  r   outputsrJ  s          r1   generatezParakeetForCTC.generate  s    : !%}".$,, #
))#
 #
 NN))b)1	 %!<<^[d[j[jkl[m<nN)-)A)AI~o&")#~~"--%33	  r2   r  r   )rS   rT   rU   r   rX   r%   r   r   r'   rW   r   r   r   r   rQ   rY   boolr   rI  rL  ro  rZ   r[   s   @r1   rO  rO  X  s    0   26)-	E
E
 !.E
 &	E

 +,E
 
E
  E
N U]]_ 26(-	33 !.3 "&	3
 +,3 
%u'7'77	83 3r2   rO  )rO  r5  r  )r   )2r   dataclassesr   typingr   r   r   r'   r   activationsr	   modeling_layersr
   modeling_outputsr   r   modeling_utilsr   r   processing_utilsr   utilsr   r   r   r   utils.deprecationr   utils.genericr   configuration_parakeetr   r   Moduler   r]   rr   rW   r   r   r,   r   r   r   r   r  r5  rI  rO  __all__r  r2   r1   <module>r~     s  ,  ! , ,   ! 9 ? F & V V 0 / L/7299 /7d 8-ryy 8-v	UU\\ 	U# 	U%,, 	U& %II%<<% 
% <<	%
 U\\*% % % '(%4` ryy ` FBryy BJ,5 ,^ <o < <~ 
T@- T@
T@n D[ D D4 
H, H
HV Kr2   