mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
add LUNIQUE op (#13554)
This commit is contained in:
@@ -1,8 +1,21 @@
|
||||
import itertools
|
||||
from examples.beautiful_mnist import Model
|
||||
from typing import Callable
|
||||
from tinygrad import nn, Tensor, dtypes, Device
|
||||
from tinygrad.helpers import getenv, trange, partition
|
||||
|
||||
class Model:
|
||||
def __init__(self):
|
||||
self.layers: list[Callable[[Tensor], Tensor]] = [
|
||||
nn.Conv2d(1, 32, 5), Tensor.relu,
|
||||
nn.Conv2d(32, 32, 5), Tensor.relu,
|
||||
nn.BatchNorm(32), Tensor.max_pool2d,
|
||||
nn.Conv2d(32, 64, 3), Tensor.relu,
|
||||
nn.Conv2d(64, 64, 3), Tensor.relu,
|
||||
nn.BatchNorm(64), Tensor.max_pool2d,
|
||||
lambda x: x.flatten(1), nn.Linear(576, 10)]
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor: return x.sequential(self.layers)
|
||||
|
||||
# TODO: refactor this into optim/onnx
|
||||
def functional_adam(g:Tensor, m:Tensor, v:Tensor, b1_t:Tensor, b2_t:Tensor, lr=0.001, b1=0.9, b2=0.999, eps=1e-6) -> Tensor:
|
||||
b1_t *= b1
|
||||
|
||||
@@ -12,8 +12,7 @@ from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.dtype import DType, ImageDType
|
||||
from tinygrad.uop.ops import UOp, Ops, GroupOp, UPat
|
||||
from tinygrad.helpers import CI, DEBUG, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
|
||||
from tinygrad.schedule.rangeify import get_rangeify_map, Kernel
|
||||
from tinygrad.engine.schedule import create_schedule_with_vars
|
||||
from tinygrad.schedule.rangeify import Kernel
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
|
||||
|
||||
class KernelCountException(Exception): pass
|
||||
@@ -24,9 +23,7 @@ def check_schedule(t:Tensor|list[Tensor]|UOp, allowed:int, to_prerealize:list[Te
|
||||
elif isinstance(t, list) and isinstance(t[0], Tensor): sched = Tensor.schedule(*t)
|
||||
else:
|
||||
assert isinstance(t, UOp), f"can't schedule {t}"
|
||||
sink = UOp.sink(t) if t.op is not Ops.SINK else t
|
||||
becomes_map = get_rangeify_map(sink)
|
||||
sched, _ = create_schedule_with_vars(sink.substitute(becomes_map))
|
||||
sched = Tensor(t).schedule()
|
||||
# test lowering all the ScheduleItems to ExecItems
|
||||
kernel_cnt = len([si for si,ei in lower_schedule(sched.copy()) if isinstance(ei.prg, CompiledRunner) or not filter_sink])
|
||||
if kernel_cnt != allowed:
|
||||
|
||||
@@ -27,7 +27,7 @@ def _apply_map_to_tensors(applied_map:dict[UOp, UOp], name:str) -> None:
|
||||
# get tensors in scope
|
||||
in_scope: dict[UOp, bool] = {}
|
||||
def visitor(node: UOp) -> bool: return True if node in applied_map else any(in_scope.get(s, False) for s in node.src)
|
||||
scope_tensors = [t for tref in list(all_tensors) if (t:=tref()) is not None and t.uop.topovisit(visitor, in_scope)]
|
||||
scope_tensors: list[Tensor] = [t for tref in list(all_tensors) if (t:=tref()) is not None and t.uop.topovisit(visitor, in_scope)]
|
||||
|
||||
# get all Tensors and apply the map
|
||||
sink = UOp.sink(*[t.uop for t in scope_tensors])
|
||||
@@ -242,6 +242,8 @@ class Tensor(OpMixin):
|
||||
NOTE: A Tensor can only be scheduled once.
|
||||
"""
|
||||
big_sink = UOp.sink(*[x.uop for x in (self,)+lst])
|
||||
|
||||
# this is where the schedule cache should go
|
||||
becomes_map, schedule, var_vals = complete_create_schedule_with_vars(big_sink)
|
||||
_apply_map_to_tensors(becomes_map, name="Apply Schedule Map")
|
||||
return schedule, var_vals
|
||||
|
||||
@@ -76,6 +76,9 @@ class Ops(FastEnum):
|
||||
# tensor graph ops
|
||||
UNIQUE = auto(); DEVICE = auto(); KERNEL = auto(); ASSIGN = auto()
|
||||
|
||||
# local unique
|
||||
LUNIQUE = auto()
|
||||
|
||||
# ops that adjust the behavior of the scheduler
|
||||
CONTIGUOUS = auto(); CONTIGUOUS_BACKWARD = auto(); DETACH = auto()
|
||||
|
||||
|
||||
@@ -216,7 +216,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
def _shape(self) -> tuple[sint, ...]|None:
|
||||
match self.op:
|
||||
# late ops don't have shape
|
||||
case Ops.UNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
||||
case Ops.UNIQUE | Ops.LUNIQUE | Ops.DEVICE | Ops.RANGE | Ops.LOAD | Ops.IF | Ops.BARRIER | Ops.CUSTOM | Ops.CUSTOMI | \
|
||||
Ops.VECTORIZE | Ops.VCONST | Ops.GEP | Ops.SPECIAL | Ops.UNROLL | Ops.CONTRACT:
|
||||
return None
|
||||
|
||||
@@ -663,6 +663,7 @@ class UOp(OpMixin, metaclass=UOpMetaClass):
|
||||
assert all_same([x.size for x in ret.bufs]) and all_same([x.dtype for x in ret.bufs]), "multibuffers mismatch buffers"
|
||||
return ret
|
||||
assert self.op is Ops.BUFFER, f"must be BUFFER {self.op}"
|
||||
assert self.src[0].op is Ops.UNIQUE, "buffer src[0] must be UNIQUE"
|
||||
if (cret:=buffers.get(self)) is not None: return cret
|
||||
rdtype = self.dtype if isinstance(self.dtype, ImageDType) else self.dtype.base
|
||||
if isinstance(self.device, tuple): ret = MultiBuffer(self.device, self.size, rdtype).ref(1)
|
||||
|
||||
@@ -60,9 +60,10 @@ movement_ops = PatternMatcher([
|
||||
_tensor_spec = PatternMatcher([
|
||||
# buffer spec
|
||||
(UPat(Ops.UNIQUE, dtypes.void, ()), lambda: True),
|
||||
(UPat(Ops.LUNIQUE, dtypes.void, ()), lambda: True),
|
||||
(UPat(Ops.DEVICE, dtypes.void, (), name="d"), lambda d:
|
||||
isinstance(d.arg, str) or (isinstance(d.arg, tuple) and all(isinstance(s, str) for s in d.arg))),
|
||||
(UPat(Ops.BUFFER, src=(UPat(Ops.UNIQUE), UPat(Ops.DEVICE)), allow_any_len=True, name="buf"),
|
||||
(UPat(Ops.BUFFER, src=(UPat((Ops.LUNIQUE, Ops.UNIQUE)), UPat(Ops.DEVICE)), allow_any_len=True, name="buf"),
|
||||
lambda buf: isinstance(buf.arg, int) and isinstance(buf.dtype, (DType, ImageDType))),
|
||||
(UPat(Ops.BUFFER_VIEW, src=(UPat(Ops.BUFFER),), name="buf_view"),
|
||||
lambda buf_view: isinstance(buf_view.arg, tuple) and len(buf_view.arg) == 2 and all(isinstance(arg, (int, UOp)) for arg in buf_view.arg)),
|
||||
@@ -85,7 +86,7 @@ _tensor_spec = PatternMatcher([
|
||||
|
||||
# device or unique
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE),)), lambda: True),
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat(Ops.UNIQUE))), lambda: True),
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.DEVICE), UPat((Ops.LUNIQUE, Ops.UNIQUE)))), lambda: True),
|
||||
|
||||
# DETACH and CONTIGUOUS change how we interpret the source UOp
|
||||
# CONTIGUOUS ensures the source UOp realizes
|
||||
|
||||
@@ -69,7 +69,7 @@ def uop_to_json(x:UOp) -> dict[int, dict]:
|
||||
excluded: set[UOp] = set()
|
||||
for u in (toposort:=x.toposort()):
|
||||
# always exclude DEVICE/CONST/UNIQUE
|
||||
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE} and u is not x: excluded.add(u)
|
||||
if u.op in {Ops.DEVICE, Ops.CONST, Ops.UNIQUE, Ops.LUNIQUE} and u is not x: excluded.add(u)
|
||||
if u.op is Ops.VCONST and u.dtype.scalar() == dtypes.index and u is not x: excluded.add(u)
|
||||
if u.op is Ops.VECTORIZE and len(u.src) == 0: excluded.add(u)
|
||||
for u in toposort:
|
||||
|
||||
Reference in New Issue
Block a user