rewrite all gates

This commit is contained in:
qazal
2024-07-22 14:52:14 +03:00
parent 11d9035fe0
commit 2425a443f3
6 changed files with 21 additions and 14 deletions

View File

@@ -29,8 +29,9 @@ class TestLinearizerDumb(unittest.TestCase):
k.required_optimizations()
for opt in opts: k.apply_opt(opt)
prg = k.to_program()
prg.uops.print()
#prg.uops.print()
print(prg.src)
#print("\n".join(prg.src.splitlines()[-4:]))
if __name__ == '__main__':
unittest.main()

View File

@@ -1,4 +1,5 @@
from __future__ import annotations
from dataclasses import replace
from typing import Iterator, Optional, Tuple, Dict, List, Set, Union, cast, TYPE_CHECKING
import functools, itertools, heapq, math
from tinygrad.dtype import dtypes, PtrDType, ImageDType
@@ -473,12 +474,20 @@ class UOpGraph:
if u.op is UOps.LOAD and u.src[-1].op is UOps.BARRIER:
if_uop = UOp(UOps.IF, None, (gate, u.src[-1]))
return UOp(u.op, u.dtype, u.src[:-1]+(if_uop,), u.arg)
if (replace_source:=tuple(_replace_gates(x, gate) for x in u.src)) != u.src: return UOp(u.op, u.dtype, replace_source, u.arg)
if (replace_source:=tuple(_replace_gates(x, gate) for x in u.src)) != u.src:
if u.op is UOps.STORE and replace_source[3] is gate: replace_source = replace_source[:3]
return UOp(u.op, u.dtype, replace_source, u.arg)
if u.op is UOps.STORE:
assert len(u.src) == 4
assert gate is u.src[3]
if_uop = UOp(UOps.IF, None, (gate,))
replace_source = u.src[:3]+(if_uop,)
return UOp(u.op, u.dtype, replace_source, u.arg)
return u
sink_srcs = list(self.sink.src)
for i, s in enumerate(sink_srcs):
if s.op is UOps.STORE and len(s.src) == 4 and (rw:=_replace_gates(s, s.src[3])) != s:
sink_srcs[i] = UOp(rw.op, rw.dtype, rw.src[:3], rw.arg)
sink_srcs[i] = UOp(rw.op, rw.dtype, rw.src, rw.arg)
sink = UOp(UOps.SINK, None, tuple(sink_srcs))
# do graph rewrite

View File

@@ -43,7 +43,9 @@ class UOp:
return (self.op.value, (self.arg if self.op is not UOps.DEFINE_VAR else self.arg.expr) if self.op is not UOps.ALU else \
self.arg.value, self.dtype, self.src)
def __lt__(self, x:UOp): return self.cmp_tuple < x.cmp_tuple
def __repr__(self): return pretty_print(self, lambda x: f"UOp({x.op}, {x.dtype}, arg={x.arg}, src=(%s))")
#def __repr__(self): return pretty_print(self, lambda x: f"UOp({x.op}, {x.dtype}, arg={x.arg}, src=(%s))")
def __repr__(self):
return f"{str(self.op):20s}: {str(self.dtype) if self.dtype is not None else '':25s} {str([x.op for x in self.src]):32s} {self.arg}"
# *** uop syntactic sugar
def ufix(self, x): return UOp.const(self.dtype, x) if not isinstance(x, UOp) else x
def cast(self, dtype=None): return UOp(UOps.CAST, dtype, (self,))
@@ -177,7 +179,7 @@ def type_verify(uops):
if uop is UOps.LOAD and len(src) > 3 and src[2].op is UOps.ALU: assert src[2].dtype == dtypes.bool and src[3].dtype == dtype
if uop is UOps.STORE:
assert dtype is None, f"{uop} dtype must be None, got {dtype}"
if len(src) == 4: assert src[3].dtype == dtypes.bool, f"gate dtype mismatch {src[3].dtype} != {dtypes.bool}"
#if len(src) == 4: assert src[3].dtype == dtypes.bool, f"gate dtype mismatch {src[3].dtype} != {dtypes.bool}"
if uop is UOps.ALU:
if arg in UnaryOps:
assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"

View File

@@ -148,10 +148,9 @@ class PTXRenderer(Renderer):
assert src[1].op is UOps.CONST, f"store isn't const {u}"
mem_type = '.shared' if src[0].op is UOps.DEFINE_LOCAL or any(x.op is UOps.DEFINE_LOCAL for x in src[0].parents) else '.global'
if src[2].dtype.count > 1:
kk((f"@{r[src[3]]} " if len(src)>3 else "") + \
f"st{mem_type}.v{src[2].dtype.count}.{self.mem_types[src[2].dtype.scalar()]} [{r[src[0]]}+{src[1].arg}], {{{', '.join(r[src[2]])}}};")
kk(f"st{mem_type}.v{src[2].dtype.count}.{self.mem_types[src[2].dtype.scalar()]} [{r[src[0]]}+{src[1].arg}], {{{', '.join(r[src[2]])}}};")
else:
kk(*self.render_store(r[src[0]], r[src[2]], src[2].dtype, gate=r[src[3]] if len(src)>3 else None, ss=mem_type, offset=src[1].arg))
kk(*self.render_store(r[src[0]], r[src[2]], src[2].dtype, gate=None, ss=mem_type, offset=src[1].arg))
else:
assert dtype is not None, f"None dtype for uop {uop}"
if uop is UOps.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:]))

View File

@@ -123,7 +123,7 @@ class CStyleLanguage(Renderer):
elif uop is UOps.STORE:
assert src[0].dtype is not None and src[2].dtype is not None
rendered_store = self.render_store(r[src[0]], src[0].dtype, r[src[2]], src[2].dtype, strip_parens(r[src[1]]), src[0].op is UOps.DEFINE_LOCAL)
kk(f"if ({r[src[3]]}) {{ {rendered_store} }}" if len(src) > 3 else rendered_store)
kk(rendered_store)
else:
assert dtype is not None, f"None dtype for uop {uop}"
if uop is UOps.RANGE:

View File

@@ -100,11 +100,7 @@ class LLVMRenderer(Renderer):
uop,dtype,src,args = u.op,u.dtype,u.src,u.arg
if uop is UOps.STORE:
element = cast(bb, lvars[src[2]], src[2].dtype, src[0].dtype)
if len(src) > 3:
with bb[-1].if_then(lvars[src[3]]):
bb[-1].store(element, bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
else:
bb[-1].store(element, bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
bb[-1].store(element, bb[-1].gep(lvars[src[0]], [lvars[src[1]]], inbounds=True))
elif uop is UOps.ENDRANGE:
loop_entry_bb, phis = loop_blocks.pop()
idx_p1 = bb[-1].add(lvars[src[0]], ir.Constant(ir.IntType(32), 1))