diff --git a/tinygrad/engine/grouper.py b/tinygrad/engine/grouper.py index 54d0a26c69..ef0a33d211 100644 --- a/tinygrad/engine/grouper.py +++ b/tinygrad/engine/grouper.py @@ -84,8 +84,9 @@ sym = symbolic_simple+PatternMatcher([ # COPY(CONST) creates a new CONST on the destination device (UPat(Ops.COPY, name="root", src=(UPat.cvar("x"), UPat(Ops.DEVICE))), lambda root,x: root.const_like(x.arg)), # store a shrink before COPY, otherwise view after the COPY - (UPat(Ops.COPY, src=(UPat(Ops.VIEW, name="v"), UPat(Ops.DEVICE)), name="copy"), lambda copy,v: v.contiguous().copy_to_device(copy.device) \ - if prod(v.shape) < prod(v.base.shape) else v.base.copy_to_device(copy.device).view(v.st)), + (UPat(Ops.COPY, src=(UPat(Ops.VIEW, name="v"), UPat(Ops.DEVICE)), name="copy"), lambda copy,v: + v.contiguous().copy_to_device(copy.device, arg=copy.arg) if prod(v.shape) < prod(v.base.shape) else \ + v.base.copy_to_device(copy.device, arg=copy.arg).view(v.st)), # remove cast to image when it's already a contiguous image (UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"),)),)), lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None), diff --git a/tinygrad/engine/memory.py b/tinygrad/engine/memory.py index 5314b0af79..c102d9a52e 100644 --- a/tinygrad/engine/memory.py +++ b/tinygrad/engine/memory.py @@ -66,4 +66,4 @@ def memory_planner(schedule:list[ScheduleItem]) -> list[ScheduleItem]: # Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs. assigned = _internal_memory_planner([list(si.bufs) for si in schedule], noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs}) - return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata) for si in schedule] + return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata, si.fixedvars) for si in schedule]