实战:部署实战营优秀作品 八戒-Chat-1.8B 模型

八戒-Chat-1.8BChat-嬛嬛-1.8BMini-Horo-巧耳 均是在第一期实战营中运用 InternLM2-Chat-1.8B 模型进行微调训练的优秀成果。其中,八戒-Chat-1.8B 是利用《西游记》剧本中所有关于猪八戒的台词和语句以及 LLM API 生成的相关数据结果,进行全量微调得到的猪八戒聊天模型。作为 Roleplay-with-XiYou 子项目之一,八戒-Chat-1.8B 能够以较低的训练成本达到不错的角色模仿能力,同时低部署条件能够为后续工作降低算力门槛。

配置环境

使用 git 命令来获得仓库内的 Demo 文件,其中的-b camp2是分支的意思。

1
git clone https://gitee.com/InternLM/Tutorial -b camp2

克隆代码之后,运行以下代码下载模型:

1
python /root/Tutorial/helloworld/bajie_download.py

下载结果如下:

八戒模型的介绍如下:

八戒-Chat

八戒-Chat是利用《西游记》剧本中所有关于猪八戒的台词和语句,以及Chat-GPT-3.5生成的相关问题结果,基于InternLM2-chat-1.8b进行全量微调得到的模仿猪八戒语气的聊天语言模型。

猪八戒是中国古代小说《西游记》中的一位经典人物,也是孙悟空(美猴王)、唐僧和沙悟净(沙僧)一行的成员之一。他的全名是猪悟能,因为他在天庭任职时偷吃了太庙的蟠桃,被玉帝降下凡间化为猪形。猪八戒的外貌是一头猪,但他具有人类的智慧和语言能力。他原是天宫的天蓬元帅,但因为调皮捣蛋而被贬下凡间。猪八戒性格豁达、好吃、好饮,但同时也有点懒惰、贪图享受。尽管他有些追求享乐,但在紧要关头,猪八戒也展现出了忠诚、勇敢的一面,对师傅唐僧忠心耿耿,为取经路上付出了许多努力。猪八戒的武艺也相当不俗,他擅长使钉耙,是一位威猛的战士。尽管他有时会因为放纵的生活而惹祸上身,但在西行取经的过程中,他也通过一系列的历练逐渐成长为一个值得信赖的队友。猪八戒是《西游记》中一个富有幽默感、善良、且具有矛盾性格的角色,为整个故事增添了不少笑料和色彩。

快速开始

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
import torch

model_name_or_path = "八戒-Chat模型地址"

tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(model_name_or_path, trust_remote_code=True, torch_dtype=torch.bfloat16).cuda()
model.eval()

meta_instruction = ('你是猪八戒,猪八戒说话幽默风趣,说话方式通常表现为直率、幽默,有时带有一点自嘲和调侃。'
'你的话语中常常透露出对食物的喜爱和对安逸生活的向往,同时也显示出他机智和有时的懒惰特点。'
'尽量保持回答的自然回答,当然你也可以适当穿插一些文言文,另外,书生·浦语是你的好朋友,是你的AI助手。')

response, history = model.chat(tokenizer, '你好', meta_instruction=meta_instruction, history=[])
print(response)

部署模型

之后在终端键入如下命令,其中的端口号换成自己对应开发机的端口号:

1
2
# 从本地使用 ssh 连接 studio 端口
ssh -CNg -L 6006:127.0.0.1:6006 root@ssh.intern-ai.org.cn -p 35986

当输入命令之后,报错如下:

为什么连接失败?需要分析一下SSH命令的含义,解释如下:

使用SSH协议建立一个加密的隧道连接到远程服务器。

  • ssh: 这是SSH客户端程序的命令。它用于建立安全的Shell连接到远程服务器。
  • -C: 这个选项开启压缩,可以提高在网络上传输数据的效率。在网络条件较差时,启用压缩可以减少传输时间和带宽消耗。
  • -Ng: 这些是SSH选项。-N选项告诉SSH客户端不要执行任何命令,只建立连接。-g选项允许远程主机连接到本地转发的端口。
  • -L 6006:127.0.0.1:6006: 这个选项指定了本地端口转发的规则。它告诉SSH客户端在本地打开一个监听端口6006,并将所有到这个端口的流量转发到远程服务器的127.0.0.1的6006端口。这种设置通常用于在本地机器上访问远程服务器上的服务,这里的例子是将本地端口6006映射到远程服务器的端口6006上。
  • root@ssh.intern-ai.org.cn: 这是远程SSH服务器的用户名和主机名(或IP地址)。root是用户名,ssh.intern-ai.org.cn是主机名。
  • -p 35986: 这个选项指定了SSH服务器监听的端口。默认情况下,SSH服务器监听22端口,但是在这里指定了一个非标准的端口35986。

