@@ -13,6 +13,8 @@ class VertexAICallbackHandler(BaseCallbackHandler):
1313 completion_tokens : int = 0
1414 completion_characters : int = 0
1515 successful_requests : int = 0
16+ total_tokens : int = 0
17+ cached_tokens : int = 0
1618
1719 def __init__ (self ) -> None :
1820 super ().__init__ ()
@@ -24,6 +26,8 @@ def __repr__(self) -> str:
2426 f"\t Prompt characters: { self .prompt_characters } \n "
2527 f"\t Completion tokens: { self .completion_tokens } \n "
2628 f"\t Completion characters: { self .completion_characters } \n "
29+ f"\t Cached tokens: { self .cached_tokens } \n "
30+ f"\t Total tokens: { self .total_tokens } \n "
2731 f"Successful requests: { self .successful_requests } \n "
2832 )
2933
@@ -44,7 +48,7 @@ def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
4448
4549 def on_llm_end (self , response : LLMResult , ** kwargs : Any ) -> None :
4650 """Collects token usage."""
47- completion_tokens , prompt_tokens = 0 , 0
51+ completion_tokens , prompt_tokens , total_tokens , cached_tokens = 0 , 0 , 0 , 0
4852 completion_characters , prompt_characters = 0 , 0
4953 for generations in response .generations :
5054 if len (generations ) > 0 and generations [0 ].generation_info :
@@ -53,6 +57,8 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
5357 )
5458 completion_tokens += usage_metadata .get ("candidates_token_count" , 0 )
5559 prompt_tokens += usage_metadata .get ("prompt_token_count" , 0 )
60+ total_tokens += usage_metadata .get ("total_token_count" , 0 )
61+ cached_tokens += usage_metadata .get ("cached_content_token_count" , 0 )
5662 completion_characters += usage_metadata .get (
5763 "candidates_billable_characters" , 0
5864 )
@@ -64,3 +70,5 @@ def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
6470 self .completion_characters += completion_characters
6571 self .completion_tokens += completion_tokens
6672 self .successful_requests += 1
73+ self .total_tokens += total_tokens
74+ self .cached_tokens += cached_tokens
0 commit comments