Skip to content

Commit bbac056

Browse files
committed
Fix a few bugs
1 parent 2ef65a8 commit bbac056

File tree

6 files changed

+93
-25
lines changed

6 files changed

+93
-25
lines changed

docs/README.md

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ The table below includes the information about all SQL functions exposed by Infe
88
| 2 | `infera_unload_model(name VARCHAR)` | `BOOLEAN` | Unloads a model, freeing its associated resources. Returns `true` on success. |
99
| 3 | `infera_set_autoload_dir(path VARCHAR)` | `VARCHAR (JSON)` | Scans a directory for `.onnx` files, loads them automatically, and returns a JSON report of loaded models and any errors. |
1010
| 4 | `infera_get_loaded_models()` | `VARCHAR (JSON)` | Returns a JSON array containing the names of all currently loaded models. |
11-
| 5 | `infera_get_model_info(name VARCHAR)` | `VARCHAR (JSON)` | Returns a JSON object with metadata about a specific loaded model, including its name, input/output shapes, and status. |
11+
| 5 | `infera_get_model_info(name VARCHAR)` | `VARCHAR (JSON)` | Returns a JSON object with metadata about a specific loaded model (name, input/output shapes). If the model is not loaded, this function raises an error. |
1212
| 6 | `infera_predict(name VARCHAR, features... FLOAT)` | `FLOAT` | Performs inference on a batch of data, returning a single float value for each input row. |
1313
| 7 | `infera_predict_multi(name VARCHAR, features... FLOAT)` | `VARCHAR (JSON)` | Performs inference and returns all outputs as a JSON-encoded array. This is useful for models that produce multiple predictions per sample. |
1414
| 8 | `infera_predict_multi_list(name VARCHAR, features... FLOAT)` | `LIST[FLOAT]` | Performs inference and returns all outputs as a typed list of floats. Useful for multi-output models without JSON parsing. |
@@ -45,7 +45,7 @@ select infera_is_model_loaded('local_model');
4545
select infera_get_loaded_models();
4646
-- Output: ["local_model", "remote_model"]
4747

48-
-- Get information about a specific model
48+
-- Get information about a specific model (throws an error if the model is not loaded)
4949
select infera_get_model_info('local_model');
5050
-- Output: {"name":"local_model","input_shape":[-1,3],"output_shape":[-1,1],"loaded":true}
5151

@@ -97,7 +97,7 @@ select infera_get_loaded_models();
9797
select infera_is_model_loaded('squeezenet');
9898
-- Output: true or false
9999

