Skip to content

Commit fea1c61

Browse files
committed
refactor code and support cuda.
1 parent 5f59b9b commit fea1c61

File tree

6 files changed

+253
-20257
lines changed

6 files changed

+253
-20257
lines changed

CMakeLists.txt

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@ link_directories(${CMAKE_CURRENT_LIST_DIR}/libs)
1212
# source files
1313
FILE(GLOB SRCS ${CMAKE_CURRENT_LIST_DIR}/src/*.cpp)
1414

15-
# target
16-
add_executable(chat ${SRCS})
15+
# compile dynamic lib
16+
add_library(chat SHARED ${SRCS})
17+
target_link_libraries(chat MNN MNN_Express)
1718

18-
if (MSVC)
19-
target_link_libraries(chat MNN)
20-
else()
21-
target_link_libraries(chat MNN)
22-
endif()
19+
# demo target
20+
add_executable(demo ${CMAKE_CURRENT_LIST_DIR}/demo/main.cpp)
21+
target_link_libraries(demo chat)

README.md

Lines changed: 30 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,11 @@
11
# ChatGLM-MNN
2+
## Describe
3+
该模型使用ChatGLM-6B, 将其转换到MNN模型并使用C++进行推理。在实现上做了如下优化:
4+
5+
1. 对其中的词表做了部分删减,删除了模型中未使用的前20000个词;将vocab, embedding, lm_head的大小从130528缩小到130528;
6+
2. `Embedding`操作调用次数较少,使用`fseek`, `fread`加载的方式降低内存;
7+
3. `lm_head`操作为`[num, 4096] @ [4096, 130528]`,将其优化为分段实现的矩阵乘`[130528, 4096] @ [4096, 1]`;
8+
2. 原模型对显存要求较高;将模型按层拆分成28个模型,可以根据用户显存大小动态将计算任务分配给GPU和CPU,充分利用显卡算力;
29

310
## Usage
411
### 1. Compile MNN library
@@ -20,7 +27,7 @@ mkdir build
2027
cd build
2128
cmake ..
2229
make -j8
23-
./chat
30+
./demo
2431
```
2532

2633
## Example
@@ -75,4 +82,25 @@ A: 内存泄漏(Memory Leak)是指程序在运行过程中,占用的内存空间
7582

7683
5. 定期清理无用的数据:如果程序中存在大量的无用数据,也会导致内存泄漏。可以通过定期清理这些数据,释放内存空间。
7784

78-
内存泄漏问题的解决需要从多个方面入手,通过不断地调试和优化程序,来找到内存泄漏的根本原因,并有效地解决问题。
85+
内存泄漏问题的解决需要从多个方面入手,通过不断地调试和优化程序,来找到内存泄漏的根本原因,并有效地解决问题。
86+
87+
---
88+
Q: 北京有什么特色小吃?
89+
90+
A: 北京是中国历史文化名城,也是中国美食之都之一,有许多特色小吃。以下是一些著名的北京特色小吃:
91+
92+
1. 炸酱面:炸酱面是中国传统面食之一,以黄酱和肉末为主要材料,配以豆瓣酱、黄瓜丝和豆芽等配料,味道鲜美。
93+
94+
2. 烤鸭:烤鸭是北京最著名的美食之一,以薄饼和鸭肉为主要材料,烤制过程中还会加入葱、姜等调料,口感鲜美。
95+
96+
3. 豆汁:豆汁是一种传统的北京小吃,以黄豆为主要原料,配以辣椒油、醋、蒜泥等调料,味道酸甜可口。
97+
98+
4. 羊蝎子:羊蝎子是一道以羊肉和羊肝为主要材料的炖菜,口感鲜美,营养丰富。
99+
100+
5. 糖葫芦:糖葫芦是一种传统的北京小吃,以草莓、山楂等水果为主料,沾上糖浆,口感酸甜可口。
101+
102+
6. 煎饼果子:煎饼果子是一种流行的中式早餐,以薄饼和蛋、肉松、油条等为主要材料,口感酥脆。
103+
104+
7. 驴打滚:驴打滚是一种传统的北京小吃,以糯米粉和豆沙为主要材料,通过卷起来和炸的方式制作,口感香甜。
105+
106+
这只是北京众多特色小吃中的一小部分,北京还有很多其他美食,如北京火锅、北京炸酱面、北京小吃街等等,值得一试。

demo/main.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
//
2+
// chat.cpp
3+
//
4+
// Created by MNN on 2023/03/24.
5+
// ZhaodeWang
6+
//
7+
8+
#include "chat.hpp"
9+
#include <iostream>
10+
11+
int main(int argc, const char* argv[]) {
12+
ChatGLM chatglm;
13+
chatglm.chat();
14+
return 0;
15+
}

include/chat.hpp

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
//
2+
// chat.hpp
3+
//
4+
// Created by MNN on 2023/03/24.
5+
// ZhaodeWang
6+
//
7+
8+
#ifndef CHAT_hpp
9+
#define CHAT_hpp
10+
11+
#include <vector>
12+
#include <memory>
13+
#include <string>
14+
#include <unordered_map>
15+
16+
#include <MNN/AutoTime.hpp>
17+
#include <MNN/expr/Expr.hpp>
18+
#include <MNN/expr/Module.hpp>
19+
#include <MNN/expr/MathOp.hpp>
20+
#include <MNN/expr/NeuralNetWorkOp.hpp>
21+
22+
using namespace MNN;
23+
using namespace Express;
24+
25+
static constexpr int MASK = 130000;
26+
static constexpr int gMASK = 130001;
27+
static constexpr int BOS = 130004;
28+
static constexpr int EOS = 130005;
29+
30+
static constexpr int VOCAB_SIZE = 130528;
31+
static constexpr int HIDDEN_SIZE = 4096;
32+
static constexpr int LAYER_SIZE = 28;
33+
34+
35+
class ChatGLM {
36+
public:
37+
// your cuda memory size (G)
38+
ChatGLM(float cuda_memory = 0) {
39+
init(cuda_memory);
40+
}
41+
void chat();
42+
std::string response(const std::string& input_str, bool debuginfo = false);
43+
private:
44+
void init(float cuda_memory);
45+
void loadModel(const char* fileName, bool cuda = false);
46+
std::vector<int> tokenizer_encode(std::string input_str);
47+
VARP gen_embedding(const std::vector<int>& input_ids);
48+
VARP gen_attention_mask(const std::vector<int>& input_ids);
49+
VARP gen_position_ids(const std::vector<int>& input_ids);
50+
int var_to_token(VARP var);
51+
int forward(const std::vector<int>& input_ids);
52+
private:
53+
std::vector<std::string> mWordDecode;
54+
std::unordered_map<std::string, int> mWordEncode;
55+
// MNN Modules
56+
std::shared_ptr<Executor::RuntimeManager> mCPURtmgr;
57+
std::shared_ptr<Executor::RuntimeManager> mCUDARtmgr;
58+
std::vector<std::shared_ptr<Module>> mModules;
59+
std::vector<VARP> mHistoryVars;
60+
// mask info
61+
int mSeqLen = 0, mContextLen = -1, mMaskIdx = -1;
62+
};
63+
64+
#endif // CHAT_hpp

0 commit comments

Comments
 (0)