Skip to content

Commit fe6f3ba

Browse files
authored
feat: add fixes to linting errors (#4800)
1 parent 34b05d6 commit fe6f3ba

File tree

6 files changed

+421
-126
lines changed

6 files changed

+421
-126
lines changed

sqlmesh/core/linter/definition.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from collections.abc import Iterator, Iterable, Set, Mapping, Callable
99
from functools import reduce
1010
from sqlmesh.core.model import Model
11-
from sqlmesh.core.linter.rule import Rule, RuleViolation, Range
11+
from sqlmesh.core.linter.rule import Rule, RuleViolation, Range, Fix
1212
from sqlmesh.core.console import LinterConsole, get_console
1313

1414
if t.TYPE_CHECKING:
@@ -75,6 +75,7 @@ def lint_model(
7575
model=model,
7676
violation_type="error",
7777
violation_range=violation.violation_range,
78+
fixes=violation.fixes,
7879
)
7980
for violation in error_violations
8081
] + [
@@ -84,6 +85,7 @@ def lint_model(
8485
model=model,
8586
violation_type="warning",
8687
violation_range=violation.violation_range,
88+
fixes=violation.fixes,
8789
)
8890
for violation in warn_violations
8991
]
@@ -152,7 +154,8 @@ def __init__(
152154
model: Model,
153155
violation_type: t.Literal["error", "warning"],
154156
violation_range: t.Optional[Range] = None,
157+
fixes: t.Optional[t.List[Fix]] = None,
155158
) -> None:
156-
super().__init__(rule, violation_msg, violation_range)
159+
super().__init__(rule, violation_msg, violation_range, fixes)
157160
self.model = model
158161
self.violation_type = violation_type

sqlmesh/core/linter/rule.py

Lines changed: 27 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,22 @@ class Range:
3939
end: Position
4040

4141

42+
@dataclass(frozen=True)
43+
class TextEdit:
44+
"""A text edit to apply to a file."""
45+
46+
range: Range
47+
new_text: str
48+
49+
50+
@dataclass(frozen=True)
51+
class Fix:
52+
"""A fix that can be applied to resolve a rule violation."""
53+
54+
title: str
55+
edits: t.List[TextEdit]
56+
57+
4258
class _Rule(abc.ABCMeta):
4359
def __new__(cls: Type[_Rule], clsname: str, bases: t.Tuple, attrs: t.Dict) -> _Rule:
4460
attrs["name"] = clsname.lower()
@@ -66,10 +82,14 @@ def violation(
6682
self,
6783
violation_msg: t.Optional[str] = None,
6884
violation_range: t.Optional[Range] = None,
85+
fixes: t.Optional[t.List[Fix]] = None,
6986
) -> RuleViolation:
7087
"""Create a RuleViolation instance for this rule"""
7188
return RuleViolation(
72-
rule=self, violation_msg=violation_msg or self.summary, violation_range=violation_range
89+
rule=self,
90+
violation_msg=violation_msg or self.summary,
91+
violation_range=violation_range,
92+
fixes=fixes,
7393
)
7494

7595
def get_definition_location(self) -> RuleLocation:
@@ -103,11 +123,16 @@ def __repr__(self) -> str:
103123

104124
class RuleViolation:
105125
def __init__(
106-
self, rule: Rule, violation_msg: str, violation_range: t.Optional[Range] = None
126+
self,
127+
rule: Rule,
128+
violation_msg: str,
129+
violation_range: t.Optional[Range] = None,
130+
fixes: t.Optional[t.List[Fix]] = None,
107131
) -> None:
108132
self.rule = rule
109133
self.violation_msg = violation_msg
110134
self.violation_range = violation_range
135+
self.fixes = fixes or []
111136

112137
def __repr__(self) -> str:
113138
return f"{self.rule.name}: {self.violation_msg}"

sqlmesh/core/linter/rules/builtin.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from sqlglot.helper import subclasses
99

1010
from sqlmesh.core.linter.helpers import TokenPositionDetails
11-
from sqlmesh.core.linter.rule import Rule, RuleViolation, Range
11+
from sqlmesh.core.linter.rule import Rule, RuleViolation, Range, Fix, TextEdit
1212
from sqlmesh.core.linter.definition import RuleSet
1313
from sqlmesh.core.model import Model, SqlModel
1414

@@ -22,7 +22,8 @@ def check_model(self, model: Model) -> t.Optional[RuleViolation]:
2222
return None
2323
if model.query.is_star:
2424
violation_range = self._get_range(model)
25-
return self.violation(violation_range=violation_range)
25+
fixes = self._create_fixes(model, violation_range)
26+
return self.violation(violation_range=violation_range, fixes=fixes)
2627
return None
2728

2829
def _get_range(self, model: SqlModel) -> t.Optional[Range]:
@@ -37,6 +38,28 @@ def _get_range(self, model: SqlModel) -> t.Optional[Range]:
3738

3839
return None
3940

41+
def _create_fixes(
42+
self, model: SqlModel, violation_range: t.Optional[Range]
43+
) -> t.Optional[t.List[Fix]]:
44+
"""Create fixes for the SELECT * violation."""
45+
if not violation_range:
46+
return None
47+
columns = model.columns_to_types
48+
if not columns:
49+
return None
50+
new_text = ", ".join(columns.keys())
51+
return [
52+
Fix(
53+
title="Replace SELECT * with explicit column list",
54+
edits=[
55+
TextEdit(
56+
range=violation_range,
57+
new_text=new_text,
58+
)
59+
],
60+
)
61+
]
62+
4063

4164
class InvalidSelectStarExpansion(Rule):
4265
def check_model(self, model: Model) -> t.Optional[RuleViolation]:

sqlmesh/lsp/context.py

Lines changed: 158 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from dataclasses import dataclass
22
from pathlib import Path
3+
import uuid
34
from sqlmesh.core.context import Context
45
import typing as t
56

@@ -8,6 +9,7 @@
89
from sqlmesh.lsp.custom import ModelForRendering
910
from sqlmesh.lsp.custom import AllModelsResponse, RenderModelEntry
1011
from sqlmesh.lsp.uri import URI
12+
from lsprotocol import types
1113

1214

1315
@dataclass
@@ -33,8 +35,14 @@ class LSPContext:
3335
map: t.Dict[Path, t.Union[ModelTarget, AuditTarget]]
3436
_render_cache: t.Dict[Path, t.List[RenderModelEntry]]
3537
_lint_cache: t.Dict[Path, t.List[AnnotatedRuleViolation]]
38+
_version_id: str
39+
"""
40+
This is a version ID for the context. It is used to track changes to the context. It can be used to
41+
return a version number to the LSP client.
42+
"""
3643

3744
def __init__(self, context: Context) -> None:
45+
self._version_id = str(uuid.uuid4())
3846
self.context = context
3947
self._render_cache = {}
4048
self._lint_cache = {}
@@ -62,6 +70,11 @@ def __init__(self, context: Context) -> None:
6270
**audit_map,
6371
}
6472