100-
-- Get detailed metadata for a specific model
100+
-- Get detailed metadata for a specific model (errors if the model is not loaded)
101101
select infera_get_model_info('squeezenet');
102102
/* Output:
103103
{
@@ -121,14 +121,14 @@ select infera_set_autoload_dir('path/to/your/models');
121121
select infera_clear_cache();
122122
-- Output: true
123123

124-
-- Get cache statistics
124+
-- Get cache statistics (field names as returned by the function)
125125
select infera_get_cache_info();
126126
/* Output:
127127
{
128-
"path": "/path/to/cache",
129-
"size_bytes": 204800,
128+
"cache_dir": "/path/to/cache",
129+
"total_size_bytes": 204800,
130130
"file_count": 10,
131-
"size_limit": 10485760
131+
"size_limit_bytes": 10485760
132132
}
133133
*/
134134
```

infera/bindings/infera_extension.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,7 @@ static void Predict(DataChunk &args, ExpressionState &state, Vector &result) {
227227

228228
infera::InferaInferenceResult res = infera::infera_predict(model_name_str.c_str(), features.data(), batch_size, feature_count);
229229
if (res.status != 0) {
230+
infera::infera_free_result(res);
230231
throw InvalidInputException("Inference failed for model '" + model_name_str + "': " + GetInferaError());
231232
}
232233
if (res.rows != batch_size || res.cols != 1) {
@@ -434,16 +435,22 @@ static void GetModelInfo(DataChunk &args, ExpressionState &state, Vector &result
434435
throw InvalidInputException("Model name cannot be NULL");
435436
}
436437
std::string model_name_str = model_name.ToString();
437-
char *json_meta = infera::infera_get_model_info(model_name_str.c_str());
438+
char *json_meta_c = infera::infera_get_model_info(model_name_str.c_str());
439+
440+
// Convert to std::string and free the C string immediately to avoid leaks
441+
std::string json_meta = json_meta_c ? std::string(json_meta_c) : std::string();
442+
if (json_meta_c) {
443+
infera::infera_free(json_meta_c);
444+
}
438445

439-
if (json_meta == nullptr) {
440-
throw InvalidInputException("Failed to get info for model '" + model_name_str + "': " + GetInferaError());
446+
// If Rust returned an error JSON, surface it as a DuckDB error per contract/tests
447+
if (json_meta.empty() || json_meta.find("\"error\"") != std::string::npos) {
448+
throw InvalidInputException("Failed to get info for model '" + model_name_str + "'");
441449
}
442450

443451
result.SetVectorType(VectorType::CONSTANT_VECTOR);
444452
ConstantVector::GetData<string_t>(result)[0] = StringVector::AddString(result, json_meta);
445453
ConstantVector::SetNull(result, false);
446-
infera::infera_free(json_meta);
447454
}
448455

449456
/**

infera/src/ffi_utils.rs

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,45 @@ pub unsafe extern "C" fn infera_free(ptr: *mut c_char) {
6969
/// on the same result will lead to undefined behavior.
7070
#[no_mangle]
7171
pub unsafe extern "C" fn infera_free_result(res: InferaInferenceResult) {
72-
if !res.data.is_null() && res.len > 0 {
73-
let _ = Vec::from_raw_parts(res.data, res.len, res.len);
72+
if !res.data.is_null() {
73+
// SAFETY: `res.data` was allocated from a Box<[f32]> via `into_raw` with length `res.len`.
74+
// Reconstruct the slice pointer and drop it to free the allocation correctly.
75+
let slice_ptr: *mut [f32] = std::ptr::slice_from_raw_parts_mut(res.data, res.len);
76+
let _ = Box::from_raw(slice_ptr);
77+
}
78+
}
79+
80+
#[cfg(test)]
81+
mod tests {
82+
use super::*;
83+
84+
#[test]
85+
fn test_infera_free_result_zero_len_non_null() {
86+
// Allocate an empty slice on heap and convert to raw
87+
let empty: Box<[f32]> = Vec::<f32>::new().into_boxed_slice();
88+
let ptr = Box::into_raw(empty) as *mut f32;
89+
let res = InferaInferenceResult {
90+
data: ptr,
91+
len: 0,
92+
rows: 0,
93+
cols: 0,
94+
status: 0,
95+
};
96+
unsafe { infera_free_result(res) }; // should not panic or leak
97+
}
98+
99+
#[test]
100+
fn test_infera_free_result_non_empty() {
101+
let data: Vec<f32> = vec![1.0, 2.0, 3.0];
102+
let len = data.len();
103+
let ptr = Box::into_raw(data.into_boxed_slice()) as *mut f32;
104+
let res = InferaInferenceResult {
105+
data: ptr,
106+
len,
107+
rows: 1,
108+
cols: len,
109+
status: 0,
110+
};
111+
unsafe { infera_free_result(res) }; // should free without UB
74112
}
75113
}

infera/src/lib.rs

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -341,10 +341,7 @@ pub extern "C" fn infera_get_cache_info() -> *mut c_char {
341341
}
342342
}
343343

344-
let size_limit = env::var("INFERA_CACHE_SIZE_LIMIT")
345-
.ok()
346-
.and_then(|s| s.parse::<u64>().ok())
347-
.unwrap_or(1024 * 1024 * 1024);
344+
let size_limit = config::CONFIG.cache_size_limit;
348345

349346
Ok(json!({
350347
"cache_dir": cache_dir_str,
@@ -625,4 +622,30 @@ mod tests {
625622
infera_unload_model(model_name.as_ptr());
626623
}
627624
}
625+
626+
#[test]
627+
fn test_infera_get_model_info_nonexistent_returns_error_json() {
628+
let name = CString::new("__missing_model__").unwrap();
629+
let info_ptr = unsafe { infera_get_model_info(name.as_ptr()) };
630+
let info_json = unsafe { CStr::from_ptr(info_ptr).to_str().unwrap() };
631+
let value: serde_json::Value = serde_json::from_str(info_json).unwrap();
632+
assert!(
633+
value.get("error").is_some(),
634+
"expected error field in JSON: {}",
635+
info_json
636+
);
637+
unsafe { infera_free(info_ptr) };
638+
}
639+
640+
#[test]
641+
fn test_infera_get_cache_info_includes_configured_limit() {
642+
let cache_info_ptr = infera_get_cache_info();
643+
let cache_info_json = unsafe { CStr::from_ptr(cache_info_ptr).to_str().unwrap() };
644+
let value: serde_json::Value = serde_json::from_str(cache_info_json).unwrap();
645+
let size_limit = value["size_limit_bytes"]
646+
.as_u64()
647+
.expect("size_limit_bytes should be u64");
648+
assert_eq!(size_limit, crate::config::CONFIG.cache_size_limit);
649+
unsafe { infera_free(cache_info_ptr) };
650+
}
628651
}

test/sql/test_edge_cases_more.test

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ select infera_unload_model('linear')
3131
----
3232
true
3333

34-
# model info after unload should contain model not found message
35-
query I
36-
select position('Model not found' in infera_get_model_info('linear')) > 0
34+
# model info after unload should error
35+
statement error
36+
select infera_get_model_info('linear')
3737
----
38-
true
38+
Failed to get info for model 'linear'

test/sql/test_integration_and_errors.test

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,11 @@ load 'build/release/extension/infera/infera.duckdb_extension'
1212

1313
# test 1: error handling for non-existent models
1414

15-
# infera_get_model_info for a missing model returns a non-null json containing an error field
16-
query I
17-
select infera_get_model_info('nonexistent_model') is null
15+
# infera_get_model_info for a missing model should error
16+
statement error
17+
select infera_get_model_info('nonexistent_model')
1818
----
19-
0
19+
Failed to get info for model 'nonexistent_model'
2020

2121
# unloading a non-existent model is benign (idempotent) and returns false
2222
statement ok

0 commit comments

Comments
 (0)