start simplifying the scheduler context [pr] (#8830)

This commit is contained in:
qazal
2025-02-02 11:11:36 -05:00
committed by GitHub
parent d64af3c884
commit 565c37c681

View File

@@ -35,7 +35,7 @@ class ScheduleItem:
@dataclass(frozen=True)
class ScheduleContext:
tensor_uops: dict[UOp, list[UOp]] = field(default_factory=dict) # this maps BUFFER uops of this schedule to the tensor uop
tensor_uops: dict[UOp, list[UOp]] # this maps BUFFER uops of this schedule to the tensor uop
assigns: set[UOp] = field(default_factory=set) # this holds all the BUFFER uops we ASSIGN to in this schedule
realizes: dict[UOp, UOp] = field(default_factory=dict) # this holds all the BUFFER uops we mutate in this schedule
allbufs: dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op
@@ -45,16 +45,16 @@ class ScheduleContext:
# wrap tensor uops around a VIEW(BUFFER, <uop>)
# this BUFFER preserves a link back to the uop on the tensor after the scheduler rewrites it.
def add_buffers(buf:UOp, tensor_map:dict[UOp, list[UOp]], ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp:
def add_buffers(buf:UOp, buffer_map:dict[UOp, UOp], cache:dict[UOp, UOp]) -> UOp:
if (r:=cache.get(buf)) is not None: return r
# SINK is passthrough
if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src))
if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, buffer_map, cache) for x in buf.src))
# skip creating buffers for CONST/BIND/DEVICE/BUFFER
if buf.base.op in {Ops.CONST, Ops.BIND, Ops.DEVICE}: return buf
if buf.base.op is Ops.BUFFER: return buf.view(unwrap(buf.st))
# VIEW is passthrough
if buf is not buf.base:
cache[buf] = ret = add_buffers(buf.base, tensor_map, ctx, cache).view(unwrap(buf.st))
cache[buf] = ret = add_buffers(buf.base, buffer_map, cache).view(unwrap(buf.st))
return ret
# make things that can't be images not images
dtype = buf.dtype
@@ -64,9 +64,9 @@ def add_buffers(buf:UOp, tensor_map:dict[UOp, list[UOp]], ctx:ScheduleContext, c
# ASSIGN already has a target buffer, otherwise we create a new one
assert isinstance(buf.device, str), f"buf device is str, not {buf.device}"
buf_uop = buf.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype)
op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src))
# track the underlying tensor uop for this buffer
ctx.tensor_uops[buf_uop] = tensor_map[buf]
op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, buffer_map, cache) for x in buf.src))
# track the buffer uop for the simplified uop
buffer_map[buf] = buf_uop
# (early) bufferize
cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op), buf.st)
return ret
@@ -431,12 +431,13 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
elif v.op is Ops.CONST and all_int(v.shape): becomes_map[k] = v
# we group the rest of UOps into ScheduleItems
rev_tensor_map: dict[UOp, list[UOp]] = {}
for k,v in tensor_map.items(): rev_tensor_map.setdefault(v, []).append(k)
# add BUFFER uops
sink = add_buffers(tensor_map[big_sink], rev_tensor_map, ctx:=ScheduleContext(), cache={})
buffer_map: dict[UOp, UOp] = {}
sink = add_buffers(tensor_map[big_sink], buffer_map, cache={})
# get realizes
realize_map = group_realizes(sink, ctx)
buf_tensors: dict[UOp, list[UOp]] = {}
for k,v in tensor_map.items():
if (b:=buffer_map.get(v)) is not None: buf_tensors.setdefault(b, []).append(k)
realize_map = group_realizes(sink, ctx:=ScheduleContext(buf_tensors))
# TODO: this should be the break between the "grouper" and the "linearizer"
# here, there should just be one sink UOp with BUFFER/KERNEL/COPY/ASSIGN (assign is the parent if you want the buffer post assign)
@@ -449,7 +450,7 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va
assert store.op is Ops.STORE, f"expected a realized BUFFER to get a STORE {sink}"
prescheduled.append(schedule_uop(store.sink(), ctx, var_vals))
# can only schedule once
for tensor_uop in ctx.tensor_uops[buf_uop]: becomes_map[tensor_uop] = buf_uop.view(unwrap(tensor_uop.st))
for tensor_uop in buf_tensors[buf_uop]: becomes_map[tensor_uop] = buf_uop.view(unwrap(tensor_uop.st))
# increment refcount for this buffer
buf_uop.buffer.ref(1)