remove lines (unused code) (#2319)

* remove lines

* uhh, i'm tired

* that function never worked

* types for ast_parse
This commit is contained in:
George Hotz
2023-11-15 14:36:11 -08:00
committed by GitHub
parent 628365eab6
commit 294e71de15
4 changed files with 12 additions and 28 deletions

View File

@@ -143,7 +143,7 @@ class Kernel:
[x!=y for x,y in zip(self.sts[0].shape[self.shape_len-self.upcasted:], self.full_shape[self.shape_len-self.upcasted:])]))
# TODO: is there a better way to write this?
def acc_offsets(self, i):
def acc_offsets(self, i) -> List[int]:
if self.upcasted == 0: return [0]
upcasted_i = self.upcasted_axis(i)
acc_strides = [x*(1-upcasted_i[::-1][i][2]) for i,x in enumerate(strides_for_shape(tuple(1 if r else s for s,_,r in upcasted_i[::-1])))]
@@ -279,6 +279,7 @@ class Kernel:
for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x]))
# ******************** GPU simplifiers ********************
def _limit_size(self, x: Tuple[int], max_size: List) -> Tuple[int, ...]:
new_shape,dims = list(x), len(x)
for i in range(dims):

View File

@@ -223,7 +223,7 @@ class Linearizer(Kernel):
# parse AST
loaded_buffers = {}
acc = []
acc: List[UOp] = []
self.load_cache: Dict[str, UOp] = {}
# reduce op
@@ -349,7 +349,7 @@ class Linearizer(Kernel):
loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier)
# there's no AST here (and there's no shape for the reduce LazyOp)
self.ast_parse(LazyOp(self.reduceop.op, (self.bufs[-1],)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) # type: ignore
self.ast_parse(LazyOp(self.reduceop.op, (LazyOp(BufferOps.MEM, (), self.bufs[-1]),)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx)
# end the late reduce loop
self.load_cache.clear()
@@ -471,11 +471,10 @@ class Linearizer(Kernel):
if cachable: self.saved_exprs[key] = ret
return ret
def ast_parse(self, x, acc, offs, loaded_buffers, do_reduce=False, loop_ctx=tuple()) -> List[UOp]:
if x.__class__ is not LazyOp: return loaded_buffers[x] # for LOCAL_BUFFER
def ast_parse(self, x:LazyOp, acc: List[UOp], offs:Optional[List[int]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], do_reduce=False, loop_ctx=tuple()) -> List[UOp]:
if x.op in BufferOps: return loaded_buffers[x.arg]
if x.op == UnaryOps.NOOP: return self.ast_parse(x.src[0], acc, offs, loaded_buffers)
if x.op == UnaryOps.CAST: return [self.uop(UOps.CAST, x.arg[0], (u,), x.arg) if not isinstance(x.arg[0], ImageDType) else u for u in self.ast_parse(x.src[0], acc, offs, loaded_buffers)]
if x.op == UnaryOps.NOOP: return self.ast_parse(cast(LazyOp, x.src[0]), acc, offs, loaded_buffers)
if x.op == UnaryOps.CAST: return [self.uop(UOps.CAST, x.arg[0], (u,), x.arg) if not isinstance(x.arg[0], ImageDType) else u for u in self.ast_parse(cast(LazyOp, x.src[0]), acc, offs, loaded_buffers)]
if x.op in ReduceOps and not do_reduce:
assert offs is None, "not available if we aren't doing reduce"
return acc
@@ -484,12 +483,12 @@ class Linearizer(Kernel):
x = LazyOp(TernaryOps.MULACC, x.src[0].src, x.arg)
if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == UnaryOps.CAST and x.src[0].src[0].__class__ is LazyOp and x.src[0].src[0].op == BinaryOps.MUL:
x = LazyOp(TernaryOps.MULACC, x.src[0].src[0].src, x.arg)
values = [self.ast_parse(v, acc, offs, loaded_buffers, loop_ctx=loop_ctx) for v in x.src]
values = [self.ast_parse(cast(LazyOp, v), acc, offs, loaded_buffers, loop_ctx=loop_ctx) for v in x.src]
ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, TernaryOps.MULACC:TernaryOps.MULACC}
if x.op in ops:
ret = []
input_acc = acc[:]
for idx, val, off in zip([[i] for i in range(len(values[0]))], zip(*values), offs):
for idx, val, off in zip([[i] for i in range(len(values[0]))], zip(*values), cast(List[int], offs)):
acc[off] = self.uop(UOps.ALU, dtypes.float32, val+(acc[off],), ops[x.op])
ret.append((idx, acc[off]))
for off in range(len(acc)):

View File

@@ -35,12 +35,6 @@ class Node:
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: raise RuntimeError(self.__class__.__name__)
def unbind(self) -> Tuple[Node, Optional[int]]: return self.substitute({v: v.unbind()[0] for v in self.vars() if v.val is not None}), None
@property
def val(self):
ret = self.substitute({x:NumNode(x.val) for x in self.vars()})
assert isinstance(ret, NumNode), f"val must be NumNode, it's {ret}"
return ret.b
@functools.cached_property
def key(self) -> str: return self.render(ctx="DEBUG")
@functools.cached_property

View File

@@ -125,12 +125,6 @@ class Tensor:
assert self.dtype.np is not None, f"no numpy dtype for {self.dtype}"
return self.detach().cast(dtypes.from_np(self.dtype.np)).contiguous().to('CPU').realize().lazydata.realized.toCPU().reshape(self.shape)
# TODO: if things are realized this won't work
def to_(self, device:str):
assert self.lazydata.realized is None
self.lazydata.device = device
if self.grad: self.grad.to_(device)
def to(self, device:str) -> Tensor:
ret = Tensor(self.lazydata, device)
if self.grad: ret.grad = self.grad.to(device)
@@ -223,10 +217,8 @@ class Tensor:
cdf = p.cumsum(1)
cdf /= cdf[:, -1].unsqueeze(1)
unif_samples = Tensor.rand(num_samples, p.shape[0], 1)
indices = (unif_samples.expand((-1, -1, p.shape[1])) >= cdf).sum(2)
indices = indices.permute((1, 0))
if self.ndim == 1:
indices = indices.squeeze(0)
indices = (unif_samples.expand((-1, -1, p.shape[1])) >= cdf).sum(2).permute((1, 0))
if self.ndim == 1: indices = indices.squeeze(0)
return indices.cast(dtypes.int32)
# ***** toposort and backward pass *****
@@ -794,9 +786,7 @@ class Tensor:
def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)
# register functions to move between devices
for device in Device._buffers:
setattr(Tensor, f"{device.lower()}", partialmethod(Tensor.to, device))
setattr(Tensor, f"{device.lower()}_", partialmethod(Tensor.to_, device))
for device in Device._buffers: setattr(Tensor, f"{device.lower()}", partialmethod(Tensor.to, device))
if IMAGE:
# if IMAGE>0 we install these replacement functions in Tensor (hack!)