mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
make uops.add private (#4950)
* make uops.add private * modernize all tests
This commit is contained in:
@@ -15,8 +15,8 @@ def uops_to_rdna(function_name:str, uops:UOpGraph) -> str:
|
||||
u.vin = tuple(n if x == o else x for x in u.vin)
|
||||
# pointer indexing
|
||||
if u.uop in {UOps.LOAD, UOps.STORE} and u.vin[0].dtype.itemsize > 1:
|
||||
val = uops.add(UOps.CONST, dtypes.int, tuple(), arg=u.vin[0].dtype.itemsize, insert_before=uops.uops.index(u))
|
||||
ptr = uops.add(UOps.ALU, dtypes.int, (u.vin[1], val), arg=BinaryOps.MUL, insert_before=uops.uops.index(u))
|
||||
val = UOp(UOps.CONST, dtypes.int, tuple(), arg=u.vin[0].dtype.itemsize, insert_before=uops.uops.index(u))
|
||||
ptr = UOp(UOps.ALU, dtypes.int, (u.vin[1], val), arg=BinaryOps.MUL, insert_before=uops.uops.index(u))
|
||||
u.vin = (u.vin[0], ptr) + u.vin[2:]
|
||||
#uops.print()
|
||||
|
||||
|
||||
@@ -141,7 +141,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
@unittest.skip("no longer supported")
|
||||
def test_rewrite_graph_folds(self):
|
||||
uops = UOpGraph()
|
||||
uops.add(UOps.CONST, dtypes.float, arg=2.0, simplify=False)
|
||||
UOp(UOps.CONST, dtypes.float, arg=2.0, simplify=False)
|
||||
matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.float),
|
||||
lambda x: UOp(UOps.CAST, dtypes.int, (UOp(UOps.ALU, x.dtype, (x, x), BinaryOps.ADD),)))])
|
||||
matcher.rewrite_graph(uops)
|
||||
@@ -153,7 +153,7 @@ class TestPatternMatcher(unittest.TestCase):
|
||||
@unittest.skip("no longer supported")
|
||||
def test_rewrite_graph_adds(self):
|
||||
uops = UOpGraph()
|
||||
uops.add(UOps.CONST, dtypes.int, arg=2, simplify=False)
|
||||
UOp(UOps.CONST, dtypes.int, arg=2, simplify=False)
|
||||
matcher = PatternMatcher([(UPat(UOps.CONST, name="x", dtype=dtypes.int),
|
||||
lambda x: UOp(UOps.STORE, x.dtype, (UOp(UOps.DEFINE_GLOBAL, x.dtype, tuple(), None), x)))])
|
||||
matcher.rewrite_graph(uops)
|
||||
|
||||
@@ -2,75 +2,69 @@ import unittest
|
||||
from tinygrad import dtypes, Variable
|
||||
from tinygrad.dtype import PtrDType
|
||||
from tinygrad.ops import BinaryOps, TernaryOps, UnaryOps
|
||||
from tinygrad.codegen.uops import UOpGraph, UOps
|
||||
from tinygrad.codegen.uops import UOpGraph, UOps, UOp
|
||||
|
||||
class TestUOpGraph(unittest.TestCase):
|
||||
def test_add_constant_fold(self):
|
||||
g = UOpGraph()
|
||||
c1 = g.add(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = g.add(UOps.CONST, dtypes.float, arg=2.0)
|
||||
out = g.add(UOps.ALU, dtypes.float, (c1, c2), BinaryOps.ADD)
|
||||
g.add(UOps.SINK, None, (out,))
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
|
||||
out = UOp(UOps.ALU, dtypes.float, (c1, c2), BinaryOps.ADD)
|
||||
g = UOpGraph([UOp(UOps.SINK, None, (out,))])
|
||||
self.assertEqual(len(g.uops), 1)
|
||||
out = g.uops[-1]
|
||||
self.assertEqual(out.uop, UOps.CONST)
|
||||
self.assertEqual(out.arg, 3.0)
|
||||
|
||||
def test_where_same_fold(self):
|
||||
g = UOpGraph()
|
||||
v = g.add(UOps.DEFINE_VAR, dtypes.int, arg=Variable('tmp', 0, 1))
|
||||
c0 = g.add(UOps.CONST, dtypes.int, arg=0)
|
||||
vc = g.add(UOps.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE)
|
||||
c1 = g.add(UOps.CONST, dtypes.float, arg=1.0)
|
||||
out = g.add(UOps.ALU, dtypes.float, (vc, c1, c1), TernaryOps.WHERE)
|
||||
g.add(UOps.SINK, None, (out,))
|
||||
v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('tmp', 0, 1))
|
||||
c0 = UOp(UOps.CONST, dtypes.int, arg=0)
|
||||
vc = UOp(UOps.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE)
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
out = UOp(UOps.ALU, dtypes.float, (vc, c1, c1), TernaryOps.WHERE)
|
||||
g = UOpGraph([UOp(UOps.SINK, None, (out,))])
|
||||
self.assertEqual(len(g.uops), 1)
|
||||
out = g.uops[-1]
|
||||
self.assertEqual(out.uop, UOps.CONST)
|
||||
self.assertEqual(out.arg, 1.0)
|
||||
|
||||
def test_where_const_fold(self):
|
||||
g = UOpGraph()
|
||||
bf = g.add(UOps.CONST, dtypes.bool, arg=False)
|
||||
c1 = g.add(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = g.add(UOps.CONST, dtypes.float, arg=2.0)
|
||||
out = g.add(UOps.ALU, dtypes.float, (bf, c1, c2), TernaryOps.WHERE)
|
||||
g.add(UOps.SINK, None, (out,))
|
||||
bf = UOp(UOps.CONST, dtypes.bool, arg=False)
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
c2 = UOp(UOps.CONST, dtypes.float, arg=2.0)
|
||||
out = UOp(UOps.ALU, dtypes.float, (bf, c1, c2), TernaryOps.WHERE)
|
||||
g = UOpGraph([UOp(UOps.SINK, None, (out,))])
|
||||
self.assertEqual(len(g.uops), 1)
|
||||
out = g.uops[-1]
|
||||
self.assertEqual(out.uop, UOps.CONST)
|
||||
self.assertEqual(out.arg, 2.0)
|
||||
|
||||
def test_const_cast(self):
|
||||
g = UOpGraph()
|
||||
bf = g.add(UOps.CONST, dtypes.bool, arg=False)
|
||||
out = g.add(UOps.CAST, dtypes.int, (bf,))
|
||||
g.add(UOps.SINK, None, (out,))
|
||||
bf = UOp(UOps.CONST, dtypes.bool, arg=False)
|
||||
out = UOp(UOps.CAST, dtypes.int, (bf,))
|
||||
g = UOpGraph([UOp(UOps.SINK, None, (out,))])
|
||||
self.assertEqual(len(g.uops), 1)
|
||||
out = g.uops[-1]
|
||||
self.assertEqual(out.uop, UOps.CONST)
|
||||
self.assertEqual(out.arg, 0)
|
||||
|
||||
def test_cast_vectorized_fold(self):
|
||||
g = UOpGraph()
|
||||
d0 = g.add(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=(0, True))
|
||||
idx = g.add(UOps.CONST, dtypes.int, arg=0)
|
||||
ld = g.add(UOps.LOAD, dtypes.float.vec(2), (d0, idx))
|
||||
cast = g.add(UOps.CAST, dtypes.float.vec(2), (ld,))
|
||||
x = g.add(UOps.GEP, dtypes.float, (cast, ), arg=0)
|
||||
alu = g.add(UOps.ALU, dtypes.float, (x, ), UnaryOps.SQRT)
|
||||
out = g.add(UOps.STORE, dtypes.float, (d0, idx, alu))
|
||||
g.add(UOps.SINK, None, (out,))
|
||||
d0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), arg=(0, True))
|
||||
idx = UOp(UOps.CONST, dtypes.int, arg=0)
|
||||
ld = UOp(UOps.LOAD, dtypes.float.vec(2), (d0, idx))
|
||||
cast = UOp(UOps.CAST, dtypes.float.vec(2), (ld,))
|
||||
x = UOp(UOps.GEP, dtypes.float, (cast, ), arg=0)
|
||||
alu = UOp(UOps.ALU, dtypes.float, (x, ), UnaryOps.SQRT)
|
||||
out = UOp(UOps.STORE, dtypes.float, (d0, idx, alu))
|
||||
g = UOpGraph([UOp(UOps.SINK, None, (out,))])
|
||||
self.assertEqual(len([x for x in g.uops if x.uop is UOps.CAST]), 0)
|
||||
|
||||
def test_depth_2_const_fold(self):
|
||||
g = UOpGraph()
|
||||
v = g.add(UOps.DEFINE_VAR, dtypes.int, arg=Variable('tmp', 0, 1))
|
||||
c2 = g.add(UOps.CONST, dtypes.int, arg=2)
|
||||
c4 = g.add(UOps.CONST, dtypes.int, arg=4)
|
||||
vc = g.add(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD)
|
||||
out = g.add(UOps.ALU, dtypes.int, (vc, c4), BinaryOps.ADD)
|
||||
g.add(UOps.SINK, None, (out,))
|
||||
v = UOp(UOps.DEFINE_VAR, dtypes.int, arg=Variable('tmp', 0, 1))
|
||||
c2 = UOp(UOps.CONST, dtypes.int, arg=2)
|
||||
c4 = UOp(UOps.CONST, dtypes.int, arg=4)
|
||||
vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD)
|
||||
out = UOp(UOps.ALU, dtypes.int, (vc, c4), BinaryOps.ADD)
|
||||
g = UOpGraph([UOp(UOps.SINK, None, (out,))])
|
||||
self.assertEqual(len(g.uops), 3)
|
||||
out = g.uops[-1]
|
||||
self.assertEqual(out.uop, UOps.ALU)
|
||||
|
||||
@@ -14,8 +14,7 @@ from tinygrad.codegen.uops import UOpGraph
|
||||
from test.helpers import is_dtype_supported
|
||||
|
||||
def _uops_to_prg(uops_list, print=False):
|
||||
uops = UOpGraph()
|
||||
for l in uops_list: uops.add(l.uop, l.dtype, l.vin, l.arg)
|
||||
uops = UOpGraph(uops_list)
|
||||
src = Device[Device.DEFAULT].renderer.render("test", uops)
|
||||
if print: uops.print()
|
||||
has_local = Device[Device.DEFAULT].renderer.has_local
|
||||
@@ -32,10 +31,10 @@ def _test_single_value(vals, op, dts):
|
||||
buf_loads = [uop(uops, UOps.DEFINE_GLOBAL, PtrDType(dtype), (), (i+1, False)) for i,dtype in enumerate(dts)]
|
||||
loads = (uop(uops, UOps.LOAD, dtype, [buf_loads[i], uop(uops, UOps.CONST, dtypes.int32, (), 0)]) for i,dtype in enumerate(dts))
|
||||
alu = uop(uops, UOps.ALU, output_dtype, loads, op)
|
||||
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
|
||||
out = uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
|
||||
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
|
||||
buf2 = [Buffer(Device.DEFAULT, 1, dtype).allocate().copyin(np.array([a], dtype=dtype.np).data) for a,dtype in zip(vals, dts)]
|
||||
prg = _uops_to_prg(uops)
|
||||
prg = _uops_to_prg([out])
|
||||
prg.exec([buf]+buf2)
|
||||
ret = np.empty(1, output_dtype.np)
|
||||
buf.copyout(ret.data)
|
||||
@@ -47,9 +46,9 @@ def _test_single_value_const(vals, op, dts):
|
||||
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, True))
|
||||
loads = (uop(uops, UOps.CONST, dtype, [], a) for a,dtype in zip(vals, dts))
|
||||
alu = uop(uops, UOps.ALU, output_dtype, loads, op)
|
||||
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
|
||||
out = uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), alu))
|
||||
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
|
||||
prg = _uops_to_prg(uops)
|
||||
prg = _uops_to_prg([out])
|
||||
prg.exec([buf])
|
||||
ret = np.empty(1, output_dtype.np)
|
||||
buf.copyout(ret.data)
|
||||
@@ -59,9 +58,9 @@ def _test_uops_result(output_dtype, uops, res):
|
||||
# uops = []
|
||||
buf_store = uop(uops, UOps.DEFINE_GLOBAL, PtrDType(output_dtype), (), (0, True))
|
||||
# res = output_fn(uops)
|
||||
uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), res))
|
||||
out = uop(uops, UOps.STORE, None, (buf_store, uop(uops, UOps.CONST, dtypes.int32, (), 0), res))
|
||||
buf = Buffer(Device.DEFAULT, 1, output_dtype).allocate()
|
||||
prg = _uops_to_prg(uops, print=True)
|
||||
prg = _uops_to_prg([out], print=True)
|
||||
prg.exec([buf])
|
||||
ret = np.empty(1, output_dtype.np)
|
||||
buf.copyout(ret.data)
|
||||
@@ -235,14 +234,13 @@ class TestGatedStoreRewrite(unittest.TestCase):
|
||||
@unittest.skip("not yet implemented")
|
||||
def test_wrap_store_parents(self):
|
||||
# wraps all store parents in the valid branch
|
||||
uops = UOpGraph()
|
||||
gmem = uops.add(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (0, True))
|
||||
gidx0 = uops.add(UOps.SPECIAL, dtypes.int, (), (0, 'gidx0', 4))
|
||||
gmem = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (0, True))
|
||||
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), (0, 'gidx0', 4))
|
||||
idx = gidx0 * UOp.const(dtypes.int, 2)
|
||||
value = uops.add(UOps.CONST, dtypes.float, (), 42.0)
|
||||
value = UOp(UOps.CONST, dtypes.float, (), 42.0)
|
||||
|
||||
gate = uops.add(UOps.ALU, dtypes.bool, (gidx0, UOp.const(dtypes.int, 1)), arg=BinaryOps.CMPLT)
|
||||
uops.add(UOps.STORE, None, (gmem, idx, value, gate))
|
||||
gate = UOp(UOps.ALU, dtypes.bool, (gidx0, UOp.const(dtypes.int, 1)), arg=BinaryOps.CMPLT)
|
||||
uops = UOpGraph([UOp(UOps.STORE, None, (gmem, idx, value, gate))])
|
||||
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
|
||||
if_uop = next(u for u in uops if u.uop is UOps.IF)
|
||||
endif = next(u for u in uops if u.uop is UOps.ENDIF)
|
||||
@@ -253,17 +251,17 @@ class TestGatedStoreRewrite(unittest.TestCase):
|
||||
@unittest.skip("not yet implemented")
|
||||
def test_wrap_some_parents(self):
|
||||
# some parents are used outside the branch
|
||||
uops = UOpGraph()
|
||||
gmem0 = uops.add(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (0, True))
|
||||
gmem1 = uops.add(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (1, True))
|
||||
gidx0 = uops.add(UOps.SPECIAL, dtypes.int, (), (0, 'gidx0', 4))
|
||||
gmem0 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (0, True))
|
||||
gmem1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), (1, True))
|
||||
gidx0 = UOp(UOps.SPECIAL, dtypes.int, (), (0, 'gidx0', 4))
|
||||
idx = gidx0 * UOp.const(dtypes.int, 2)
|
||||
value0 = uops.add(UOps.CONST, dtypes.float, (), 42.0)
|
||||
value1 = uops.add(UOps.CONST, dtypes.float, (), 43.0)
|
||||
value0 = UOp(UOps.CONST, dtypes.float, (), 42.0)
|
||||
value1 = UOp(UOps.CONST, dtypes.float, (), 43.0)
|
||||
|
||||
gate = uops.add(UOps.ALU, dtypes.bool, (gidx0, UOp.const(dtypes.int, 1)), arg=BinaryOps.CMPLT)
|
||||
uops.add(UOps.STORE, None, (gmem0, idx, value0, gate))
|
||||
uops.add(UOps.STORE, None, (gmem1, idx, value1))
|
||||
gate = UOp(UOps.ALU, dtypes.bool, (gidx0, UOp.const(dtypes.int, 1)), arg=BinaryOps.CMPLT)
|
||||
outs = [UOp(UOps.STORE, None, (gmem0, idx, value0, gate))]
|
||||
outs.append(UOp(UOps.STORE, None, (gmem1, idx, value1)))
|
||||
uops = UOpGraph(outs)
|
||||
if DEBUG >= 4: print(Device[Device.DEFAULT].renderer.render("test", uops))
|
||||
if_uop = next(u for u in uops if u.uop is UOps.IF)
|
||||
endif = next(u for u in uops if u.uop is UOps.ENDIF)
|
||||
@@ -297,27 +295,25 @@ class TestLocalAccess(unittest.TestCase):
|
||||
@unittest.skipUnless(Device.DEFAULT in {"CUDA"} and getenv("PTX"), "This only tests assembly backends")
|
||||
class TestAssembly(unittest.TestCase):
|
||||
def test_bitshift_left(self):
|
||||
uops = UOpGraph()
|
||||
g1 = uops.add(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), (0, True))
|
||||
c1 = uops.add(UOps.CONST, dtypes.int, (), 2)
|
||||
c2 = uops.add(UOps.CONST, dtypes.int, (), 3)
|
||||
l1 = uops.add(UOps.LOAD, dtypes.int, (g1, c1))
|
||||
a1 = uops.add(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.MUL)
|
||||
a2 = uops.add(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.MUL)
|
||||
uops.add(UOps.SINK, None, (a1,a2))
|
||||
g1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), (0, True))
|
||||
c1 = UOp(UOps.CONST, dtypes.int, (), 2)
|
||||
c2 = UOp(UOps.CONST, dtypes.int, (), 3)
|
||||
l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1))
|
||||
a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.MUL)
|
||||
a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.MUL)
|
||||
uops = UOpGraph([UOp(UOps.SINK, None, (a1,a2))])
|
||||
Device[Device.DEFAULT].renderer.render("test", uops)
|
||||
self.assertEqual(uops.uops[-1].arg, BinaryOps.MUL)
|
||||
self.assertEqual(uops.uops[-2].arg, BinaryOps.SHL)
|
||||
|
||||
def test_bitshift_right(self):
|
||||
uops = UOpGraph()
|
||||
g1 = uops.add(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), (0, True))
|
||||
c1 = uops.add(UOps.CONST, dtypes.int, (), 2)
|
||||
c2 = uops.add(UOps.CONST, dtypes.int, (), 3)
|
||||
l1 = uops.add(UOps.LOAD, dtypes.int, (g1, c1))
|
||||
a1 = uops.add(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.IDIV)
|
||||
a2 = uops.add(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.IDIV)
|
||||
uops.add(UOps.SINK, None, (a1,a2))
|
||||
g1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), (0, True))
|
||||
c1 = UOp(UOps.CONST, dtypes.int, (), 2)
|
||||
c2 = UOp(UOps.CONST, dtypes.int, (), 3)
|
||||
l1 = UOp(UOps.LOAD, dtypes.int, (g1, c1))
|
||||
a1 = UOp(UOps.ALU, dtypes.int, (l1, c1), BinaryOps.IDIV)
|
||||
a2 = UOp(UOps.ALU, dtypes.int, (l1, c2), BinaryOps.IDIV)
|
||||
uops = UOpGraph([UOp(UOps.SINK, None, (a1,a2))])
|
||||
Device[Device.DEFAULT].renderer.render("test", uops)
|
||||
self.assertEqual(uops.uops[-1].arg, BinaryOps.IDIV)
|
||||
self.assertEqual(uops.uops[-2].arg, BinaryOps.SHR)
|
||||
|
||||
@@ -2,7 +2,7 @@ import unittest
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import lower_schedule_item
|
||||
from tinygrad.codegen.uops import UOpGraph, UOps
|
||||
from tinygrad.codegen.uops import UOpGraph, UOps, UOp
|
||||
from tinygrad.ops import BinaryOps, TernaryOps
|
||||
from tinygrad.dtype import dtypes
|
||||
|
||||
@@ -55,26 +55,24 @@ class TestUOpsStats(unittest.TestCase):
|
||||
|
||||
#MULACC should have the same stats as MUL + ADD
|
||||
def test_mulacc(self):
|
||||
uops = UOpGraph()
|
||||
globl = uops.add(UOps.DEFINE_GLOBAL, dtypes.int, tuple())
|
||||
o1 = uops.add(UOps.CONST, dtypes.int, tuple(), 1)
|
||||
o2 = uops.add(UOps.CONST, dtypes.int, tuple(), 2)
|
||||
u1 = uops.add(UOps.LOAD, dtypes.int, (globl, o1))
|
||||
u2 = uops.add(UOps.LOAD, dtypes.int, (globl, o2))
|
||||
u3 = uops.add(UOps.CONST, dtypes.int, tuple(), 3)
|
||||
u4 = uops.add(UOps.ALU, dtypes.int, (u1,u2), BinaryOps.MUL)
|
||||
u5 = uops.add(UOps.ALU, dtypes.int, (u4,u3), BinaryOps.ADD)
|
||||
uops.add(UOps.SINK, None, (u5,))
|
||||
globl = UOp(UOps.DEFINE_GLOBAL, dtypes.int, tuple())
|
||||
o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1)
|
||||
o2 = UOp(UOps.CONST, dtypes.int, tuple(), 2)
|
||||
u1 = UOp(UOps.LOAD, dtypes.int, (globl, o1))
|
||||
u2 = UOp(UOps.LOAD, dtypes.int, (globl, o2))
|
||||
u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3)
|
||||
u4 = UOp(UOps.ALU, dtypes.int, (u1,u2), BinaryOps.MUL)
|
||||
u5 = UOp(UOps.ALU, dtypes.int, (u4,u3), BinaryOps.ADD)
|
||||
uops = UOpGraph([UOp(UOps.SINK, None, (u5,))])
|
||||
|
||||
uops_fma = UOpGraph()
|
||||
globl = uops_fma.add(UOps.DEFINE_GLOBAL, dtypes.int, tuple())
|
||||
o1 = uops_fma.add(UOps.CONST, dtypes.int, tuple(), 1)
|
||||
o2 = uops_fma.add(UOps.CONST, dtypes.int, tuple(), 2)
|
||||
u1 = uops_fma.add(UOps.LOAD, dtypes.int, (globl, o1))
|
||||
u2 = uops_fma.add(UOps.LOAD, dtypes.int, (globl, o2))
|
||||
u3 = uops_fma.add(UOps.CONST, dtypes.int, tuple(), 3)
|
||||
u4 = uops_fma.add(UOps.ALU, dtypes.int, (u1,u2,u3), TernaryOps.MULACC)
|
||||
uops_fma.add(UOps.SINK, None, (u4,))
|
||||
globl = UOp(UOps.DEFINE_GLOBAL, dtypes.int, tuple())
|
||||
o1 = UOp(UOps.CONST, dtypes.int, tuple(), 1)
|
||||
o2 = UOp(UOps.CONST, dtypes.int, tuple(), 2)
|
||||
u1 = UOp(UOps.LOAD, dtypes.int, (globl, o1))
|
||||
u2 = UOp(UOps.LOAD, dtypes.int, (globl, o2))
|
||||
u3 = UOp(UOps.CONST, dtypes.int, tuple(), 3)
|
||||
u4 = UOp(UOps.ALU, dtypes.int, (u1,u2,u3), TernaryOps.MULACC)
|
||||
uops_fma = UOpGraph([UOp(UOps.SINK, None, (u4,))])
|
||||
|
||||
self.assertEqual(uops.flops_mem(), uops_fma.flops_mem())
|
||||
|
||||
|
||||
@@ -254,12 +254,12 @@ class UOpGraph:
|
||||
def __init__(self, add_nodes:Optional[List[UOp]]=None):
|
||||
self.nodes: Dict[Tuple, UOp] = {}
|
||||
self._uops: Optional[List[UOp]] = None
|
||||
if add_nodes is not None: self.multiadd(add_nodes)
|
||||
if add_nodes is not None: self._multiadd(add_nodes)
|
||||
|
||||
def __iter__(self) -> Iterator[UOp]: return iter(self.uops)
|
||||
def __getitem__(self, index) -> UOp: return self.uops[index]
|
||||
|
||||
def multiadd(self, unprocessed_nodes:List[UOp]):
|
||||
def _multiadd(self, unprocessed_nodes:List[UOp]):
|
||||
# add nodes to graph in reverse BFS order
|
||||
# TODO: i feel like this is written in a few places, possible to library it?
|
||||
in_degree: DefaultDict[UOp, int] = defaultdict(int)
|
||||
@@ -278,7 +278,7 @@ class UOpGraph:
|
||||
while len(queue):
|
||||
n = queue.pop(0)
|
||||
if n in replace_nodes: continue
|
||||
replace_nodes[n] = self.add(n.uop, n.dtype, tuple(replace_nodes.get(x, x) for x in n.vin), n.arg)
|
||||
replace_nodes[n] = self._add(n.uop, n.dtype, tuple(replace_nodes.get(x, x) for x in n.vin), n.arg)
|
||||
for x in children[n]:
|
||||
in_degree[x] -= 1
|
||||
if in_degree[x] == 0:
|
||||
@@ -411,7 +411,7 @@ class UOpGraph:
|
||||
|
||||
if type_verify: self.type_verify()
|
||||
|
||||
def add(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None) -> UOp:
|
||||
def _add(self, uop:UOps, dtype:Optional[DType]=None, vin:Tuple[UOp, ...]=tuple(), arg:Any=None) -> UOp:
|
||||
if found:=self.nodes.get(key:=(uop, dtype, vin, arg)): return found
|
||||
self.nodes[key] = ret = UOp(*key)
|
||||
return ret
|
||||
|
||||
Reference in New Issue
Block a user