mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 18:35:12 -05:00
various touchups (#1058)
* op isn't optional * barrier + named local buffers * end global and local loop together to avoid useless if statement * better comments
This commit is contained in:
@@ -63,7 +63,7 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
|
||||
local_size = []
|
||||
pend_close = None
|
||||
|
||||
bufnames = ["temp" if isinstance(b, LocalBuffer) else f"data{i}" for i,b in enumerate(bufs)]
|
||||
bufnames = [b.name if isinstance(b, LocalBuffer) else f"data{i}" for i,b in enumerate(bufs)]
|
||||
|
||||
depth = 0
|
||||
def kk(s): kernel.append(" "*depth+s)
|
||||
@@ -88,10 +88,11 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
|
||||
else:
|
||||
kk(f"for (int {var.expr} = {var.min}; {var.expr} <= {var.max}; ++{var.expr}) {{")
|
||||
depth += 1
|
||||
elif uop == UOps.BARRIER:
|
||||
kk(lang.barrier)
|
||||
elif uop == UOps.ENDLOOP:
|
||||
if args[1] == "local" and len(lang.lid):
|
||||
# TODO: this is a bit of a hack. the local loop isn't real on the GPU
|
||||
kk(lang.barrier)
|
||||
kk(f"if ({Variable.sum(args[0]).render(render_cl)} == 0) {{")
|
||||
pend_close = "}"*(len(args[0])+1) + f" /* {args[1]} */"
|
||||
else:
|
||||
|
||||
@@ -10,10 +10,11 @@ from tinygrad.ops import MovementOps, ReduceOps, BinaryOps, FusedOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
|
||||
class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = auto(); CONST = auto(); ENDLOOP = auto(); STORE = auto(); CAST = auto(); \
|
||||
class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = auto(); CONST = auto(); ENDLOOP = auto(); STORE = auto(); CAST = auto(); BARRIER = auto(); \
|
||||
SPECIAL = auto(); DEFINE_REGISTER = auto(); LABEL = auto(); COND_BRANCH = auto() # noqa: E702
|
||||
|
||||
class LocalBuffer(NamedTuple):
|
||||
name: str
|
||||
dtype: DType = dtypes.float32
|
||||
realized: None = None
|
||||
|
||||
@@ -223,7 +224,7 @@ class Linearizer:
|
||||
|
||||
# add a local buffer for multistage reduce
|
||||
if len(self.group_for_reduce):
|
||||
self.bufs.append(LocalBuffer())
|
||||
self.bufs.append(LocalBuffer("temp"))
|
||||
# TODO: the strides of this can be controlled
|
||||
self.sts.append(ShapeTracker(tuple([1] * self.first_reduce + self.group_for_reduce + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)])))
|
||||
self.uop(UOps.DEFINE_LOCAL, None, [], ("temp", self.sts[-1].size()))
|
||||
@@ -280,7 +281,8 @@ class Linearizer:
|
||||
if self.group_for_reduce:
|
||||
fake_global_idxs = [x*0 for x in global_idxs]
|
||||
self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs, acc, ssa) # store accumulators
|
||||
self.uop(UOps.ENDLOOP, None, [], (local_idxs, "local")) # this is a barrier on GPUs
|
||||
self.uop(UOps.BARRIER, None, [], ())
|
||||
self.uop(UOps.ENDLOOP, None, [], (local_idxs, "local"))
|
||||
|
||||
# local indexs are over, 0 them out
|
||||
local_idxs = [x*0 for x in local_idxs]
|
||||
@@ -320,11 +322,11 @@ class Linearizer:
|
||||
self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs, val, ssa)
|
||||
|
||||
if not self.group_for_reduce:
|
||||
# end the local loop
|
||||
self.uop(UOps.ENDLOOP, None, [], (local_idxs, "local"))
|
||||
|
||||
# end the global loop
|
||||
self.uop(UOps.ENDLOOP, None, [], (global_idxs, "global"))
|
||||
# end the global+local loop
|
||||
self.uop(UOps.ENDLOOP, None, [], (global_idxs+local_idxs, "global+local"))
|
||||
else:
|
||||
# end the global loop
|
||||
self.uop(UOps.ENDLOOP, None, [], (global_idxs, "global"))
|
||||
|
||||
_OT = TypeVar("_OT")
|
||||
def uop(self, uop:UOps, out:_OT, vin:List[Token], arg:Any=None) -> _OT:
|
||||
|
||||
@@ -90,9 +90,8 @@ class LazyBuffer:
|
||||
# TODO: does children have to be a ref count instead of a set? can a Buffer be a double child?
|
||||
self.children: LightWeakSet = LightWeakSet()
|
||||
# NOTE: op should be read only after construction of LazyBuffer
|
||||
if op:
|
||||
self.op: LazyOp = op
|
||||
for x in op.buffers: x.children.add(self)
|
||||
self.op: LazyOp = op
|
||||
for x in op.buffers: x.children.add(self)
|
||||
if not LAZY: self.realize()
|
||||
|
||||
# log phantom ops to the graph
|
||||
|
||||
@@ -33,7 +33,7 @@ class LazyOp:
|
||||
self.op = op
|
||||
self.src = src
|
||||
self.arg = arg
|
||||
# TODO: this hasattr is required because the key function maps the buffers to ints
|
||||
# TODO: this hasattr is required because the linearizer's key function maps the buffers to ints
|
||||
self.buffers = functools.reduce(lambda x,s: (x+s.buffers) if hasattr(s, 'buffers') else x, src, tuple())
|
||||
|
||||
def __repr__(self): return f"LazyOp(op={self.op}, src={self.src}, arg={self.arg})"
|
||||
|
||||
Reference in New Issue
Block a user