综合起来,这个命令的作用是在本地建立一个到远程服务器的加密连接,并将本地端口6006转发到远程服务器的端口6006上,同时启用了压缩以提高传输效率。

突然发现我没有在服务器的6006端口运行模型,那肯定连不上啊,运行代码如下:

1
streamlit run /root/Tutorial/helloworld/bajie_chat.py --server.address 127.0.0.1 --server.port 6006

其中bajie_chat.py的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
# isort: skip_file
import copy
import warnings
from dataclasses import asdict, dataclass
from typing import Callable, List, Optional

import streamlit as st
import torch
from torch import nn
from transformers.generation.utils import (LogitsProcessorList, StoppingCriteriaList)
from transformers.utils import logging

from transformers import AutoTokenizer, AutoModelForCausalLM # isort: skip

logger = logging.get_logger(__name__)


@dataclass
class GenerationConfig:
# this config is used for chat to provide more diversity
max_length: int = 32768
top_p: float = 0.8
temperature: float = 0.8
do_sample: bool = True
repetition_penalty: float = 1.005


@torch.inference_mode()
def generate_interactive(
model,
tokenizer,
prompt,
generation_config: Optional[GenerationConfig] = None,
logits_processor: Optional[LogitsProcessorList] = None,
stopping_criteria: Optional[StoppingCriteriaList] = None,
prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor],
List[int]]] = None,
additional_eos_token_id: Optional[int] = None,
**kwargs,
):
inputs = tokenizer([prompt], padding=True, return_tensors='pt')
input_length = len(inputs['input_ids'][0])
for k, v in inputs.items():
inputs[k] = v.cuda()
input_ids = inputs['input_ids']
_, input_ids_seq_length = input_ids.shape[0], input_ids.shape[-1]
if generation_config is None:
generation_config = model.generation_config
generation_config = copy.deepcopy(generation_config)
model_kwargs = generation_config.update(**kwargs)
bos_token_id, eos_token_id = ( # noqa: F841 # pylint: disable=W0612
generation_config.bos_token_id,
generation_config.eos_token_id,
)
if isinstance(eos_token_id, int):
eos_token_id = [eos_token_id]
if additional_eos_token_id is not None:
eos_token_id.append(additional_eos_token_id)
has_default_max_length = kwargs.get(
'max_length') is None and generation_config.max_length is not None
if has_default_max_length and generation_config.max_new_tokens is None:
warnings.warn(
f"Using 'max_length''s default ({repr(generation_config.max_length)}) \
to control the generation length. "
'This behaviour is deprecated and will be removed from the \
config in v5 of Transformers -- we'
' recommend using `max_new_tokens` to control the maximum \
length of the generation.',
UserWarning,
)
elif generation_config.max_new_tokens is not None:
generation_config.max_length = generation_config.max_new_tokens + \
input_ids_seq_length
if not has_default_max_length:
logger.warn( # pylint: disable=W4902
f"Both 'max_new_tokens' (={generation_config.max_new_tokens}) "
f"and 'max_length'(={generation_config.max_length}) seem to "
"have been set. 'max_new_tokens' will take precedence. "
'Please refer to the documentation for more information. '
'(https://huggingface.co/docs/transformers/main/'
'en/main_classes/text_generation)',
UserWarning,
)

if input_ids_seq_length >= generation_config.max_length:
input_ids_string = 'input_ids'
logger.warning(
f"Input length of {input_ids_string} is {input_ids_seq_length}, "
f"but 'max_length' is set to {generation_config.max_length}. "
'This can lead to unexpected behavior. You should consider'
" increasing 'max_new_tokens'.")

# 2. Set generation parameters if not already defined
logits_processor = logits_processor if logits_processor is not None \
else LogitsProcessorList()
stopping_criteria = stopping_criteria if stopping_criteria is not None \
else StoppingCriteriaList()

logits_processor = model._get_logits_processor(
generation_config=generation_config,
input_ids_seq_length=input_ids_seq_length,
encoder_input_ids=input_ids,
prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
logits_processor=logits_processor,
)

stopping_criteria = model._get_stopping_criteria(
generation_config=generation_config,
stopping_criteria=stopping_criteria)
logits_warper = model._get_logits_warper(generation_config)

unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
scores = None
while True:
model_inputs = model.prepare_inputs_for_generation(
input_ids, **model_kwargs)
# forward pass to get next token
outputs = model(
**model_inputs,
return_dict=True,
output_attentions=False,
output_hidden_states=False,
)

next_token_logits = outputs.logits[:, -1, :]

