This commit is contained in:
qazal
2024-08-16 16:02:54 +03:00
parent 5756508789
commit 2d2f465552
10 changed files with 41 additions and 34 deletions

View File

@@ -37,9 +37,10 @@ print("******** second, the Device ***********")
DEVICE = "CLANG" # NOTE: you can change this!
import struct
from tinygrad.dtype import dtypes
from tinygrad.dtype import PtrDType, dtypes
from tinygrad.device import Buffer, Device
from tinygrad.ops import LazyOp, BufferOps, MemBuffer, BinaryOps, MetaOps
from tinygrad.ops import BinaryOps, MetaOps
from tinygrad.codegen.uops import UOp, UOps
from tinygrad.shape.shapetracker import ShapeTracker
# allocate some buffers + load in values
@@ -49,15 +50,19 @@ b = Buffer(DEVICE, 1, dtypes.int32).allocate().copyin(memoryview(bytearray(struc
# NOTE: a._buf is the same as the return from MallocAllocator.alloc
# describe the computation
ld_1 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.int32, ShapeTracker.from_shape((1,))))
ld_2 = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.int32, ShapeTracker.from_shape((1,))))
alu = LazyOp(BinaryOps.ADD, (ld_1, ld_2))
st_0 = LazyOp(BufferOps.STORE, (alu,), MemBuffer(0, dtypes.int32, ShapeTracker.from_shape((1,))))
k = LazyOp(MetaOps.KERNEL, (st_0,))
buf_1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 1)
buf_2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 2)
ld_1 = UOp(UOps.LOAD, dtypes.int32, (buf_1, *UOp.from_st(ShapeTracker.from_shape((1,)))))
ld_2 = UOp(UOps.LOAD, dtypes.int32, (buf_2, *UOp.from_st(ShapeTracker.from_shape((1,)))))
alu = ld_1 + ld_2
output_buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 0)
idx, valid = UOp.from_st(ShapeTracker.from_shape((1,)))
st_0 = UOp(UOps.STORE, None, (output_buf, idx, alu, valid))
s = UOp(UOps.SINK, None, (st_0,))
# convert the computation to a "linearized" format (print the format)
from tinygrad.engine.realize import get_kernel, CompiledRunner
kernel = get_kernel(Device[DEVICE].renderer, k).linearize()
kernel = get_kernel(Device[DEVICE].renderer, s).linearize()
kernel.uops.print()
# compile a program (and print the source)

View File

@@ -1,4 +1,4 @@
from typing import List
from typing import List, Tuple
from extra.models.resnet import ResNet50
from extra.mcts_search import mcts_search
from examples.mlperf.helpers import get_mlperf_bert_model
@@ -83,7 +83,7 @@ if __name__ == "__main__":
rawbufs = bufs_from_lin(Kernel(si.ast))
# "linearize" the op into uops in different ways
lins:List[Kernel] = []
lins: List[Tuple[Kernel, str]] = []
# always try hand coded opt
lin = Kernel(si.ast, opts=device.renderer)
@@ -109,10 +109,10 @@ if __name__ == "__main__":
# benchmark the programs
choices = []
for (lin, nm) in lins:
for lin, nm in lins:
tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10, disable_cache=True)
ops = (prg:=lin.to_program()).op_estimate
gflops = sym_infer(ops, {k:k.min for k in lin.ast.vars()})*1e-9/tm
gflops = sym_infer(ops, {k:k.min for k in lin.ast.variables()})*1e-9/tm
choices.append((tm, gflops, lin, prg, nm))
sorted_choices = sorted(choices, key=lambda x: x[0])

View File

