diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index f5bc8612ee..e0aebefb35 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -35,7 +35,7 @@ class ScheduleItem: @dataclass(frozen=True) class ScheduleContext: - tensor_uops: dict[UOp, list[UOp]] = field(default_factory=dict) # this maps BUFFER uops of this schedule to the tensor uop + tensor_uops: dict[UOp, list[UOp]] # this maps BUFFER uops of this schedule to the tensor uop assigns: set[UOp] = field(default_factory=set) # this holds all the BUFFER uops we ASSIGN to in this schedule realizes: dict[UOp, UOp] = field(default_factory=dict) # this holds all the BUFFER uops we mutate in this schedule allbufs: dict[UOp, UOp] = field(default_factory=dict) # this maps BUFFER uops the actual op @@ -45,16 +45,16 @@ class ScheduleContext: # wrap tensor uops around a VIEW(BUFFER, ) # this BUFFER preserves a link back to the uop on the tensor after the scheduler rewrites it. -def add_buffers(buf:UOp, tensor_map:dict[UOp, list[UOp]], ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp: +def add_buffers(buf:UOp, buffer_map:dict[UOp, UOp], cache:dict[UOp, UOp]) -> UOp: if (r:=cache.get(buf)) is not None: return r # SINK is passthrough - if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src)) + if buf.op is Ops.SINK: return buf.replace(src=tuple(add_buffers(x, buffer_map, cache) for x in buf.src)) # skip creating buffers for CONST/BIND/DEVICE/BUFFER if buf.base.op in {Ops.CONST, Ops.BIND, Ops.DEVICE}: return buf if buf.base.op is Ops.BUFFER: return buf.view(unwrap(buf.st)) # VIEW is passthrough if buf is not buf.base: - cache[buf] = ret = add_buffers(buf.base, tensor_map, ctx, cache).view(unwrap(buf.st)) + cache[buf] = ret = add_buffers(buf.base, buffer_map, cache).view(unwrap(buf.st)) return ret # make things that can't be images not images dtype = buf.dtype @@ -64,9 +64,9 @@ def add_buffers(buf:UOp, tensor_map:dict[UOp, list[UOp]], ctx:ScheduleContext, c # ASSIGN already has a target buffer, otherwise we create a new one assert isinstance(buf.device, str), f"buf device is str, not {buf.device}" buf_uop = buf.buf_uop if buf.op is Ops.ASSIGN else UOp.new_buffer(buf.device, buf.size, dtype) - op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, tensor_map, ctx, cache) for x in buf.src)) - # track the underlying tensor uop for this buffer - ctx.tensor_uops[buf_uop] = tensor_map[buf] + op = buf.replace(dtype=dtype, src=tuple(add_buffers(x, buffer_map, cache) for x in buf.src)) + # track the buffer uop for the simplified uop + buffer_map[buf] = buf_uop # (early) bufferize cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (buf_uop, op), buf.st) return ret @@ -431,12 +431,13 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va elif v.op is Ops.CONST and all_int(v.shape): becomes_map[k] = v # we group the rest of UOps into ScheduleItems - rev_tensor_map: dict[UOp, list[UOp]] = {} - for k,v in tensor_map.items(): rev_tensor_map.setdefault(v, []).append(k) - # add BUFFER uops - sink = add_buffers(tensor_map[big_sink], rev_tensor_map, ctx:=ScheduleContext(), cache={}) + buffer_map: dict[UOp, UOp] = {} + sink = add_buffers(tensor_map[big_sink], buffer_map, cache={}) # get realizes - realize_map = group_realizes(sink, ctx) + buf_tensors: dict[UOp, list[UOp]] = {} + for k,v in tensor_map.items(): + if (b:=buffer_map.get(v)) is not None: buf_tensors.setdefault(b, []).append(k) + realize_map = group_realizes(sink, ctx:=ScheduleContext(buf_tensors)) # TODO: this should be the break between the "grouper" and the "linearizer" # here, there should just be one sink UOp with BUFFER/KERNEL/COPY/ASSIGN (assign is the parent if you want the buffer post assign) @@ -449,7 +450,7 @@ def create_schedule_with_vars(big_sink:UOp) -> tuple[list[ScheduleItem], dict[Va assert store.op is Ops.STORE, f"expected a realized BUFFER to get a STORE {sink}" prescheduled.append(schedule_uop(store.sink(), ctx, var_vals)) # can only schedule once - for tensor_uop in ctx.tensor_uops[buf_uop]: becomes_map[tensor_uop] = buf_uop.view(unwrap(tensor_uop.st)) + for tensor_uop in buf_tensors[buf_uop]: becomes_map[tensor_uop] = buf_uop.view(unwrap(tensor_uop.st)) # increment refcount for this buffer buf_uop.buffer.ref(1)