diff --git a/babi/file.py b/babi/file.py index fd99b74b..4c853a97 100644 --- a/babi/file.py +++ b/babi/file.py @@ -220,6 +220,21 @@ def __init__( self.selection = Selection() self._file_hls: Tuple[FileHL, ...] = () + def refresh_syntax(self) -> None: + file_hls = [] + for factory in self._hl_factories: + if self.filename is not None: + hl = factory.file_highlighter(self.filename, self.buf[0]) + file_hls.append(hl) + else: + file_hls.append(factory.blank_file_highlighter()) + self._file_hls = ( + *file_hls, + self._trailing_whitespace, self._replace_hl, self.selection, + ) + for file_hl in self._file_hls: + file_hl.register_callbacks(self.buf) + def ensure_loaded( self, status: Status, @@ -253,20 +268,7 @@ def ensure_loaded( status.update(f'mixed newlines will be converted to {self.nl!r}') self.modified = True - file_hls = [] - for factory in self._hl_factories: - if self.filename is not None: - hl = factory.file_highlighter(self.filename, self.buf[0]) - file_hls.append(hl) - else: - file_hls.append(factory.blank_file_highlighter()) - self._file_hls = ( - *file_hls, - self._trailing_whitespace, self._replace_hl, self.selection, - ) - for file_hl in self._file_hls: - file_hl.register_callbacks(self.buf) - + self.refresh_syntax() self.go_to_line(self.initial_line, margin) def __repr__(self) -> str: diff --git a/babi/screen.py b/babi/screen.py index 765f1963..d6a40ddd 100644 --- a/babi/screen.py +++ b/babi/screen.py @@ -520,6 +520,8 @@ def save(self) -> Optional[PromptResult]: action.end_modified = not first action.start_modified = True first = False + + self.file.refresh_syntax() return None def save_filename(self) -> Optional[PromptResult]: