
    hXC                        d dl Z d dlmZmZ d dlZd dlmZ ddlmZ ddlm	Z	 ddl
mZ ddlmZ dd	lmZ dd
lmZmZmZ ddlmZmZ ddlmZ ddlmZmZ ddlmZmZ ddlm Z m!Z! ddl"m#Z#m$Z$ ddl%m&Z&m'Z' dejP                  de)fdZ* G d de!      Z+ G d de       Z, G d de      Z- G d de      Z. G d d e$      Z/ G d! d"e      Z0 G d# d$e      Z1 G d% d&e#      Z2 G d' d(ejf                        Z4 G d) d*ejj                        Z6 G d+ d,e      Z7 G d- d.e7      Z8 G d/ d0e      Z9e G d1 d2ee	             Z:g d3Z;y)4    N)OptionalUnion)nn   )Cache)GenerationMixin)BaseModelOutput)PreTrainedModel)Unpack)TransformersKwargsauto_docstringcan_return_tuple   )Aimv2AttentionAimv2EncoderLayer)	AutoModel)LlamaMLPLlamaRMSNorm)LlavaForConditionalGeneration
LlavaModel)LlavaNextCausalLMOutputWithPastLlavaNextModelOutputWithPast)SiglipEncoderSiglipVisionEmbeddings   )Ovis2ConfigOvis2VisionConfiglogitsdimc                     | j                  |      }|j                  |d      d   }t        j                  | t        j                        j                  ||d      }||j                         z
  |z   }|S )NT)keepdimr   )memory_formatg      ?)softmaxmaxtorch
zeros_likelegacy_contiguous_formatscatter_detach)r   r   y_softindexy_hardrets         f/var/www/html/aiagenthome/venv/lib/python3.12/site-packages/transformers/models/ovis2/modular_ovis2.pyhard_softmaxr/   %   sk    ^^C FJJsDJ)!,EfE4R4RS\\]`bgilmF
6==?
"V
+CJ    c                       e Zd Zy)Ovis2ModelOutputWithPastN__name__
__module____qualname__ r0   r.   r2   r2   /       r0   r2   c                       e Zd Zy)Ovis2CausalLMOutputWithPastNr3   r7   r0   r.   r:   r:   3   r8   r0   r:   c                       e Zd Zy)Ovis2RMSNormNr3   r7   r0   r.   r<   r<   7   r8   r0   r<   c                       e Zd Zy)Ovis2VisionMLPNr3   r7   r0   r.   r>   r>   ;   r8   r0   r>   c                   b     e Zd Zdef fdZd Zdej                  dej                  fdZ	 xZ
S )Ovis2VisionEmbeddingsconfigc                 n    t         |   |       t        |j                  |j                        | _        y N)super__init__r<   hidden_sizerms_norm_epsrms_normselfrA   	__class__s     r.   rE   zOvis2VisionEmbeddings.__init__@   s*     $V%7%79L9LMr0   c                     t        d      NzNot needed for Ovis2)NotImplementedErrorrJ   s    r.   interpolate_pos_encodingz.Ovis2VisionEmbeddings.interpolate_pos_encodingD   s    !"899r0   pixel_valuesreturnc                 (   | j                   j                  j                  }| j                  |j                  |            }|j	                  d      j                  dd      }| j                  |      }|| j                  | j                        z   }|S )Ndtyper   r   )	patch_embeddingweightrU   toflatten	transposerH   position_embeddingposition_ids)rJ   rQ   target_dtypepatch_embeds
embeddingss        r.   forwardzOvis2VisionEmbeddings.forwardG   s    ++2288++LOO,O,OP!))!,66q!<
]]:.
$"9"9$:K:K"LL
r0   )r4   r5   r6   r   rE   rP   r%   FloatTensorTensorr`   __classcell__rK   s   @r.   r@   r@   ?   s4    N0 N:E$5$5 %,, r0   r@   c                       e Zd Zy)Ovis2VisionAttentionNr3   r7   r0   r.   rf   rf   R   r8   r0   rf   c                       e Zd Zy)Ovis2VisionEncoderLayerNr3   r7   r0   r.   rh   rh   V   r8   r0   rh   c            	       p     e Zd Zdef fdZee	 ddeej                     de
e   defd              Z xZS )Ovis2VisionEncoderrA   c                     t         |   |       t        j                  t	        |j
                        D cg c]  }t        |       c}      | _        y c c}w rC   )rD   rE   r   
