mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
small changes from O(1) multi [pr] (#10309)
This commit is contained in:
@@ -84,8 +84,9 @@ sym = symbolic_simple+PatternMatcher([
|
||||
# COPY(CONST) creates a new CONST on the destination device
|
||||
(UPat(Ops.COPY, name="root", src=(UPat.cvar("x"), UPat(Ops.DEVICE))), lambda root,x: root.const_like(x.arg)),
|
||||
# store a shrink before COPY, otherwise view after the COPY
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.VIEW, name="v"), UPat(Ops.DEVICE)), name="copy"), lambda copy,v: v.contiguous().copy_to_device(copy.device) \
|
||||
if prod(v.shape) < prod(v.base.shape) else v.base.copy_to_device(copy.device).view(v.st)),
|
||||
(UPat(Ops.COPY, src=(UPat(Ops.VIEW, name="v"), UPat(Ops.DEVICE)), name="copy"), lambda copy,v:
|
||||
v.contiguous().copy_to_device(copy.device, arg=copy.arg) if prod(v.shape) < prod(v.base.shape) else \
|
||||
v.base.copy_to_device(copy.device, arg=copy.arg).view(v.st)),
|
||||
# remove cast to image when it's already a contiguous image
|
||||
(UPat(Ops.CAST, name="cast", src=(UPat(Ops.VIEW, name="vm", src=(UPat(Ops.CONTIGUOUS, name="base"),)),)),
|
||||
lambda cast,base,vm: base.view(vm.st) if isinstance(cast.dtype, ImageDType) and isinstance(base.dtype, ImageDType) else None),
|
||||
|
||||
@@ -66,4 +66,4 @@ def memory_planner(schedule:list[ScheduleItem]) -> list[ScheduleItem]:
|
||||
# Exclude buffers involved in load ops (e.g transfers) to preserve parallelism in graphs.
|
||||
assigned = _internal_memory_planner([list(si.bufs) for si in schedule],
|
||||
noopt_buffers={b for si in schedule if si.ast.op is not Ops.SINK for b in si.bufs})
|
||||
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata) for si in schedule]
|
||||
return [ScheduleItem(si.ast, tuple(assigned.get(x, x) for x in si.bufs), si.metadata, si.fixedvars) for si in schedule]
|
||||
|
||||
Reference in New Issue
Block a user