Skip to content

Commit 152b632

Browse files
authored
Create model_chat.py
1 parent 9b4a4dd commit 152b632

File tree

1 file changed

+30
-0
lines changed

1 file changed

+30
-0
lines changed

src/model_chat.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
"""Talk to your model."""
2+
3+
import torch
4+
from util import convert
5+
6+
if __name__ == "__main__":
7+
model_path = "./saved/net_save.pkl"
8+
with open(model_path, "rb") as weight_file:
9+
model, args, data_loader = torch.load(
10+
weight_file, map_location=torch.device("cpu"), weights_only=False
11+
)
12+
13+
model.eval()
14+
print("args", args)
15+
seq_len = 256
16+
17+
prompt = input("Enter prompt:")
18+
encoded_promt = [data_loader.vocab[char] for char in prompt]
19+
encoded_promt = torch.tensor(encoded_promt)
20+
init_chars = (
21+
torch.nn.functional.one_hot(encoded_promt, num_classes=data_loader.vocab_size)
22+
.unsqueeze(0)
23+
.type(torch.float32)
24+
)
25+
print(init_chars.shape)
26+
sequences = model.sample(init_chars, seq_len)
27+
sequences = torch.argmax(sequences, -1)
28+
seq_conv = convert(sequences, data_loader.inv_vocab)
29+
30+
print("".join(seq_conv[0]))

0 commit comments

Comments
 (0)