Skip to content
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions crates/lib/src/admin/token_util.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::{
config::Config,
error::KoraError,
state::{get_request_signer_with_signer_key, get_signer_pool},
token::token::TokenType,
Expand Down Expand Up @@ -94,7 +95,8 @@ pub async fn initialize_atas_with_chunk_size(
for address in addresses_to_initialize_atas {
println!("Initializing ATAs for address: {address}");

let atas_to_create = find_missing_atas(rpc_client, address).await?;
let config = get_config()?;
let atas_to_create = find_missing_atas(&config, rpc_client, address).await?;

if atas_to_create.is_empty() {
println!("✓ All required ATAs already exist for address: {address}");
Expand Down Expand Up @@ -245,11 +247,10 @@ async fn create_atas_for_signer(
}

pub async fn find_missing_atas(
config: &Config,
rpc_client: &RpcClient,
payment_address: &Pubkey,
) -> Result<Vec<ATAToCreate>, KoraError> {
let config = get_config()?;

// Parse all allowed SPL paid token mints
let mut token_mints = Vec::new();
for token_str in &config.validation.allowed_spl_paid_tokens {
Expand All @@ -273,16 +274,17 @@ pub async fn find_missing_atas(
for mint in &token_mints {
let ata = get_associated_token_address(payment_address, mint);

match CacheUtil::get_account(rpc_client, &ata, false).await {
match CacheUtil::get_account(&config, rpc_client, &ata, false).await {
Ok(_) => {
println!("✓ ATA already exists for token {mint}: {ata}");
}
Err(_) => {
// Fetch mint account to determine if it's SPL or Token2022
let mint_account =
CacheUtil::get_account(rpc_client, mint, false).await.map_err(|e| {
KoraError::RpcError(format!("Failed to fetch mint account for {mint}: {e}"))
})?;
let mint_account = CacheUtil::get_account(&config, rpc_client, mint, false)
.await
.map_err(|e| {
KoraError::RpcError(format!("Failed to fetch mint account for {mint}: {e}"))
})?;

let token_program = TokenType::get_token_program_from_owner(&mint_account.owner)?;

Expand Down Expand Up @@ -328,7 +330,8 @@ mod tests {
let rpc_client = create_mock_rpc_client_account_not_found();
let payment_address = Pubkey::new_unique();

let result = find_missing_atas(&rpc_client, &payment_address).await.unwrap();
let config = get_config().unwrap();
let result = find_missing_atas(&config, &rpc_client, &payment_address).await.unwrap();

assert!(result.is_empty(), "Should return empty vec when no SPL tokens configured");
}
Expand Down Expand Up @@ -366,9 +369,10 @@ mod tests {
cache_ctx
.expect()
.times(3)
.returning(move |_, _, _| responses_clone.lock().unwrap().pop_front().unwrap());
.returning(move |_, _, _, _| responses_clone.lock().unwrap().pop_front().unwrap());

let result = find_missing_atas(&rpc_client, &payment_address).await;
let config = get_config().unwrap();
let result = find_missing_atas(&config, &rpc_client, &payment_address).await;

assert!(result.is_ok(), "Should handle SPL tokens with proper mocking");
let atas = result.unwrap();
Expand Down
29 changes: 15 additions & 14 deletions crates/lib/src/cache.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ impl CacheUtil {
pub async fn init() -> Result<(), KoraError> {
let config = get_config()?;

let pool = if CacheUtil::is_cache_enabled() {
let pool = if CacheUtil::is_cache_enabled(&config) {
let redis_url = config.kora.cache.url.as_ref().ok_or(KoraError::ConfigError)?;

let cfg = deadpool_redis::Config::from_url(redis_url);
Expand Down Expand Up @@ -179,23 +179,19 @@ impl CacheUtil {
}

/// Check if cache is enabled and available
fn is_cache_enabled() -> bool {
match get_config() {
Ok(config) => config.kora.cache.enabled && config.kora.cache.url.is_some(),
Err(_) => false,
}
fn is_cache_enabled(config: &crate::config::Config) -> bool {
config.kora.cache.enabled && config.kora.cache.url.is_some()
}

/// Get account from cache with optional force refresh
pub async fn get_account(
config: &crate::config::Config,
rpc_client: &RpcClient,
pubkey: &Pubkey,
force_refresh: bool,
) -> Result<Account, KoraError> {
let config = get_config()?;

// If cache is disabled or force refresh is requested, go directly to RPC
if !CacheUtil::is_cache_enabled() {
if !CacheUtil::is_cache_enabled(config) {
return Self::get_account_from_rpc(rpc_client, pubkey).await;
}

Expand Down Expand Up @@ -264,7 +260,8 @@ mod tests {
async fn test_is_cache_enabled_disabled() {
let _m = ConfigMockBuilder::new().with_cache_enabled(false).build_and_setup();

assert!(!CacheUtil::is_cache_enabled());
let config = get_config().unwrap();
assert!(!CacheUtil::is_cache_enabled(&config));
}

#[tokio::test]
Expand All @@ -275,7 +272,8 @@ mod tests {
.build_and_setup();

// Without URL, cache should be disabled
assert!(!CacheUtil::is_cache_enabled());
let config = get_config().unwrap();
assert!(!CacheUtil::is_cache_enabled(&config));
}

#[tokio::test]
Expand All @@ -286,7 +284,8 @@ mod tests {
.build_and_setup();

// Give time for config to be set up
assert!(CacheUtil::is_cache_enabled());
let config = get_config().unwrap();
assert!(CacheUtil::is_cache_enabled(&config));
}

#[tokio::test]
Expand Down Expand Up @@ -336,7 +335,8 @@ mod tests {

let rpc_client = RpcMockBuilder::new().with_account_info(&expected_account).build();

let result = CacheUtil::get_account(&rpc_client, &pubkey, false).await;
let config = get_config().unwrap();
let result = CacheUtil::get_account(&config, &rpc_client, &pubkey, false).await;

assert!(result.is_ok());
let account = result.unwrap();
Expand All @@ -355,7 +355,8 @@ mod tests {
let rpc_client = RpcMockBuilder::new().with_account_info(&expected_account).build();

// force_refresh = true should always go to RPC
let result = CacheUtil::get_account(&rpc_client, &pubkey, true).await;
let config = get_config().unwrap();
let result = CacheUtil::get_account(&config, &rpc_client, &pubkey, true).await;

assert!(result.is_ok());
let account = result.unwrap();
Expand Down
Loading
Loading