improve error message

This commit is contained in:
JOLIMAITRE Matthieu 2024-05-23 02:27:56 +02:00
parent 6e35db16ba
commit a49b14c12c
5 changed files with 127 additions and 19 deletions

View file

@ -7,6 +7,7 @@ import sys
sys.path.append(f"{dirname(__file__)}/../src") sys.path.append(f"{dirname(__file__)}/../src")
from pyalibert import regex, Parser, just, Declare from pyalibert import regex, Parser, just, Declare
from pyalibert.result import ParseError
# ast modeling # ast modeling
@ -79,9 +80,12 @@ def repl():
print("> ", end="") print("> ", end="")
sys.stdout.flush() sys.stdout.flush()
for line in sys.stdin: for line in sys.stdin:
try:
result = parser.parse(line.strip()) result = parser.parse(line.strip())
print(result) print(result)
print(f"= {result.eval()}") print(f"= {result.eval()}")
except ParseError as err:
print(err)
print("> ", end="") print("> ", end="")
sys.stdout.flush() sys.stdout.flush()

View file

@ -1,7 +1,7 @@
from dataclasses import dataclass from dataclasses import dataclass
from typing import Union, Generic, TypeVar from typing import Generator, Union, Generic, TypeVar
from .utils import TypeInfo from .utils import best_by_key, last, split_lines_into_bounded
P = TypeVar("P") P = TypeVar("P")
@ -16,13 +16,38 @@ class Success(Generic[P]):
class Failure: class Failure:
at_index: int at_index: int
expected: set[str] expected: set[str]
depth: int | None = None children: list["Failure"]
def __init__(self, at_index: int, expected: set[str], *children: "Failure") -> None:
self.at_index = at_index
self.expected = expected
for child in children:
for e in child.expected:
self.expected.add(e)
self.children = [*children]
def depth(self):
depths = [child.depth() + 1 for child in self.children]
return max(0, 0, *depths)
def deepest_descendency(self):
def generator(failure: Failure) -> Generator[Failure, None, None]:
deepest = best_by_key(failure.children, lambda f: f.depth())
if deepest is None:
return
yield deepest
yield from generator(deepest)
return [*generator(self)]
def parent(self):
return Failure(self.at_index, self.expected, self)
P = TypeVar("P") P = TypeVar("P")
Result = Union[Failure, Success[P]] Result = Union[Failure, Success[P]]
@dataclass
class ParseError(BaseException): class ParseError(BaseException):
failure: Failure failure: Failure
stream: str stream: str
@ -30,11 +55,36 @@ class ParseError(BaseException):
def __init__(self, failure: Failure, stream: str) -> None: def __init__(self, failure: Failure, stream: str) -> None:
self.failure = failure self.failure = failure
self.stream = stream self.stream = stream
message = self.format_error_secton()
fail_section = stream[failure.at_index:80]
message = f"""
Parsing failed at position {self.failure.at_index} of stream :
${fail_section}
Expected one of: {failure.expected}
""".strip()
super().__init__(message) super().__init__(message)
def format_error_secton(self):
actual_failure = last(self.failure.deepest_descendency()) or self.failure
failure_pos = actual_failure.at_index
lines = [*split_lines_into_bounded(self.stream)]
failure_line_index = next(i for (i, ((s, e), l)) in enumerate(lines) if s <= failure_pos and e >= failure_pos)
result = "\n"
if failure_line_index > 0:
index = failure_line_index - 1
result += format_error_section_line(index, lines[index][1])[1] + "\n"
(prefix_len, line) = format_error_section_line(failure_line_index, lines[failure_line_index][1])
result += line + "\n"
result += format_error_section_indicator(failure_pos, lines[failure_line_index][0][0], prefix_len, actual_failure.expected) + "\n"
if failure_line_index < (len(lines) - 1):
index = failure_line_index + 1
result += format_error_section_line(index, lines[index][1])[1]
return result
def format_error_section_line(index: int, line: str):
line_num_part = str(index).rjust(4)
prefix = f"{line_num_part} |"
result = f"{prefix}{line}"
return (len(prefix), result)
def format_error_section_indicator(failure_pos: int, failure_line_start_pos: int, formatted_line_prefix_len: int, expected: set[str]):
relative_pos = failure_pos - failure_line_start_pos
padding = formatted_line_prefix_len + relative_pos
expects = ", ".join(e for e in expected)
return f"{' ' * padding}^\\_ expected : {expects}"

