reduce number of lines (#645)

This commit is contained in:
Cyril Roumégous
2023-03-06 00:42:32 +01:00
committed by GitHub
parent 7989f79820
commit c10131ddf5
16 changed files with 40 additions and 81 deletions

View File

@@ -7,8 +7,7 @@ from tinygrad.shape import ShapeTracker, View, strides_for_shape
def get_first_reduce(shapes):
for i in range(len(shapes[0])):
if not all_same([x[i] for x in shapes]):
return i
if not all_same([x[i] for x in shapes]): return i
return len(shapes[0]) # off the end
# this will be removed soon anyway
@@ -89,9 +88,7 @@ class ASTKernel:
self.full_buf_index : int = self.bufs.index(self.earlybufs[0]) if len(self.earlybufs) > 0 else 0
def print(self):
buf_count = -1
op_count = -1
cache = {}
buf_count, op_count, cache = -1, -1, {}
def print_ast(x, name=None):
nonlocal buf_count, op_count
if x not in cache:
@@ -114,8 +111,7 @@ class ASTKernel:
def printbufs(self, prefix="", print_shapetrackers=False):
print(f"first_reduce: {self.first_reduce} shape_len: {self.shape_len} group_for_reduce: {self.group_for_reduce}")
if print_shapetrackers:
for st in self.sts:
print(st)
for st in self.sts: print(st)
for i in range(len(self.sts)):
print(prefix, self.buftokens[i], f"early:{'T' if i < len(self.bufs) and self.bufs[i] in self.earlybufs else 'F'}", self.sts[i].shape, self.sts[i].views[-1].strides, len(self.sts[i].views), type(self.bufs[i]._buf) if self.bufs[i] is not None else "FAKE")
@@ -158,10 +154,8 @@ class ASTKernel:
# more can merge than this
mergeable = all(can_merge) and i != self.first_reduce
for j in range(len(shapes)):
if mergeable:
rets[j][-1] = (rets[j][-1][0] * shapes[j][i], strides[j][i])
else:
rets[j].append((shapes[j][i], strides[j][i]))
if mergeable: rets[j][-1] = (rets[j][-1][0] * shapes[j][i], strides[j][i])
else: rets[j].append((shapes[j][i], strides[j][i]))
for i,x in enumerate(rets): self.sts[i].reshape(tuple(y[0] for y in x))
self.first_reduce = get_first_reduce([x.shape for x in self.sts])

View File

@@ -322,12 +322,10 @@ class GPUCodegen(ASTKernel):
function_name = ("re_S" if self.reduceop else "ew_S") + '_'.join([str(x) for x in self.full_shape])
# painfully name the function
if prg in GPUCodegen.kernel_name_cache:
function_name = GPUCodegen.kernel_name_cache[prg]
if prg in GPUCodegen.kernel_name_cache: function_name = GPUCodegen.kernel_name_cache[prg]
else:
GPUCodegen.kernel_cnt[function_name] += 1
if GPUCodegen.kernel_cnt[function_name]:
function_name = f"{function_name}{'_N'+str(GPUCodegen.kernel_cnt[function_name])}"
if GPUCodegen.kernel_cnt[function_name]: function_name = f"{function_name}{'_N'+str(GPUCodegen.kernel_cnt[function_name])}"
GPUCodegen.kernel_name_cache[prg] = function_name
return ASTRunner(function_name, prg.replace("KERNEL_NAME_PLACEHOLDER", function_name), self.bufs_to_delete,

View File

@@ -112,8 +112,7 @@ class LLVMCodegen(ASTKernel):
func.attributes.add('"no-nans-fp-math"="true"')
# construct the structure of the loops
loop_entry = [ir.IRBuilder(func.append_basic_block(name="entry"))]
loop_exit = []
loop_entry, loop_exit = [ir.IRBuilder(func.append_basic_block(name="entry"))], []
for i,_ in enumerate(full_shape): loop_entry.append(ir.IRBuilder(func.append_basic_block(name=f"loop_{i}")))
for i,_ in enumerate(full_shape): loop_exit.append(ir.IRBuilder(func.append_basic_block(name=f"loopexit_{len(full_shape)-1-i}")))
loop_exit.append(ir.IRBuilder(func.append_basic_block(name="exit")))
@@ -174,8 +173,7 @@ class LLVMCodegen(ASTKernel):
if self.reduceop:
reduce_input = ast_parse(loop_exit[-1], self.reduceop.src[0], -1)
phis = [LLVMCodegen.start_for_op[self.reduceop.op]] # type: ignore
if kernel_output_dim > 1:
phis = [kernel_output_type(phis * kernel_output_dim)]
if kernel_output_dim > 1: phis = [kernel_output_type(phis * kernel_output_dim)]
for i in range(store_loop+1, len(loop_entry)):
val = loop_entry[i].phi(kernel_output_type, f"reduce_phi_{i}")
val.add_incoming(phis[-1], loop_entry[i-1]._block)

View File

@@ -18,10 +18,8 @@ G = nx.DiGraph() if nx is not None else None
cnts : Dict[OpType, int] = defaultdict(int)
if GRAPH:
def save_graph_exit():
for k,v in cnts.items():
print(k, v)
if PRUNEGRAPH:
prune_graph()
for k,v in cnts.items(): print(k, v)
if PRUNEGRAPH: prune_graph()
print("saving", G)
nx.drawing.nx_pydot.write_dot(G, f'{GRAPHPATH}.dot')
# -Gnslimit=100 can make it finish, but you won't like results
@@ -61,8 +59,7 @@ def log_op(ret : DeviceBuffer, ast : LazyOp, show_graph : Optional[bool] = None)
G.add_edge(nm(x), nm(ret), label=get_sop(op))
if 'label' not in G.nodes[nm(x)]:
G.nodes[nm(x)]['label'] = str(x.shape)
if nm(ret) not in G.nodes:
G.add_node(nm(ret))
if nm(ret) not in G.nodes: G.add_node(nm(ret))
G.nodes[nm(ret)]['label'] = str(set(x.shape for x in inp))+"\n"+str(ret.shape) if optype == ReduceOps else str(ret.shape)
G.nodes[nm(ret)]['fillcolor'] = (top_colors[optype] + ('80' if dashed else str())) if optype in top_colors else "#ffffff"

View File

@@ -15,5 +15,4 @@ def mnum(i) -> str: return str(i) if i >= 0 else f"m{-i}"
@functools.lru_cache(maxsize=None)
def getenv(key, default=0): return type(default)(os.getenv(key, default))
DEBUG = getenv("DEBUG", 0)
IMAGE = getenv("IMAGE", 0)
DEBUG, IMAGE = getenv("DEBUG", 0), getenv("IMAGE", 0)

View File

@@ -8,10 +8,10 @@ from tinygrad.ops import GlobalCounters, CompiledBuffer, RawBuffer
class TinyJit:
def __init__(self, fxn:Callable):
self.fxn = fxn
self.cnt = 0
self.fxn : Callable = fxn
self.cnt : int = 0
self.jit_cache : List[Tuple[Callable, Any]] = [] # TODO: Any should be List[RawBuffer], but this fails
self.ret = None
self.ret : Any = None
self.input_replace : Dict[Tuple[int, int], Union[int, str]]= {}
def __call__(self, *args, **kwargs) -> Any:

View File

@@ -62,8 +62,7 @@ def _ast_binaryops(self:LazyBuffer) -> LazyOp:
# reshape all the late ops into the output shape
# NOTE: these RESHAPEs will return self if they don't change the shape
for x in real_srcs.keys():
if real_srcs[x] is None:
real_srcs[x] = x.movement_op(MovementOps.RESHAPE, intermediate_shape)
if real_srcs[x] is None: real_srcs[x] = x.movement_op(MovementOps.RESHAPE, intermediate_shape)
ast = map_buffers(real_srcs, self.op)
return LazyOp(MovementOps.RESHAPE, (ast, ), self.shape) if intermediate_shape != self.shape else ast
@@ -105,17 +104,14 @@ class LazyBuffer:
# TODO: does children have to be a ref count instead of a set? can a Buffer be a double child?
self.children : weakref.WeakSet[LazyBuffer] = weakref.WeakSet()
# NOTE: op should be read only after construction of LazyBuffer
for x in get_buffers(op):
x.children.add(self)
if not LAZY:
self.realize()
for x in get_buffers(op): x.children.add(self)
if not LAZY: self.realize()
def __repr__(self): return f"<LB {self.shape} op:{self.op.op if self.realized is None else 'realized'}>"
# this produces a device buffer
def realize(self:LazyBuffer, required_device=None) -> DeviceBuffer:
if required_device is not None:
assert required_device == self.device
assert required_device is None or required_device == self.device
if self.realized is None:
# get real ops first
if self.op.op == LoadOps.FROMCPU:
@@ -162,8 +158,7 @@ class LazyBuffer:
def contiguous(self:LazyBuffer) -> LazyBuffer: return LazyBuffer(self.device, self.shape, LoadOps, LazyOp(LoadOps.CONTIGUOUS, (self,)))
def reduce_op(self:LazyBuffer, op:ReduceOps, new_shape:Tuple[int, ...]) -> LazyBuffer:
if self.shape == tuple(new_shape):
return self
if self.shape == tuple(new_shape): return self
reduce = list(enumerate(zip(self.shape, new_shape)))
# move the reduce axes to the end
x = self.movement_op(MovementOps.PERMUTE, tuple([i for i,(s,n) in reduce if s == n] + [i for i,(s,n) in reduce if s != n]))
@@ -224,16 +219,13 @@ class LazyBuffer:
out.append(curr)
if len(new_shape) == len(out) and all(prod(i) == j and len(i) >= 1 for i,j in zip(out, new_shape)):
return out
contraction = get_contraction(self.op.src[0].shape, self.shape)
if contraction is not None:
numbered = []
start = 0
if contraction := get_contraction(self.op.src[0].shape, self.shape):
numbered, start = [], 0
for c in contraction:
numbered.append(list(range(start, start+len(c))))
start += len(c)
new_arg = []
for p in arg:
new_arg += numbered[p]
for p in arg: new_arg += numbered[p]
self.op.src[0].children.discard(self) # this changes nothing?
return self.op.src[0].movement_op(MovementOps.PERMUTE, tuple(new_arg)) \
.movement_op(MovementOps.RESHAPE, ShapeTracker(self.st).movement_op(op, arg).shape)

View File

@@ -6,8 +6,7 @@ class Optimizer:
def __init__(self, params : List[Tensor]):
# if it's None, but being put into an optimizer, set it to True
for x in params:
if x.requires_grad is None:
x.requires_grad = True
if x.requires_grad is None: x.requires_grad = True
self.params : List[Tensor] = [x for x in params if x.requires_grad]
self.buffers : List[Tensor] = [x for x in params if not x.requires_grad] # buffers are still realized
@@ -20,8 +19,7 @@ class Optimizer:
param.grad.assign(param.grad.clip(-(amount**2), (amount**2)))
def zero_grad(self):
for param in self.params:
param.grad = None
for param in self.params: param.grad = None
def realize(self, extra=None):
# TODO: corealize
@@ -83,9 +81,7 @@ def get_parameters(obj) -> List[Tensor]:
if isinstance(obj, Tensor):
parameters.append(obj)
elif isinstance(obj, (list, tuple)):
for x in obj:
parameters.extend(get_parameters(x))
for x in obj: parameters.extend(get_parameters(x))
elif hasattr(obj, '__dict__'):
for v in obj.__dict__.values():
parameters.extend(get_parameters(v))
for v in obj.__dict__.values(): parameters.extend(get_parameters(v))
return parameters

View File

@@ -124,8 +124,7 @@ class ASTRunner:
def lower(self, bufs) -> List[RawBuffer]: return [x.raw() for i,x in enumerate(bufs) if x is not None and i not in self.bufs_to_delete]
def __call__(self, bufs):
if getenv("OPTLOCAL") and self.global_size is not None and self.local_size is None: self.local_size = self.optimize_local_size(bufs)
et = self.clprg(self.global_size, self.local_size, *bufs, wait=DEBUG>=2)
if et is not None: GlobalCounters.time_sum_s += et
if et := self.clprg(self.global_size, self.local_size, *bufs, wait=DEBUG>=2): GlobalCounters.time_sum_s += et
if DEBUG >= 1:
print(f"**** {GlobalCounters.kernel_count:4d} {self.name:20s} args {len(bufs):5d} kernels {str(self.global_size):18s} {str(self.local_size):12s} OPs {self.op_estimate/1e6:7.1f}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/(et*1e9):8.2f} GFLOPS)"))

View File

@@ -35,6 +35,4 @@ class ClangCodegen(GPUCodegen):
lang = GPULanguage(buffer_suffix="restrict")
class ClangBuffer(CompiledBuffer):
raw_buffer_type = RawMallocBuffer
codegen_type = ClangCodegen
runtime_type = ClangProgram
raw_buffer_type, codegen_type, runtime_type = RawMallocBuffer, ClangCodegen, ClangProgram

View File

@@ -40,6 +40,4 @@ class CUDACodegen(GPUCodegen):
lid = [f'threadIdx.{chr(120+i)}' for i in range(3)])
class CUDABuffer(CompiledBuffer):
raw_buffer_type = RawCUDABuffer
codegen_type = CUDACodegen
runtime_type = CUDAProgram
raw_buffer_type, codegen_type, runtime_type = RawCUDABuffer, CUDACodegen, CUDAProgram