@@ -7,7 +7,6 @@ from tinygrad.helpers import DEBUG, getenv, CACHELEVEL, diskcache_get, diskcache
from tinygrad.codegen.kernel import Kernel
from tinygrad.device import Buffer, Device, CompileError
from tinygrad.engine.search import _ensure_buffer_alloc, get_kernel_actions, _time_program
from tinygrad.ops import LazyOp
class MCTSNode:
def __init__(self, kernel:Kernel, parent=None):
@@ -87,14 +86,14 @@ def mcts_search(lin:Kernel, rawbufs:List[Buffer], amt:int) -> Kernel:
return ret
rawbufs = _ensure_buffer_alloc(rawbufs)
var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()}
var_vals = {k:(k.max+k.min)//2 for k in lin.ast.variables()}
dev = Device[lin.opts.device]
root = MCTSNode(lin)
st = time.perf_counter()
best, best_idx, best_tm = lin, 0, math.inf
seen_libs: Dict[bytes, MCTSNode] = {}
seen_asts: Dict[LazyOp, MCTSNode] = {}
seen_asts: Dict[bytes, MCTSNode] = {}
compile_time, runtime_time = 0.0, 0.0
for i in range(amt):
node = sample_tree(root, best_tm) # sample and expand
@@ -102,12 +101,12 @@ def mcts_search(lin:Kernel, rawbufs:List[Buffer], amt:int) -> Kernel:
node.i = i # when was node explored
opt_ast = node.kernel.get_optimized_ast()
if (sibling_node:=seen_asts.get(opt_ast, None)) is not None:
if (sibling_node:=seen_asts.get(opt_ast.key, None)) is not None:
# early check for same optimized AST hit
remove_node(node)
tm = sibling_node.t
else:
seen_asts[opt_ast] = node
seen_asts[opt_ast.key] = node
# lowering (50% of the time)
p = node.kernel.to_program(name_override="test")

View File

@@ -3,7 +3,7 @@ from enum import Enum, auto
from collections import defaultdict
from typing import List, Tuple, DefaultDict
from extra.optimization.helpers import load_worlds, ast_str_to_ast
from tinygrad.ops import BufferOps, LazyOp
from extra.ops import BufferOps, LazyOp
from tinygrad.helpers import prod, tqdm
from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.symbolic import sym_infer, Node
@@ -145,4 +145,4 @@ if __name__ == "__main__":
for ast_str in tqdm(ast_strs):
test_rebuild_bufferop_st(ast_str_to_ast(ast_str))
print(f"avg length of mop = {sum(k*v for k,v in c.items()) / sum(c.values()):.2f}")
print(f"avg length of mop = {sum(k*v for k,v in c.items()) / sum(c.values()):.2f}")

View File

@@ -86,7 +86,7 @@ def compare_linearizer(lin: Kernel, rawbufs=None, var_vals=None, ground_truth=No
if var_vals is None:
# TODO: handle symbolic max case
var_vals = {v: random.randint(v.min, v.max if isinstance(v.max, int) else v.min) for v in lin.ast.vars()}
var_vals = {v: random.randint(v.min, v.max if isinstance(v.max, int) else v.min) for v in lin.ast.variables()}
if ground_truth is None and not has_bf16:
unoptimized = Kernel(lin.ast)

View File

@@ -71,7 +71,7 @@ class Kernel:
def ordered_parents(op:UOp) -> List[UOp]: return dedup([item for x in op.src for item in ordered_parents(x)] + [op])
self.reduceops = dedup([x for x in ordered_parents(self.ast) if x.op is UOps.REDUCE_AXIS])
self.vars: List[Variable] = dedup([x.arg for x in self.ast.vars()])
self.vars: List[Variable] = self.ast.variables()
self.bufs: List[UOp] = [x for x in self.ast.parents if x.op in BUFFER_UOPS]
# get earlybufs, before any reduceops
@@ -481,7 +481,7 @@ class Kernel:
self.tensor_core_opts.fix_axes(axis) # fix up axes in TC opts if required after simplify_ones()
def required_optimizations(self) -> Kernel:
if self.bufs[0].dtype.__class__ is ImageDType:
if isinstance(self.membufs[0].dtype, ImageDType):
unit_stride_axes_mul_4 = [i for i in self.sts[0].unit_stride_axes(ignore_valid=True) if self.sts[0].shape[i]%4 == 0]
assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[0]}"
if len(unit_stride_axes_mul_4) and all(x < self.first_upcast for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: # noqa: E501
@@ -692,7 +692,7 @@ class Kernel:
for i,s in enumerate(self.full_shape))
srcs = []
for i,(src,fix_st_fxn) in enumerate(zip(rsrc.src, [fix_st1, fix_st2])):
st_load = [self.sts[self.bufs.index(op)].real_strides() for op in src.parents if op.op is UOps.LOAD]
st_load = [self.sts[self.bufs.index(op)].real_strides() for op in rsrc.parents if op.op is UOps.LOAD]
local_shape = tuple(s if max(cast(int, x[i]) for x in st_load) != 0 else 1 for i,s in enumerate(ex_shape))
idx, valid = UOp.from_st(ShapeTracker.from_shape(local_shape).expand(ex_shape))
membuf = UOp(UOps.DEFINE_LOCAL, PtrDType(tc.dtype_in), (), (f"temp{-(-1-i)}", idx.arg.real_size()))

View File

@@ -161,21 +161,21 @@ class IndependentLowerer:
def _to_uop(self, x:UOp) -> UOp:
if x.op in BUFFER_UOPS:
idx, valid = st_to_uops(x.src[-1].arg, self.ridxs if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL else self.idxs, cast(DType,x.dtype))
idx, valid = st_to_uops(x.src[-1].arg, self.ridxs if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL else self.idxs,
cast(DType, x.dtype if x.op is UOps.CONST else x.src[0].dtype))
# TODO: check has_valid in UPat, not here
has_valid = valid.op is not UOps.CONST or valid.arg is not True
if x.op is UOps.CONST: return valid.where(UOp.const(x.dtype, x.arg), UOp.const(x.dtype, 0))
buf = x.src[0]
if x.op is UOps.LOAD:
barrier = (UOp(UOps.BARRIER, None, (self.to_uop(x.src[1]),)),) if x.src[0].op is UOps.DEFINE_LOCAL else ()
load_dtype = cast(DType,x.dtype).scalar()
if idx.dtype == dtypes.int.vec(3):
# this should all simplify if there's consts for id4. if not, w/e
idx, id4 = UOp(UOps.VECTORIZE, dtypes.int.vec(2), (idx.src[0], idx.src[1])), idx.src[2]
vec_load = UOp(UOps.LOAD, load_dtype.vec(4), (buf, idx) + ((UOp.const(load_dtype.vec(4), 0), valid) if has_valid else ()) + barrier)
return functools.reduce(lambda ret, i: id4.ne(i).where(ret, UOp(UOps.GEP, load_dtype, (vec_load,), i)),
range(4), UOp.const(load_dtype, float('nan')))
return UOp(UOps.LOAD, load_dtype, (buf, idx) + ((UOp.const(load_dtype, 0), valid) if has_valid else ()) + barrier)
vec_load = UOp(UOps.LOAD, dt:=cast(DType, x.dtype).vec(4), (buf, idx) + ((UOp.const(dt, 0), valid) if has_valid else ()) + barrier)
return functools.reduce(lambda ret, i: id4.ne(i).where(ret, UOp(UOps.GEP, x.dtype, (vec_load,), i)),
range(4), UOp.const(x.dtype, float('nan')))
return UOp(UOps.LOAD, x.dtype, (buf, idx) + ((UOp.const(x.dtype, 0), valid) if has_valid else ()) + barrier)
# NOTE: only store the local reduceop in the first thread (this is wrong for non group for reduces!)
if x.src[0].op is UOps.DEFINE_GLOBAL:
for oidx, ridx in zip(self.idxs, self.ridxs):

View File

@@ -105,6 +105,9 @@ class UOp:
# NOTE: UOps.DEFINE_GLOBAL and UOps.DEFINE_LOCAL don't have shape
return tuple(max(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL}]))
def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
def variables(self) -> List[Variable]:
st_vars: List[Set[Variable]] = [x.src[-1].arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
return sorted(set.union(*st_vars, set([x.arg for x in self.sparents if x.op is UOps.DEFINE_VAR])), key=lambda v: v.expr)
def const_factor(self) -> int:
"""largest known int that divides self"""
if self.op is UOps.CONST: return self.arg

View File

@@ -96,8 +96,8 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
if buf.op in {MetaOps.CONTIGUOUS, MetaOps.ASSIGN}:
assert buf in outputs, f"{buf.op} must be writable"
return in_ops[0]
if buf.op is UnaryOps.CAST: return cache.setdefault((buf, st), UOp(UOps.CAST, buf.arg.scalar(), in_ops))
if buf.op is UnaryOps.BITCAST: return cache.setdefault((buf, st), UOp(UOps.BITCAST, buf.arg.scalar(), in_ops))
if buf.op is UnaryOps.CAST: return cache.setdefault((buf, st), UOp(UOps.CAST, dtype, in_ops))
if buf.op is UnaryOps.BITCAST: return cache.setdefault((buf, st), UOp(UOps.BITCAST, dtype, in_ops))
return cache.setdefault((buf, st), UOp(UOps.ALU, dtype, in_ops, buf.op))
def _permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]:

