mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-24 06:18:01 -05:00
start simplifying the scheduler context [pr] (#8830)
This commit is contained in:
@@ -35,7 +35,7 @@ class ScheduleItem:
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleContext:
|
||||
tensor_uops: dict[UOp, list[UOp]] = field(default_factory=dict) # this maps BUFFER uops of this schedule to the tensor uop
|
||||
tensor_uops: dict[UOp, list[UOp]] # this maps BUFFER uops of this schedule to the tensor uop
|
||||
assigns: set[UOp] = field(default_factory=set) # this holds all the BUFFER uops we ASSIGN to in this schedule
|
||||
realizes: dict[UOp, UOp] = field(default_factory=dict) # this holds all the BUFFER uops we mutate in this schedule
|
||||
allbufs: dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op
|
||||
@@ -45,16 +45,16 @@ class ScheduleContext:
|
||||
|
||||
# wrap tensor uops around a VIEW(BUFFER, <uop>)
|
||||
# this BUFFER preserves a link back to the uop on the tensor after the scheduler rewrites it.
|
||||
def add_buffers(buf:UOp, tensor_map:dict[UOp, list[UOp]], ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp:
|
||||
def add_buffers(buf:UOp, buffer_map:dict[UOp, UOp], cache:dict[UOp, UOp]) -> UOp:
|
||||
if (r:=cache.get(buf)) is not None: return r
|
||||
# SINK is passthrough
|
||||
if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src))
|
||||
if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, buffer_map, cache) for x in buf.src))
|
||||
# skip creating buffers for CONST/BIND/DEVICE/BUFFER
|
||||
if buf.base.op in {Ops.CONST, Ops.BIND, Ops.DEVICE}: return buf
|
||||
if buf.base.op is Ops.BUFFER: return buf.view(unwrap(buf.st))
|
||||
# VIEW is passthrough
|
||||
if buf is not buf.base:
|
||||
cache[buf] = ret = add_buffers(buf.base, tensor_map, ctx, cache).view(unwrap(buf.st))
|
||||
cache[buf] = ret = add_buffers(buf.base, buffer_map, cache).view(unwrap(buf.st))
|
||||
return ret
|
||||
# make things that can't be images not images
|
||||
dtype = buf.dtype
|
||||
@@ -64,9 +64,9 @@ def add_buffers(buf:UOp, tensor_map:dict[UOp, list[UOp]], ctx:ScheduleContext, c
|
||||
# ASSIGN already has a target buffer, otherwise we create a new one
|
||||
assert isinstance(buf.device, str), f"buf device is str, not {buf.device}"
|
||||
buf_uop = buf.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype)
|
||||
op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src))
|
||||
# track the underlying tensor uop for this buffer
|
||||
ctx.tensor_uops[buf_uop] = tensor_map[buf]
|
||||
op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, buffer_map, cache) for x in buf.src))
|
||||
# track the buffer uop for the simplified uop
|
||||
buffer_map[buf] = buf_uop
|
||||
# (early) bufferize
|
||||
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op), buf.st)
|
||||
return ret
|
||||
@@ -431,12 +431,13 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
|
||||
elif v.op is Ops.CONST and all_int(v.shape): becomes_map[k] = v
|
||||
|
||||
# we group the rest of UOps into ScheduleItems
|
||||
rev_tensor_map: dict[UOp, list[UOp]] = {}
|
||||
for k,v in tensor_map.items(): rev_tensor_map.setdefault(v, []).append(k)
|
||||
# add BUFFER uops
|
||||
sink = add_buffers(tensor_map[big_sink], rev_tensor_map, ctx:=ScheduleContext(), cache={})
|
||||
buffer_map: dict[UOp, UOp] = {}
|
||||
sink = add_buffers(tensor_map[big_sink], buffer_map, cache={})
|
||||
# get realizes
|
||||
realize_map = group_realizes(sink, ctx)
|
||||
buf_tensors: dict[UOp, list[UOp]] = {}
|
||||
for k,v in tensor_map.items():
|
||||
if (b:=buffer_map.get(v)) is not None: buf_tensors.setdefault(b, []).append(k)
|
||||
realize_map = group_realizes(sink, ctx:=ScheduleContext(buf_tensors))
|
||||
|
||||
# TODO: this should be the break between the "grouper" and the "linearizer"
|
||||
# here, there should just be one sink UOp with BUFFER/KERNEL/COPY/ASSIGN (assign is the parent if you want the buffer post assign)
|
||||
@@ -449,7 +450,7 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
|
||||
assert store.op is Ops.STORE, f"expected a realized BUFFER to get a STORE {sink}"
|
||||
prescheduled.append(schedule_uop(store.sink(), ctx, var_vals))
|
||||
# can only schedule once
|
||||
for tensor_uop in ctx.tensor_uops[buf_uop]: becomes_map[tensor_uop] = buf_uop.view(unwrap(tensor_uop.st))
|
||||
for tensor_uop in buf_tensors[buf_uop]: becomes_map[tensor_uop] = buf_uop.view(unwrap(tensor_uop.st))
|
||||
# increment refcount for this buffer
|
||||
buf_uop.buffer.ref(1)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user