View File

@@ -27,8 +27,7 @@ class LLVM:
# TODO: this makes compile times so much faster
if getenv("LLVMOPT"):
llvm.set_option(str(), '-force-vector-interleave=4') # this makes sum the same speed as torch, it also doubles the (slow) conv speed
if DEBUG >= 4:
llvm.set_option(str(), '--debug-only=loop-vectorize')
if DEBUG >= 4: llvm.set_option(str(), '--debug-only=loop-vectorize')
#llvm.set_option(str(), '--debug')
# does this do anything?
@@ -64,6 +63,4 @@ class LLVMProgram:
if wait: return time.monotonic()-st
class LLVMBuffer(CompiledBuffer):
raw_buffer_type = RawMallocBuffer
codegen_type = LLVMCodegen
runtime_type = LLVMProgram
raw_buffer_type, codegen_type, runtime_type = RawMallocBuffer, LLVMCodegen, LLVMProgram

View File

@@ -84,6 +84,4 @@ class MetalCodegen(GPUCodegen):
extra_args = ['uint3 gid [[thread_position_in_grid]]', 'uint3 lid [[thread_position_in_threadgroup]]'])
class MetalBuffer(CompiledBuffer):
raw_buffer_type = RawMetalBuffer
codegen_type = MetalCodegen
runtime_type = MetalProgram
raw_buffer_type, codegen_type, runtime_type = RawMetalBuffer, MetalCodegen, MetalProgram