ModuleListrangenum_hidden_layersrh   layers)rJ   rA   _rK   s      r.   rE   zOvis2VisionEncoder.__init__[   sF     mmeTZTlTlNm$nNm%<V%DNm$no$ns   Aattention_maskkwargsrR   c                 T    |}| j                   D ]  } |||fi |} t        |      S )Nlast_hidden_state)ro   r	   )rJ   inputs_embedsrq   rr   hidden_statesencoder_layers         r.   r`   zOvis2VisionEncoder.forward_   s5     &![[M)-R6RM ) ??r0   rC   )r4   r5   r6   r   rE   r   r   r   r%   rb   r   r   r	   r`   rc   rd   s   @r.   rj   rj   Z   se    p0 p  26
@ !.
@ +,	
@
 

@  
@r0   rj   c                   X     e Zd Zdef fdZe	 ddeej                     fd       Z	 xZ
S )Ovis2VisionTransformerrA   c                     t         |           || _        t        |      | _        t        |      | _        t        |j                  |j                        | _
        d| _        y )NF)rD   rE   rA   r@   r_   rj   encoderr<   rF   rG   rH   gradient_checkpointingrI   s     r.   rE   zOvis2VisionTransformer.__init__o   sO    /7)&1$V%7%79L9LM&+#r0   rq   c                     | j                  |      } | j                  d||d|}|j                  }| j                  |      }t	        |      S )N)rv   rq   rt   r7   )r_   r|   ru   rH   r	   )rJ   rQ   rq   rr   rw   encoder_outputsru   s          r.   r`   zOvis2VisionTransformer.forwardw   sa     5+74<< ,
'),
 ,
 ,== MM*;<1BCCr0   rC   )r4   r5   r6   r   rE   r   r   r%   rb   r`   rc   rd   s   @r.   rz   rz   n   s?    ,0 ,  26D !.D Dr0   rz   c                   P     e Zd Zdej                  dej                  f fdZ xZS )Ovis2VisualEmbeddingTablevisual_tokensrR   c                    |j                   t        j                  t        j                  t        j                  t        j
                  t        j                  fv rt        | !  |      S t        j                  || j                        S rC   )rU   r%   int8int16int32int64longrD   r`   matmulrW   )rJ   r   rK   s     r.   r`   z!Ovis2VisualEmbeddingTable.forward   sW    5::u{{EKKV[V`V`"aa7?=11||M4;;77r0   )r4   r5   r6   r%   rb   r`   rc   rd   s   @r.   r   r      s#    8U\\ 8ell 8 8r0   r   c                   B    e Zd ZU eed<   dZdZdgZdZdZ	dZ
dZdZdZdZy)Ovis2PreTrainedModelrA   modelTrf   past_key_valuesN)r4   r5   r6   r   __annotations__base_model_prefixsupports_gradient_checkpointing_no_split_modules_skip_keys_device_placement_supports_cache_class_supports_flash_attn_supports_flex_attn_supports_sdpa_can_compile_fullgraph_supports_attention_backendr7   r0   r.   r   r      sF    &*#/0"3 N!"&r0   r   c                        e Zd ZU eed<   def fdZdej                  deej                  ej                  f   fdZ
 xZS )Ovis2VisionModelrA   c                    t         |   |       || _        t        |      | _        |j
                  | _        |j                  | _        t        j                  |j                  |j                  z  |j                  z  | j                  | j
                  z
  d      | _        t        j                  | j                  | j
                  z
        | _        y NF)bias)rD   rE   rA   rz   transformernum_visual_indicator_tokens
vocab_sizer   LinearrF   hidden_stridehead_linear	LayerNorm	head_normrI   s     r.   rE   zOvis2VisionModel.__init__   s     1&9+1+M+M( ++99!5!558L8LLOOd>>>

 doo8X8X&XYr0   rQ   rR   c           	          | j                   |fi |}|d   }| j                  j                  dkD  r|j                  \  }}}| j                  j                  }t	        t        j                  |            }	|	|	z  |k7  rt        d      ||	|z  z
  |z  }
t        j                  j                  |ddd|
d|
fdd      }|	|
z  }	|j                  ||	|z  ||	|z  ||      }|j                  dddddd      }|j                  |d	||z  |z        }| j                  |      }| j                  |      }| j                  j                  d
k(  r$t        j                  j!                  |d	d      }|S | j                  j                  dk(  rt#        |d	      }|S | j                  j                  dk(  r!t        j                  j%                  |d	      }S )Nr   r   z.Token sequence length must be a perfect squareconstantr   r         gumbel_argmaxT)r   hard	st_argmaxr   r#   )r   rA   r   shapeintmathsqrt
ValueErrorr   
functionalpadreshapepermuter   r   tokenize_functiongumbel_softmaxr/   r#   )rJ   rQ   rr   outputsru   
num_imagesseq_len
hidden_dimr   sqrt_lpad_sizer   
prob_tokens                r.   r`   zOvis2VisionModel.forward   s   "$""<:6:#AJ;;$$q(.?.E.E+J KK55M7+,F') !QRR%-)?@MQH " 1 12CaAxYZ\dEegqst uhF 1 9 9Fm3]FmD[]jlv! !2 9 9!Q1a K 1 9 9B =
 J! !!"34';;((O;55f"45PJ  [[**k9%f"5J  [[**i7..v2.>Jr0   )r4   r5   r6   r   r   rE   r%   ra   tuplerb   r`   rc   rd   s   @r.   r   r      sF    Z0 Z!E$5$5 !E%,,X]XdXdJdDe !r0   r   c            !           e Zd Zi Zdef fdZdej                  dej                  fdZe	e
	 	 	 	 	 	 	 	 	 	 	 	 	 ddeej                     deej                     deej                     deej                     d	ee   d