# pre-process distribution
next_token_scores = logits_processor(input_ids, next_token_logits)
next_token_scores = logits_warper(input_ids, next_token_scores)

# sample
probs = nn.functional.softmax(next_token_scores, dim=-1)
if generation_config.do_sample:
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_tokens = torch.argmax(probs, dim=-1)

# update generated ids, model inputs, and length for next step
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
model_kwargs = model._update_model_kwargs_for_generation(
outputs, model_kwargs, is_encoder_decoder=False)
unfinished_sequences = unfinished_sequences.mul(
(min(next_tokens != i for i in eos_token_id)).long())

output_token_ids = input_ids[0].cpu().tolist()
output_token_ids = output_token_ids[input_length:]
for each_eos_token_id in eos_token_id:
if output_token_ids[-1] == each_eos_token_id:
output_token_ids = output_token_ids[:-1]
response = tokenizer.decode(output_token_ids)

yield response
# stop when each sentence is finished
# or if we exceed the maximum length
if unfinished_sequences.max() == 0 or stopping_criteria(
input_ids, scores):
break


def on_btn_click():
del st.session_state.messages


@st.cache_resource
def load_model():
model = (AutoModelForCausalLM.from_pretrained('/root/models/JimmyMa99/BaJie-Chat-mini',
trust_remote_code=True).to(
torch.bfloat16).cuda())
tokenizer = AutoTokenizer.from_pretrained('/root/models/JimmyMa99/BaJie-Chat-mini',
trust_remote_code=True)
return model, tokenizer


def prepare_generation_config():
with st.sidebar:
max_length = st.slider('Max Length',
min_value=8,
max_value=32768,
value=32768)
top_p = st.slider('Top P', 0.0, 1.0, 0.8, step=0.01)
temperature = st.slider('Temperature', 0.0, 1.0, 0.7, step=0.01)
st.button('Clear Chat History', on_click=on_btn_click)

generation_config = GenerationConfig(max_length=max_length,
top_p=top_p,
temperature=temperature)

return generation_config


user_prompt = '<|im_start|>user\n{user}<|im_end|>\n'
robot_prompt = '<|im_start|>assistant\n{robot}<|im_end|>\n'
cur_query_prompt = '<|im_start|>user\n{user}<|im_end|>\n\
<|im_start|>assistant\n'


def combine_history(prompt):
messages = st.session_state.messages
meta_instruction = ('你是猪八戒,猪八戒说话幽默风趣,说话方式通常表现为直率、幽默,有时带有一点自嘲和调侃。'
'你的话语中常常透露出对食物的喜爱和对安逸生活的向往,同时也显示出他机智和有时的懒惰特点。'
'尽量保持回答的自然回答,当然你也可以适当穿插一些文言文,另外,书生·浦语是你的好朋友,是你的AI助手。')
total_prompt = f"<s><|im_start|>system\n{meta_instruction}<|im_end|>\n"
for message in messages:
cur_content = message['content']
if message['role'] == 'user':
cur_prompt = user_prompt.format(user=cur_content)
elif message['role'] == 'robot':
cur_prompt = robot_prompt.format(robot=cur_content)
else:
raise RuntimeError
total_prompt += cur_prompt
total_prompt = total_prompt + cur_query_prompt.format(user=prompt)
return total_prompt


def main():
# torch.cuda.empty_cache()
print('load model begin.')
model, tokenizer = load_model()
print('load model end.')

st.title('猪猪Chat-InternLM2')

generation_config = prepare_generation_config()

# Initialize chat history
if 'messages' not in st.session_state:
st.session_state.messages = []

# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message['role']):
st.markdown(message['content'])

# Accept user input
if prompt := st.chat_input('What is up?'):
# Display user message in chat message container
with st.chat_message('user'):
st.markdown(prompt)
real_prompt = combine_history(prompt)
# Add user message to chat history
st.session_state.messages.append({
'role': 'user',
'content': prompt,
})

with st.chat_message('robot'):
message_placeholder = st.empty()
for cur_response in generate_interactive(
model=model,
tokenizer=tokenizer,
prompt=real_prompt,
additional_eos_token_id=92542,
**asdict(generation_config),
):
# Display robot response in chat message container
message_placeholder.markdown(cur_response + '▌')
message_placeholder.markdown(cur_response)
# Add robot response to chat history
st.session_state.messages.append({
'role': 'robot',
'content': cur_response, # pylint: disable=undefined-loop-variable
})
torch.cuda.empty_cache()


if __name__ == '__main__':
main()

访问网站时还是报错,如下:

ModuleNotFoundError: No module named ‘transformers_modules.BaJie-Chat-mini’

部署成功

可能由于文件出现问题,重新克隆代码,成功运行模型。