|
|
|
|
@@ -5,7 +5,7 @@ from tinygrad.uop.ops import PatternMatcher, UPat, Ops, UOp, resolve, GroupOp, _
|
|
|
|
|
from tinygrad.uop.ops import track_rewrites, graph_rewrite, identity_element, sint, AxisType, BottomUpGate
|
|
|
|
|
from tinygrad.uop.symbolic import symbolic_flat
|
|
|
|
|
from tinygrad.helpers import argsort, prod, all_same, pluralize, getenv, flatten, dedup, all_int, DEBUG, SPLIT_REDUCEOP, Metadata, DEBUG_RANGEIFY
|
|
|
|
|
from tinygrad.helpers import PCONTIG, partition
|
|
|
|
|
from tinygrad.helpers import PCONTIG, partition, get_single_element
|
|
|
|
|
from tinygrad.codegen.simplify import pm_flatten_range, pm_reduce_simplify
|
|
|
|
|
from tinygrad.codegen.opt import Opt
|
|
|
|
|
from tinygrad.schedule.indexing import run_rangeify, BufferizeOpts, ALWAYS_CONTIGUOUS, IndexingContext, apply_movement_op
|
|
|
|
|
@@ -299,11 +299,11 @@ pm_limit_bufs = PatternMatcher([(UPat(set.union(GroupOp.Binary, GroupOp.Ternary)
|
|
|
|
|
# BUFFERIZE returns the BUFFER ready for INDEXing (doing this will make splitting a lot easier)
|
|
|
|
|
# NOTE: this has been fixed up a bit
|
|
|
|
|
|
|
|
|
|
def bufferize_to_store(x:UOp, allow_locals=True):
|
|
|
|
|
rngs = x.src[1:]
|
|
|
|
|
shape = x.shape
|
|
|
|
|
size = prod(shape)
|
|
|
|
|
assert size > 0 and isinstance(size, int), f"no zero sized or symbolic sized buffers {shape}"
|
|
|
|
|
def bufferize_to_store(x:UOp, idx:UOp, allow_locals=True):
|
|
|
|
|
#assert isinstance(x.tag, Flat), "bufferize must be flat"
|
|
|
|
|
size = prod(x.shape)
|
|
|
|
|
rngs = sorted(idx.ranges, key=lambda x: x.arg)
|
|
|
|
|
assert size > 0 and isinstance(size, int), f"no zero sized or symbolic sized buffers {size}"
|
|
|
|
|
|
|
|
|
|
sdtype = x.dtype.ptr(size=size, addrspace=x.arg.addrspace)
|
|
|
|
|
if x.src[0].op is Ops.ASSIGN:
|
|
|
|
|
@@ -311,7 +311,7 @@ def bufferize_to_store(x:UOp, allow_locals=True):
|
|
|
|
|
assert assign_target.op is Ops.INDEX, f"{assign_target.op} is not index"
|
|
|
|
|
# in assign, this is the buffer size, not the bufferize size
|
|
|
|
|
# TODO: assign_mops here
|
|
|
|
|
do_store = assign_target.replace(dtype=sdtype).store(assign_src, tag=x.tag).end(*[x for x in rngs if x.op is Ops.RANGE])
|
|
|
|
|
do_store = assign_target.replace(dtype=sdtype).store(assign_src, tag=x.tag).end(*rngs)
|
|
|
|
|
ret = assign_target.src[0].after(do_store)
|
|
|
|
|
mops = []
|
|
|
|
|
walk = assign_mops
|
|
|
|
|
@@ -319,37 +319,44 @@ def bufferize_to_store(x:UOp, allow_locals=True):
|
|
|
|
|
mops.append((walk.op, walk.marg))
|
|
|
|
|
walk = walk.src[0]
|
|
|
|
|
for m in mops[::-1]: ret = ret._mop(*m)
|
|
|
|
|
return ret.forced_reshape(shape).replace(tag=x.tag)
|
|
|
|
|
return ret
|
|
|
|
|
|
|
|
|
|
# NOTE: the DEFINE_LOCAL needs to be disambiguated here
|
|
|
|
|
if sdtype.addrspace == AddrSpace.GLOBAL:
|
|
|
|
|
buf = UOp.new_buffer(x.arg.device, size, x.dtype)
|
|
|
|
|
do_store = buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0], tag=x.tag).end(*[x for x in rngs if x.op is Ops.RANGE])
|
|
|
|
|
ret = buf.after(do_store).forced_reshape(shape)
|
|
|
|
|
# TODO: is this right? what if it's offset
|
|
|
|
|
if any(r.op is Ops.RANGE and r.src[0].op is not Ops.CONST for r in rngs):
|
|
|
|
|
sym_shape = tuple([ssimplify(r.src[0]) if r.op is not Ops.CONST else 1 for r in rngs])
|
|
|
|
|
ret = ret.shrink(tuple([(0,x) for x in sym_shape]))
|
|
|
|
|
return ret.replace(tag=x.tag)
|
|
|
|
|
do_store = buf.index(idx, dtype=sdtype).store(x.src[0], tag=x.tag).end(*rngs)
|
|
|
|
|
return buf.after(do_store)
|
|
|
|
|
|
|
|
|
|
if allow_locals:
|
|
|
|
|
# handle locals
|
|
|
|
|
tag = x.arg.device
|
|
|
|
|
if tag is None: tag = UOp.unique().arg # TODO: hack
|
|
|
|
|
buf = UOp(Ops.DEFINE_LOCAL, sdtype, arg=tag)
|
|
|
|
|
do_store = buf.reshape(shape).index(*rngs, dtype=sdtype).store(x.src[0]).end(*[x for x in rngs if x.op is Ops.RANGE])
|
|
|
|
|
return buf.after(do_store.barrier()).reshape(shape)
|
|
|
|
|
do_store = buf.index(idx, dtype=sdtype).store(x.src[0]).end(*rngs)
|
|
|
|
|
return buf.after(do_store.barrier())
|
|
|
|
|
|
|
|
|
|
pm_add_buffers = pm_mops+to_bufferview+PatternMatcher([
|
|
|
|
|
(UPat(Ops.BUFFERIZE, name="x"), lambda x: bufferize_to_store(x, allow_locals=False)),
|
|
|
|
|
# collapse any BUFFERIZE to single input BUFFERIZE. move the tag to a reshape
|
|
|
|
|
def flatten_bufferize(x:UOp):
|
|
|
|
|
if x.tag is None and len(x.src) == 2: return None
|
|
|
|
|
ret = x.replace(tag=None, src=(x.src[0], get_single_element(apply_movement_op(Ops.RESHAPE, (prod(x.shape),), x.shape, x.src[1:]))))
|
|
|
|
|
rngs = x.src[1:]
|
|
|
|
|
ret = ret.forced_reshape(x.shape)
|
|
|
|
|
if any(r.op is Ops.RANGE and r.src[0].op is not Ops.CONST for r in rngs):
|
|
|
|
|
sym_shape = tuple([ssimplify(r.src[0]) if r.op is not Ops.CONST else 1 for r in rngs])
|
|
|
|
|
ret = ret.shrink(tuple([(0,x) for x in sym_shape]))
|
|
|
|
|
return ret.rtag(x.tag)
|
|
|
|
|
pm_flatten_bufferize = PatternMatcher([(UPat(Ops.BUFFERIZE, name="x"), flatten_bufferize)])
|
|
|
|
|
|
|
|
|
|
pm_add_buffers = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
|
|
|
|
|
(UPat(Ops.BUFFERIZE, src=(UPat(), UPat(name="idx")), name="x"), lambda x, idx: bufferize_to_store(x, idx, allow_locals=False)),
|
|
|
|
|
|
|
|
|
|
# move RESHAPEs through MSELECT/MSTACK
|
|
|
|
|
(UPat((Ops.MSELECT, Ops.MSTACK), src=UPat(Ops.RESHAPE), name="m"),
|
|
|
|
|
lambda m: m.replace(src=tuple([x.src[0].base for x in m.src]), tag=None).reshape(m.shape).rtag(m.tag)),
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
pm_add_buffers_local = pm_mops+to_bufferview+PatternMatcher([
|
|
|
|
|
(UPat(Ops.BUFFERIZE, name="x"), bufferize_to_store),
|
|
|
|
|
pm_add_buffers_local = pm_mops+pm_flatten_bufferize+to_bufferview+PatternMatcher([
|
|
|
|
|
(UPat(Ops.BUFFERIZE, src=(UPat(), UPat(name="idx")), name="x"), bufferize_to_store),
|
|
|
|
|
])
|
|
|
|
|
|
|
|
|
|
# *****************
|
|
|
|
|
@@ -435,7 +442,7 @@ rangeify_codegen = PatternMatcher([
|
|
|
|
|
|
|
|
|
|
def remove_metadata_tags(ctx:LocalAddBufferContext, x:UOp):
|
|
|
|
|
if x.tag is None or x.tag == (): return None
|
|
|
|
|
ctx.parent_tags += list(x.tag)
|
|
|
|
|
if isinstance(x.tag, tuple): ctx.parent_tags += list(x.tag)
|
|
|
|
|
return x.replace(tag=None)
|
|
|
|
|
|
|
|
|
|
pm_remove_tags = PatternMatcher([
|
|
|
|
|
|