eej                     deej                     dee   dee   dee   dee   deej                     deeej                  f   deeef   fd              Z xZS )
Ovis2ModelrA   c                 |   t         |   |       t        |j                        | _        t        |j                  j                  |j                        | _        |j                  j                  | _	        |j                  | _        |j                  | _
        t        j                  |j                        | _        | `y rC   )rD   rE   r   vision_configvision_towerr   r   rF   visual_embeddings_tablevisual_vocab_sizevisual_indicator_token_idsr   from_configtext_configlanguage_modelmulti_modal_projectorrI   s     r.   rE   zOvis2Model.__init__   s     ,V-A-AB'@AUAUA`A`bhbtbt'u$!'!5!5!@!@ ++*0*K*K''33F4F4FG&r0   rQ   rR   c                 4   | j                  |      }|j                  \  }}}t        j                  ||| j                   j                  f|j
                  |j                  d|j                        }t        j                  ||gd      }| j                  |      }t        j                  | j                  | j                   j                  z
  | j                  t        j                        j                  |j                        }| j                  |      }||fS )NF)rU   devicerequires_gradlayoutr   r   rT   )r   r   r%   zerosr   rU   r   r   catr   aranger   r   rX   )	rJ   rQ   image_features
batch_sizeimg_seq_lenrp   padding_tensorvisual_indicatorvisual_indicator_featuress	            r.   get_image_featureszOvis2Model.get_image_features   s     **<8%3%9%9"
Kd&7&7&S&ST &&!((!((
 NN#CK55nE <<""T%6%6%R%RR""**
 "^""
#	 	
 %)$@$@AQ$R!888r0   	input_idsrq   r\   r   rv   labels	use_cacheoutput_attentionsoutput_hidden_statesreturn_dictcache_positionlogits_to_keepc                    |	|	n| j                   j                  }	|
|
n| j                   j                  }
|d u |d uz  rt        d      | | j	                         |      }| | j                  |      \  }}| j                  |||      }|j                  ||      }t        | j                        D ]  \  }}|Y| | j	                         t        j                  |t        j                  |j                              k(  }|j                  d      }n||k(  j                  |j                        }|j!                         s||   j#                  ||         j                  |j                  |j$                        ||<     | j&                  d	||||||	|
d||d
|}t)        |j*                  |j,                  |j.                  |j0                  |      S d       S )
Nz:You must specify exactly one of input_ids or inputs_embedsrQ   )rv   r   )rU   r   r   T)
rq   r\   r   rv   r   r   r   r   r   r   )ru   r   rw   
attentionsimage_hidden_statesr7   )rA   r   r   r   get_input_embeddingsr   get_placeholder_maskmasked_scatter	enumerater   r%   tensorr   r   allrX   any	expand_asrU   r   r2   ru   r   rw   r   )rJ   r   rQ   rq   r\   r   rv   r   r   r   r   r   r   r   rr   r   r   special_image_maskivisual_indicator_idmaskr   s                         r.   r`   zOvis2Model.forward   s   & 2C1N-TXT_T_TqTq$8$D $++JjJj 	 -t";<YZZ 7D557	BM#8<8O8O]i8O8j5N5!%!:!:+- "; "
 *889K^\M*3D4S4S*T&&$(,GD,E,E,G%8

S`SgSgh- D  88B<D%)<<@@AUAUVD88:1!4"=#67M00-2E2EF "$' +U  &$%% 
)%+'/!5))
 
 (%77#33!//))2>2J
 	

 QU
 	
r0   NNNNNNNNNNNNr   )r4   r5   r6   _checkpoint_conversion_mappingr   rE   r%   ra   r   r   r   r   
LongTensorrb   r   boolr   r   r   r2   r`   rc   rd   s   @r.   r   r      s   %'"	'{ 	'9''9 
		92  15481537+/59-1$(,0/3&*5934J
E,,-J
 u001J
 !.	J

 u//0J
 "%J
   1 12J
 ))*J
 D>J
 $D>J
 'tnJ
 d^J
 !!1!12J
 c5<</0J
  
u..	/!J
  J
r0   r   c            !           e Zd Zi Zdef fdZed        Zdej                  fdZ
ee	 	 	 	 	 	 	 	 	 	 	 	 	 ddeej                     deej                     deej                     deej                     d	ee   d
eej                     deej                     dee   dee   dee   dee   deej                     deeej                  f   deeef   fd              Z xZS )Ovis2ForConditionalGenerationrA   c                     t         |   |       t        j                  |j                  |j
                  d      | _        y r   )rD   rE   r   r   rF   r   lm_headrI   s     r.   rE   z&Ovis2ForConditionalGeneration.__init__P  s0     yy!3!3V5F5FUSr0   c                     t        d      rM   )AttributeErrorrO   s    r.   r   z3Ovis2ForConditionalGeneration.multi_modal_projectorT  s    344r0   rQ   c                 :    | j                   j                  |      S )Nr   )r   r   )rJ   rQ   s     r.   r   z0Ovis2ForConditionalGeneration.get_image_featuresX  s    zz,,,,GGr0   r   rq   r\   r   rv   r   r   r   r   r   r   r   rR   c                    |	|	n| j                   j                  }	|