View file

@ -15,10 +15,10 @@ class AndTransform(Transformer[tuple[L, R]]):
def parse(self, stream: str, at_index: int) -> Result[tuple[L, R]]: def parse(self, stream: str, at_index: int) -> Result[tuple[L, R]]:
result_left = self.left.parse(stream, at_index) result_left = self.left.parse(stream, at_index)
if isinstance(result_left, Failure): if isinstance(result_left, Failure):
return result_left return result_left.parent()
result_right = self.right.parse(stream, result_left.next_index) result_right = self.right.parse(stream, result_left.next_index)
if isinstance(result_right, Failure): if isinstance(result_right, Failure):
return result_right return result_right.parent()
return Success((result_left.value, result_right.value), result_right.next_index) return Success((result_left.value, result_right.value), result_right.next_index)
@ -42,4 +42,4 @@ def test_and_shifted(ctx):
def test_and_none(ctx): def test_and_none(ctx):
parser = AndTransform(JustTransform("ar"), JustTransform("bre")) parser = AndTransform(JustTransform("ar"), JustTransform("bre"))
input = "......" input = "......"
assert parser.parse(input, 2) == Failure(2, set(["ar"])) assert parser.parse(input, 2) == Failure(2, set(["ar"]), Failure(2, set(["ar"])))

View file

@ -19,7 +19,7 @@ class OrTransform(Transformer[Union[L, R]]):
result_right = self.right.parse(stream, at_index) result_right = self.right.parse(stream, at_index)
if isinstance(result_right, Success): if isinstance(result_right, Success):
return Success(result_right.value, result_right.next_index) return Success(result_right.value, result_right.next_index)
return Failure(at_index, result_left.expected.union(result_right.expected)) return Failure(at_index, set(), result_left, result_right)
from okipy.lib import test from okipy.lib import test
@ -48,4 +48,4 @@ def test_or_shifted(ctx):
def test_or_fail(ctx): def test_or_fail(ctx):
parser = OrTransform(JustTransform("ar"), JustTransform("bre")) parser = OrTransform(JustTransform("ar"), JustTransform("bre"))
input = "......" input = "......"
assert parser.parse(input, 2) == Failure(2, set(["ar", "bre"])) assert parser.parse(input, 2) == Failure(2, set(["ar", "bre"]), Failure(2, set(["ar"])), Failure(2, set(["bre"])))

View file

@ -1,5 +1,59 @@
from typing import Generic, TypeVar from typing import Callable, Collection, Generator, Generic, Iterable, TypeVar, Union
from okipy.lib import test
T = TypeVar("T") T = TypeVar("T")
class TypeInfo(Generic[T]): class TypeInfo(Generic[T]):
pass pass
T = TypeVar("T")
K = TypeVar("K", int, float)
def best_by_key(collection: Iterable[T], extract: Callable[[T], K], max_: bool = True):
best: Union[tuple[T,K], None] = None
for item in collection:
key = extract(item)
if best is None:
best = (item, key)
continue
should_replace = (max_ and key > best[1]) or (not max_ and key < best[1])
if should_replace:
best = (item, key)
return best[0] if best is not None else None
def split_lines_into_bounded(text: str) -> Generator[tuple[tuple[int, int], str], None, None]:
rest = text
acc = 0
while True:
parts = rest.split("\n", 1)
start = acc
acc += len(parts[0])
end = acc
yield ((start, end), parts[0])
acc += 1
if len(parts) == 1:
break
rest = parts[1]
@test()
def test_split_lines_bounded(ctx):
input = "\nhello\n\nworld\n"
splitted = [*split_lines_into_bounded(input)]
assert splitted == [
((0, 0), ''),
((1, 6), 'hello'),
((7, 7), ''),
((8, 13), 'world'),
((14, 14), '')
]
T = TypeVar("T")
def last(collection: Collection[T]):
result = None
for item in collection:
result = item
return result