mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-25 23:08:06 -05:00
remove lines (unused code) (#2319)
* remove lines * uhh, i'm tired * that function never worked * types for ast_parse
This commit is contained in:
@@ -143,7 +143,7 @@ class Kernel:
|
||||
[x!=y for x,y in zip(self.sts[0].shape[self.shape_len-self.upcasted:], self.full_shape[self.shape_len-self.upcasted:])]))
|
||||
|
||||
# TODO: is there a better way to write this?
|
||||
def acc_offsets(self, i):
|
||||
def acc_offsets(self, i) -> List[int]:
|
||||
if self.upcasted == 0: return [0]
|
||||
upcasted_i = self.upcasted_axis(i)
|
||||
acc_strides = [x*(1-upcasted_i[::-1][i][2]) for i,x in enumerate(strides_for_shape(tuple(1 if r else s for s,_,r in upcasted_i[::-1])))]
|
||||
@@ -279,6 +279,7 @@ class Kernel:
|
||||
for i,x in enumerate(rets[:len(self.sts)]): self.sts[i] = self.sts[i].reshape(tuple([y[0] for y in x]))
|
||||
|
||||
# ******************** GPU simplifiers ********************
|
||||
|
||||
def _limit_size(self, x: Tuple[int], max_size: List) -> Tuple[int, ...]:
|
||||
new_shape,dims = list(x), len(x)
|
||||
for i in range(dims):
|
||||
|
||||
@@ -223,7 +223,7 @@ class Linearizer(Kernel):
|
||||
|
||||
# parse AST
|
||||
loaded_buffers = {}
|
||||
acc = []
|
||||
acc: List[UOp] = []
|
||||
self.load_cache: Dict[str, UOp] = {}
|
||||
|
||||
# reduce op
|
||||
@@ -349,7 +349,7 @@ class Linearizer(Kernel):
|
||||
loaded_buffers[self.bufs[-1]] = self.global_load(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, barrier=barrier)
|
||||
|
||||
# there's no AST here (and there's no shape for the reduce LazyOp)
|
||||
self.ast_parse(LazyOp(self.reduceop.op, (self.bufs[-1],)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx) # type: ignore
|
||||
self.ast_parse(LazyOp(self.reduceop.op, (LazyOp(BufferOps.MEM, (), self.bufs[-1]),)), acc, self.acc_offsets(-1), loaded_buffers, do_reduce=True, loop_ctx=loop_ctx)
|
||||
|
||||
# end the late reduce loop
|
||||
self.load_cache.clear()
|
||||
@@ -471,11 +471,10 @@ class Linearizer(Kernel):
|
||||
if cachable: self.saved_exprs[key] = ret
|
||||
return ret
|
||||
|
||||
def ast_parse(self, x, acc, offs, loaded_buffers, do_reduce=False, loop_ctx=tuple()) -> List[UOp]:
|
||||
if x.__class__ is not LazyOp: return loaded_buffers[x] # for LOCAL_BUFFER
|
||||
def ast_parse(self, x:LazyOp, acc: List[UOp], offs:Optional[List[int]], loaded_buffers:Dict[Union[MemBuffer, ConstBuffer, LocalBuffer], List[UOp]], do_reduce=False, loop_ctx=tuple()) -> List[UOp]:
|
||||
if x.op in BufferOps: return loaded_buffers[x.arg]
|
||||
if x.op == UnaryOps.NOOP: return self.ast_parse(x.src[0], acc, offs, loaded_buffers)
|
||||
if x.op == UnaryOps.CAST: return [self.uop(UOps.CAST, x.arg[0], (u,), x.arg) if not isinstance(x.arg[0], ImageDType) else u for u in self.ast_parse(x.src[0], acc, offs, loaded_buffers)]
|
||||
if x.op == UnaryOps.NOOP: return self.ast_parse(cast(LazyOp, x.src[0]), acc, offs, loaded_buffers)
|
||||
if x.op == UnaryOps.CAST: return [self.uop(UOps.CAST, x.arg[0], (u,), x.arg) if not isinstance(x.arg[0], ImageDType) else u for u in self.ast_parse(cast(LazyOp, x.src[0]), acc, offs, loaded_buffers)]
|
||||
if x.op in ReduceOps and not do_reduce:
|
||||
assert offs is None, "not available if we aren't doing reduce"
|
||||
return acc
|
||||
@@ -484,12 +483,12 @@ class Linearizer(Kernel):
|
||||
x = LazyOp(TernaryOps.MULACC, x.src[0].src, x.arg)
|
||||
if x.op == ReduceOps.SUM and x.src[0].__class__ is LazyOp and x.src[0].op == UnaryOps.CAST and x.src[0].src[0].__class__ is LazyOp and x.src[0].src[0].op == BinaryOps.MUL:
|
||||
x = LazyOp(TernaryOps.MULACC, x.src[0].src[0].src, x.arg)
|
||||
values = [self.ast_parse(v, acc, offs, loaded_buffers, loop_ctx=loop_ctx) for v in x.src]
|
||||
values = [self.ast_parse(cast(LazyOp, v), acc, offs, loaded_buffers, loop_ctx=loop_ctx) for v in x.src]
|
||||
ops = {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, TernaryOps.MULACC:TernaryOps.MULACC}
|
||||
if x.op in ops:
|
||||
ret = []
|
||||
input_acc = acc[:]
|
||||
for idx, val, off in zip([[i] for i in range(len(values[0]))], zip(*values), offs):
|
||||
for idx, val, off in zip([[i] for i in range(len(values[0]))], zip(*values), cast(List[int], offs)):
|
||||
acc[off] = self.uop(UOps.ALU, dtypes.float32, val+(acc[off],), ops[x.op])
|
||||
ret.append((idx, acc[off]))
|
||||
for off in range(len(acc)):
|
||||
|
||||
@@ -35,12 +35,6 @@ class Node:
|
||||
def substitute(self, var_vals: Dict[VariableOrNum, Node]) -> Node: raise RuntimeError(self.__class__.__name__)
|
||||
def unbind(self) -> Tuple[Node, Optional[int]]: return self.substitute({v: v.unbind()[0] for v in self.vars() if v.val is not None}), None
|
||||
|
||||
@property
|
||||
def val(self):
|
||||
ret = self.substitute({x:NumNode(x.val) for x in self.vars()})
|
||||
assert isinstance(ret, NumNode), f"val must be NumNode, it's {ret}"
|
||||
return ret.b
|
||||
|
||||
@functools.cached_property
|
||||
def key(self) -> str: return self.render(ctx="DEBUG")
|
||||
@functools.cached_property
|
||||
|
||||
@@ -125,12 +125,6 @@ class Tensor:
|
||||
assert self.dtype.np is not None, f"no numpy dtype for {self.dtype}"
|
||||
return self.detach().cast(dtypes.from_np(self.dtype.np)).contiguous().to('CPU').realize().lazydata.realized.toCPU().reshape(self.shape)
|
||||
|
||||
# TODO: if things are realized this won't work
|
||||
def to_(self, device:str):
|
||||
assert self.lazydata.realized is None
|
||||
self.lazydata.device = device
|
||||
if self.grad: self.grad.to_(device)
|
||||
|
||||
def to(self, device:str) -> Tensor:
|
||||
ret = Tensor(self.lazydata, device)
|
||||
if self.grad: ret.grad = self.grad.to(device)
|
||||
@@ -223,10 +217,8 @@ class Tensor:
|
||||
cdf = p.cumsum(1)
|
||||
cdf /= cdf[:, -1].unsqueeze(1)
|
||||
unif_samples = Tensor.rand(num_samples, p.shape[0], 1)
|
||||
indices = (unif_samples.expand((-1, -1, p.shape[1])) >= cdf).sum(2)
|
||||
indices = indices.permute((1, 0))
|
||||
if self.ndim == 1:
|
||||
indices = indices.squeeze(0)
|
||||
indices = (unif_samples.expand((-1, -1, p.shape[1])) >= cdf).sum(2).permute((1, 0))
|
||||
if self.ndim == 1: indices = indices.squeeze(0)
|
||||
return indices.cast(dtypes.int32)
|
||||
|
||||
# ***** toposort and backward pass *****
|
||||
@@ -794,9 +786,7 @@ class Tensor:
|
||||
def is_floating_point(self) -> bool: return dtypes.is_float(self.dtype)
|
||||
|
||||
# register functions to move between devices
|
||||
for device in Device._buffers:
|
||||
setattr(Tensor, f"{device.lower()}", partialmethod(Tensor.to, device))
|
||||
setattr(Tensor, f"{device.lower()}_", partialmethod(Tensor.to_, device))
|
||||
for device in Device._buffers: setattr(Tensor, f"{device.lower()}", partialmethod(Tensor.to, device))
|
||||
|
||||
if IMAGE:
|
||||
# if IMAGE>0 we install these replacement functions in Tensor (hack!)
|
||||
|
||||
Reference in New Issue
Block a user