|
n| j                   j                  }
 | j                  d||||||||	|
d|d|}|d   }t	        |t
              rt        | d      n|}| j                  |dd|ddf         }d}|4 | j                  d||| j                   j                  j                  d|}t        |||j                  |j                  |j                  |j                        S )a  
        labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
            config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
            (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

        Example:

        ```python
        >>> from PIL import Image
        >>> import requests
        >>> from transformers import AutoProcessor, Ovis2ForConditionalGeneration

        >>> model = Ovis2ForConditionalGeneration.from_pretrained("thisisiron/Ovis2-2B-hf")
        >>> processor = AutoProcessor.from_pretrained("thisisiron/Ovis2-2B-hf")

        >>> prompt = "<|im_start|>user\n<image>\nDescribe the image.<|im_end|>\n<|im_start|>assistant\n"
        >>> url = "http://images.cocodataset.org/val2014/COCO_val2014_000000537955.jpg"
        >>> image = Image.open(requests.get(url, stream=True).raw)

        >>> inputs = processor(images=image, text=prompt, return_tensors="pt")

        >>> # Generate
        >>> generate_ids = model.generate(**inputs, max_new_tokens=15)
        >>> processor.batch_decode(generate_ids, skip_special_tokens=True)[0]
        "user\n\nDescribe the image.\nassistant\nThe image features a brown dog standing on a wooden floor, looking up with"
        ```NT)r   rQ   rq   r\   r   rv   r   r   r   r   r   r   )r   r   r   )lossr   r   rw   r   r   r7   )rA   r   r   r   
isinstancer   slicer   loss_functionr   r   r:   r   rw   r   r   )rJ   r   rQ   rq   r\   r   rv   r   r   r   r   r   r   r   rr   r   rw   slice_indicesr   r   s                       r.   r`   z%Ovis2ForConditionalGeneration.forward[  s7   \ 2C1N-TXT_T_TqTq$8$D $++JjJj 	 $** 
%)%+'/!5)
 
  
8B>SV8W~ot4]kmA}a,?@A%4%% f9P9P9[9[_eD +#33!//)) ' ; ;
 	
r0   r   )r4   r5   r6   r   r   rE   propertyr   r%   ra   r   r   r   r   r   rb   r   r   r   r   r   r:   r`   rc   rd   s   @r.   r   r   L  s   %'"T{ T 5 5Hu/@/@ H  15481537+/59-1$(,0/3&*5934R
E,,-R
 u001R
 !.	R

 u//0R
 "%R
   1 12R
 ))*R
 D>R
 $D>R
 'tnR
 d^R
 !!1!12R
 c5<</0R
  
u11	2!R
  R
r0   r   )r   r   r   )<r   typingr   r   r%   r   cache_utilsr   
generationr   modeling_outputsr	   modeling_utilsr
   processing_utilsr   utilsr   r   r   aimv2.modeling_aimv2r   r   autor   llama.modeling_llamar   r   llava.modeling_llavar   r   llava_next.modeling_llava_nextr   r   siglip.modeling_siglipr   r   configuration_ovis2r   r   rb   r   r/   r2   r:   r<   r>   r@   rf   rh   rj   Modulerz   	Embeddingr   r   r   r   r   __all__r7   r0   r.   <module>r     s/     "     ) / - & I I D  9 L j J ? C 	; 		"A 		< 		X 	2 &	> 		/ 	@ @(DRYY D<8 8'? '1+ 1hs
 s
l b
$A? b
 b
J Rr0   