o
    
sh|                     @   s&  d dl Z d dlZd dlZd dlZd dlZd dlZd dlZd dlZd dlm	Z	m
Z
 d dlmZ d dlmZmZ d dlmZ d dlmZ d dlZd dlmZmZ d dlmZmZmZ d d	lmZ d d
lmZm Z  d dl!m"Z"m#Z# e$ dkrxd dl%Z%e" rd dl&m'Z' d dl(m)Z) d dl*m+Z+ e# rd dl,Z,d dlm-Z-mZm.Z.mZ e/ej0ej1 Z2e/ej0ej3 ej1 d Z4ddiddiddiddiddiddiddidZ5dZ6dd7e58  dZ9G dd  d Z:eG d!d" d"Z;d#e
fd$d%Z<G d&d' d'eZ=e>d(kre; Z?d)e?_@d*e?_@e=e?ZAeAB  dS dS )+    N)ArgumentParser	Namespace)AsyncIterator)	dataclassfield)Thread)Optional)AsyncInferenceClientChatCompletionStreamOutput)AutoTokenizerGenerationConfigPreTrainedTokenizer)BaseTransformersCLICommand)ServeArgumentsServeCommand)is_rich_availableis_torch_availableWindows)Console)Live)Markdown)AutoModelForCausalLMr   BitsAndBytesConfigr   z .!\"#$%&'()*+,\-/:<=>?@[]^_`{|}~textz5There is a Llama in my lawn, how can I get rid of it?zyWrite a Python function that integrates any Python function f(x) numerically over an arbitrary interval [x_start, x_end].z4How many helicopters can a human eat in one sitting?z4Count to 10 but skip every number ending with an 'e'zWhy aren't birds real?z2Why is it important to eat socks after meditating?z$Which number is larger, 9.9 or 9.11?)llamacode
helicopternumbersbirdssocksnumbers2a  

**TRANSFORMERS CHAT INTERFACE**

Chat interface to try out a model. Besides chatting with the model, here are some basic commands:
- **!help**: shows all available commands (set generation settings, save chat, etc.)
- **!status**: shows the current status of the model and generation settings
- **!clear**: clears the current conversation and starts a new one
- **!exit**: closes the interface
am  

**TRANSFORMERS CHAT INTERFACE HELP**

Full command list:
- **!help**: shows this help message
- **!clear**: clears the current conversation and starts a new one
- **!status**: shows the current status of the model and generation settings
- **!example {NAME}**: loads example named `{NAME}` from the config and uses it as the user input.
Available example names: `z`, `a%  `
- **!set {ARG_1}={VALUE_1} {ARG_2}={VALUE_2}** ...: changes the system prompt or generation settings (multiple
settings are separated by a space). Accepts the same flags and format as the `generate_flags` CLI argument.
If you're a new user, check this basic flag guide: https://huggingface.co/docs/transformers/llm_tutorial#common-options
- **!save {SAVE_NAME} (optional)**: saves the current chat and settings to file by default to
`./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml` or `{SAVE_NAME}` if provided
- **!exit**: closes the interface
c                   @   s   e Zd Zddee dee fddZdee deee	f fdd	Z
defd
dZdd ZdefddZdedefddZddefddZdededefddZdS )RichInterfaceN
model_name	user_namec                 C   s:   t  | _|d u rd| _n|| _|d u rd| _d S || _d S )N	assistantuser)r   _consoler"   r#   )selfr"   r#    r(   X/var/www/html/alpaca_bot/venv/lib/python3.10/site-packages/transformers/commands/chat.py__init__s   s   

zRichInterface.__init__streamreturnc           	         s   | j d| j d t| j dd\}d}|I d H 2 zK3 d H W }|jd jj}|s+qtdd|}||7 }g }|	 D ]}|
| |d	rN|
d
 q<|
