Skip to content

Commit 7c6e8a1

Browse files
authored
Renew cloud token for parallel upload (#667)
1 parent bc9616a commit 7c6e8a1

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

src/snowflake/connector/file_transfer_agent.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -442,6 +442,7 @@ def _upload_files_in_parallel(self, file_metas: List["SnowflakeFileMeta"]) -> No
442442
qtask.put(meta)
443443
retry_round = 0
444444
presigned_url_handled = False
445+
445446
while not qtask.empty():
446447
len_file_metas = qtask.qsize()
447448
thread_number = min(len_file_metas, self._parallel)
@@ -460,6 +461,8 @@ def _upload_files_in_parallel(self, file_metas: List["SnowflakeFileMeta"]) -> No
460461
pool.shutdown()
461462

462463
# update presigned url before starting next round
464+
# every time we see renew_token triggers, do one renew only
465+
renew_token_handled = False
463466
while not triggers.empty():
464467
result_meta = triggers.get()
465468
if (
@@ -468,7 +471,19 @@ def _upload_files_in_parallel(self, file_metas: List["SnowflakeFileMeta"]) -> No
468471
):
469472
self._update_file_metas_with_presigned_url()
470473
presigned_url_handled = True
471-
474+
if (
475+
not renew_token_handled
476+
and result_meta.result_status == ResultStatus.RENEW_TOKEN
477+
):
478+
logger.debug("renewing expired token")
479+
ret = self._cursor._execute_helper(
480+
self._command
481+
) # rerun the command to get the credential
482+
self._stage_info = ret["data"]["stageInfo"]
483+
484+
for meta in qtask:
485+
meta.client_meta.stage_info = self._stage_info
486+
renew_token_handled = True
472487
retry_round += 1
473488

474489
@staticmethod
@@ -505,10 +520,7 @@ def _upload_files_in_queue_thread(
505520
result_meta = SnowflakeFileTransferAgent.upload_one_file(meta)
506521
if result_meta.result_status == ResultStatus.RENEW_TOKEN:
507522
# need to retry this upload. the meta will be added back once renew is done
508-
thread_client = cln_meta.storage_client.create_client(
509-
cln_meta.stage_info,
510-
use_accelerate_endpoint=cln_meta.use_accelerate_endpoint,
511-
)
523+
triggers.put(result_meta)
512524
qtask.put(meta)
513525
elif result_meta.result_status == ResultStatus.RENEW_PRESIGNED_URL:
514526
# now stop this round - by adding one item to triggers

0 commit comments

Comments
 (0)