Skip to content

Commit 87ef261

Browse files
authored
feat: implement regexp_* text functions (#247)
# Add Regular Expression Functions for Text Processing - Add regexp_count: count regex pattern matches - Add regexp_extract: extract specific capture group from match - Add regexp_extract_all: extract all matches of a capture group (uses extract_groups for PySpark compatibility) - Add regexp_instr: find 1-based position of regex match - Add regexp_substr: extract first substring matching pattern - Validate regex patterns using py_validate_regex ## Implementation Details ### PySpark Regex Compatibility All regex functions use Polars/Rust regex syntax and validate patterns using `py_validate_regex` to catch errors early. ### regexp_extract_all with Capture Groups Required a pure rust implementation in the case where pattern or idx are column expressions, which is not supported natively by polars. ### regexp_instr with dynamic group index Required a pure rust implementation for the case in which the group index is a column expression, which is not supported natively by polars. ### Other Functions - **regexp_count**: Uses `str.count_matches(pattern)` directly - **regexp_extract**: Uses `str.extract(pattern, group_idx)` directly - **regexp_substr**: Uses `str.extract(pattern, 0)` for entire match
1 parent 727074b commit 87ef261

File tree

16 files changed

+1887
-287
lines changed

16 files changed

+1887
-287
lines changed

protos/logical_plan/v1/expressions.proto

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,11 @@ message LogicalExpr {
116116
FuzzyRatioExpr fuzzy_ratio = 72;
117117
FuzzyTokenSortRatioExpr fuzzy_token_sort_ratio = 73;
118118
FuzzyTokenSetRatioExpr fuzzy_token_set_ratio = 74;
119+
RegexpCountExpr regexp_count = 75;
120+
RegexpExtractExpr regexp_extract = 76;
121+
RegexpExtractAllExpr regexp_extract_all = 77;
122+
RegexpInstrExpr regexp_instr = 78;
123+
RegexpSubstrExpr regexp_substr = 79;
119124

120125
// JSON expressions
121126
JqExpr jq = 80;
@@ -520,6 +525,34 @@ message RegexpSplitExpr {
520525
int32 limit = 3;
521526
}
522527

528+
message RegexpCountExpr {
529+
LogicalExpr expr = 1;
530+
LogicalExpr pattern = 2;
531+
}
532+
533+
message RegexpExtractExpr {
534+
LogicalExpr expr = 1;
535+
LogicalExpr pattern = 2;
536+
LogicalExpr idx = 3;
537+
}
538+
539+
message RegexpExtractAllExpr {
540+
LogicalExpr expr = 1;
541+
LogicalExpr pattern = 2;
542+
LogicalExpr idx = 3;
543+
}
544+
545+
message RegexpInstrExpr {
546+
LogicalExpr expr = 1;
547+
LogicalExpr pattern = 2;
548+
LogicalExpr idx = 3;
549+
}
550+
551+
message RegexpSubstrExpr {
552+
LogicalExpr expr = 1;
553+
LogicalExpr pattern = 2;
554+
}
555+
523556
message SplitPartExpr {
524557
LogicalExpr expr = 1;
525558
LogicalExpr delimiter = 2;

rust/src/regex/mod.rs

Lines changed: 221 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
1+
use polars::datatypes::DataType;
2+
use polars::prelude::*;
13
use pyo3::exceptions::PyValueError;
24
use pyo3::prelude::*;
5+
use pyo3_polars::derive::polars_expr;
36
use regex::Regex;
7+
use std::collections::hash_map::Entry;
8+
use std::collections::HashMap;
49

510
#[pyfunction]
611
pub fn py_validate_regex(regex: &str) -> PyResult<()> {
@@ -9,3 +14,219 @@ pub fn py_validate_regex(regex: &str) -> PyResult<()> {
914
Err(error) => Err(PyValueError::new_err(error.to_string())),
1015
}
1116
}
17+
18+
/// Get or compile a regex from the cache.
19+
/// Uses the Entry API for efficient single-lookup caching.
20+
fn get_or_compile_regex<'a>(
21+
regex_cache: &'a mut HashMap<String, Regex>,
22+
pattern: &str,
23+
) -> PolarsResult<&'a Regex> {
24+
match regex_cache.entry(pattern.to_string()) {
25+
Entry::Occupied(entry) => Ok(entry.into_mut()),
26+
Entry::Vacant(entry) => {
27+
let regex = Regex::new(pattern).map_err(|e| {
28+
PolarsError::ComputeError(
29+
format!("Invalid regex pattern '{}': {}", pattern, e).into(),
30+
)
31+
})?;
32+
Ok(entry.insert(regex))
33+
}
34+
}
35+
}
36+
37+
#[polars_expr(output_type=Int32)]
38+
fn regexp_instr(inputs: &[Series]) -> PolarsResult<Series> {
39+
let text_series = inputs[0].str()?;
40+
let pattern_series = inputs[1].str()?;
41+
let idx_series = inputs[2].i64()?;
42+
43+
let len = text_series.len();
44+
let mut regex_cache: HashMap<String, Regex> = HashMap::new();
45+
46+
// Handle broadcasting: if a series has length 1, it's a literal that should be broadcast
47+
let pattern_is_literal = pattern_series.len() == 1;
48+
let idx_is_literal = idx_series.len() == 1;
49+
50+
let mut result_vec = Vec::with_capacity(len);
51+
52+
for i in 0..len {
53+
let text_opt = text_series.get(i);
54+
// Use index 0 for literals (they'll be broadcast), otherwise use i
55+
let pattern_opt = pattern_series.get(if pattern_is_literal { 0 } else { i });
56+
let idx_opt = idx_series.get(if idx_is_literal { 0 } else { i });
57+
58+
let value = match (text_opt, pattern_opt, idx_opt) {
59+
(Some(text), Some(pattern), Some(idx)) => {
60+
// Validate index is non-negative
61+
if idx < 0 {
62+
Some(0)
63+
} else {
64+
// Get or compile regex, return error if invalid
65+
let regex = get_or_compile_regex(&mut regex_cache, pattern)?;
66+
67+
// Try to find a match
68+
if let Some(captures) = regex.captures(text) {
69+
let idx_usize = idx as usize;
70+
// idx=0 is whole match, idx=1+ are capture groups
71+
if let Some(matched) = captures.get(idx_usize) {
72+
// Return 1-based position (PySpark compatibility)
73+
Some((matched.start() as i32) + 1)
74+
} else {
75+
// No match for this group
76+
Some(0)
77+
}
78+
} else {
79+
// No match
80+
Some(0)
81+
}
82+
}
83+
}
84+
_ => None, // If any input is null, return null
85+
};
86+
87+
result_vec.push(value);
88+
}
89+
90+
Ok(Int32Chunked::from_iter_options(PlSmallStr::EMPTY, result_vec.into_iter()).into_series())
91+
}
92+
93+
#[polars_expr(output_type_func=extract_all_output_type)]
94+
fn regexp_extract_all(inputs: &[Series]) -> PolarsResult<Series> {
95+
let text_series = inputs[0].str()?;
96+
let pattern_series = inputs[1].str()?;
97+
let idx_series = inputs[2].i64()?;
98+
99+
let len = text_series.len();
100+
let mut regex_cache: HashMap<String, Regex> = HashMap::new();
101+
102+
// Handle broadcasting: if a series has length 1, it's a literal that should be broadcast
103+
let pattern_is_literal = pattern_series.len() == 1;
104+
let idx_is_literal = idx_series.len() == 1;
105+
106+
let mut result_vec = Vec::with_capacity(len);
107+
108+
for i in 0..len {
109+
let text_opt = text_series.get(i);
110+
// Use index 0 for literals (they'll be broadcast), otherwise use i
111+
let pattern_opt = pattern_series.get(if pattern_is_literal { 0 } else { i });
112+
let idx_opt = idx_series.get(if idx_is_literal { 0 } else { i });
113+
114+
let value = match (text_opt, pattern_opt, idx_opt) {
115+
(Some(text), Some(pattern), Some(idx)) => {
116+
// Validate index is non-negative
117+
if idx < 0 {
118+
Some(Series::new_empty(PlSmallStr::EMPTY, &DataType::String))
119+
} else {
120+
// Get or compile regex, return error if invalid
121+
let regex = get_or_compile_regex(&mut regex_cache, pattern)?;
122+
let idx_usize = idx as usize;
123+
let mut matches = Vec::new();
124+
125+
// Find all matches
126+
for captures in regex.captures_iter(text) {
127+
// idx=0 is whole match, idx=1+ are capture groups
128+
if let Some(matched) = captures.get(idx_usize) {
129+
matches.push(matched.as_str());
130+
}
131+
}
132+
133+
// Return as Series
134+
Some(
135+
StringChunked::from_iter_values(PlSmallStr::EMPTY, matches.into_iter())
136+
.into_series(),
137+
)
138+
}
139+
}
140+
_ => None, // If any input is null, return null
141+
};
142+
143+
result_vec.push(value);
144+
}
145+
146+
let list_chunked: ListChunked = result_vec.into_iter().collect();
147+
Ok(list_chunked.into_series())
148+
}
149+
150+
fn extract_all_output_type(input_fields: &[Field]) -> PolarsResult<Field> {
151+
let field = &input_fields[0];
152+
Ok(Field::new(
153+
field.name().clone(),
154+
DataType::List(Box::new(DataType::String)),
155+
))
156+
}
157+
158+
#[cfg(test)]
159+
mod tests {
160+
use super::*;
161+
162+
// Pure Rust tests - these can run with `cargo test`
163+
#[test]
164+
fn test_get_or_compile_regex_valid_pattern() {
165+
let mut cache = HashMap::new();
166+
let pattern = r"\d+";
167+
168+
let result = get_or_compile_regex(&mut cache, pattern);
169+
assert!(result.is_ok());
170+
171+
let regex = result.unwrap();
172+
assert!(regex.is_match("123"));
173+
assert!(!regex.is_match("abc"));
174+
}
175+
176+
#[test]
177+
fn test_get_or_compile_regex_caches() {
178+
let mut cache = HashMap::new();
179+
let pattern = r"[a-z]+";
180+
181+
// First call should compile and cache
182+
let result1 = get_or_compile_regex(&mut cache, pattern);
183+
assert!(result1.is_ok());
184+
assert_eq!(cache.len(), 1);
185+
186+
// Second call should use cache (verify cache size doesn't change)
187+
let result2 = get_or_compile_regex(&mut cache, pattern);
188+
assert!(result2.is_ok());
189+
assert_eq!(cache.len(), 1);
190+
191+
// Different pattern should add to cache
192+
let result3 = get_or_compile_regex(&mut cache, r"\d+");
193+
assert!(result3.is_ok());
194+
assert_eq!(cache.len(), 2);
195+
}
196+
197+
#[test]
198+
fn test_get_or_compile_regex_invalid_pattern() {
199+
let mut cache = HashMap::new();
200+
let pattern = r"[invalid(";
201+
202+
let result = get_or_compile_regex(&mut cache, pattern);
203+
assert!(result.is_err());
204+
205+
// Verify error message format
206+
let err = result.unwrap_err();
207+
let err_msg = err.to_string();
208+
assert!(err_msg.contains("Invalid regex pattern"));
209+
assert!(err_msg.contains("[invalid("));
210+
}
211+
212+
#[test]
213+
fn test_get_or_compile_regex_special_patterns() {
214+
let mut cache = HashMap::new();
215+
216+
// Test email pattern
217+
let email_pattern = r"(\w+)@(\w+)\.(\w+)";
218+
let result = get_or_compile_regex(&mut cache, email_pattern);
219+
assert!(result.is_ok());
220+
assert!(result.unwrap().is_match("[email protected]"));
221+
222+
// Test word boundary pattern
223+
let word_boundary = r"\bword\b";
224+
let result = get_or_compile_regex(&mut cache, word_boundary);
225+
assert!(result.is_ok());
226+
assert!(result.unwrap().is_match("a word here"));
227+
}
228+
229+
// PyO3 tests - these are tested via Python integration tests
230+
// Note: py_validate_regex is tested in tests/_backends/local/functions/test_regexp_functions.py
231+
// because standalone PyO3 tests require Python runtime linking
232+
}

src/fenic/_backends/local/polars_plugins/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616
from fenic._backends.local.polars_plugins.markdown import (
1717
MarkdownExtractor,
1818
)
19+
from fenic._backends.local.polars_plugins.regex import (
20+
Regexp,
21+
)
1922
from fenic._backends.local.polars_plugins.tokenization import (
2023
Tokenization,
2124
count_tokens,
@@ -37,4 +40,5 @@
3740
"Jinja",
3841
"Dtypes",
3942
"Fuzz",
43+
"Regexp",
4044
]
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from pathlib import Path
2+
3+
import polars as pl
4+
from polars._typing import IntoExpr
5+
from polars.plugins import register_plugin_function
6+
7+
PLUGIN_PATH = Path(__file__).parents[3]
8+
9+
10+
@pl.api.register_expr_namespace("regexp")
11+
class Regexp:
12+
"""Namespace for regular expression operations on Polars expressions."""
13+
14+
def __init__(self, expr: pl.Expr) -> None:
15+
"""Initialize a Regexp Namespace with a Polars expression.
16+
17+
Args:
18+
expr: A Polars expression containing the text data for regex operations.
19+
"""
20+
self.expr = expr
21+
22+
def instr(self, pattern: IntoExpr, idx: IntoExpr) -> pl.Expr:
23+
"""Find the position of a regex match in a string.
24+
25+
Args:
26+
pattern: Regular expression pattern to search for.
27+
idx: Capture group index (0 for whole match, 1+ for capture groups).
28+
29+
Returns:
30+
1-based position of the match, or 0 if no match found, or null if input is null.
31+
"""
32+
return register_plugin_function(
33+
plugin_path=PLUGIN_PATH,
34+
function_name="regexp_instr",
35+
args=[self.expr, pattern, idx],
36+
is_elementwise=True,
37+
)
38+
39+
def extract_all(self, pattern: IntoExpr, idx: IntoExpr) -> pl.Expr:
40+
"""Extract all matches of a regex pattern, optionally from a specific capture group.
41+
42+
Args:
43+
pattern: Regular expression pattern to search for.
44+
idx: Capture group index (0 for whole match, 1+ for capture groups).
45+
46+
Returns:
47+
List of all matches, or empty list if no matches, or null if input is null.
48+
"""
49+
return register_plugin_function(
50+
plugin_path=PLUGIN_PATH,
51+
function_name="regexp_extract_all",
52+
args=[self.expr, pattern, idx],
53+
is_elementwise=True,
54+
)
55+

0 commit comments

Comments
 (0)