diff --git a/client/commands/v2/pysa_server.py b/client/commands/v2/pysa_server.py index f610e0f15fc..d5d2e4dc259 100644 --- a/client/commands/v2/pysa_server.py +++ b/client/commands/v2/pysa_server.py @@ -11,7 +11,12 @@ import asyncio import logging +from collections import defaultdict +from pathlib import Path +from typing import List, Sequence, Dict, Set +from ....api import query, connection as api_connection +from ....api.connection import PyreQueryError from ... import ( json_rpc, command_arguments, @@ -28,6 +33,7 @@ _read_lsp_request, try_initialize, _log_lsp_event, + _publish_diagnostics, InitializationExit, InitializationSuccess, InitializationFailure, @@ -59,6 +65,60 @@ def __init__( self.pyre_arguments = pyre_arguments self.binary_location = binary_location self.server_identifier = server_identifier + self.pyre_connection = api_connection.PyreConnection( + Path(self.pyre_arguments.global_root) + ) + self.file_tracker: Set[Path] = set() + + def invalid_model_to_diagnostic( + self, invalid_model: query.InvalidModel + ) -> lsp.Diagnostic: + return lsp.Diagnostic( + range=lsp.Range( + start=lsp.Position( + line=invalid_model.line - 1, character=invalid_model.column + ), + end=lsp.Position( + line=invalid_model.stop_line - 1, + character=invalid_model.stop_column, + ), + ), + message=invalid_model.full_error_message, + severity=lsp.DiagnosticSeverity.ERROR, + code=None, + source="Pysa", + ) + + def invalid_models_to_diagnostics( + self, invalid_models: Sequence[query.InvalidModel] + ) -> Dict[Path, List[lsp.Diagnostic]]: + result: Dict[Path, List[lsp.Diagnostic]] = defaultdict(list) + for model in invalid_models: + if model.path is None: + self.log_and_show_message_to_client( + f"{model.full_error_message}", lsp.MessageType.ERROR + ) + else: + result[Path(model.path)].append(self.invalid_model_to_diagnostic(model)) + return result + + async def update_errors(self) -> None: + # Publishing empty diagnostics to clear errors in VSCode and reset self.file_tracker + for document_path in self.file_tracker: + await _publish_diagnostics(self.output_channel, document_path, []) + self.file_tracker.clear() + + try: + model_errors = query.get_invalid_taint_models(self.pyre_connection) + diagnostics = self.invalid_models_to_diagnostics(model_errors) + # Keep track of files we publish diagnostics for + self.file_tracker.update(diagnostics.keys()) + + await self.show_model_errors_to_client(diagnostics) + except PyreQueryError as e: + await self.log_and_show_message_to_client( + f"Error querying Pyre: {e}", lsp.MessageType.WARNING + ) async def show_message_to_client( self, message: str, level: lsp.MessageType = lsp.MessageType.INFO @@ -86,6 +146,12 @@ async def log_and_show_message_to_client( LOG.debug(message) await self.show_message_to_client(message, level) + async def show_model_errors_to_client( + self, diagnostics: Dict[Path, List[lsp.Diagnostic]] + ) -> None: + for path, diagnostic in diagnostics.items(): + await _publish_diagnostics(self.output_channel, path, diagnostic or []) + async def wait_for_exit(self) -> int: while True: async with _read_lsp_request( @@ -102,15 +168,18 @@ async def process_open_request( self, parameters: lsp.DidOpenTextDocumentParameters ) -> None: document_path = parameters.text_document.document_uri().to_file_path() + if document_path is None: raise json_rpc.InvalidRequestError( f"Document URI is not a file: {parameters.text_document.uri}" ) + await self.update_errors() async def process_close_request( self, parameters: lsp.DidCloseTextDocumentParameters ) -> None: document_path = parameters.text_document.document_uri().to_file_path() + if document_path is None: raise json_rpc.InvalidRequestError( f"Document URI is not a file: {parameters.text_document.uri}" @@ -129,6 +198,7 @@ async def process_did_save_request( raise json_rpc.InvalidRequestError( f"Document URI is not a file: {parameters.text_document.uri}" ) + await self.update_errors() async def run(self) -> int: while True: