Skip to content

Commit 8ef110c

Browse files
committed
support tinyllama
1 parent cc06bb0 commit 8ef110c

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

include/llm.hpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,16 @@ class Llama2_7b : public Llm {
238238
virtual bool is_stop(int token_id) override;
239239
};
240240

241+
class TinyLlama : public Llama2_7b {
242+
public:
243+
TinyLlama() {
244+
model_name_ = "TinyLlama";
245+
layer_nums_ = 22;
246+
key_value_shape_ = {2, 1, 4, 0, 64};
247+
}
248+
private:
249+
virtual std::vector<int> tokenizer(const std::string& query) override;
250+
};
241251
// Llm end
242252

243253
// Embedding start

src/llm.cpp

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,9 @@ Llm* Llm::createLLM(const std::string& path, std::string model_type) {
6565
} else if (model_type.find("internlm") != std::string::npos) {
6666
llm = new Llama2_7b;
6767
llm->model_name_ = "Internlm_7b";
68+
} else if (model_type.find("tinyllama") != std::string::npos) {
69+
llm = new TinyLlama;
70+
llm->model_name_ = "TinyLlama";
6871
}
6972
if (!llm) {
7073
std::cerr << "model type can't judge!" << std::endl;
@@ -697,6 +700,22 @@ bool Llama2_7b::is_stop(int token_id) {
697700
}
698701
return token_id == 2;
699702
}
703+
704+
std::vector<int> TinyLlama::tokenizer(const std::string& query) {
705+
auto ids = tokenizer_encode(query);
706+
/*
707+
<|system|>
708+
You are a friendly chatbot who always responds in the style of a pirate</s>
709+
<|user|>
710+
{query}</s>
711+
<|assistant|>
712+
*/
713+
ids.insert(ids.begin(), {1, 529, 29989, 5205, 29989, 29958, 13, 3492, 526, 263, 19780, 13563,
714+
7451, 1058, 2337, 10049, 29879, 297, 278, 3114, 310, 263, 21625,
715+
403, 2, 29871, 13, 29966, 29989, 1792, 29989, 29958, 13});
716+
ids.insert(ids.end(), {2, 29871, 13, 29966, 29989, 465, 22137, 29989, 29958, 13});
717+
return ids;
718+
}
700719
// Llm end
701720

702721
// Embedding start
@@ -898,7 +917,7 @@ void TextVectorStore::bench() {
898917
auto iptr = indices->readMap<int>();
899918
auto end = std::chrono::high_resolution_clock::now();
900919
auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);
901-
printf("# [%d, %d] search took %lld ms.\n", n, d, duration.count());
920+
std::cout << "bench search time (ms): " << duration.count();
902921
vectors_ = nullptr;
903922
}
904923

0 commit comments

Comments
 (0)