Skip to content
218 changes: 217 additions & 1 deletion crates/postgresql-cst-parser/src/tree_sitter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ impl Tree {
}
}

pub fn root_node(&self) -> Node {
pub fn root_node(&self) -> Node<'_> {
Node {
input: &self.src,
range_map: Rc::clone(&self.range_map),
Expand All @@ -63,6 +63,14 @@ pub struct Node<'a> {
pub node_or_token: NodeOrToken<'a>,
}

impl<'a> PartialEq for Node<'a> {
fn eq(&self, other: &Self) -> bool {
self.node_or_token == other.node_or_token
}
}

impl<'a> Eq for Node<'a> {}

#[derive(Debug, Clone)]
pub struct TreeCursor<'a> {
pub input: &'a str,
Expand Down Expand Up @@ -98,6 +106,28 @@ impl std::fmt::Display for Range {
}
}

impl Range {
pub fn extended_by(&self, other: &Self) -> Self {
Range {
start_byte: self.start_byte.min(other.start_byte),
end_byte: self.end_byte.max(other.end_byte),

start_position: Point {
row: self.start_position.row.min(other.start_position.row),
column: self.start_position.column.min(other.start_position.column),
},
end_position: Point {
row: self.end_position.row.max(other.end_position.row),
column: self.end_position.column.max(other.end_position.column),
},
}
}

pub fn is_adjacent(&self, other: &Self) -> bool {
self.end_byte == other.start_byte || self.start_byte == other.end_byte
}
}

impl<'a> Node<'a> {
pub fn walk(&self) -> TreeCursor<'a> {
TreeCursor {
Expand Down Expand Up @@ -144,6 +174,48 @@ impl<'a> Node<'a> {
}
}

pub fn children(&self) -> Vec<Node<'a>> {
if let Some(node) = self.node_or_token.as_node() {
node.children_with_tokens()
.map(|node| Node {
input: self.input,
range_map: Rc::clone(&self.range_map),
node_or_token: node,
})
.collect()
} else {
vec![]
}
}

/// Returns the first child element of this node.
/// this is not tree-sitter's API
pub fn first_child(&self) -> Option<Node<'a>> {
if let Some(node) = self.node_or_token.as_node() {
node.first_child_or_token().map(|child| Node {
input: self.input,
range_map: Rc::clone(&self.range_map),
node_or_token: child,
})
} else {
None
}
}

/// Returns the last child element of this node.
/// this is not tree-sitter's API
pub fn last_child(&self) -> Option<Node<'a>> {
if let Some(node) = self.node_or_token.as_node() {
node.last_child_or_token().map(|child| Node {
input: self.input,
range_map: Rc::clone(&self.range_map),
node_or_token: child,
})
} else {
None
}
}

pub fn next_sibling(&self) -> Option<Node<'a>> {
self.node_or_token
.next_sibling_or_token()
Expand All @@ -154,6 +226,16 @@ impl<'a> Node<'a> {
})
}

pub fn prev_sibling(&self) -> Option<Node<'a>> {
self.node_or_token
.prev_sibling_or_token()
.map(|sibling| Node {
input: self.input,
range_map: Rc::clone(&self.range_map),
node_or_token: sibling,
})
}

