mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 18:35:12 -05:00
from tensor cores + lb touchup (#1127)
This commit is contained in:
@@ -17,8 +17,10 @@ class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = aut
|
||||
|
||||
class LocalBuffer(NamedTuple):
|
||||
name: str
|
||||
size: int
|
||||
dtype: DType = dtypes.float32
|
||||
realized: None = None
|
||||
def __str__(self): return f"localbuffer<{self.name}[{self.size}]>"
|
||||
|
||||
class Token(NamedTuple):
|
||||
name: str
|
||||
@@ -218,9 +220,9 @@ class Linearizer:
|
||||
|
||||
# add a local buffer for multistage reduce
|
||||
if len(self.group_for_reduce):
|
||||
self.bufs.append(LocalBuffer("temp"))
|
||||
# TODO: the strides of this can be controlled
|
||||
self.sts.append(ShapeTracker(tuple([1] * self.first_reduce + self.group_for_reduce + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)])))
|
||||
self.bufs.append(LocalBuffer("temp", self.sts[-1].size()))
|
||||
self.uop(UOps.DEFINE_LOCAL, None, [], ("temp", self.sts[-1].size()))
|
||||
|
||||
# print
|
||||
@@ -401,10 +403,11 @@ class Linearizer:
|
||||
assert len(colors) == self.shape_len, "colors size mismatch"
|
||||
return colors
|
||||
|
||||
def colored_shape(self) -> str: return ' '.join(colored(f"{s:4d}", color) for s,color in zip(self.full_shape, self.colors()))
|
||||
def printbufs(self, prefix=""):
|
||||
for i in range(len(self.sts)):
|
||||
print(prefix, f"{i:3d} {str(self.bufs[i].realized) if self.bufs[i] is not None else 'FAKE':47s}", self.sts[i].views)
|
||||
print(' '.join(colored(f"{s:4d}", color) for s,color in zip(self.full_shape, self.colors())))
|
||||
print(prefix, f"{i:3d} {str(self.bufs[i].realized) if self.bufs[i].realized is not None else str(self.bufs[i]):47s}", self.sts[i].views)
|
||||
print(self.colored_shape())
|
||||
|
||||
# ******************** base simplifiers ********************
|
||||
|
||||
|
||||
@@ -164,10 +164,10 @@ class ShapeTracker:
|
||||
ret[idxs.index(this_dim.a)] = this_dim.b
|
||||
elif isinstance(this_dim, Variable):
|
||||
ret[idxs.index(this_dim)] = 1
|
||||
render_idx, render_valid = idx.render(), valid.render()
|
||||
for i in range(len(self.shape)):
|
||||
if f'idx{i}' in render_valid and not ignore_valid: ret[i] = None
|
||||
elif f'idx{i}' not in render_idx: ret[i] = 0
|
||||
idx_vars, valid_vars = idx.vars(), valid.vars()
|
||||
for i,tidx in enumerate(idxs):
|
||||
if tidx in valid_vars and not ignore_valid: ret[i] = None
|
||||
elif tidx not in idx_vars: ret[i] = 0
|
||||
return tuple(ret)
|
||||
def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ class Node:
|
||||
if ops is None: ops = render_python
|
||||
assert self.__class__ in (Variable, NumNode) or self.min != self.max
|
||||
return ops[type(self)](self, ops, ctx)
|
||||
def vars(self): return []
|
||||
@functools.cached_property
|
||||
def key(self) -> str: return self.render(ctx="DEBUG")
|
||||
@functools.cached_property
|
||||
@@ -124,6 +125,7 @@ class Variable(Node):
|
||||
|
||||
def __init__(self, expr:Optional[str], nmin:int, nmax:int):
|
||||
self.expr, self.min, self.max = expr, nmin, nmax
|
||||
def vars(self): return [self]
|
||||
|
||||
class NumNode(Node):
|
||||
def __init__(self, num:int):
|
||||
@@ -138,6 +140,7 @@ class OpNode(Node):
|
||||
def __init__(self, a:Node, b:int):
|
||||
self.a, self.b = a, b
|
||||
self.min, self.max = self.get_bounds()
|
||||
def vars(self): return self.a.vars()
|
||||
@abstractmethod
|
||||
def get_bounds(self) -> Tuple[int, int]: pass
|
||||
|
||||
@@ -174,6 +177,7 @@ class ModNode(OpNode):
|
||||
|
||||
class RedNode(Node):
|
||||
def __init__(self, nodes:List[Node]): self.nodes = nodes
|
||||
def vars(self): return functools.reduce(lambda l,x: l+x.vars(), self.nodes, [])
|
||||
|
||||
class SumNode(RedNode):
|
||||
def __mul__(self, b: int): return Node.sum([x*b for x in self.nodes]) # distribute mul into sum
|
||||
|
||||
Reference in New Issue
Block a user