mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 06:18:01 -05:00
rewrite all gates
This commit is contained in:
@@ -29,8 +29,9 @@ class TestLinearizerDumb(unittest.TestCase):
|
||||
k.required_optimizations()
|
||||
for opt in opts: k.apply_opt(opt)
|
||||
prg = k.to_program()
|
||||
prg.uops.print()
|
||||
#prg.uops.print()
|
||||
print(prg.src)
|
||||
#print("\n".join(prg.src.splitlines()[-4:]))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from __future__ import annotations
|
||||
from dataclasses import replace
|
||||
from typing import Iterator, Optional, Tuple, Dict, List, Set, Union, cast, TYPE_CHECKING
|
||||
import functools, itertools, heapq, math
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType
|
||||
@@ -473,12 +474,20 @@ class UOpGraph:
|
||||
if u.op is UOps.LOAD and u.src[-1].op is UOps.BARRIER:
|
||||
if_uop = UOp(UOps.IF, None, (gate, u.src[-1]))
|
||||
return UOp(u.op, u.dtype, u.src[:-1]+(if_uop,), u.arg)
|
||||
if (replace_source:=tuple(_replace_gates(x, gate) for x in u.src)) != u.src: return UOp(u.op, u.dtype, replace_source, u.arg)
|
||||
if (replace_source:=tuple(_replace_gates(x, gate) for x in u.src)) != u.src:
|
||||
if u.op is UOps.STORE and replace_source[3] is gate: replace_source = replace_source[:3]
|
||||
return UOp(u.op, u.dtype, replace_source, u.arg)
|
||||
if u.op is UOps.STORE:
|
||||
assert len(u.src) == 4
|
||||
assert gate is u.src[3]
|
||||
if_uop = UOp(UOps.IF, None, (gate,))
|
||||
replace_source = u.src[:3]+(if_uop,)
|
||||
return UOp(u.op, u.dtype, replace_source, u.arg)
|
||||
return u
|
||||
sink_srcs = list(self.sink.src)
|
||||
for i, s in enumerate(sink_srcs):
|
||||
if s.op is UOps.STORE and len(s.src) == 4 and (rw:=_replace_gates(s, s.src[3])) != s:
|
||||
sink_srcs[i] = UOp(rw.op, rw.dtype, rw.src[:3], rw.arg)
|
||||
sink_srcs[i] = UOp(rw.op, rw.dtype, rw.src, rw.arg)
|
||||
sink = UOp(UOps.SINK, None, tuple(sink_srcs))
|
||||
|
||||
# do graph rewrite
|
||||
|
||||
@@ -43,7 +43,9 @@ class UOp:
|
||||
return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg.expr) if self.op is not UOps.ALU else \
|
||||
self.arg.value, self.dtype, self.src)
|
||||
def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
|
||||
def __repr__(self): return pretty_print(self, lambda x: f"UOp({x.op}, {x.dtype}, arg={x.arg}, src=(%s))")
|
||||
#def __repr__(self): return pretty_print(self, lambda x: f"UOp({x.op}, {x.dtype}, arg={x.arg}, src=(%s))")
|
||||
def __repr__(self):
|
||||
return f"{str(self.op):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.op for x in self.src]):32s} {self.arg}"
|
||||
# *** uop syntactic sugar
|
||||
def ufix(self, x): return UOp.const(self.dtype, x) if not isinstance(x, UOp) else x
|
||||
def cast(self, dtype=None): return UOp(UOps.CAST, dtype, (self,))
|
||||
@@ -177,7 +179,7 @@ def type_verify(uops):
|
||||
if uop is UOps.LOAD and len(src) > 3 and src[2].op is UOps.ALU: assert src[2].dtype == dtypes.bool and src[3].dtype == dtype
|
||||
if uop is UOps.STORE:
|
||||
assert dtype is None, f"{uop} dtype must be None, got {dtype}"
|
||||
if len(src) == 4: assert src[3].dtype == dtypes.bool, f"gate dtype mismatch {src[3].dtype} != {dtypes.bool}"
|
||||
#if len(src) == 4: assert src[3].dtype == dtypes.bool, f"gate dtype mismatch {src[3].dtype} != {dtypes.bool}"
|
||||
if uop is UOps.ALU:
|
||||
if arg in UnaryOps:
|
||||
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
|
||||
|
||||
@@ -148,10 +148,9 @@ class PTXRenderer(Renderer):
|
||||
assert src[1].op is UOps.CONST, f"store isn't const {u}"
|
||||
mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global'
|
||||
if src[2].dtype.count > 1:
|
||||
kk((f"@{r[src[3]]} " if len(src)>3 else "") + \
|
||||
f"st{mem_type}.v{src[2].dtype.count}.{self.mem_types[src[2].dtype.scalar()]} [{r[src[0]]}+{src[1].arg}], {{{', '.join(r[src[2]])}}};")
|
||||
kk(f"st{mem_type}.v{src[2].dtype.count}.{self.mem_types[src[2].dtype.scalar()]} [{r[src[0]]}+{src[1].arg}], {{{', '.join(r[src[2]])}}};")
|
||||
else:
|
||||
kk(*self.render_store(r[src[0]], r[src[2]], src[2].dtype, gate=r[src[3]] if len(src)>3 else None, ss=mem_type, offset=src[1].arg))
|
||||
kk(*self.render_store(r[src[0]], r[src[2]], src[2].dtype, gate=None, ss=mem_type, offset=src[1].arg))
|
||||
else:
|
||||
assert dtype is not None, f"None dtype for uop {uop}"
|
||||
if uop is UOps.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:]))
|
||||
|
||||
@@ -123,7 +123,7 @@ class CStyleLanguage(Renderer):
|
||||
elif uop is UOps.STORE:
|
||||
assert src[0].dtype is not None and src[2].dtype is not None
|
||||
rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
|
||||
kk(f"if ({r[src[3]]}) {{ {rendered_store} }}" if len(src) > 3 else rendered_store)
|
||||
kk(rendered_store)
|
||||
else:
|
||||
assert dtype is not None, f"None dtype for uop {uop}"
|
||||
if uop is UOps.RANGE:
|
||||
|
||||
@@ -100,11 +100,7 @@ class LLVMRenderer(Renderer):
|
||||
uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
|
||||
if uop is UOps.STORE:
|
||||
element = cast(bb, lvars[src[2]], src[2].dtype, src[0].dtype)
|
||||
if len(src) > 3:
|
||||
with bb[-1].if_then(lvars[src[3]]):
|
||||
bb[-1].store(element, bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
|
||||
else:
|
||||
bb[-1].store(element, bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
|
||||
bb[-1].store(element, bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
|
||||
elif uop is UOps.ENDRANGE:
|
||||
loop_entry_bb, phis = loop_blocks.pop()
|
||||
idx_p1 = bb[-1].add(lvars[src[0]], ir.Constant(ir.IntType(32), 1))
|
||||
|
||||
Reference in New Issue
Block a user