pub fn parent(&self) -> Option<Node<'a>> {
self.node_or_token.parent().map(|parent| Node {
input: self.input,
Expand All @@ -165,6 +247,82 @@ impl<'a> Node<'a> {
pub fn is_comment(&self) -> bool {
matches!(self.kind(), SyntaxKind::C_COMMENT | SyntaxKind::SQL_COMMENT)
}

/// Returns the rightmost token in the subtree of this node.
/// this is not tree-sitter's API
pub fn last_node(&self) -> Option<Node<'a>> {
match &self.node_or_token {
NodeOrToken::Node(node) => node.last_token().map(|token| Node {
input: self.input,
range_map: Rc::clone(&self.range_map),
node_or_token: NodeOrToken::Token(token),
}),
NodeOrToken::Token(token) => Some(Node {
input: self.input,
range_map: Rc::clone(&self.range_map),
node_or_token: NodeOrToken::Token(token),
}),
}
}

/// Returns the next token in the tree.
/// This is not necessarily a direct sibling of this node/token,
/// but will always be further right in the tree.
/// this is not tree-sitter's API
pub fn next_token(&self) -> Option<Node<'a>> {
match &self.node_or_token {
NodeOrToken::Token(token) => token.next_token().map(|next_token| Node {
input: self.input,
range_map: Rc::clone(&self.range_map),
node_or_token: NodeOrToken::Token(next_token),
}),
NodeOrToken::Node(node) => {
// For a node, find its last token and then get the next token
node.last_token()
.and_then(|last_token| last_token.next_token())
.map(|next_token| Node {
input: self.input,
range_map: Rc::clone(&self.range_map),
node_or_token: NodeOrToken::Token(next_token),
})
}
}
}

/// Returns an iterator over all descendant nodes (including tokens)
/// this is not tree-sitter's API
pub fn descendants(&self) -> impl Iterator<Item = Node<'a>> + '_ {
struct Descendants<'a> {
iter: Box<dyn Iterator<Item = Node<'a>> + 'a>,
}

impl<'a> Iterator for Descendants<'a> {
type Item = Node<'a>;

fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
}

if let Some(node) = self.node_or_token.as_node() {
let input = self.input;
let range_map = Rc::clone(&self.range_map);
Descendants {
iter: Box::new(
node.descendants_with_tokens()
.map(move |node_or_token| Node {
input,
range_map: Rc::clone(&range_map),
node_or_token,
}),
),
}
} else {
Descendants {
iter: Box::new(std::iter::empty()),
}
}
}
}

impl<'a> From<Node<'a>> for TreeCursor<'a> {
Expand Down Expand Up @@ -214,6 +372,15 @@ impl<'a> TreeCursor<'a> {
}
}

pub fn goto_prev_sibling(&mut self) -> bool {
if let Some(sibling) = self.node_or_token.prev_sibling_or_token() {
self.node_or_token = sibling;
true
} else {
false
}
}

pub fn is_comment(&self) -> bool {
matches!(
self.node_or_token.kind(),
Expand Down Expand Up @@ -462,4 +629,53 @@ from

assert_eq!(stmt_count, 2);
}

#[test]
fn test_last_node_returns_rightmost_node() {
let src = "SELECT u.*, (v).id, name;";
let tree = parse(src).unwrap();
let root = tree.root_node();

let target_list = root
.descendants()
.find(|node| node.kind() == SyntaxKind::target_list)
.expect("should find target_list");

// last node of the target_list is returned
let last_node = target_list.last_node().expect("should have last node");
assert_eq!(last_node.text(), "name");

let target_els = target_list
.children()
.into_iter()
.filter(|node| node.kind() == SyntaxKind::target_el)
.collect::<Vec<_>>();

let mut last_nodes = target_els
.iter()
.map(|node| node.last_node().expect("should have last node"));

// last node of each target_el is returned
assert_eq!(last_nodes.next().unwrap().text(), "*");
assert_eq!(last_nodes.next().unwrap().text(), "id");
assert_eq!(last_nodes.next().unwrap().text(), "name");
assert!(last_nodes.next().is_none());
}

#[test]
fn test_next_token() {
let src = "SELECT tbl.name as n from TBL;";
let tree = parse(src).unwrap();
let root = tree.root_node();

let name = root
.descendants()
.find(|node| node.kind() == SyntaxKind::NAME_P)
.expect("should find NAME_P");

// Even if not a direct sibling or not belonging to the same subtree, the next_token can retrieve the next token.
let next_token = name.next_token().expect("should have next token");
assert_eq!(next_token.text(), "as");
assert_eq!(next_token.kind(), SyntaxKind::AS);
}
}
Loading