d q<td| dd}|j|dd q6 W d    n1 srw   Y  | j   |S )Nz[bold blue]<z>:   )consolerefresh_per_second r   z<(/*)(\w*)>z\<\1\2\>z```
z  
zgithub-dark)
code_themeT)refresh)r&   printr"   r   choicesdeltacontentresub
splitlinesappend
startswithr   joinstripupdate)	r'   r+   liver   tokenoutputslineslinemarkdownr(   r(   r)   stream_output~   s,   


(zRichInterface.stream_outputc                 C   s$   | j d| j d}| j   |S )z!Gets user input from the console.[bold red]<z>:
)r&   inputr#   r4   )r'   rH   r(   r(   r)   rH      s   
zRichInterface.inputc                 C   s   | j   dS )zClears the console.N)r&   clearr'   r(   r(   r)   rI      s   zRichInterface.clearr   c                 C   s(   | j d| j d|  | j   dS )z%Prints a user message to the console.rG   z>:[/ bold red]
N)r&   r4   r#   )r'   r   r(   r(   r)   print_user_message   s   z RichInterface.print_user_messagecolorc                 C   s&   | j d| d|  | j   dS )z,Prints text in a given color to the console.z[bold ]Nr&   r4   )r'   r   rL   r(   r(   r)   print_color      zRichInterface.print_colorFminimalc                 C   s&   | j t|rtnt | j   dS )z'Prints the help message to the console.N)r&   r4   r   HELP_STRING_MINIMALHELP_STRING)r'   rQ   r(   r(   r)   
print_help   rP   zRichInterface.print_helpgeneration_configmodel_kwargsc                 C   sJ   | j d| d |r| j d|  | j d|  | j   dS )zFPrints the status of the model and generation settings to the console.z[bold blue]Model: r1   z[bold blue]Model kwargs: z[bold blue]NrN   )r'   r"   rU   rV   r(   r(   r)   print_status   s
   zRichInterface.print_status)NN)F)__name__
__module____qualname__r   strr*   r   r
   tupleintrF   rH   rI   rK   rO   boolrT   r   dictrW   r(   r(   r(   r)   r!   r   s    .r!   c                   @   s  e Zd ZU dZedddidZee ed< edddidZ	ee ed< eddd	idZ
ee ed
< edddidZeed< edddidZee ed< edddidZeed< edddidZee ed< edddidZeed< edddidZeed< eddg dddZee ed< eddg dddZee ed < eddd!idZeed"< eddd#idZee ed$< eddd%idZeed&< eddd'idZeed(< ed)d*d+d)gddZeed,< eddd-idZeed.< ed/dd0idZeed1< ed2dd3idZeed4< d5d6 ZdS )7ChatArgumentsz
    Arguments for the chat CLI.

    See the metadata arg for each argument's description -- the medatata will be printed with
    `transformers chat --help`
    Nhelpz_Name of the pre-trained model. The positional argument will take precedence if both are passed.)defaultmetadatamodel_name_or_pathzKUsername to display in chat interface. Defaults to the current user's name.r%   zSystem prompt.system_promptz./chat_history/zFolder to save chat history.save_folderz"Path to a yaml file with examples.examples_pathFz7Whether to show runtime warnings in the chat interface.verbosezPath to a local generation config file or to a HuggingFace repo containing a `generation_config.json` file. Other generation settings passed as CLI arguments will be applied on top of this generation config.rU   mainzLSpecific model version to use (can be a branch name, tag name or commit id).model_revisionautozDevice to use for inference.devicezA`torch_dtype` is deprecated! Please use `dtype` argument instead.)rk   bfloat16float16float32)ra   r5   torch_dtypezOverride the default `torch.dtype` and load the model under this dtype. If `'auto'` is passed, the dtype will be automatically derived from the model's weights.dtypez2Whether to trust remote code when loading a model.trust_remote_codezWhich attention implementation to use; you can run --attn_implementation=flash_attention_2, in which case you must install this manually by running `pip install flash-attn --no-build-isolation`.attn_implementationzIWhether to use 8 bit precision for the base model - works only with LoRA.load_in_8bitzIWhether to use 4 bit precision for the base model - works only with LoRA.load_in_4bitnf4zQuantization type.fp4bnb_4bit_quant_typez#Whether to use nested quantization.use_bnb_nested_quant	localhostz%Interface the server will listen to..host@  zPort the server will listen to.portc                 C   s(   | j dur| jdkr| j | _dS dS dS )z(Only used for BC `torch_dtype` argument.Nrk   )rp   rq   rJ   r(   r(   r)   __post_init__!  s   zChatArguments.__post_init__)rX   rY   rZ   __doc__r   rd   r   r[   __annotations__r%   re   rf   rg   rh   r^   rU   rj   rl   rp   rq   rr   rs   rt   ru   rx   ry   r{   r}   r]   r~   r(   r(   r(   r)   r`      sv   
 r`   argsc                 C   s   t | S )z;
    Factory function used to chat with a local model.
    )ChatCommand)r   r(   r(   r)   chat_command_factory(  s   r   c                   @   sr  e Zd ZedefddZdd ZedefddZed/d
e	de
e defddZed/de
e dee fddZdee defddZd
e	dedeeef fddZededede
e de
e deeee f f
ddZede	de
d fdd Zd
e	ded!ef fd"d#Zd$ed
e	d%ed&eeeeef f ded'ed(ee deee eef fd)d*Zd+d, Zd-d. Zd	S )0r   parserc                 C   sT   t f}| jd|d}|d}|jdtddd |jdtdd	d
d |jtd dS )z
        Register this command to argparse so it's available for the transformer-cli

        Args:
            parser: Root parser to register command-specific arguments
        chat)dataclass_typeszPositional argumentsmodel_name_or_path_or_addressNz7Name of the pre-trained model or address to connect to.)typerb   ra   generate_flagsa  Flags to pass to `generate`, using a space as a separator between flags. Accepts booleans, numbers, and lists of integers, more advanced parameterization should be set through --generation-config. Example: `transformers chat <model_repo> max_new_tokens=100 do_sample=False eos_token_id=[1,2]`. If you're a new user, check this basic flag guide: https://huggingface.co/docs/transformers/llm_tutorial#common-options*)r   rb   ra   nargs)func)r`   
add_parseradd_argument_groupadd_argumentr[   set_defaultsr   )r   r   chat_parsergroupr(   r(   r)   register_subcommand0  s"   
zChatCommand.register_subcommandc                 C   s   |j d urD|j }|ds|ds|dr=d| _|jdks$|jdkr(td|j dd\|_|_|jd u r<td	nd
| _|j |_t sQt	 sQ| jrQt
dt sXt
dt	 sb| jrbt
d|| _d S )Nhttphttpsrz   Fr|   uu   Looks like you’ve set both a server address and a custom host/port. Please pick just one way to specify the server.:   z\When connecting to a server, please specify a model name with the --model_name_or_path flag.TzYou need to install rich to use the chat interface. Additionally, you have not specified a remote endpoint and are therefore spawning a backend. Torch is required for this: (`pip install rich torch`)zHYou need to install rich to use the chat interface. (`pip install rich`)zYou have not specified a remote endpoint and are therefore spawning a backend. Torch is required for this: (`pip install rich torch`))r   r<   spawn_backendr{   r}   
ValueErrorrsplitrd   r   r   ImportErrorr   )r'   r   namer(   r(   r)   r*   Q  s6   


zChatCommand.__init__r,   c                   C   s$   t  dkr
t S tt jS )z)Returns the username of the current user.r   )platformsystemosgetloginpwdgetpwuidgetuidpw_namer(   r(   r(   r)   get_usernamex  s   zChatCommand.get_usernameNr   filenamec                 C   s   i }t ||d< | |d< |j}|du r(td}|j d| d}tj||}tjtj	|dd t
|d	}tj||d
d W d   n1 sKw   Y  tj|S )z!Saves the chat history to a file.settingschat_historyNz%Y-%m-%d_%H-%M-%Sz/chat_.jsonT)exist_okwr-   )indent)varsrf   timestrftimer   r   pathr=   makedirsdirnameopenjsondumpabspath)r   r   r   output_dictfoldertime_strfr(   r(   r)   	save_chat  s   
zChatCommand.save_chatre   c                 C   s    | du rg }|S d| dg}|S )zClears the chat history.Nr   roler7   r(   )re   r   r(   r(   r)   clear_chat_history  s
   zChatCommand.clear_chat_historyr   c                    s   t |dkri S dd |D }dd | D }dd | D }dtdtfdd	  fd
d| D }ddd | D }d| d }|dd}|dd}|dd}|dd}|dd}|dd}zt|}W |S  tjy   t	dw )zUParses the generate flags from the user input into a dictionary of `generate` kwargs.r   c                 S   s.   i | ]}d | dd  d  | dd qS )"=r   r   )split).0flagr(   r(   r)   
<dictcomp>  s   . z4ChatCommand.parse_generate_flags.<locals>.<dictcomp>c                 S   s*   i | ]\}}||  d v r|  n|qS ))truefalse)lowerr   kvr(   r(   r)   r     s    c                 S   s"   i | ]\}}||d krdn|qS )Nonenullr(   r   r(   r(   r)   r     s   " sr,   c                 S   s(   |  dr| dd  } | ddd S )N-r   .r0   )r<   replaceisdigit)r   r(   r(   r)   	is_number  s   
z3ChatCommand.parse_generate_flags.<locals>.is_numberc                    s*   i | ]\}}| |sd | d n|qS )r   r(   r   r   r(   r)   r     s   * z, c                 S   s   g | ]\}}| d | qS )z: r(   r   r(   r(   r)   
<listcomp>  s    z4ChatCommand.parse_generate_flags.<locals>.<listcomp>{}z"null"r   z"true"r   z"false"r   z"[[z]"rM   r   r   zFailed to convert `generate_flags` into a valid JSON object.
`generate_flags` = {generate_flags}
Converted JSON string = {generate_flags_string})
lenitemsr[   r^   r=   r   r   loadsJSONDecodeErrorr   )r'   r   generate_flags_as_dictgenerate_flags_stringprocessed_generate_flagsr(   r   r)   parse_generate_flags  s2   z ChatCommand.parse_generate_flagsmodel_generation_configc                 C   s   |j dur&d|j v rtj|j }tj|j }t||}nt|j }nt|}|j	di ddd | 
|j}|j	di |}||fS )zj
        Returns a GenerationConfig object holding the generation parameters for the CLI command.
        Nr   T   )	do_samplemax_new_tokensr(   )rU   r   r   r   basenamer   from_pretrainedcopydeepcopyr?   r   r   )r'   r   r   r   r   rU   parsed_generate_flagsrV   r(   r(   r)   get_generation_parameterization  s   


z+ChatCommand.get_generation_parameterization	tokenizerrU   
eos_tokenseos_token_idsc                 C   s|   |j du r	|j}n|j }g }|dur|| |d |dur.|dd |dD  t|dkr:||j ||fS )z:Retrieves the pad token ID and all possible EOS token IDs.N,c                 S   s   g | ]}t |qS r(   )r]   )r   token_idr(   r(   r)   r     s    z0ChatCommand.parse_eos_tokens.<locals>.<listcomp>r   )pad_token_ideos_token_idextendconvert_tokens_to_idsr   r   r;   )r   rU   r   r   r   all_eos_token_idsr(   r(   r)   parse_eos_tokens  s   
zChatCommand.parse_eos_tokens
model_argsr   c                 C   s@   | j rtd| j| j| j| jd}|S | jrtdd}|S d }|S )NT)ru   bnb_4bit_compute_dtyperx   bnb_4bit_use_double_quantbnb_4bit_quant_storage)rt   )ru   r   rq   rx   ry   rt   )r   quantization_configr(   r(   r)   get_quantization_config  s    z#ChatCommand.get_quantization_configr   c                 C   s   t j|j|j|jd}|jdv r|jntt|j}| |}|j|j	|d|d}t
j|jfd|ji|}t|dd d u rC||j}||fS )N)revisionrr   )rk   Nrk   )r   rs   rq   
device_mapr   rr   hf_device_map)r   r   model_name_or_path_positionalrj   rr   rq   getattrtorchr   rs   r   torl   )r'   r   r   rq   r   rV   modelr(   r(   r)   load_model_and_tokenizer   s.   
z$ChatCommand.load_model_and_tokenizer
user_input	interfaceexamplesrV   r   c                 C   s  d}|dkr|  |j}|  n|dkr|  n|drKt| dk rK| }	t|	dkr6|	d }
nd}
| |||
}
|jd|
 d	d
d n|dr|dd 	 }| }|D ]}d|vrq|jd| ddd  nq^| 
|}|jdi |}|jdi | ng|drt| dkr| d }||v r|  g }||| d  |d|| d d n4d| dt|  d}|j|dd n|dkr|j|j||d nd}|jd| ddd |  ||||fS )z
        Handles all user commands except for `!exit`. May update the chat history (e.g. reset it) or the
        generation config (e.g. set a new flag).
        Tz!clearz!helpz!save   r   NzChat saved in !green)r   rL   z!setr-   r   z(Invalid flag format, missing `=` after `z;`. Please use the format `arg_1=value_1 arg_2=value_2 ...`.red!exampler   r%   r   zExample z* not found in list of available examples: r   z!status)r"   rU   rV   F'z/' is not a valid command. Showing help message.r(   )r   re   rI   rT   r<   r   r   r   rO   r>   r   r?   rK   r;   listkeysrW   rd   )r'   r  r   r  r  rU   rV   r   valid_commandsplit_inputr   new_generate_flagsr   parsed_new_generate_flagsnew_model_kwargsexample_nameexample_errorr(   r(   r)   handle_non_exit_user_commands;  s`   





z)ChatCommand.handle_non_exit_user_commandsc                 C   s   t |   d S N)asynciorun
_inner_runrJ   r(   r(   r)   r    s   zChatCommand.runc                    sr  | j r8t| jj| jj| jj| jj| jj| jj| jj	| jj
| jj| jjdd}t|}t|jd}d|_|  | jjd | jj }| jjdkrJdn| jj}t| d| jj }| j}|jd u rdt}nt|j}	t|	}W d    n1 syw   Y  |jd u r|  }
n|j}
t|j}| ||\}}t|j|
d	}|   | !|j"}|j#dd
 	 z~za|$ }|%dr|dkrW W |& I d H  d S | j'|||||||d\}}}}|r|%dsW W |& I d H  qn|(d|d |j)|d|* |dd}|+|I d H }|(d|d W n t,y&   Y W |& I d H  d S w W |& I d H  n|& I d H  w q)Nerror)rl   rq   rr   rs   rt   ru   rx   ry   r{   r}   	log_level)targetT@rz   zhttp://localhostr   )r"   r#   )rQ   r  z!exit)r  r   r  r  rU   rV   r   r  r%   r   )rU   r  )r+   
extra_bodyr$   )-r   r   r   rl   rq   rr   rs   rt   ru   rx   ry   r{   r}   r   r   r  daemonstartrd   rj   r	   rg   DEFAULT_EXAMPLESr   yaml	safe_loadr%   r   r   r   r   r!   rI   r   re   rT   rH   r<   closer  r;   chat_completionto_json_stringrF   KeyboardInterrupt)r'   
serve_argsserve_commandthreadr  r{   clientr   r  r   r%   r   rU   rV   r  r   r  r  r+   model_outputr(   r(   r)   r    s   



!
	 zChatCommand._inner_runr  )rX   rY   rZ   staticmethodr   r   r*   r[   r   r`   r   r   r  r_   r   r   r   r\   r   r   r]   r   r   r   r  r!   r  r  r  r(   r(   r(   r)   r   /  sh     '
6

	
Tr   __main__z meta-llama/Llama-3.2-3b-Instructzhttp://localhost:8000)Cr  r   r   r   r   r8   stringr   argparser   r   collections.abcr   dataclassesr   r   	threadingr   typingr   r#  huggingface_hubr	   r
   transformersr   r   r   transformers.commandsr   transformers.commands.servingr   r   transformers.utilsr   r   r   r   rich.consoler   	rich.liver   rich.markdownr   r   r   r   setascii_letters
whitespaceALLOWED_KEY_CHARSdigitsALLOWED_VALUE_CHARSr"  rR   r=   r  rS   r!   r`   r   r   rX   r   r   r   r  r(   r(   r(   r)   <module>   sz   	\Y   
@