File tree Expand file tree Collapse file tree 1 file changed +30
-0
lines changed
Expand file tree Collapse file tree 1 file changed +30
-0
lines changed Original file line number Diff line number Diff line change 1+ """Talk to your model."""
2+
3+ import torch
4+ from util import convert
5+
6+ if __name__ == "__main__" :
7+ model_path = "./saved/net_save.pkl"
8+ with open (model_path , "rb" ) as weight_file :
9+ model , args , data_loader = torch .load (
10+ weight_file , map_location = torch .device ("cpu" ), weights_only = False
11+ )
12+
13+ model .eval ()
14+ print ("args" , args )
15+ seq_len = 256
16+
17+ prompt = input ("Enter prompt:" )
18+ encoded_promt = [data_loader .vocab [char ] for char in prompt ]
19+ encoded_promt = torch .tensor (encoded_promt )
20+ init_chars = (
21+ torch .nn .functional .one_hot (encoded_promt , num_classes = data_loader .vocab_size )
22+ .unsqueeze (0 )
23+ .type (torch .float32 )
24+ )
25+ print (init_chars .shape )
26+ sequences = model .sample (init_chars , seq_len )
27+ sequences = torch .argmax (sequences , - 1 )
28+ seq_conv = convert (sequences , data_loader .inv_vocab )
29+
30+ print ("" .join (seq_conv [0 ]))
You can’t perform that action at this time.
0 commit comments