mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-12 07:35:16 -05:00
const fold left const operand for ADD and MUL (#4029)
* const fold left const operand for ADD and MUL * neg have dtype issue
This commit is contained in:
@@ -132,14 +132,16 @@ class LazyBuffer:
|
||||
if op is UnaryOps.NEG: assert srcs[0].dtype != dtypes.bool, "UnaryOps.NEG does not accept dtype bool"
|
||||
|
||||
# const folding
|
||||
if op is BinaryOps.ADD and in_srcs[0].is_unrealized_unpadded_const() and in_srcs[0].base.arg == 0: return self
|
||||
if op is BinaryOps.SUB and in_srcs[0].is_unrealized_unpadded_const() and in_srcs[0].base.arg == 0: return self
|
||||
if op is BinaryOps.MUL and in_srcs[0].is_unrealized_unpadded_const():
|
||||
if (val := in_srcs[0].base.arg) == 1: return self
|
||||
if val == -1: return self.e(UnaryOps.NEG)
|
||||
if val == 0: return self.const(0)
|
||||
if op is BinaryOps.DIV and dtypes.is_float(self.dtype) and in_srcs[0].is_unrealized_unpadded_const() and in_srcs[0].base.arg != 0:
|
||||
return self.e(BinaryOps.MUL, self.const(1 / in_srcs[0].base.arg))
|
||||
if op in BinaryOps: x, y = self, in_srcs[0]
|
||||
if op is BinaryOps.ADD:
|
||||
if y.is_unrealized_unpadded_const() and y.base.arg == 0: return x
|
||||
if x.is_unrealized_unpadded_const() and x.base.arg == 0: return y
|
||||
if op is BinaryOps.SUB and y.is_unrealized_unpadded_const() and y.base.arg == 0: return x
|
||||
if op is BinaryOps.MUL:
|
||||
if x.is_unrealized_unpadded_const() and (val := x.base.arg) in (1, 0): return {1: y, 0: y.const(0)}[val]
|
||||
if y.is_unrealized_unpadded_const() and (val := y.base.arg) in (1, 0): return {1: x, 0: x.const(0)}[val]
|
||||
if op is BinaryOps.DIV and dtypes.is_float(x.dtype) and y.is_unrealized_unpadded_const() and y.base.arg != 0:
|
||||
return x.e(BinaryOps.MUL, x.const(1 / y.base.arg))
|
||||
|
||||
out_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPEQ) else srcs[-1].dtype
|
||||
return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, arg, tuple(srcs))
|
||||
|
||||
Reference in New Issue
Block a user