mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
green
This commit is contained in:
@@ -37,9 +37,10 @@ print("******** second, the Device ***********")
|
||||
DEVICE = "CLANG" # NOTE: you can change this!
|
||||
|
||||
import struct
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.dtype import PtrDType, dtypes
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.ops import LazyOp, BufferOps, MemBuffer, BinaryOps, MetaOps
|
||||
from tinygrad.ops import BinaryOps, MetaOps
|
||||
from tinygrad.codegen.uops import UOp, UOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
||||
# allocate some buffers + load in values
|
||||
@@ -49,15 +50,19 @@ b = Buffer(DEVICE, 1, dtypes.int32).allocate().copyin(memoryview(bytearray(struc
|
||||
# NOTE: a._buf is the same as the return from MallocAllocator.alloc
|
||||
|
||||
# describe the computation
|
||||
ld_1 = LazyOp(BufferOps.LOAD, (), MemBuffer(1, dtypes.int32, ShapeTracker.from_shape((1,))))
|
||||
ld_2 = LazyOp(BufferOps.LOAD, (), MemBuffer(2, dtypes.int32, ShapeTracker.from_shape((1,))))
|
||||
alu = LazyOp(BinaryOps.ADD, (ld_1, ld_2))
|
||||
st_0 = LazyOp(BufferOps.STORE, (alu,), MemBuffer(0, dtypes.int32, ShapeTracker.from_shape((1,))))
|
||||
k = LazyOp(MetaOps.KERNEL, (st_0,))
|
||||
buf_1 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 1)
|
||||
buf_2 = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 2)
|
||||
ld_1 = UOp(UOps.LOAD, dtypes.int32, (buf_1, *UOp.from_st(ShapeTracker.from_shape((1,)))))
|
||||
ld_2 = UOp(UOps.LOAD, dtypes.int32, (buf_2, *UOp.from_st(ShapeTracker.from_shape((1,)))))
|
||||
alu = ld_1 + ld_2
|
||||
output_buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.int32), (), 0)
|
||||
idx, valid = UOp.from_st(ShapeTracker.from_shape((1,)))
|
||||
st_0 = UOp(UOps.STORE, None, (output_buf, idx, alu, valid))
|
||||
s = UOp(UOps.SINK, None, (st_0,))
|
||||
|
||||
# convert the computation to a "linearized" format (print the format)
|
||||
from tinygrad.engine.realize import get_kernel, CompiledRunner
|
||||
kernel = get_kernel(Device[DEVICE].renderer, k).linearize()
|
||||
kernel = get_kernel(Device[DEVICE].renderer, s).linearize()
|
||||
kernel.uops.print()
|
||||
|
||||
# compile a program (and print the source)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import List
|
||||
from typing import List, Tuple
|
||||
from extra.models.resnet import ResNet50
|
||||
from extra.mcts_search import mcts_search
|
||||
from examples.mlperf.helpers import get_mlperf_bert_model
|
||||
@@ -83,7 +83,7 @@ if __name__ == "__main__":
|
||||
rawbufs = bufs_from_lin(Kernel(si.ast))
|
||||
|
||||
# "linearize" the op into uops in different ways
|
||||
lins:List[Kernel] = []
|
||||
lins: List[Tuple[Kernel, str]] = []
|
||||
|
||||
# always try hand coded opt
|
||||
lin = Kernel(si.ast, opts=device.renderer)
|
||||
@@ -109,10 +109,10 @@ if __name__ == "__main__":
|
||||
|
||||
# benchmark the programs
|
||||
choices = []
|
||||
for (lin, nm) in lins:
|
||||
for lin, nm in lins:
|
||||
tm = time_linearizer(lin, rawbufs, allow_test_size=False, cnt=10, disable_cache=True)
|
||||
ops = (prg:=lin.to_program()).op_estimate
|
||||
gflops = sym_infer(ops, {k:k.min for k in lin.ast.vars()})*1e-9/tm
|
||||
gflops = sym_infer(ops, {k:k.min for k in lin.ast.variables()})*1e-9/tm
|
||||
choices.append((tm, gflops, lin, prg, nm))
|
||||
|
||||
sorted_choices = sorted(choices, key=lambda x: x[0])
|
||||
|
||||
@@ -7,7 +7,6 @@ from tinygrad.helpers import DEBUG, getenv, CACHELEVEL, diskcache_get, diskcache
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.device import Buffer, Device, CompileError
|
||||
from tinygrad.engine.search import _ensure_buffer_alloc, get_kernel_actions, _time_program
|
||||
from tinygrad.ops import LazyOp
|
||||
|
||||
class MCTSNode:
|
||||
def __init__(self, kernel:Kernel, parent=None):
|
||||
@@ -87,14 +86,14 @@ def mcts_search(lin:Kernel, rawbufs:List[Buffer], amt:int) -> Kernel:
|
||||
return ret
|
||||
|
||||
rawbufs = _ensure_buffer_alloc(rawbufs)
|
||||
var_vals = {k:(k.max+k.min)//2 for k in lin.ast.vars()}
|
||||
var_vals = {k:(k.max+k.min)//2 for k in lin.ast.variables()}
|
||||
dev = Device[lin.opts.device]
|
||||
root = MCTSNode(lin)
|
||||
|
||||
st = time.perf_counter()
|
||||
best, best_idx, best_tm = lin, 0, math.inf
|
||||
seen_libs: Dict[bytes, MCTSNode] = {}
|
||||
seen_asts: Dict[LazyOp, MCTSNode] = {}
|
||||
seen_asts: Dict[bytes, MCTSNode] = {}
|
||||
compile_time, runtime_time = 0.0, 0.0
|
||||
for i in range(amt):
|
||||
node = sample_tree(root, best_tm) # sample and expand
|
||||
@@ -102,12 +101,12 @@ def mcts_search(lin:Kernel, rawbufs:List[Buffer], amt:int) -> Kernel:
|
||||
node.i = i # when was node explored
|
||||
|
||||
opt_ast = node.kernel.get_optimized_ast()
|
||||
if (sibling_node:=seen_asts.get(opt_ast, None)) is not None:
|
||||
if (sibling_node:=seen_asts.get(opt_ast.key, None)) is not None:
|
||||
# early check for same optimized AST hit
|
||||
remove_node(node)
|
||||
tm = sibling_node.t
|
||||
else:
|
||||
seen_asts[opt_ast] = node
|
||||
seen_asts[opt_ast.key] = node
|
||||
|
||||
# lowering (50% of the time)
|
||||
p = node.kernel.to_program(name_override="test")
|
||||
|
||||
@@ -3,7 +3,7 @@ from enum import Enum, auto
|
||||
from collections import defaultdict
|
||||
from typing import List, Tuple, DefaultDict
|
||||
from extra.optimization.helpers import load_worlds, ast_str_to_ast
|
||||
from tinygrad.ops import BufferOps, LazyOp
|
||||
from extra.ops import BufferOps, LazyOp
|
||||
from tinygrad.helpers import prod, tqdm
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import sym_infer, Node
|
||||
@@ -145,4 +145,4 @@ if __name__ == "__main__":
|
||||
for ast_str in tqdm(ast_strs):
|
||||
test_rebuild_bufferop_st(ast_str_to_ast(ast_str))
|
||||
|
||||
print(f"avg length of mop = {sum(k*v for k,v in c.items()) / sum(c.values()):.2f}")
|
||||
print(f"avg length of mop = {sum(k*v for k,v in c.items()) / sum(c.values()):.2f}")
|
||||
|
||||
2
test/external/fuzz_linearizer.py
vendored
2
test/external/fuzz_linearizer.py
vendored
@@ -86,7 +86,7 @@ def compare_linearizer(lin: Kernel, rawbufs=None, var_vals=None, ground_truth=No
|
||||
|
||||
if var_vals is None:
|
||||
# TODO: handle symbolic max case
|
||||
var_vals = {v: random.randint(v.min, v.max if isinstance(v.max, int) else v.min) for v in lin.ast.vars()}
|
||||
var_vals = {v: random.randint(v.min, v.max if isinstance(v.max, int) else v.min) for v in lin.ast.variables()}
|
||||
|
||||
if ground_truth is None and not has_bf16:
|
||||
unoptimized = Kernel(lin.ast)
|
||||
|
||||
@@ -71,7 +71,7 @@ class Kernel:
|
||||
def ordered_parents(op:UOp) -> List[UOp]: return dedup([item for x in op.src for item in ordered_parents(x)] + [op])
|
||||
self.reduceops = dedup([x for x in ordered_parents(self.ast) if x.op is UOps.REDUCE_AXIS])
|
||||
|
||||
self.vars: List[Variable] = dedup([x.arg for x in self.ast.vars()])
|
||||
self.vars: List[Variable] = self.ast.variables()
|
||||
self.bufs: List[UOp] = [x for x in self.ast.parents if x.op in BUFFER_UOPS]
|
||||
|
||||
# get earlybufs, before any reduceops
|
||||
@@ -481,7 +481,7 @@ class Kernel:
|
||||
self.tensor_core_opts.fix_axes(axis) # fix up axes in TC opts if required after simplify_ones()
|
||||
|
||||
def required_optimizations(self) -> Kernel:
|
||||
if self.bufs[0].dtype.__class__ is ImageDType:
|
||||
if isinstance(self.membufs[0].dtype, ImageDType):
|
||||
unit_stride_axes_mul_4 = [i for i in self.sts[0].unit_stride_axes(ignore_valid=True) if self.sts[0].shape[i]%4 == 0]
|
||||
assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[0]}"
|
||||
if len(unit_stride_axes_mul_4) and all(x < self.first_upcast for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes: # noqa: E501
|
||||
@@ -692,7 +692,7 @@ class Kernel:
|
||||
for i,s in enumerate(self.full_shape))
|
||||
srcs = []
|
||||
for i,(src,fix_st_fxn) in enumerate(zip(rsrc.src, [fix_st1, fix_st2])):
|
||||
st_load = [self.sts[self.bufs.index(op)].real_strides() for op in src.parents if op.op is UOps.LOAD]
|
||||
st_load = [self.sts[self.bufs.index(op)].real_strides() for op in rsrc.parents if op.op is UOps.LOAD]
|
||||
local_shape = tuple(s if max(cast(int, x[i]) for x in st_load) != 0 else 1 for i,s in enumerate(ex_shape))
|
||||
idx, valid = UOp.from_st(ShapeTracker.from_shape(local_shape).expand(ex_shape))
|
||||
membuf = UOp(UOps.DEFINE_LOCAL, PtrDType(tc.dtype_in), (), (f"temp{-(-1-i)}", idx.arg.real_size()))
|
||||
|
||||
@@ -161,21 +161,21 @@ class IndependentLowerer:
|
||||
|
||||
def _to_uop(self, x:UOp) -> UOp:
|
||||
if x.op in BUFFER_UOPS:
|
||||
idx, valid = st_to_uops(x.src[-1].arg, self.ridxs if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL else self.idxs, cast(DType,x.dtype))
|
||||
idx, valid = st_to_uops(x.src[-1].arg, self.ridxs if x.op is UOps.LOAD and x.src[0].op is UOps.DEFINE_LOCAL else self.idxs,
|
||||
cast(DType, x.dtype if x.op is UOps.CONST else x.src[0].dtype))
|
||||
# TODO: check has_valid in UPat, not here
|
||||
has_valid = valid.op is not UOps.CONST or valid.arg is not True
|
||||
if x.op is UOps.CONST: return valid.where(UOp.const(x.dtype, x.arg), UOp.const(x.dtype, 0))
|
||||
buf = x.src[0]
|
||||
if x.op is UOps.LOAD:
|
||||
barrier = (UOp(UOps.BARRIER, None, (self.to_uop(x.src[1]),)),) if x.src[0].op is UOps.DEFINE_LOCAL else ()
|
||||
load_dtype = cast(DType,x.dtype).scalar()
|
||||
if idx.dtype == dtypes.int.vec(3):
|
||||
# this should all simplify if there's consts for id4. if not, w/e
|
||||
idx, id4 = UOp(UOps.VECTORIZE, dtypes.int.vec(2), (idx.src[0], idx.src[1])), idx.src[2]
|
||||
vec_load = UOp(UOps.LOAD, load_dtype.vec(4), (buf, idx) + ((UOp.const(load_dtype.vec(4), 0), valid) if has_valid else ()) + barrier)
|
||||
return functools.reduce(lambda ret, i: id4.ne(i).where(ret, UOp(UOps.GEP, load_dtype, (vec_load,), i)),
|
||||
range(4), UOp.const(load_dtype, float('nan')))
|
||||
return UOp(UOps.LOAD, load_dtype, (buf, idx) + ((UOp.const(load_dtype, 0), valid) if has_valid else ()) + barrier)
|
||||
vec_load = UOp(UOps.LOAD, dt:=cast(DType, x.dtype).vec(4), (buf, idx) + ((UOp.const(dt, 0), valid) if has_valid else ()) + barrier)
|
||||
return functools.reduce(lambda ret, i: id4.ne(i).where(ret, UOp(UOps.GEP, x.dtype, (vec_load,), i)),
|
||||
range(4), UOp.const(x.dtype, float('nan')))
|
||||
return UOp(UOps.LOAD, x.dtype, (buf, idx) + ((UOp.const(x.dtype, 0), valid) if has_valid else ()) + barrier)
|
||||
# NOTE: only store the local reduceop in the first thread (this is wrong for non group for reduces!)
|
||||
if x.src[0].op is UOps.DEFINE_GLOBAL:
|
||||
for oidx, ridx in zip(self.idxs, self.ridxs):
|
||||
|
||||
@@ -105,6 +105,9 @@ class UOp:
|
||||
# NOTE: UOps.DEFINE_GLOBAL and UOps.DEFINE_LOCAL don't have shape
|
||||
return tuple(max(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {UOps.DEFINE_GLOBAL, UOps.DEFINE_LOCAL}]))
|
||||
def vars(self) -> Set[UOp]: return set([x for x in self.sparents if x.op is UOps.DEFINE_VAR])
|
||||
def variables(self) -> List[Variable]:
|
||||
st_vars: List[Set[Variable]] = [x.src[-1].arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
|
||||
return sorted(set.union(*st_vars, set([x.arg for x in self.sparents if x.op is UOps.DEFINE_VAR])), key=lambda v: v.expr)
|
||||
def const_factor(self) -> int:
|
||||
"""largest known int that divides self"""
|
||||
if self.op is UOps.CONST: return self.arg
|
||||
|
||||
@@ -96,8 +96,8 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
|
||||
if buf.op in {MetaOps.CONTIGUOUS, MetaOps.ASSIGN}:
|
||||
assert buf in outputs, f"{buf.op} must be writable"
|
||||
return in_ops[0]
|
||||
if buf.op is UnaryOps.CAST: return cache.setdefault((buf, st), UOp(UOps.CAST, buf.arg.scalar(), in_ops))
|
||||
if buf.op is UnaryOps.BITCAST: return cache.setdefault((buf, st), UOp(UOps.BITCAST, buf.arg.scalar(), in_ops))
|
||||
if buf.op is UnaryOps.CAST: return cache.setdefault((buf, st), UOp(UOps.CAST, dtype, in_ops))
|
||||
if buf.op is UnaryOps.BITCAST: return cache.setdefault((buf, st), UOp(UOps.BITCAST, dtype, in_ops))
|
||||
return cache.setdefault((buf, st), UOp(UOps.ALU, dtype, in_ops, buf.op))
|
||||
|
||||
def _permute_reduce(input_st:ShapeTracker, axis:Tuple[int, ...]) -> Tuple[ShapeTracker, Tuple[sint, ...]]:
|
||||
|
||||
@@ -93,7 +93,7 @@ def bufs_from_lin(lin:Kernel, allocate:bool=True) -> List[Buffer]:
|
||||
bufsts: DefaultDict[int, List[UOp]] = defaultdict(list)
|
||||
for x in lin.bufs:
|
||||
if x.src[0].op is UOps.DEFINE_GLOBAL: bufsts[x.src[0].arg].append(x)
|
||||
rawbufs:List[Optional[Buffer]] = [None]*len(bufsts)
|
||||
rawbufs: List[Optional[Buffer]] = [None]*len(bufsts)
|
||||
for k,lx in bufsts.items():
|
||||
buf_size = prod(dtype.shape) if isinstance(dtype:=cast(DType,lx[0].src[0].dtype), ImageDType) else max(y.src[-1].arg.real_size() for y in lx)
|
||||
if buf_size == 0: buf_size = 1 # create a size 1 buffer if no cell is accessed in kernel. # TODO: remove from kernel input in this case.
|
||||
@@ -141,7 +141,7 @@ def beam_search(lin:Kernel, rawbufs:List[Buffer], amt:int, allow_test_size=True,
|
||||
|
||||
try:
|
||||
rawbufs = _ensure_buffer_alloc(rawbufs)
|
||||
var_vals: Dict[Variable, int] = {k.arg:(k.arg.max+k.arg.min)//2 for k in lin.ast.vars()}
|
||||
var_vals: Dict[Variable, int] = {k:(k.max+k.min)//2 for k in lin.ast.variables()}
|
||||
exiting, st = False, time.perf_counter()
|
||||
dev = Device[lin.opts.device]
|
||||
while not exiting:
|
||||
@@ -199,7 +199,7 @@ def time_linearizer(lin:Kernel, rawbufs:List[Buffer], allow_test_size=True, max_
|
||||
assert dev.compiler is not None
|
||||
|
||||
rawbufs = _ensure_buffer_alloc(rawbufs)
|
||||
var_vals: Dict[Variable, int] = {k.arg:(k.arg.max+k.arg.min)//2 for k in lin.ast.vars()}
|
||||
var_vals: Dict[Variable, int] = {k:(k.max+k.min)//2 for k in lin.ast.variables()}
|
||||
p = lin.to_program()
|
||||
tms = _time_program(p, dev.compiler.compile(p.src), var_vals, rawbufs,
|
||||
max_global_size=max_global_size if allow_test_size else None, clear_l2=clear_l2, cnt=cnt, name=to_function_name(lin.name))
|
||||
|
||||
Reference in New Issue
Block a user