View File

@@ -29,8 +29,7 @@ class View:
ret = [Variable.num(self.offset+offset)]
acc = 1
for d,s in self.shape_strides[::-1]:
if d != 1 and s != 0:
ret.append(((idx//acc)%d)*s)
ret.append(((idx//acc)%d)*s)
acc *= d
return Variable.sum(ret)
@@ -63,8 +62,7 @@ ViewTypes = Union[View, ZeroView]
@functools.lru_cache(maxsize=None)
def strides_for_shape(shape:Tuple[int, ...]) -> Tuple[int, ...]:
strides = [1]
for d in shape[::-1][:-1]:
strides = [d*strides[0]] + strides
for d in shape[::-1][:-1]: strides = [d*strides[0]] + strides
return tuple(st if s != 1 else 0 for st, s in zip(strides, shape))
@functools.lru_cache(maxsize=None)
@@ -73,8 +71,7 @@ def view_from_shape(shape:Tuple[int, ...]) -> View:
return View(tuple(shape), strides_for_shape(shape))
def merge_views(vm2:View, vm1:View) -> Optional[View]:
new_strides = []
new_offset = vm2.expr_node(Variable.num(vm1.offset))
new_strides, new_offset = [], vm2.expr_node(Variable.num(vm1.offset))
assert isinstance(new_offset, NumNode), "new_offset wasn't a number?!?"
for s,st in zip(vm1.shape, vm1.strides):
this_dim = View(vm2.shape, vm2.strides).expr_node(Variable('idx', 0, s-1)*st)
@@ -147,8 +144,7 @@ class ShapeTracker:
assert prod(self.shape) == prod(new_shape), f"can't reshape {self.shape} -> {new_shape}"
view = View(new_shape, strides_for_shape(new_shape))
if self.contiguous:
self.views[-1] = view # NOTE: if it's contiguous it can't have an offset
if self.contiguous: self.views[-1] = view # NOTE: if it's contiguous it can't have an offset
else:
# NOTE: the last view in self.views is never a ZeroView
if (merged_view := merge_views(cast(View, self.views[-1]), view)) is not None: self.views[-1] = merged_view

View File

@@ -4,7 +4,7 @@ from typing import List, Dict, Callable, Type, Union
from tinygrad.helpers import partition, all_same
# NOTE: Python has different behavior for negative mod and floor div than c
# symbolic matches the Python behavior, but the code is outputs is agnostic, and will never have negative numbers in div or mod
# symbolic matches the Python behavior, but the code output is agnostic, and will never have negative numbers in div or mod
def create_node(typ:Type[Node], *args):
ret = typ(*args)

View File

@@ -157,8 +157,7 @@ class Tensor:
visited.add(node)
if node._ctx:
for i in node._ctx.parents:
if i not in visited:
_deepwalk(i, visited, nodes)
if i not in visited: _deepwalk(i, visited, nodes)
nodes.append(node)
return nodes
return _deepwalk(self, set(), [])