diff --git a/src/textual/widgets/_masked_input.py b/src/textual/widgets/_masked_input.py index a48ef9b60b..65b55e2a6e 100644 --- a/src/textual/widgets/_masked_input.py +++ b/src/textual/widgets/_masked_input.py @@ -18,7 +18,7 @@ from textual.reactive import Reactive, var from textual.validation import ValidationResult, Validator -from textual.widgets._input import Input +from textual.widgets._input import Input, Selection InputValidationOn = Literal["blur", "changed", "submitted"] """Possible messages that trigger input validation.""" @@ -200,19 +200,28 @@ def insert_separators(self, value: str, cursor_position: int) -> tuple[str, int] cursor_position += 1 return value, cursor_position - def insert_text_at_cursor(self, text: str) -> str | None: - """Inserts `text` at current cursor position. If not present in `text`, any expected separator is automatically - inserted at the correct position. + def replace(self, text: str, start: int, end: int) -> tuple[str, int] | None: + """Replace the text between the start and end locations with the given text. + If not present in `text`, any expected separator is automatically inserted at the correct position. Args: - text: The text to be inserted. + text: Text to replace the existing text with. + start: Start index to replace (inclusive). + end: End index to replace (inclusive). Returns: A tuple in the form `(value, cursor_position)` with the new control value and current cursor position if `text` matches the template, None otherwise. """ value = self.input.value - cursor_position = self.input.cursor_position + start, end = sorted((max(0, start), min(len(value), end))) + + if not text and end == len(value): + new_value = value[:start] + cursor_position = start + return new_value, cursor_position + + template_len = len(self.template) separators = set( [ char_definition.char @@ -220,6 +229,12 @@ def insert_text_at_cursor(self, text: str) -> str | None: if _CharFlags.SEPARATOR in char_definition.flags ] ) + + empty_text = self.empty_mask[start:end] + new_value = f"{value[:start]}{empty_text}{value[end:]}" + + cursor_position = start + for char in text: if char in separators: if char == self.next_separator(cursor_position): @@ -234,35 +249,81 @@ def insert_text_at_cursor(self, text: str) -> str | None: char = self.template[cursor_position].char else: char = " " - value = ( - value[:cursor_position] + new_value = ( + new_value[:cursor_position] + char - + value[cursor_position + 1 :] + + new_value[cursor_position + 1 :] ) cursor_position += 1 continue - if cursor_position >= len(self.template): + + if cursor_position >= template_len: break + char_definition = self.template[cursor_position] assert _CharFlags.SEPARATOR not in char_definition.flags + if not char_definition.pattern.match(char): return None + if _CharFlags.LOWERCASE in char_definition.flags: char = char.lower() elif _CharFlags.UPPERCASE in char_definition.flags: char = char.upper() - value = value[:cursor_position] + char + value[cursor_position + 1 :] + + new_value = ( + new_value[:cursor_position] + char + new_value[cursor_position + 1 :] + ) cursor_position += 1 - value, cursor_position = self.insert_separators(value, cursor_position) - return value, cursor_position + new_value, cursor_position = self.insert_separators( + new_value, cursor_position + ) - def move_cursor(self, delta: int) -> None: + if ( + new_value[cursor_position:] + == self.empty_mask[cursor_position : len(new_value)] + ): + new_value = new_value[:cursor_position] + + return new_value, cursor_position + + def insert(self, text: str, index: int) -> tuple[str, int] | None: + """Inserts `text` at the given index. If not present in `text`, any expected separator is automatically + inserted at the correct position. + + Args: + text: The text to be inserted. + index: Index to insert the text at (inclusive). + + Returns: + A tuple in the form `(value, cursor_position)` with the new control value and current cursor position if + `text` matches the template, None otherwise. + """ + return self.replace(text, index, index) + + def insert_text_at_cursor(self, text: str) -> tuple[str, int] | None: + """Inserts `text` at current cursor position. If not present in `text`, any expected separator is automatically + inserted at the correct position. + + Args: + text: The text to be inserted. + + Returns: + A tuple in the form `(value, cursor_position)` with the new control value and current cursor position if + `text` matches the template, None otherwise. + """ + return self.insert(text, self.input.cursor_position) + + def move_cursor(self, delta: int) -> int: """Moves the cursor position by `delta` characters, skipping separators if running over them. Args: delta: The number of characters to move; positive moves right, negative moves left. + + Returns: + The new cursor position. """ cursor_position = self.input.cursor_position if delta < 0 and all( @@ -271,7 +332,8 @@ def move_cursor(self, delta: int) -> None: for char_definition in self.template[:cursor_position] ] ): - return + return cursor_position + cursor_position += delta while ( (cursor_position >= 0) @@ -279,7 +341,8 @@ def move_cursor(self, delta: int) -> None: and (_CharFlags.SEPARATOR in self.template[cursor_position].flags) ): cursor_position += delta - self.input.cursor_position = cursor_position + + return cursor_position def delete_at_position(self, position: int | None = None) -> None: """Deletes character at `position`. @@ -474,6 +537,7 @@ def __init__( which determine when to do input validation. The default is to do validation for all messages. valid_empty: Empty values are valid. + select_on_focus: Whether to select all text on focus. name: Optional name for the masked input widget. id: Optional ID for the widget. classes: Optional initial classes for the widget. @@ -572,12 +636,19 @@ def render_line(self, y: int) -> Strip: if char == " ": result.stylize(style, index, index + 1) - if self._cursor_visible and self.has_focus: - if self.cursor_at_end: - result.pad_right(1) - cursor_style = self.get_component_rich_style("input--cursor") - cursor = self.cursor_position - result.stylize(cursor_style, cursor, cursor + 1) + if self.has_focus: + if not self.selection.is_empty: + start, end = self.selection + start, end = sorted((start, end)) + selection_style = self.get_component_rich_style("input--selection") + result.stylize_before(selection_style, start, end) + + if self._cursor_visible: + cursor_style = self.get_component_rich_style("input--cursor") + cursor = self.cursor_position + if self.cursor_at_end: + result.pad_right(1) + result.stylize(cursor_style, cursor, cursor + 1) segments = list(result.render(self.app.console)) line_length = Segment.get_line_length(segments) @@ -598,7 +669,8 @@ async def _on_click(self, event: events.Click) -> None: """Ensure clicking on value does not leave cursor on a separator.""" await super()._on_click(event) if self._template.at_separator(): - self._template.move_cursor(1) + cursor_position = self._template.move_cursor(1) + self.cursor_position = cursor_position def insert_text_at_cursor(self, text: str) -> None: """Insert new text at the cursor, move the cursor to the end of the new text. @@ -613,41 +685,111 @@ def insert_text_at_cursor(self, text: str) -> None: else: self.restricted() + def replace(self, text: str, start: int, end: int) -> None: + """Replace the text between the start and end locations with the given text. + + Args: + text: Text to replace the existing text with. + start: Start index to replace (inclusive). + end: End index to replace (inclusive). + """ + new_value = self._template.replace(text, start, end) + if new_value is not None: + self.value, self.cursor_position = new_value + else: + self.restricted() + def clear(self) -> None: """Clear the masked input.""" self.value, self.cursor_position = self._template.insert_separators("", 0) - def action_cursor_left(self) -> None: - """Move the cursor one position to the left; separators are skipped.""" - self._template.move_cursor(-1) + def action_cursor_left(self, select: bool = False) -> None: + """Move the cursor one position to the left; separators are skipped. + + Args: + select: If `True`, select the text to the left of the cursor. + """ + start, end = self.selection + cursor_position = self._template.move_cursor(-1) + if select: + self.selection = Selection(start, cursor_position) + else: + if self.selection.is_empty: + self.cursor_position = cursor_position + else: + self.cursor_position = min(start, end) + + def action_cursor_right(self, select: bool = False) -> None: + """Move the cursor one position to the right; separators are skipped. + + Args: + select: If `True`, select the text to the right of the cursor. + """ + start, end = self.selection + cursor_position = self._template.move_cursor(1) + if select: + self.selection = Selection(start, cursor_position) + else: + if self.selection.is_empty: + self.cursor_position = cursor_position + else: + self.cursor_position = max(start, end) - def action_cursor_right(self) -> None: - """Move the cursor one position to the right; separators are skipped.""" - self._template.move_cursor(1) + def action_home(self, select: bool = False) -> None: + """Move the cursor to the start of the input. - def action_home(self) -> None: - """Move the cursor to the start of the input.""" - self._template.move_cursor(-len(self.template)) + Args: + select: If `True`, select the text between the old and new cursor positions. + """ + cursor_position = self._template.move_cursor(-len(self.template)) + if select: + self.selection = Selection(self.cursor_position, cursor_position) + else: + self.cursor_position = cursor_position - def action_cursor_left_word(self) -> None: + def action_cursor_left_word(self, select: bool = False) -> None: """Move the cursor left next to the previous separator. If no previous - separator is found, moves the cursor to the start of the input.""" + separator is found, moves the cursor to the start of the input. + + Args: + select: If `True`, select the text between the old and new cursor positions. + """ if self._template.at_separator(self.cursor_position - 1): - position = self._template.prev_separator_position(self.cursor_position - 1) + separator_position = self._template.prev_separator_position( + self.cursor_position - 1 + ) else: - position = self._template.prev_separator_position() - if position: - position += 1 - self.cursor_position = position or 0 + separator_position = self._template.prev_separator_position() - def action_cursor_right_word(self) -> None: + if separator_position is None: + cursor_position = 0 + else: + cursor_position = separator_position + 1 + + if select: + start, _ = self.selection + self.selection = Selection(start, cursor_position) + else: + self.cursor_position = cursor_position + + def action_cursor_right_word(self, select: bool = False) -> None: """Move the cursor right next to the next separator. If no next - separator is found, moves the cursor to the end of the input.""" - position = self._template.next_separator_position() - if position is None: - self.cursor_position = len(self._template.mask) + separator is found, moves the cursor to the end of the input. + + Args: + select: If `True`, select the text between the old and new cursor positions. + """ + separator_position = self._template.next_separator_position() + if separator_position is None: + cursor_position = len(self._template.mask) + else: + cursor_position = separator_position + 1 + + if select: + start, _ = self.selection + self.selection = Selection(start, cursor_position) else: - self.cursor_position = position + 1 + self.cursor_position = cursor_position def action_delete_right(self) -> None: """Delete one character at the current cursor position.""" @@ -671,7 +813,8 @@ def action_delete_left(self) -> None: if self.cursor_position <= 0: # Cursor at the start, so nothing to delete return - self._template.move_cursor(-1) + cursor_position = self._template.move_cursor(-1) + self.cursor_position = cursor_position self._template.delete_at_position() def action_delete_left_word(self) -> None: diff --git a/tests/snapshot_tests/__snapshots__/test_snapshots/test_masked_input_highlights_selection.svg b/tests/snapshot_tests/__snapshots__/test_snapshots/test_masked_input_highlights_selection.svg new file mode 100644 index 0000000000..4b609be24a --- /dev/null +++ b/tests/snapshot_tests/__snapshots__/test_snapshots/test_masked_input_highlights_selection.svg @@ -0,0 +1,153 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + MaskedInputApp + + + + + + + + + + ▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔▔ +1230-0000-0000-0000 +▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/tests/snapshot_tests/test_snapshots.py b/tests/snapshot_tests/test_snapshots.py index 4eba817cdd..2f3273fa41 100644 --- a/tests/snapshot_tests/test_snapshots.py +++ b/tests/snapshot_tests/test_snapshots.py @@ -41,6 +41,7 @@ ListView, Log, Markdown, + MaskedInput, OptionList, Placeholder, ProgressBar, @@ -4800,3 +4801,22 @@ async def run_before(pilot: Pilot) -> None: await pilot.press("ctrl+v") assert snap_compare(TextAreaApp(), run_before=run_before) + + +def test_masked_input_highlights_selection(snap_compare) -> None: + """Regression test for https://github.com/Textualize/textual/issues/5495 + + You should see a MaskedInput where the selection is highlighted. + """ + + class MaskedInputApp(App): + def compose(self) -> ComposeResult: + yield MaskedInput( + template="9999-9999-9999-9999;0", + value="123" + ) + + async def run_before(pilot): + pilot.app.query_one(MaskedInput).cursor_blink = False + + assert snap_compare(MaskedInputApp(), run_before=run_before) diff --git a/tests/test_masked_input.py b/tests/test_masked_input.py index 72b4d9dc09..35a0fd11a1 100644 --- a/tests/test_masked_input.py +++ b/tests/test_masked_input.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from typing import Union import pytest @@ -10,16 +12,24 @@ InputEvent = Union[MaskedInput.Changed, MaskedInput.Submitted] -class InputApp(App[None]): - def __init__(self, template: str, placeholder: str = ""): +class MaskedInputApp(App[None]): + def __init__( + self, + template: str, + value: str | None = None, + select_on_focus: bool = True, + ): super().__init__() self.messages: list[InputEvent] = [] self.template = template - self.placeholder = placeholder + self.value = value + self.select_on_focus = select_on_focus def compose(self) -> ComposeResult: yield MaskedInput( - template=self.template, placeholder=self.placeholder, select_on_focus=False + template=self.template, + value=self.value, + select_on_focus=self.select_on_focus, ) @on(MaskedInput.Changed) @@ -29,7 +39,10 @@ def on_changed_or_submitted(self, event: InputEvent) -> None: async def test_missing_required(): - app = InputApp(">9999-99-99") + app = MaskedInputApp( + template=">9999-99-99", + select_on_focus=False, + ) async with app.run_test() as pilot: input = app.query_one(MaskedInput) input.value = "2024-12" @@ -48,7 +61,10 @@ async def test_missing_required(): async def test_valid_required(): - app = InputApp(">9999-99-99") + app = MaskedInputApp( + template=">9999-99-99", + select_on_focus=False, + ) async with app.run_test() as pilot: input = app.query_one(MaskedInput) input.value = "2024-12-31" @@ -59,7 +75,10 @@ async def test_valid_required(): async def test_missing_optional(): - app = InputApp(">9999-99-00") + app = MaskedInputApp( + template=">9999-99-00", + select_on_focus=False, + ) async with app.run_test() as pilot: input = app.query_one(MaskedInput) input.value = "2024-12" @@ -71,7 +90,10 @@ async def test_missing_optional(): async def test_editing(): serial = "ABCDE-FGHIJ-KLMNO-PQRST" - app = InputApp(">NNNNN-NNNNN-NNNNN-NNNNN;_") + app = MaskedInputApp( + template=">NNNNN-NNNNN-NNNNN-NNNNN;_", + select_on_focus=False, + ) async with app.run_test() as pilot: input = app.query_one(MaskedInput) await pilot.press("A", "B", "C", "D") @@ -94,9 +116,75 @@ async def test_editing(): assert input.cursor_position == len(serial) +async def test_overwrite_typing(): + app = MaskedInputApp( + template="9999-9999-9999-9999;0", + select_on_focus=False, + ) + async with app.run_test() as pilot: + input = app.query_one(MaskedInput) + input.value = "0000-99" + input.action_home() + + await pilot.press("1", "2", "3") + assert input.cursor_position == 3 + assert input.value == "1230-99" + + await pilot.press("4") + assert input.cursor_position == 5 + assert input.value == "1234-99" + + await pilot.press("0", "0") + assert input.cursor_position == 7 + assert input.value == "1234-00" + + await pilot.press("7", "8") + assert input.cursor_position == 10 + assert input.value == "1234-0078-" + + await pilot.press("left", "left") + await pilot.press("backspace", "backspace") + assert input.cursor_position == 5 + assert input.value == "1234- 78" + + await pilot.press("5", "6") + assert input.cursor_position == 7 + assert input.value == "1234-5678" + + +async def test_insert_jump_to_next_separator(): + app = MaskedInputApp( + template="9999-9999-9999-9999;0", + select_on_focus=False, + ) + async with app.run_test() as pilot: + input = app.query_one(MaskedInput) + + # If cursor is at the start, input should not jump to next separator + await pilot.press("-") + assert input.value == "" + assert input.cursor_position == 0 + + await pilot.press("1", "-") + assert input.value == "1 -" + assert input.cursor_position == 5 + + # If previous character is a separator, input should not jump to next separator + await pilot.press("-") + assert input.value == "1 -" + assert input.cursor_position == 5 + + await pilot.press("2", "-") + assert input.value == "1 -2 -" + assert input.cursor_position == 10 + + async def test_key_movement_actions(): serial = "ABCDE-FGHIJ-KLMNO-PQRST" - app = InputApp(">NNNNN-NNNNN-NNNNN-NNNNN;_") + app = MaskedInputApp( + template=">NNNNN-NNNNN-NNNNN-NNNNN;_", + select_on_focus=False, + ) async with app.run_test(): input = app.query_one(MaskedInput) input.value = serial @@ -116,7 +204,10 @@ async def test_key_movement_actions(): async def test_key_modification_actions(): serial = "ABCDE-FGHIJ-KLMNO-PQRST" - app = InputApp(">NNNNN-NNNNN-NNNNN-NNNNN;_") + app = MaskedInputApp( + template=">NNNNN-NNNNN-NNNNN-NNNNN;_", + select_on_focus=False, + ) async with app.run_test() as pilot: input = app.query_one(MaskedInput) input.value = serial @@ -153,7 +244,10 @@ async def test_key_modification_actions(): async def test_cursor_word_right_after_last_separator(): - app = InputApp(">NNN-NNN-NNN-NNNNN;_") + app = MaskedInputApp( + template=">NNN-NNN-NNN-NNNNN;_", + select_on_focus=False, + ) async with app.run_test(): input = app.query_one(MaskedInput) input.value = "123-456-789-012" @@ -163,7 +257,10 @@ async def test_cursor_word_right_after_last_separator(): async def test_case_conversion_meta_characters(): - app = InputApp("NN<-N!N>N") + app = MaskedInputApp( + template="NN<-N!N>N", + select_on_focus=False, + ) async with app.run_test() as pilot: input = app.query_one(MaskedInput) await pilot.press("a", "B", "C", "D", "e") @@ -172,7 +269,10 @@ async def test_case_conversion_meta_characters(): async def test_case_conversion_override(): - app = InputApp(">-