11from dataclasses import dataclass
22from pathlib import Path
3+ import uuid
34from sqlmesh .core .context import Context
45import typing as t
56
89from sqlmesh .lsp .custom import ModelForRendering
910from sqlmesh .lsp .custom import AllModelsResponse , RenderModelEntry
1011from sqlmesh .lsp .uri import URI
12+ from lsprotocol import types
1113
1214
1315@dataclass
@@ -33,8 +35,14 @@ class LSPContext:
3335 map : t .Dict [Path , t .Union [ModelTarget , AuditTarget ]]
3436 _render_cache : t .Dict [Path , t .List [RenderModelEntry ]]
3537 _lint_cache : t .Dict [Path , t .List [AnnotatedRuleViolation ]]
38+ _version_id : str
39+ """
40+ This is a version ID for the context. It is used to track changes to the context. It can be used to
41+ return a version number to the LSP client.
42+ """
3643
3744 def __init__ (self , context : Context ) -> None :
45+ self ._version_id = str (uuid .uuid4 ())
3846 self .context = context
3947 self ._render_cache = {}
4048 self ._lint_cache = {}
@@ -62,6 +70,11 @@ def __init__(self, context: Context) -> None:
6270 ** audit_map ,
6371 }
6472
73+ @property
74+ def version_id (self ) -> str :
75+ """Get the version ID for the context."""
76+ return self ._version_id
77+
6578 def render_model (self , uri : URI ) -> t .List [RenderModelEntry ]:
6679 """Get rendered models for a file, using cache when available.
6780
@@ -150,6 +163,86 @@ def lint_model(self, uri: URI) -> t.List[AnnotatedRuleViolation]:
150163 self ._lint_cache [path ] = diagnostics
151164 return diagnostics
152165
166+ def get_code_actions (
167+ self , uri : URI , params : types .CodeActionParams
168+ ) -> t .Optional [t .List [t .Union [types .Command , types .CodeAction ]]]:
169+ """Get code actions for a file."""
170+
171+ # Get the violations (which contain the fixes)
172+ violations = self .lint_model (uri )
173+
174+ # Convert violations to a map for quick lookup
175+ # Use a hashable representation of Range as the key
176+ violation_map : t .Dict [
177+ t .Tuple [str , t .Tuple [int , int , int , int ]], AnnotatedRuleViolation
178+ ] = {}
179+ for violation in violations :
180+ if violation .violation_range :
181+ lsp_diagnostic = self .diagnostic_to_lsp_diagnostic (violation )
182+ if lsp_diagnostic :
183+ # Create a hashable key from the diagnostic message and range
184+ key = (
185+ lsp_diagnostic .message ,
186+ (
187+ lsp_diagnostic .range .start .line ,
188+ lsp_diagnostic .range .start .character ,
189+ lsp_diagnostic .range .end .line ,
190+ lsp_diagnostic .range .end .character ,
191+ ),
192+ )
193+ violation_map [key ] = violation
194+
195+ # Get diagnostics in the requested range
196+ diagnostics = params .context .diagnostics if params .context else []
197+
198+ code_actions : t .List [t .Union [types .Command , types .CodeAction ]] = []
199+
200+ for diagnostic in diagnostics :
201+ # Find the corresponding violation
202+ key = (
203+ diagnostic .message ,
204+ (
205+ diagnostic .range .start .line ,
206+ diagnostic .range .start .character ,
207+ diagnostic .range .end .line ,
208+ diagnostic .range .end .character ,
209+ ),
210+ )
211+ found_violation = violation_map .get (key )
212+
213+ if found_violation is not None and found_violation .fixes :
214+ # Create code actions for each fix
215+ for fix in found_violation .fixes :
216+ # Convert our Fix to LSP TextEdits
217+ text_edits = []
218+ for edit in fix .edits :
219+ text_edits .append (
220+ types .TextEdit (
221+ range = types .Range (
222+ start = types .Position (
223+ line = edit .range .start .line ,
224+ character = edit .range .start .character ,
225+ ),
226+ end = types .Position (
227+ line = edit .range .end .line ,
228+ character = edit .range .end .character ,
229+ ),
230+ ),
231+ new_text = edit .new_text ,
232+ )
233+ )
234+
235+ # Create the code action
236+ code_action = types .CodeAction (
237+ title = fix .title ,
238+ kind = types .CodeActionKind .QuickFix ,
239+ diagnostics = [diagnostic ],
240+ edit = types .WorkspaceEdit (changes = {params .text_document .uri : text_edits }),
241+ )
242+ code_actions .append (code_action )
243+
244+ return code_actions if code_actions else None
245+
153246 def list_of_models_for_rendering (self ) -> t .List [ModelForRendering ]:
154247 """Get a list of models for rendering.
155248
@@ -186,3 +279,68 @@ def get_completions(
186279 from sqlmesh .lsp .completions import get_sql_completions
187280
188281 return get_sql_completions (self , uri , file_content )
282+
283+ @staticmethod
284+ def diagnostics_to_lsp_diagnostics (
285+ diagnostics : t .List [AnnotatedRuleViolation ],
286+ ) -> t .List [types .Diagnostic ]:
287+ """
288+ Converts a list of AnnotatedRuleViolations to a list of LSP diagnostics. It will remove duplicates based on the message and range.
289+ """
290+ lsp_diagnostics = {}
291+ for diagnostic in diagnostics :
292+ lsp_diagnostic = LSPContext .diagnostic_to_lsp_diagnostic (diagnostic )
293+ if lsp_diagnostic is not None :
294+ # Create a unique key combining message and range
295+ diagnostic_key = (
296+ lsp_diagnostic .message ,
297+ lsp_diagnostic .range .start .line ,
298+ lsp_diagnostic .range .start .character ,
299+ lsp_diagnostic .range .end .line ,
300+ lsp_diagnostic .range .end .character ,
301+ )
302+ if diagnostic_key not in lsp_diagnostics :
303+ lsp_diagnostics [diagnostic_key ] = lsp_diagnostic
304+ return list (lsp_diagnostics .values ())
305+
306+ @staticmethod
307+ def diagnostic_to_lsp_diagnostic (
308+ diagnostic : AnnotatedRuleViolation ,
309+ ) -> t .Optional [types .Diagnostic ]:
310+ if diagnostic .model ._path is None :
311+ return None
312+ if not diagnostic .violation_range :
313+ with open (diagnostic .model ._path , "r" , encoding = "utf-8" ) as file :
314+ lines = file .readlines ()
315+ diagnostic_range = types .Range (
316+ start = types .Position (line = 0 , character = 0 ),
317+ end = types .Position (line = len (lines ) - 1 , character = len (lines [- 1 ])),
318+ )
319+ else :
320+ diagnostic_range = types .Range (
321+ start = types .Position (
322+ line = diagnostic .violation_range .start .line ,
323+ character = diagnostic .violation_range .start .character ,
324+ ),
325+ end = types .Position (
326+ line = diagnostic .violation_range .end .line ,
327+ character = diagnostic .violation_range .end .character ,
328+ ),
329+ )
330+
331+ # Get rule definition location for diagnostics link
332+ rule_location = diagnostic .rule .get_definition_location ()
333+ rule_uri_wihout_extension = URI .from_path (rule_location .file_path )
334+ rule_uri = f"{ rule_uri_wihout_extension .value } #L{ rule_location .start_line } "
335+
336+ # Use URI format to create a link for "related information"
337+ return types .Diagnostic (
338+ range = diagnostic_range ,
339+ message = diagnostic .violation_msg ,
340+ severity = types .DiagnosticSeverity .Error
341+ if diagnostic .violation_type == "error"
342+ else types .DiagnosticSeverity .Warning ,
343+ source = "sqlmesh" ,
344+ code = diagnostic .rule .name ,
345+ code_description = types .CodeDescription (href = rule_uri ),
346+ )
0 commit comments