View File

@@ -93,7 +93,7 @@ def bufs_from_lin(lin:Kernel, allocate:bool=True) -> List[Buffer]:
bufsts: DefaultDict[int, List[UOp]] = defaultdict(list)
for x in lin.bufs:
if x.src[0].op is UOps.DEFINE_GLOBAL: bufsts[x.src[0].arg].append(x)
rawbufs:List[Optional[Buffer]] = [None]*len(bufsts)
rawbufs: List[Optional[Buffer]] = [None]*len(bufsts)
for k,lx in bufsts.items():
buf_size = prod(dtype.shape) if isinstance(dtype:=cast(DType,lx[0].src[0].dtype), ImageDType) else max(y.src[-1].arg.real_size() for y in lx)
if buf_size == 0: buf_size = 1 # create a size 1 buffer if no cell is accessed in kernel. # TODO: remove from kernel input in this case.
@@ -141,7 +141,7 @@ def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True,
try:
rawbufs = _ensure_buffer_alloc(rawbufs)
var_vals: Dict[Variable, int] = {k.arg:(k.arg.max+k.arg.min)//2 for k in lin.ast.vars()}
var_vals: Dict[Variable, int] = {k:(k.max+k.min)//2 for k in lin.ast.variables()}
exiting, st = False, time.perf_counter()
dev = Device[lin.opts.device]
while not exiting:
@@ -199,7 +199,7 @@ def time_linearizer(lin:Kernel, rawbufs:List[Buffer], allow_test_size=True, max_
assert dev.compiler is not None
rawbufs = _ensure_buffer_alloc(rawbufs)
var_vals: Dict[Variable, int] = {k.arg:(k.arg.max+k.arg.min)//2 for k in lin.ast.vars()}
var_vals: Dict[Variable, int] = {k:(k.max+k.min)//2 for k in lin.ast.variables()}
p = lin.to_program()
tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs,
max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))