expr: add CompExpr

This commit is contained in:
Shiz 2021-06-24 04:14:31 +02:00
parent c8919903cf
commit 3055156ef5
1 changed files with 39 additions and 3 deletions

View File

@ -116,14 +116,14 @@ class Expr(G[T]):
return CallExpr(self, args, kwargs)
for x in ('lt', 'le', 'eq', 'ne', 'ge', 'gt'):
locals()['__' + x + '__'] = functools.partialmethod(lambda self, x, other: CompExpr(getattr(operator, x), self, other), x)
locals()['__' + x.strip('_') + '__'] = functools.partialmethod(lambda self, x, other: CompExpr(getattr(operator, x), self, other), x)
for x in ('not_', 'truth', 'abs', 'index', 'inv', 'neg', 'pos'):
locals()['__' + x + '__'] = functools.partialmethod(lambda self, x: UnaryExpr(getattr(operator, x), self), x)
locals()['__' + x.strip('_') + '__'] = functools.partialmethod(lambda self, x: UnaryExpr(getattr(operator, x), self), x)
for x in (
'add', 'and_', 'floordiv', 'lshift', 'mod', 'mul', 'matmul', 'or_', 'pow', 'rshift', 'sub', 'truediv', 'xor',
'concat', 'contains', 'delitem', 'getitem', 'delitem', 'getitem', 'setitem',
):
locals()['__' + x + '__'] = functools.partialmethod(lambda self, x, other: BinExpr(getattr(operator, x), self, other), x)
locals()['__' + x.strip('_') + '__'] = functools.partialmethod(lambda self, x, other: BinExpr(getattr(operator, x), self, other), x)
del x
class AttrExpr(G[T], Expr[T]):
@ -278,6 +278,42 @@ class BinExpr(G[T], Expr[T]):
def __repr__(self) -> str:
return f'({self.__left!r} {symbols[self.__op]} {self.__right!r})'
class CompExpr(Expr[bool]):
def __init__(self, op: Callable[[Expr, Expr], bool], left: Expr, right: Expr) -> None:
self.__op = op
self.__left = left
self.__right = right
def _sx_get_(self) -> bool:
return self.__op(get(self.__left), get(self.__right))
def _sx_peek_(self) -> bool:
return self.__op(peek(self.__left), peek(self.__right))
def _sx_put_(self, value: bool) -> None:
if not isinstance(self.__left, Expr):
operand = self.__left
target = self.__right
elif not isinstance(self.__right, Expr):
operand = self.__right
target = self.__left
else:
raise NotImplementedError(f'{self.__class__.__name__} has two expression operands and is not invertible')
if self.__op == operator.eq and value:
value = operand
elif self.__op == operator.ne and not value:
value = operand
else:
raise NotImplementedError(f'{self.__class__.__name__} {symbols[self.__op]!r} is not invertible')
put(target, value)
def __str__(self) -> str:
return f'({self.__left} {symbols[self.__op]} {self.__right})'
def __repr__(self) -> str:
return f'({self.__left!r} {symbols[self.__op]} {self.__right!r})'
def get(expr: Any) -> Any:
if isinstance(expr, Expr):