mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
move DEFINE_VAR min/max from src to arg (#6388)
new arg is (Variable, min as CONST, max as CONST)
This commit is contained in:
@@ -131,17 +131,17 @@ class TestGraphRewrite(unittest.TestCase):
|
||||
self.assertEqual(nout.src[1].arg, 3.0)
|
||||
|
||||
def test_consts_go_last(self):
|
||||
a = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('a', 0, 1))
|
||||
b = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('b', 0, 1))
|
||||
c = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('c', 0, 1))
|
||||
d = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('d', 0, 1))
|
||||
a = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('a', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
|
||||
b = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('b', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
|
||||
c = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('c', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
|
||||
d = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('d', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
|
||||
outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const_like(2), a), arg=BinaryOps.ADD), (4+d)+c+(2+a)+b]
|
||||
for out in outs:
|
||||
sink = graph_rewrite(out, constant_folder)
|
||||
print(sink)
|
||||
self.assertEqual(sink.op, UOps.ALU)
|
||||
self.assertEqual(sink.src[1].op, UOps.CONST)
|
||||
self.assertEqual(len([x for x in sink.sparents if x.op is UOps.CONST]), 3)
|
||||
self.assertEqual(len([x for x in sink.sparents if x.op is UOps.CONST]), 1)
|
||||
|
||||
class TestUOpGraph(unittest.TestCase):
|
||||
def test_add_constant_fold(self):
|
||||
@@ -155,7 +155,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
self.assertEqual(out.arg, 3.0)
|
||||
|
||||
def test_where_same_fold(self):
|
||||
v = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp(UOps.CONST, dtypes.int, (), 0), UOp(UOps.CONST, dtypes.int, (), 1)), arg=Variable('tmp', 0, 1))
|
||||
v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('tmp', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
|
||||
c0 = UOp(UOps.CONST, dtypes.int, arg=0)
|
||||
vc = UOp(UOps.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE)
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
@@ -249,7 +249,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
for i in [2, 4, 8]:
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0),))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
assert_equiv_uops(uops[0], acc)
|
||||
@@ -258,7 +258,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
for i in [2, 4, 8]:
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0),))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
assert_equiv_uops(uops[0], acc)
|
||||
@@ -268,19 +268,19 @@ class TestUOpGraph(unittest.TestCase):
|
||||
for i in [4, 8]:
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
|
||||
tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) +
|
||||
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=Variable(f'tmp{j}', 0.0, 1.0)) for j in range(i//2)))
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=(Variable(f'tmp{j}', 0.0, 1.0),)) for j in range(i//2)))
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0),))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0),))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
assert_equiv_uops(uops[-1], wmma)
|
||||
|
||||
for i in [4, 8]:
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0),))
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
|
||||
tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) +
|
||||
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=Variable(f'tmp{j}', 0.0, 1.0)) for j in range(i//2)))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=(Variable(f'tmp{j}', 0.0, 1.0),)) for j in range(i//2)))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0),))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
assert_equiv_uops(uops[-1], wmma)
|
||||
@@ -288,17 +288,17 @@ class TestUOpGraph(unittest.TestCase):
|
||||
for i in [2, 4, 8]:
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
|
||||
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0),))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0),))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
assert_equiv_uops(uops[-1], wmma)
|
||||
|
||||
for i in [2, 4, 8]:
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0),))
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
|
||||
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
|
||||
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0),))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
assert_equiv_uops(uops[-1], wmma)
|
||||
@@ -324,13 +324,13 @@ class TestUOpGraph(unittest.TestCase):
|
||||
self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 1)
|
||||
|
||||
def test_depth_2_const_fold(self):
|
||||
v = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('tmp', 0, 1))
|
||||
v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('tmp', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
|
||||
c2 = UOp(UOps.CONST, dtypes.int, arg=2)
|
||||
c4 = UOp(UOps.CONST, dtypes.int, arg=4)
|
||||
vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD)
|
||||
out = UOp(UOps.ALU, dtypes.int, (vc, c4), BinaryOps.ADD)
|
||||
uops = to_uops_list([out])
|
||||
self.assertEqual(len(uops), 5)
|
||||
self.assertEqual(len(uops), 3)
|
||||
out = uops[-1]
|
||||
self.assertEqual(out.op, UOps.ALU)
|
||||
self.assertEqual(out.arg, BinaryOps.ADD)
|
||||
@@ -575,8 +575,8 @@ class TestLoadStoreFolder(unittest.TestCase):
|
||||
|
||||
def test_simple_load_dont_fold_different_gated(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
|
||||
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g1")
|
||||
gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g2")
|
||||
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g1", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True)))
|
||||
gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g2", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True)))
|
||||
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
|
||||
sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,4),))
|
||||
sink = float4_rewrite(sink)
|
||||
@@ -591,7 +591,7 @@ class TestLoadStoreFolder(unittest.TestCase):
|
||||
|
||||
def test_simple_store_fold_gate(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
|
||||
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g1")
|
||||
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g1", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True)))
|
||||
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
|
||||
sink = UOp(UOps.SINK, None, tuple(load))
|
||||
sink = float4_rewrite(sink)
|
||||
@@ -602,8 +602,8 @@ class TestLoadStoreFolder(unittest.TestCase):
|
||||
|
||||
def test_simple_store_dont_fold(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
|
||||
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g1")
|
||||
gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g2")
|
||||
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g1", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True)))
|
||||
gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g2", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True)))
|
||||
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
|
||||
sink = UOp(UOps.SINK, None, tuple(load))
|
||||
sink = float4_rewrite(sink)
|
||||
|
||||
@@ -30,7 +30,7 @@ def Variable(expr, nmin, nmax):
|
||||
# TODO: fix DEFINE_VAR to not need this
|
||||
class TempVar:
|
||||
def __init__(self, x): self.expr = x
|
||||
return UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, nmin), UOp.const(dtypes.int, nmax)), TempVar(expr))
|
||||
return UOp(UOps.DEFINE_VAR, dtypes.int, arg=(TempVar(expr), UOp.const(dtypes.int, nmin), UOp.const(dtypes.int, nmax)))
|
||||
class Node:
|
||||
@staticmethod
|
||||
def sum(ops): return functools.reduce(lambda x,y: x+y, ops)
|
||||
|
||||
Reference in New Issue
Block a user