diff --git a/src/textual/widgets/_masked_input.py b/src/textual/widgets/_masked_input.py
index a48ef9b60b..65b55e2a6e 100644
--- a/src/textual/widgets/_masked_input.py
+++ b/src/textual/widgets/_masked_input.py
@@ -18,7 +18,7 @@
from textual.reactive import Reactive, var
from textual.validation import ValidationResult, Validator
-from textual.widgets._input import Input
+from textual.widgets._input import Input, Selection
InputValidationOn = Literal["blur", "changed", "submitted"]
"""Possible messages that trigger input validation."""
@@ -200,19 +200,28 @@ def insert_separators(self, value: str, cursor_position: int) -> tuple[str, int]
cursor_position += 1
return value, cursor_position
- def insert_text_at_cursor(self, text: str) -> str | None:
- """Inserts `text` at current cursor position. If not present in `text`, any expected separator is automatically
- inserted at the correct position.
+ def replace(self, text: str, start: int, end: int) -> tuple[str, int] | None:
+ """Replace the text between the start and end locations with the given text.
+ If not present in `text`, any expected separator is automatically inserted at the correct position.
Args:
- text: The text to be inserted.
+ text: Text to replace the existing text with.
+ start: Start index to replace (inclusive).
+ end: End index to replace (inclusive).
Returns:
A tuple in the form `(value, cursor_position)` with the new control value and current cursor position if
`text` matches the template, None otherwise.
"""
value = self.input.value
- cursor_position = self.input.cursor_position
+ start, end = sorted((max(0, start), min(len(value), end)))
+
+ if not text and end == len(value):
+ new_value = value[:start]
+ cursor_position = start
+ return new_value, cursor_position
+
+ template_len = len(self.template)
separators = set(
[
char_definition.char
@@ -220,6 +229,12 @@ def insert_text_at_cursor(self, text: str) -> str | None:
if _CharFlags.SEPARATOR in char_definition.flags
]
)
+
+ empty_text = self.empty_mask[start:end]
+ new_value = f"{value[:start]}{empty_text}{value[end:]}"
+
+ cursor_position = start
+
for char in text:
if char in separators:
if char == self.next_separator(cursor_position):
@@ -234,35 +249,81 @@ def insert_text_at_cursor(self, text: str) -> str | None:
char = self.template[cursor_position].char
else:
char = " "
- value = (
- value[:cursor_position]
+ new_value = (
+ new_value[:cursor_position]
+ char
- + value[cursor_position + 1 :]
+ + new_value[cursor_position + 1 :]
)
cursor_position += 1
continue
- if cursor_position >= len(self.template):
+
+ if cursor_position >= template_len:
break
+
char_definition = self.template[cursor_position]
assert _CharFlags.SEPARATOR not in char_definition.flags
+
if not char_definition.pattern.match(char):
return None
+
if _CharFlags.LOWERCASE in char_definition.flags:
char = char.lower()
elif _CharFlags.UPPERCASE in char_definition.flags:
char = char.upper()
- value = value[:cursor_position] + char + value[cursor_position + 1 :]
+
+ new_value = (
+ new_value[:cursor_position] + char + new_value[cursor_position + 1 :]
+ )
cursor_position += 1
- value, cursor_position = self.insert_separators(value, cursor_position)
- return value, cursor_position
+ new_value, cursor_position = self.insert_separators(
+ new_value, cursor_position
+ )
- def move_cursor(self, delta: int) -> None:
+ if (
+ new_value[cursor_position:]
+ == self.empty_mask[cursor_position : len(new_value)]
+ ):
+ new_value = new_value[:cursor_position]
+
+ return new_value, cursor_position
+
+ def insert(self, text: str, index: int) -> tuple[str, int] | None:
+ """Inserts `text` at the given index. If not present in `text`, any expected separator is automatically
+ inserted at the correct position.
+
+ Args:
+ text: The text to be inserted.
+ index: Index to insert the text at (inclusive).
+
+ Returns:
+ A tuple in the form `(value, cursor_position)` with the new control value and current cursor position if
+ `text` matches the template, None otherwise.
+ """
+ return self.replace(text, index, index)
+
+ def insert_text_at_cursor(self, text: str) -> tuple[str, int] | None:
+ """Inserts `text` at current cursor position. If not present in `text`, any expected separator is automatically
+ inserted at the correct position.
+
+ Args:
+ text: The text to be inserted.
+
+ Returns:
+ A tuple in the form `(value, cursor_position)` with the new control value and current cursor position if
+ `text` matches the template, None otherwise.
+ """
+ return self.insert(text, self.input.cursor_position)
+
+ def move_cursor(self, delta: int) -> int:
"""Moves the cursor position by `delta` characters, skipping separators if
running over them.
Args:
delta: The number of characters to move; positive moves right, negative
moves left.
+
+ Returns:
+ The new cursor position.
"""
cursor_position = self.input.cursor_position
if delta < 0 and all(
@@ -271,7 +332,8 @@ def move_cursor(self, delta: int) -> None:
for char_definition in self.template[:cursor_position]
]
):
- return
+ return cursor_position
+
cursor_position += delta
while (
(cursor_position >= 0)
@@ -279,7 +341,8 @@ def move_cursor(self, delta: int) -> None:
and (_CharFlags.SEPARATOR in self.template[cursor_position].flags)
):
cursor_position += delta
- self.input.cursor_position = cursor_position
+
+ return cursor_position
def delete_at_position(self, position: int | None = None) -> None:
"""Deletes character at `position`.
@@ -474,6 +537,7 @@ def __init__(
which determine when to do input validation. The default is to do
validation for all messages.
valid_empty: Empty values are valid.
+ select_on_focus: Whether to select all text on focus.
name: Optional name for the masked input widget.
id: Optional ID for the widget.
classes: Optional initial classes for the widget.
@@ -572,12 +636,19 @@ def render_line(self, y: int) -> Strip:
if char == " ":
result.stylize(style, index, index + 1)
- if self._cursor_visible and self.has_focus:
- if self.cursor_at_end:
- result.pad_right(1)
- cursor_style = self.get_component_rich_style("input--cursor")
- cursor = self.cursor_position
- result.stylize(cursor_style, cursor, cursor + 1)
+ if self.has_focus:
+ if not self.selection.is_empty:
+ start, end = self.selection
+ start, end = sorted((start, end))
+ selection_style = self.get_component_rich_style("input--selection")
+ result.stylize_before(selection_style, start, end)
+
+ if self._cursor_visible:
+ cursor_style = self.get_component_rich_style("input--cursor")
+ cursor = self.cursor_position
+ if self.cursor_at_end:
+ result.pad_right(1)
+ result.stylize(cursor_style, cursor, cursor + 1)
segments = list(result.render(self.app.console))
line_length = Segment.get_line_length(segments)
@@ -598,7 +669,8 @@ async def _on_click(self, event: events.Click) -> None:
"""Ensure clicking on value does not leave cursor on a separator."""
await super()._on_click(event)
if self._template.at_separator():
- self._template.move_cursor(1)
+ cursor_position = self._template.move_cursor(1)
+ self.cursor_position = cursor_position
def insert_text_at_cursor(self, text: str) -> None:
"""Insert new text at the cursor, move the cursor to the end of the new text.
@@ -613,41 +685,111 @@ def insert_text_at_cursor(self, text: str) -> None:
else:
self.restricted()
+ def replace(self, text: str, start: int, end: int) -> None:
+ """Replace the text between the start and end locations with the given text.
+
+ Args:
+ text: Text to replace the existing text with.
+ start: Start index to replace (inclusive).
+ end: End index to replace (inclusive).
+ """
+ new_value = self._template.replace(text, start, end)
+ if new_value is not None:
+ self.value, self.cursor_position = new_value
+ else:
+ self.restricted()
+
def clear(self) -> None:
"""Clear the masked input."""
self.value, self.cursor_position = self._template.insert_separators("", 0)
- def action_cursor_left(self) -> None:
- """Move the cursor one position to the left; separators are skipped."""
- self._template.move_cursor(-1)
+ def action_cursor_left(self, select: bool = False) -> None:
+ """Move the cursor one position to the left; separators are skipped.
+
+ Args:
+ select: If `True`, select the text to the left of the cursor.
+ """
+ start, end = self.selection
+ cursor_position = self._template.move_cursor(-1)
+ if select:
+ self.selection = Selection(start, cursor_position)
+ else:
+ if self.selection.is_empty:
+ self.cursor_position = cursor_position
+ else:
+ self.cursor_position = min(start, end)
+
+ def action_cursor_right(self, select: bool = False) -> None:
+ """Move the cursor one position to the right; separators are skipped.
+
+ Args:
+ select: If `True`, select the text to the right of the cursor.
+ """
+ start, end = self.selection
+ cursor_position = self._template.move_cursor(1)
+ if select:
+ self.selection = Selection(start, cursor_position)
+ else:
+ if self.selection.is_empty:
+ self.cursor_position = cursor_position
+ else:
+ self.cursor_position = max(start, end)
- def action_cursor_right(self) -> None:
- """Move the cursor one position to the right; separators are skipped."""
- self._template.move_cursor(1)
+ def action_home(self, select: bool = False) -> None:
+ """Move the cursor to the start of the input.
- def action_home(self) -> None:
- """Move the cursor to the start of the input."""
- self._template.move_cursor(-len(self.template))
+ Args:
+ select: If `True`, select the text between the old and new cursor positions.
+ """
+ cursor_position = self._template.move_cursor(-len(self.template))
+ if select:
+ self.selection = Selection(self.cursor_position, cursor_position)
+ else:
+ self.cursor_position = cursor_position
- def action_cursor_left_word(self) -> None:
+ def action_cursor_left_word(self, select: bool = False) -> None:
"""Move the cursor left next to the previous separator. If no previous
- separator is found, moves the cursor to the start of the input."""
+ separator is found, moves the cursor to the start of the input.
+
+ Args:
+ select: If `True`, select the text between the old and new cursor positions.
+ """
if self._template.at_separator(self.cursor_position - 1):
- position = self._template.prev_separator_position(self.cursor_position - 1)
+ separator_position = self._template.prev_separator_position(
+ self.cursor_position - 1
+ )
else:
- position = self._template.prev_separator_position()
- if position:
- position += 1
- self.cursor_position = position or 0
+ separator_position = self._template.prev_separator_position()
- def action_cursor_right_word(self) -> None:
+ if separator_position is None:
+ cursor_position = 0
+ else:
+ cursor_position = separator_position + 1
+
+ if select:
+ start, _ = self.selection
+ self.selection = Selection(start, cursor_position)
+ else:
+ self.cursor_position = cursor_position
+
+ def action_cursor_right_word(self, select: bool = False) -> None:
"""Move the cursor right next to the next separator. If no next
- separator is found, moves the cursor to the end of the input."""
- position = self._template.next_separator_position()
- if position is None:
- self.cursor_position = len(self._template.mask)
+ separator is found, moves the cursor to the end of the input.
+
+ Args:
+ select: If `True`, select the text between the old and new cursor positions.
+ """
+ separator_position = self._template.next_separator_position()
+ if separator_position is None:
+ cursor_position = len(self._template.mask)
+ else:
+ cursor_position = separator_position + 1
+
+ if select:
+ start, _ = self.selection
+ self.selection = Selection(start, cursor_position)
else:
- self.cursor_position = position + 1
+ self.cursor_position = cursor_position
def action_delete_right(self) -> None:
"""Delete one character at the current cursor position."""
@@ -671,7 +813,8 @@ def action_delete_left(self) -> None:
if self.cursor_position <= 0:
# Cursor at the start, so nothing to delete
return
- self._template.move_cursor(-1)
+ cursor_position = self._template.move_cursor(-1)
+ self.cursor_position = cursor_position
self._template.delete_at_position()
def action_delete_left_word(self) -> None:
diff --git a/tests/snapshot_tests/__snapshots__/test_snapshots/test_masked_input_highlights_selection.svg b/tests/snapshot_tests/__snapshots__/test_snapshots/test_masked_input_highlights_selection.svg
new file mode 100644
index 0000000000..4b609be24a
--- /dev/null
+++ b/tests/snapshot_tests/__snapshots__/test_snapshots/test_masked_input_highlights_selection.svg
@@ -0,0 +1,153 @@
+
diff --git a/tests/snapshot_tests/test_snapshots.py b/tests/snapshot_tests/test_snapshots.py
index 4eba817cdd..2f3273fa41 100644
--- a/tests/snapshot_tests/test_snapshots.py
+++ b/tests/snapshot_tests/test_snapshots.py
@@ -41,6 +41,7 @@
ListView,
Log,
Markdown,
+ MaskedInput,
OptionList,
Placeholder,
ProgressBar,
@@ -4800,3 +4801,22 @@ async def run_before(pilot: Pilot) -> None:
await pilot.press("ctrl+v")
assert snap_compare(TextAreaApp(), run_before=run_before)
+
+
+def test_masked_input_highlights_selection(snap_compare) -> None:
+ """Regression test for https://github.com/Textualize/textual/issues/5495
+
+ You should see a MaskedInput where the selection is highlighted.
+ """
+
+ class MaskedInputApp(App):
+ def compose(self) -> ComposeResult:
+ yield MaskedInput(
+ template="9999-9999-9999-9999;0",
+ value="123"
+ )
+
+ async def run_before(pilot):
+ pilot.app.query_one(MaskedInput).cursor_blink = False
+
+ assert snap_compare(MaskedInputApp(), run_before=run_before)
diff --git a/tests/test_masked_input.py b/tests/test_masked_input.py
index 72b4d9dc09..35a0fd11a1 100644
--- a/tests/test_masked_input.py
+++ b/tests/test_masked_input.py
@@ -1,3 +1,5 @@
+from __future__ import annotations
+
from typing import Union
import pytest
@@ -10,16 +12,24 @@
InputEvent = Union[MaskedInput.Changed, MaskedInput.Submitted]
-class InputApp(App[None]):
- def __init__(self, template: str, placeholder: str = ""):
+class MaskedInputApp(App[None]):
+ def __init__(
+ self,
+ template: str,
+ value: str | None = None,
+ select_on_focus: bool = True,
+ ):
super().__init__()
self.messages: list[InputEvent] = []
self.template = template
- self.placeholder = placeholder
+ self.value = value
+ self.select_on_focus = select_on_focus
def compose(self) -> ComposeResult:
yield MaskedInput(
- template=self.template, placeholder=self.placeholder, select_on_focus=False
+ template=self.template,
+ value=self.value,
+ select_on_focus=self.select_on_focus,
)
@on(MaskedInput.Changed)
@@ -29,7 +39,10 @@ def on_changed_or_submitted(self, event: InputEvent) -> None:
async def test_missing_required():
- app = InputApp(">9999-99-99")
+ app = MaskedInputApp(
+ template=">9999-99-99",
+ select_on_focus=False,
+ )
async with app.run_test() as pilot:
input = app.query_one(MaskedInput)
input.value = "2024-12"
@@ -48,7 +61,10 @@ async def test_missing_required():
async def test_valid_required():
- app = InputApp(">9999-99-99")
+ app = MaskedInputApp(
+ template=">9999-99-99",
+ select_on_focus=False,
+ )
async with app.run_test() as pilot:
input = app.query_one(MaskedInput)
input.value = "2024-12-31"
@@ -59,7 +75,10 @@ async def test_valid_required():
async def test_missing_optional():
- app = InputApp(">9999-99-00")
+ app = MaskedInputApp(
+ template=">9999-99-00",
+ select_on_focus=False,
+ )
async with app.run_test() as pilot:
input = app.query_one(MaskedInput)
input.value = "2024-12"
@@ -71,7 +90,10 @@ async def test_missing_optional():
async def test_editing():
serial = "ABCDE-FGHIJ-KLMNO-PQRST"
- app = InputApp(">NNNNN-NNNNN-NNNNN-NNNNN;_")
+ app = MaskedInputApp(
+ template=">NNNNN-NNNNN-NNNNN-NNNNN;_",
+ select_on_focus=False,
+ )
async with app.run_test() as pilot:
input = app.query_one(MaskedInput)
await pilot.press("A", "B", "C", "D")
@@ -94,9 +116,75 @@ async def test_editing():
assert input.cursor_position == len(serial)
+async def test_overwrite_typing():
+ app = MaskedInputApp(
+ template="9999-9999-9999-9999;0",
+ select_on_focus=False,
+ )
+ async with app.run_test() as pilot:
+ input = app.query_one(MaskedInput)
+ input.value = "0000-99"
+ input.action_home()
+
+ await pilot.press("1", "2", "3")
+ assert input.cursor_position == 3
+ assert input.value == "1230-99"
+
+ await pilot.press("4")
+ assert input.cursor_position == 5
+ assert input.value == "1234-99"
+
+ await pilot.press("0", "0")
+ assert input.cursor_position == 7
+ assert input.value == "1234-00"
+
+ await pilot.press("7", "8")
+ assert input.cursor_position == 10
+ assert input.value == "1234-0078-"
+
+ await pilot.press("left", "left")
+ await pilot.press("backspace", "backspace")
+ assert input.cursor_position == 5
+ assert input.value == "1234- 78"
+
+ await pilot.press("5", "6")
+ assert input.cursor_position == 7
+ assert input.value == "1234-5678"
+
+
+async def test_insert_jump_to_next_separator():
+ app = MaskedInputApp(
+ template="9999-9999-9999-9999;0",
+ select_on_focus=False,
+ )
+ async with app.run_test() as pilot:
+ input = app.query_one(MaskedInput)
+
+ # If cursor is at the start, input should not jump to next separator
+ await pilot.press("-")
+ assert input.value == ""
+ assert input.cursor_position == 0
+
+ await pilot.press("1", "-")
+ assert input.value == "1 -"
+ assert input.cursor_position == 5
+
+ # If previous character is a separator, input should not jump to next separator
+ await pilot.press("-")
+ assert input.value == "1 -"
+ assert input.cursor_position == 5
+
+ await pilot.press("2", "-")
+ assert input.value == "1 -2 -"
+ assert input.cursor_position == 10
+
+
async def test_key_movement_actions():
serial = "ABCDE-FGHIJ-KLMNO-PQRST"
- app = InputApp(">NNNNN-NNNNN-NNNNN-NNNNN;_")
+ app = MaskedInputApp(
+ template=">NNNNN-NNNNN-NNNNN-NNNNN;_",
+ select_on_focus=False,
+ )
async with app.run_test():
input = app.query_one(MaskedInput)
input.value = serial
@@ -116,7 +204,10 @@ async def test_key_movement_actions():
async def test_key_modification_actions():
serial = "ABCDE-FGHIJ-KLMNO-PQRST"
- app = InputApp(">NNNNN-NNNNN-NNNNN-NNNNN;_")
+ app = MaskedInputApp(
+ template=">NNNNN-NNNNN-NNNNN-NNNNN;_",
+ select_on_focus=False,
+ )
async with app.run_test() as pilot:
input = app.query_one(MaskedInput)
input.value = serial
@@ -153,7 +244,10 @@ async def test_key_modification_actions():
async def test_cursor_word_right_after_last_separator():
- app = InputApp(">NNN-NNN-NNN-NNNNN;_")
+ app = MaskedInputApp(
+ template=">NNN-NNN-NNN-NNNNN;_",
+ select_on_focus=False,
+ )
async with app.run_test():
input = app.query_one(MaskedInput)
input.value = "123-456-789-012"
@@ -163,7 +257,10 @@ async def test_cursor_word_right_after_last_separator():
async def test_case_conversion_meta_characters():
- app = InputApp("NN<-N!N>N")
+ app = MaskedInputApp(
+ template="NN<-N!N>N",
+ select_on_focus=False,
+ )
async with app.run_test() as pilot:
input = app.query_one(MaskedInput)
await pilot.press("a", "B", "C", "D", "e")
@@ -172,7 +269,10 @@ async def test_case_conversion_meta_characters():
async def test_case_conversion_override():
- app = InputApp(">-