73+
@property
74+
def version_id(self) -> str:
75+
"""Get the version ID for the context."""
76+
return self._version_id
77+
6578
def render_model(self, uri: URI) -> t.List[RenderModelEntry]:
6679
"""Get rendered models for a file, using cache when available.
6780
@@ -150,6 +163,86 @@ def lint_model(self, uri: URI) -> t.List[AnnotatedRuleViolation]:
150163
self._lint_cache[path] = diagnostics
151164
return diagnostics
152165

166+
def get_code_actions(
167+
self, uri: URI, params: types.CodeActionParams
168+
) -> t.Optional[t.List[t.Union[types.Command, types.CodeAction]]]:
169+
"""Get code actions for a file."""
170+
171+
# Get the violations (which contain the fixes)
172+
violations = self.lint_model(uri)
173+
174+
# Convert violations to a map for quick lookup
175+
# Use a hashable representation of Range as the key
176+
violation_map: t.Dict[
177+
t.Tuple[str, t.Tuple[int, int, int, int]], AnnotatedRuleViolation
178+
] = {}
179+
for violation in violations:
180+
if violation.violation_range:
181+
lsp_diagnostic = self.diagnostic_to_lsp_diagnostic(violation)
182+
if lsp_diagnostic:
183+
# Create a hashable key from the diagnostic message and range
184+
key = (
185+
lsp_diagnostic.message,
186+
(
187+
lsp_diagnostic.range.start.line,
188+
lsp_diagnostic.range.start.character,
189+
lsp_diagnostic.range.end.line,
190+
lsp_diagnostic.range.end.character,
191+
),
192+
)
193+
violation_map[key] = violation
194+
195+
# Get diagnostics in the requested range
196+
diagnostics = params.context.diagnostics if params.context else []
197+
198+
code_actions: t.List[t.Union[types.Command, types.CodeAction]] = []
199+
200+
for diagnostic in diagnostics:
201+
# Find the corresponding violation
202+
key = (
203+
diagnostic.message,
204+
(
205+
diagnostic.range.start.line,
206+
diagnostic.range.start.character,
207+
diagnostic.range.end.line,
208+
diagnostic.range.end.character,
209+
),
210+
)
211+
found_violation = violation_map.get(key)
212+
213+
if found_violation is not None and found_violation.fixes:
214+
# Create code actions for each fix
215+
for fix in found_violation.fixes:
216+
# Convert our Fix to LSP TextEdits
217+
text_edits = []
218+
for edit in fix.edits:
219+
text_edits.append(
220+
types.TextEdit(
221+
range=types.Range(
222+
start=types.Position(
223+
line=edit.range.start.line,
224+
character=edit.range.start.character,
225+
),
226+
end=types.Position(
227+
line=edit.range.end.line,
228+
character=edit.range.end.character,
229+
),
230+
),
231+
new_text=edit.new_text,
232+
)
233+
)
234+
235+
# Create the code action
236+
code_action = types.CodeAction(
237+
title=fix.title,
238+
kind=types.CodeActionKind.QuickFix,
239+
diagnostics=[diagnostic],
240+
edit=types.WorkspaceEdit(changes={params.text_document.uri: text_edits}),
241+
)
242+
code_actions.append(code_action)
243+
244+
return code_actions if code_actions else None
245+
153246
def list_of_models_for_rendering(self) -> t.List[ModelForRendering]:
154247
"""Get a list of models for rendering.
155248
@@ -186,3 +279,68 @@ def get_completions(
186279
from sqlmesh.lsp.completions import get_sql_completions
187280

188281
return get_sql_completions(self, uri, file_content)
282+
283+
@staticmethod
284+
def diagnostics_to_lsp_diagnostics(
285+
diagnostics: t.List[AnnotatedRuleViolation],
286+
) -> t.List[types.Diagnostic]:
287+
"""
288+
Converts a list of AnnotatedRuleViolations to a list of LSP diagnostics. It will remove duplicates based on the message and range.
289+
"""
290+
lsp_diagnostics = {}
291+
for diagnostic in diagnostics:
292+
lsp_diagnostic = LSPContext.diagnostic_to_lsp_diagnostic(diagnostic)
293+
if lsp_diagnostic is not None:
294+
# Create a unique key combining message and range
295+
diagnostic_key = (
296+
lsp_diagnostic.message,
297+
lsp_diagnostic.range.start.line,
298+
lsp_diagnostic.range.start.character,
299+
lsp_diagnostic.range.end.line,
300+
lsp_diagnostic.range.end.character,
301+
)
302+
if diagnostic_key not in lsp_diagnostics:
303+
lsp_diagnostics[diagnostic_key] = lsp_diagnostic
304+
return list(lsp_diagnostics.values())
305+
306+
@staticmethod
307+
def diagnostic_to_lsp_diagnostic(
308+
diagnostic: AnnotatedRuleViolation,
309+
) -> t.Optional[types.Diagnostic]:
310+
if diagnostic.model._path is None:
311+
return None
312+
if not diagnostic.violation_range:
313+
with open(diagnostic.model._path, "r", encoding="utf-8") as file:
314+
lines = file.readlines()
315+
diagnostic_range = types.Range(
316+
start=types.Position(line=0, character=0),
317+
end=types.Position(line=len(lines) - 1, character=len(lines[-1])),
318+
)
319+
else:
320+
diagnostic_range = types.Range(
321+
start=types.Position(
322+
line=diagnostic.violation_range.start.line,
323+
character=diagnostic.violation_range.start.character,
324+
),
325+
end=types.Position(
326+
line=diagnostic.violation_range.end.line,
327+
character=diagnostic.violation_range.end.character,
328+
),
329+
)
330+
331+
# Get rule definition location for diagnostics link
332+
rule_location = diagnostic.rule.get_definition_location()
333+
rule_uri_wihout_extension = URI.from_path(rule_location.file_path)
334+
rule_uri = f"{rule_uri_wihout_extension.value}#L{rule_location.start_line}"
335+
336+
# Use URI format to create a link for "related information"
337+
return types.Diagnostic(
338+
range=diagnostic_range,
339+
message=diagnostic.violation_msg,
340+
severity=types.DiagnosticSeverity.Error
341+
if diagnostic.violation_type == "error"
342+
else types.DiagnosticSeverity.Warning,
343+
source="sqlmesh",
344+
code=diagnostic.rule.name,
345+
code_description=types.CodeDescription(href=rule_uri),
346+
)

0 commit comments

Comments
 (0)