mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-14 16:44:59 -05:00
shorter BufferOps.LOAD creation (#5685)
This commit is contained in:
@@ -37,7 +37,7 @@ class ScheduleItem:
|
||||
|
||||
# *** DAG transformation: List[LazyBuffer] -> ScheduleItem ***
|
||||
|
||||
def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], st:ShapeTracker,
|
||||
def _recursive_lazyop(buf:LazyBuffer, inputs:Dict[LazyBuffer, int], outputs:Tuple[LazyBuffer, ...], var_vals:Dict[Variable, int], st:ShapeTracker,
|
||||
realizes:Dict[LazyBuffer, None], assign_targets:Dict[LazyBuffer, LazyBuffer],
|
||||
reduce_info:Dict[LazyBuffer, Tuple[ShapeTracker, Tuple[int, ...]]], cache) -> LazyOp:
|
||||
"""recursively create a lazyop"""
|
||||
@@ -58,16 +58,13 @@ def _recursive_lazyop(buf:LazyBuffer, inputs:List[LazyBuffer], outputs:Tuple[Laz
|
||||
unbound_st, st_var_vals = st.simplify().unbind()
|
||||
var_vals.update(st_var_vals)
|
||||
if buf in assign_targets:
|
||||
# can only assign to contiguous read+write buffer
|
||||
if not unbound_st.contiguous:
|
||||
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
|
||||
if not (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and
|
||||
ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask)):
|
||||
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
return LazyOp(BufferOps.LOAD, (), MemBuffer(outputs.index(assign_targets[buf]), buf.dtype, unbound_st))
|
||||
if buf not in inputs: inputs.append(buf)
|
||||
return LazyOp(BufferOps.LOAD, (), MemBuffer(len(outputs)+inputs.index(buf), buf.dtype, unbound_st))
|
||||
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
|
||||
if unbound_st.contiguous or (len(unbound_st.views) == 1 and unbound_st.views[0].mask is not None and\
|
||||
ShapeTracker.from_shape(unbound_st.shape).shrink(unbound_st.views[0].mask) == unbound_st.shrink(unbound_st.views[0].mask)):
|
||||
return LazyOp(BufferOps.LOAD, (), MemBuffer(outputs.index(assign_targets[buf]), buf.dtype, unbound_st))
|
||||
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
return LazyOp(BufferOps.LOAD, (), MemBuffer(len(outputs)+inputs.setdefault(buf, len(inputs)), buf.dtype, unbound_st))
|
||||
|
||||
# if a CONTIGUOUS or ASSIGN made it all the way here, just skip it
|
||||
if buf.op in {MetaOps.CONTIGUOUS, MetaOps.ASSIGN}:
|
||||
@@ -126,7 +123,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]):
|
||||
assign_targets = {x.srcs[1]:x for x in outs if x.op is MetaOps.ASSIGN}
|
||||
cache: Dict[Tuple[LazyBuffer, ShapeTracker], LazyOp] = {}
|
||||
ast: List[LazyOp] = []
|
||||
inputs: List[LazyBuffer] = []
|
||||
inputs: Dict[LazyBuffer, int] = {}
|
||||
reduce_info: Dict[LazyBuffer, Tuple[ShapeTracker, Tuple[int, ...]]] = {}
|
||||
seen_ops: Dict[Tuple[LazyBuffer, ShapeTracker], None] = {}
|
||||
for i, out in enumerate(outs):
|
||||
@@ -137,7 +134,7 @@ def _lower_lazybuffer(outs:List[LazyBuffer], realizes:Dict[LazyBuffer, None]):
|
||||
output_view, vv = output_view.simplify().unbind()
|
||||
if vv: var_vals.update(vv)
|
||||
ast.append(LazyOp(BufferOps.STORE, (lop, ), MemBuffer(i, out.dtype, output_view)))
|
||||
return LazyOp(MetaOps.KERNEL, tuple(ast)), inputs, var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs])
|
||||
return LazyOp(MetaOps.KERNEL, tuple(ast)), list(inputs), var_vals, dedup([x[0].metadata for x in cache if x[0].metadata and x[0] not in inputs])
|
||||
|
||||
# *** DAG creation: decide which LazyBuffers should realize ***
|
||||
|
||||
|
||||
Reference in New Issue
Block a user