move DEFINE_VAR min/max from src to arg (#6388)

new arg is (Variable, min as CONST, max as CONST)
This commit is contained in:
chenyu
2024-09-06 05:01:02 -04:00
committed by GitHub
parent f1bd2a5519
commit 9a9fea7b8c
7 changed files with 42 additions and 40 deletions

View File

@@ -131,17 +131,17 @@ class TestGraphRewrite(unittest.TestCase):
self.assertEqual(nout.src[1].arg, 3.0)
def test_consts_go_last(self):
a = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('a', 0, 1))
b = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('b', 0, 1))
c = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('c', 0, 1))
d = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('d', 0, 1))
a = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('a', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
b = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('b', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
c = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('c', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
d = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('d', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const_like(2), a), arg=BinaryOps.ADD), (4+d)+c+(2+a)+b]
for out in outs:
sink = graph_rewrite(out, constant_folder)
print(sink)
self.assertEqual(sink.op, UOps.ALU)
self.assertEqual(sink.src[1].op, UOps.CONST)
self.assertEqual(len([x for x in sink.sparents if x.op is UOps.CONST]), 3)
self.assertEqual(len([x for x in sink.sparents if x.op is UOps.CONST]), 1)
class TestUOpGraph(unittest.TestCase):
def test_add_constant_fold(self):
@@ -155,7 +155,7 @@ class TestUOpGraph(unittest.TestCase):
self.assertEqual(out.arg, 3.0)
def test_where_same_fold(self):
v = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp(UOps.CONST, dtypes.int, (), 0), UOp(UOps.CONST, dtypes.int, (), 1)), arg=Variable('tmp', 0, 1))
v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('tmp', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
c0 = UOp(UOps.CONST, dtypes.int, arg=0)
vc = UOp(UOps.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE)
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
@@ -249,7 +249,7 @@ class TestUOpGraph(unittest.TestCase):
for i in [2, 4, 8]:
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0),))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = to_uops_list([wmma])
assert_equiv_uops(uops[0], acc)
@@ -258,7 +258,7 @@ class TestUOpGraph(unittest.TestCase):
for i in [2, 4, 8]:
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0),))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = to_uops_list([wmma])
assert_equiv_uops(uops[0], acc)
@@ -268,19 +268,19 @@ class TestUOpGraph(unittest.TestCase):
for i in [4, 8]:
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) +
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=Variable(f'tmp{j}', 0.0, 1.0)) for j in range(i//2)))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=(Variable(f'tmp{j}', 0.0, 1.0),)) for j in range(i//2)))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0),))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0),))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = to_uops_list([wmma])
assert_equiv_uops(uops[-1], wmma)
for i in [4, 8]:
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0),))
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 0.0) for _ in range(i//2)) +
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=Variable(f'tmp{j}', 0.0, 1.0)) for j in range(i//2)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
tuple(UOp(UOps.DEFINE_VAR, dtypes.half, arg=(Variable(f'tmp{j}', 0.0, 1.0),)) for j in range(i//2)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0),))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = to_uops_list([wmma])
assert_equiv_uops(uops[-1], wmma)
@@ -288,17 +288,17 @@ class TestUOpGraph(unittest.TestCase):
for i in [2, 4, 8]:
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0),))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0),))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
uops = to_uops_list([wmma])
assert_equiv_uops(uops[-1], wmma)
for i in [2, 4, 8]:
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable(f'tmp{i}', 0.0, 1.0))
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable(f'tmp{i}', 0.0, 1.0),))
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i),
tuple(UOp.const(dtypes.half, 1.0 if j == 0 else 0.0) for j in range(i)))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=Variable('acc', 0.0, 1.0))
acc = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i), arg=(Variable('acc', 0.0, 1.0),))
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
uops = to_uops_list([wmma])
assert_equiv_uops(uops[-1], wmma)
@@ -324,13 +324,13 @@ class TestUOpGraph(unittest.TestCase):
self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 1)
def test_depth_2_const_fold(self):
v = UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)), arg=Variable('tmp', 0, 1))
v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=(Variable('tmp', 0, 1), UOp.const(dtypes.int, 0), UOp.const(dtypes.int, 1)))
c2 = UOp(UOps.CONST, dtypes.int, arg=2)
c4 = UOp(UOps.CONST, dtypes.int, arg=4)
vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD)
out = UOp(UOps.ALU, dtypes.int, (vc, c4), BinaryOps.ADD)
uops = to_uops_list([out])
self.assertEqual(len(uops), 5)
self.assertEqual(len(uops), 3)
out = uops[-1]
self.assertEqual(out.op, UOps.ALU)
self.assertEqual(out.arg, BinaryOps.ADD)
@@ -575,8 +575,8 @@ class TestLoadStoreFolder(unittest.TestCase):
def test_simple_load_dont_fold_different_gated(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g1")
gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g2")
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g1", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True)))
gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g2", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True)))
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
sink = UOp(UOps.EXPAND, dtypes.float, tuple(load), ((0,4),))
sink = float4_rewrite(sink)
@@ -591,7 +591,7 @@ class TestLoadStoreFolder(unittest.TestCase):
def test_simple_store_fold_gate(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g1")
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g1", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True)))
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
sink = UOp(UOps.SINK, None, tuple(load))
sink = float4_rewrite(sink)
@@ -602,8 +602,8 @@ class TestLoadStoreFolder(unittest.TestCase):
def test_simple_store_dont_fold(self):
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g1")
gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg="g2")
gate = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g1", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True)))
gate2 = UOp(UOps.DEFINE_VAR, dtypes.bool, arg=("g2", UOp.const(dtypes.bool, False), UOp.const(dtypes.bool, True)))
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
sink = UOp(UOps.SINK, None, tuple(load))
sink = float4_rewrite(sink)

View File

@@ -30,7 +30,7 @@ def Variable(expr, nmin, nmax):
# TODO: fix DEFINE_VAR to not need this
class TempVar:
def __init__(self, x): self.expr = x
return UOp(UOps.DEFINE_VAR, dtypes.int, (UOp.const(dtypes.int, nmin), UOp.const(dtypes.int, nmax)), TempVar(expr))
return UOp(UOps.DEFINE_VAR, dtypes.int, arg=(TempVar(expr), UOp.const(dtypes.int, nmin), UOp.const(dtypes.int, nmax)))
class Node:
@staticmethod
def sum(ops): return functools.reduce(lambda x,y: x+y, ops)