from tensor cores + lb touchup (#1127)

This commit is contained in:
George Hotz
2023-07-04 15:45:20 -07:00
committed by GitHub
parent 2f968f8547
commit 793a670187
5 changed files with 29 additions and 8 deletions

View File

@@ -17,8 +17,10 @@ class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = aut
class LocalBuffer(NamedTuple):
name: str
size: int
dtype: DType = dtypes.float32
realized: None = None
def __str__(self): return f"localbuffer<{self.name}[{self.size}]>"
class Token(NamedTuple):
name: str
@@ -218,9 +220,9 @@ class Linearizer:
# add a local buffer for multistage reduce
if len(self.group_for_reduce):
self.bufs.append(LocalBuffer("temp"))
# TODO: the strides of this can be controlled
self.sts.append(ShapeTracker(tuple([1] * self.first_reduce + self.group_for_reduce + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)])))
self.bufs.append(LocalBuffer("temp", self.sts[-1].size()))
self.uop(UOps.DEFINE_LOCAL, None, [], ("temp", self.sts[-1].size()))
# print
@@ -401,10 +403,11 @@ class Linearizer:
assert len(colors) == self.shape_len, "colors size mismatch"
return colors
def colored_shape(self) -> str: return ' '.join(colored(f"{s:4d}", color) for s,color in zip(self.full_shape, self.colors()))
def printbufs(self, prefix=""):
for i in range(len(self.sts)):
print(prefix, f"{i:3d} {str(self.bufs[i].realized) if self.bufs[i] is not None else 'FAKE':47s}", self.sts[i].views)
print(' '.join(colored(f"{s:4d}", color) for s,color in zip(self.full_shape, self.colors())))
print(prefix, f"{i:3d} {str(self.bufs[i].realized) if self.bufs[i].realized is not None else str(self.bufs[i]):47s}", self.sts[i].views)
print(self.colored_shape())
# ******************** base simplifiers ********************

View File

@@ -164,10 +164,10 @@ class ShapeTracker:
ret[idxs.index(this_dim.a)] = this_dim.b
elif isinstance(this_dim, Variable):
ret[idxs.index(this_dim)] = 1
render_idx, render_valid = idx.render(), valid.render()
for i in range(len(self.shape)):
if f'idx{i}' in render_valid and not ignore_valid: ret[i] = None
elif f'idx{i}' not in render_idx: ret[i] = 0
idx_vars, valid_vars = idx.vars(), valid.vars()
for i,tidx in enumerate(idxs):
if tidx in valid_vars and not ignore_valid: ret[i] = None
elif tidx not in idx_vars: ret[i] = 0
return tuple(ret)
def unit_stride_axes(self, ignore_valid=False) -> List[int]: return [i for i,st in enumerate(self.real_strides(ignore_valid)) if st == 1]

View File

@@ -16,6 +16,7 @@ class Node:
if ops is None: ops = render_python
assert self.__class__ in (Variable, NumNode) or self.min != self.max
return ops[type(self)](self, ops, ctx)
def vars(self): return []
@functools.cached_property
def key(self) -> str: return self.render(ctx="DEBUG")
@functools.cached_property
@@ -124,6 +125,7 @@ class Variable(Node):
def __init__(self, expr:Optional[str], nmin:int, nmax:int):
self.expr, self.min, self.max = expr, nmin, nmax
def vars(self): return [self]
class NumNode(Node):
def __init__(self, num:int):
@@ -138,6 +140,7 @@ class OpNode(Node):
def __init__(self, a:Node, b:int):
self.a, self.b = a, b
self.min, self.max = self.get_bounds()
def vars(self): return self.a.vars()
@abstractmethod
def get_bounds(self) -> Tuple[int, int]: pass
@@ -174,6 +177,7 @@ class ModNode(OpNode):
class RedNode(Node):
def __init__(self, nodes:List[Node]): self.nodes = nodes
def vars(self): return functools.reduce(lambda l,x: l+x.vars(), self.nodes, [])
class SumNode(RedNode):
def __mul__(self, b: int): return Node.sum([x*b for x in self.nodes]) # distribute mul into sum