diff --git a/src/okipy/lib.py b/src/okipy/lib.py index c8a74fd..6662647 100644 --- a/src/okipy/lib.py +++ b/src/okipy/lib.py @@ -6,6 +6,7 @@ from traceback import print_exception from .utils import CaptureOutput from .colors import red, green, bold, yellow +from .strategies import Sequential, Parallel @dataclass @@ -132,31 +133,33 @@ class Suite: return procedure return decorate - def run(self, filters: list[str] = []): + def run(self, filters: list[str] = [], parallel = False): """ Runs the test suite. Optionnally filter to run by their name : Only keep tests which names contains all filters. """ - if self.name is not None: - print(" ", green("Suite"), bold(self.name)) to_run = [*self.filter_tests(filters)] + strategy = Parallel() if parallel else Sequential() + if self.name is not None: print(" ", green("Suite"), bold(self.name)) print(yellow('Running'), bold(str(len(to_run))), yellow('/'), bold(str(len(self.tests))), yellow('tests')) print() - failures = list[TestFailure]() - for test in to_run: - failure = test.run() - if failure is None: - print(f" {green('Ok')} {bold(test.name)}") - else: - print(f" {red('Err')} {bold(test.name)}") - failures.append(failure) + failures = [fail for fail in strategy.run_all(to_run, Suite.run_one) if fail is not None] print() print("", yellow('Failed'), bold(str(len(failures))), yellow('/'), bold(str(len(to_run))), yellow('tests')) - for (index, failure) in enumerate(failures): - failure.print(index) + for (index, failure) in enumerate(failures): failure.print(index) return failures + @staticmethod + def run_one(test: Test): + failure = test.run() + if failure is None: + print(f" {green('Ok')} {bold(test.name)}") + else: + print(f" {red('Err')} {bold(test.name)}") + return failure + + def filter_tests(self, filters: list[str]): for test in self.tests: oki = True diff --git a/src/okipy/strategies.py b/src/okipy/strategies.py new file mode 100644 index 0000000..796389c --- /dev/null +++ b/src/okipy/strategies.py @@ -0,0 +1,25 @@ +from dataclasses import dataclass +from typing import Callable, TypeVar +from multiprocessing import Pool + + +T = TypeVar("T") +O = TypeVar("O") +class RunStrategy: + """ + Some strategy to run an operation on a collection of items and return a collection of results from those operations. + """ + def run_all(self, items: list[T], oper: Callable[[T], O]) -> list[O]: + raise NotImplementedError("Abstract method must be overwriten") + + +@dataclass +class Parallel(RunStrategy): + procs: None | int = None + def run_all(self, items: list[T], oper: Callable[[T], O]) -> list[O]: + return Pool(self.procs).map(oper, items) + + +class Sequential(RunStrategy): + def run_all(self, items: list[T], oper: Callable[[T], O]) -> list[O]: + return [oper(item) for item in items]