mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Merge branch 'master' into retinanet_mlperf
This commit is contained in:
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
@@ -423,6 +423,8 @@ jobs:
|
||||
run: LLVM=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
|
||||
- name: Test Additional ONNX Ops (CPU)
|
||||
run: CPU=1 PYTHONPATH=. python3 test/external/external_test_onnx_ops.py
|
||||
- name: Test Quantize ONNX
|
||||
run: CPU=1 PYTHONPATH=. python3 test/test_quantize_onnx.py
|
||||
- name: Run CLOUD=1 Test
|
||||
run: |
|
||||
CLOUDDEV=CPU CLOUD=1 python3 test/test_tiny.py
|
||||
@@ -467,7 +469,7 @@ jobs:
|
||||
testdsp:
|
||||
name: Linux (DSP)
|
||||
runs-on: ubuntu-24.04
|
||||
timeout-minutes: 10
|
||||
timeout-minutes: 15
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
@@ -928,9 +928,9 @@ def train_bert():
|
||||
# ** hyperparameters **
|
||||
BS = config["GLOBAL_BATCH_SIZE"] = getenv("BS", 11 * len(GPUS) if dtypes.default_float in (dtypes.float16, dtypes.bfloat16) else 8 * len(GPUS))
|
||||
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 1 * len(GPUS))
|
||||
max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.0002 * math.sqrt(BS/96))
|
||||
max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.000175 * math.sqrt(BS/96))
|
||||
|
||||
train_steps = config["TRAIN_STEPS"] = getenv("TRAIN_STEPS", 3630000 // BS)
|
||||
train_steps = config["TRAIN_STEPS"] = getenv("TRAIN_STEPS", 3300000 // BS)
|
||||
warmup_steps = config["NUM_WARMUP_STEPS"] = getenv("NUM_WARMUP_STEPS", 1)
|
||||
max_eval_steps = config["MAX_EVAL_STEPS"] = getenv("MAX_EVAL_STEPS", (10000 + EVAL_BS - 1) // EVAL_BS) # EVAL_BS * MAX_EVAL_STEPS >= 10000
|
||||
eval_step_freq = config["EVAL_STEP_FREQ"] = getenv("EVAL_STEP_FREQ", int((math.floor(0.05 * (230.23 * BS + 3000000) / 25000) * 25000) / BS)) # Round down
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import numpy as np
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.dtype import _to_np_dtype
|
||||
from tinygrad import dtypes, Tensor
|
||||
|
||||
dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float
|
||||
@@ -13,12 +14,17 @@ K = getenv("K", N)
|
||||
CNT = getenv("CNT", 10)
|
||||
ATOL = getenv("ATOL", 1e-4)
|
||||
RTOL = getenv("RTOL", 3e-2)
|
||||
INT_LOW = getenv("INT_LOW", 0)
|
||||
INT_HIGH = getenv("INT_HIGH", 10)
|
||||
|
||||
if __name__ == "__main__":
|
||||
def init_matrix(rows, cols):
|
||||
rng = np.random.default_rng()
|
||||
# NOTE: numpy does not support bfloat16
|
||||
if (np_dtype := _to_np_dtype(dtype_in)) is None: np_dtype = np.float32
|
||||
if dtype_in in dtypes.ints:
|
||||
return Tensor.randint((rows, cols), dtype=dtype_in).realize()
|
||||
return Tensor.rand(rows, cols, dtype=dtype_in).realize()
|
||||
return Tensor(rng.integers(INT_LOW, INT_HIGH, (rows, cols), dtype=np_dtype)).realize()
|
||||
return Tensor(rng.random((rows, cols), dtype=np.float32).astype(np_dtype)).cast(dtype_in).realize()
|
||||
|
||||
a, b = init_matrix(M, K), init_matrix(K, N)
|
||||
for i in range(CNT):
|
||||
|
||||
25
extra/hip_large_kernel.py
Normal file
25
extra/hip_large_kernel.py
Normal file
@@ -0,0 +1,25 @@
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.dtype import dtypes, _to_np_dtype
|
||||
|
||||
dev = Device.default
|
||||
mbin = dev.compiler.compile("""
|
||||
typedef long unsigned int size_t;
|
||||
extern "C" __attribute__((device, const)) size_t __ockl_get_group_id(unsigned int);
|
||||
extern "C" __attribute__((global)) void __attribute__((amdgpu_flat_work_group_size(1, 1))) write_ones(signed char* data0) {
|
||||
int gidx0 = __ockl_get_group_id(0); /* 16 */
|
||||
int gidx1 = __ockl_get_group_id(1); /* 1026048 */
|
||||
*(data0+(gidx0+gidx1*1)) = 1;
|
||||
}
|
||||
""")
|
||||
dev.compiler.disassemble(mbin)
|
||||
buf0 = Buffer(Device.DEFAULT, 1*65537, dtypes.uint8).ensure_allocated()
|
||||
|
||||
prg = dev.runtime("write_ones", mbin)
|
||||
prg(buf0._buf, global_size=(1,65537,1), local_size=(1,1,1), wait=True)
|
||||
|
||||
import numpy as np
|
||||
def to_np(buf): return np.frombuffer(buf.as_buffer().cast(buf.dtype.base.fmt), dtype=_to_np_dtype(buf.dtype.base))
|
||||
|
||||
big = to_np(buf0)
|
||||
print(big)
|
||||
print((big-1).nonzero())
|
||||
@@ -728,7 +728,11 @@ def get_onnx_ops():
|
||||
def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1):
|
||||
out_dtype = y_zero_point.dtype if isinstance(y_zero_point, Tensor) else dtype_parse(output_dtype) if output_dtype else dtypes.uint8
|
||||
y_scale, y_zero_point = _prepare_quantize(x, y_scale, y_zero_point, axis, block_size)
|
||||
return _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype).contiguous()
|
||||
if out_dtype == dtypes.uchar:
|
||||
# this appears to work in practice, at least for uchar out_dtype. it folds with the quantize stuff
|
||||
return _clamp_cast((x / y_scale + 0.4999999 + y_zero_point).int(), out_dtype).contiguous()
|
||||
else:
|
||||
return _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype).contiguous()
|
||||
|
||||
def DynamicQuantizeLinear(x: Tensor):
|
||||
# only support uint8
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import pickle, sys
|
||||
from dataclasses import replace
|
||||
from tinygrad import Device
|
||||
from tinygrad import Device, Context
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
@@ -8,10 +9,11 @@ from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
|
||||
|
||||
if __name__ == "__main__":
|
||||
with open(sys.argv[1], "rb") as f:
|
||||
fxn: TinyJit = pickle.load(f)
|
||||
print(f"{f.tell()/1e6:.2f}M loaded")
|
||||
print(type(fxn))
|
||||
with Context(DEBUG=0):
|
||||
with open(sys.argv[1], "rb") as f:
|
||||
fxn: TinyJit = pickle.load(f)
|
||||
print(f"{f.tell()/1e6:.2f}M loaded")
|
||||
print(type(fxn))
|
||||
|
||||
knum = 1
|
||||
for ei in fxn.captured.jit_cache:
|
||||
@@ -21,15 +23,33 @@ if __name__ == "__main__":
|
||||
p: ProgramSpec = ei.prg.p
|
||||
k = Kernel(p.ast, Device["DSP"].renderer)
|
||||
if not getenv("NOOPT"):
|
||||
if knum == 2:
|
||||
if knum in [6,7,9,11]:
|
||||
k.apply_opt(Opt(OptOps.PADTO, 1, 128))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 1, 128))
|
||||
elif knum in [5,8]:
|
||||
k.apply_opt(Opt(op=OptOps.UNROLL, axis=1, arg=0))
|
||||
k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=0))
|
||||
k.apply_opt(Opt(OptOps.PADTO, 2, 128))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 2, 128))
|
||||
elif knum == 2:
|
||||
k.apply_opt(Opt(op=OptOps.UNROLL, axis=1, arg=0))
|
||||
k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=0))
|
||||
k.apply_opt(Opt(OptOps.PADTO, 2, 128))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 2, 128))
|
||||
#k.apply_opt(Opt(op=OptOps.UPCAST, axis=1, arg=4))
|
||||
elif knum == 1:
|
||||
k.apply_opt(Opt(op=OptOps.UNROLL, axis=2, arg=0))
|
||||
k.apply_opt(Opt(op=OptOps.UNROLL, axis=1, arg=0))
|
||||
#k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=0))
|
||||
k.apply_opt(Opt(OptOps.PADTO, 2, 128))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 2, 128))
|
||||
elif knum == 3:
|
||||
k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=4))
|
||||
k.apply_opt(Opt(OptOps.UPCAST, 1, 128))
|
||||
else:
|
||||
k.hand_coded_optimizations()
|
||||
#if knum in [5]: k.apply_opt(Opt(OptOps.UPCAST, 1, 2))
|
||||
p2 = k.to_program()
|
||||
new_ei = replace(ei, prg=CompiledRunner(p2))
|
||||
new_ei = replace(ei, prg=CompiledRunner(p2), bufs=[Buffer("DSP", 1024+b.size*2, b.dtype).view(b.size, b.dtype, 512) for b in ei.bufs])
|
||||
new_ei.run()
|
||||
knum += 1
|
||||
|
||||
|
||||
@@ -25,23 +25,29 @@ class TestArange(unittest.TestCase):
|
||||
return p.estimates.ops
|
||||
|
||||
def test_complexity(self, opts=None, limit=None):
|
||||
# add 1 to avoid divide by 0. arange is 0 flops now!
|
||||
f1 = self._get_flops(256, opts) + 1
|
||||
f2 = self._get_flops(2560, opts) + 1
|
||||
f1 = self._get_flops(256, opts)
|
||||
f2 = self._get_flops(2560, opts)
|
||||
print(f"{f1=}, {f2=}")
|
||||
assert (f1 < 6000 and f2 < 6000) or (f2 / f1 < 16), f"bad complexity, flops {f2/f1:.1f}X while inputs 10X"
|
||||
# add 1 to avoid divide by 0. arange is 0 flops now!
|
||||
assert (f1 < 6000 and f2 < 6000) or ((f2+1) / (f1+1) < 16), f"bad complexity, flops {(f2+1) / (f1+1):.1f}X while inputs 10X"
|
||||
if limit is not None and not getenv("PTX"):
|
||||
# PTX counts index ALU in flops
|
||||
assert f1 <= limit, f"{f1=}, {limit=}"
|
||||
|
||||
def test_complexity_w_upcast(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4)], limit=1)
|
||||
def test_complexity_w_unroll2(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 2)], limit=1)
|
||||
def test_complexity_w_unroll4(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 4)], limit=1)
|
||||
def test_complexity_w_unroll8(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 8)], limit=1)
|
||||
def test_complexity_w_upcast_and_unroll(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], limit=1)
|
||||
def test_complexity_w_upcast(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4)], limit=0)
|
||||
def test_complexity_w_unroll2(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 2)], limit=0)
|
||||
def test_complexity_w_unroll4(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 4)], limit=0)
|
||||
def test_complexity_w_unroll8(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 8)], limit=0)
|
||||
def test_complexity_w_upcast_and_unroll(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], limit=0)
|
||||
|
||||
@unittest.skip("doesn't work yet")
|
||||
def test_complexity_w_local_and_padto(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(op=OptOps.PADTO, axis=1, arg=32)])
|
||||
if Device.default.renderer.has_local:
|
||||
# TODO: fix limit
|
||||
def test_complexity_w_group(self): return self.test_complexity([Opt(OptOps.GROUP, 0, 16)], limit=81920)
|
||||
def test_complexity_w_group_top(self): return self.test_complexity([Opt(OptOps.GROUPTOP, 0, 16)], limit=106496)
|
||||
|
||||
def test_complexity_w_local(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16)], limit=0)
|
||||
@unittest.skip("doesn't work yet")
|
||||
def test_complexity_w_local_and_padto(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.PADTO, axis=1, arg=32)])
|
||||
|
||||
def test_all_opts(self, opts=None, exclude=None):
|
||||
k = Kernel(Tensor.arange(256).schedule()[-1].ast)
|
||||
|
||||
@@ -820,6 +820,7 @@ class TestAutoCastType(unittest.TestCase):
|
||||
np.testing.assert_allclose(t.grad.numpy(), [1, 0])
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT == "PYTHON", "very slow")
|
||||
@unittest.skipIf(CI and Device.DEFAULT == "AMD", "very slow")
|
||||
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "Binding size is larger than the maximum storage buffer binding size")
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
|
||||
def test_mean_half_precision_underflow(self):
|
||||
|
||||
@@ -2,8 +2,9 @@ import numpy as np
|
||||
import unittest
|
||||
from dataclasses import replace
|
||||
from tinygrad import Tensor, Context, Device, dtypes
|
||||
from tinygrad.ops import Ops
|
||||
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem
|
||||
from tinygrad.engine.realize import CompiledRunner, ExecItem, lower_schedule_item
|
||||
|
||||
N = 512
|
||||
|
||||
@@ -44,24 +45,46 @@ def sexec(out:Tensor, opts:list[Opt], replace_src=None, run_count=3):
|
||||
ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata)
|
||||
for _ in range(run_count): ei.run(wait=True)
|
||||
|
||||
def get_quantized_model(sz):
|
||||
from onnxruntime.quantization import quantize_static, QuantFormat, QuantType, CalibrationDataReader
|
||||
class FakeDataReader(CalibrationDataReader):
|
||||
def __init__(self): self.cnt = 0
|
||||
def get_next(self) -> dict:
|
||||
self.cnt += 1
|
||||
if self.cnt == 100: return None
|
||||
return {"input": np.random.uniform(size=(sz, sz)).astype(np.float32)}
|
||||
out_file = "/tmp/test_out.onnx"
|
||||
quantize_static(create_gemm_model("/tmp/test_in.onnx", sz, sz, sz), out_file,
|
||||
FakeDataReader(), quant_format=QuantFormat.QDQ, per_channel=False, reduce_range=False,
|
||||
activation_type=QuantType.QUInt8, weight_type=QuantType.QInt8,
|
||||
extra_options={"ActivationSymmetric": False})
|
||||
return out_file
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT != "CPU", "only tests for CPU")
|
||||
class TestQuantizeOnnxCPU(unittest.TestCase):
|
||||
def test_quant_128(self, sz=128):
|
||||
try:
|
||||
import onnx
|
||||
except ImportError:
|
||||
raise unittest.SkipTest()
|
||||
from extra.onnx import OnnxRunner
|
||||
out_file = get_quantized_model(sz)
|
||||
onnx_model = onnx.load(out_file)
|
||||
run_onnx = OnnxRunner(onnx_model)
|
||||
inp = Tensor(np.random.uniform(size=(sz, sz)).astype(np.float32))
|
||||
with Context(DONT_REALIZE_EXPAND=1, QUANTIZE=1):
|
||||
sched = run_onnx({"input":inp})["output"].schedule()
|
||||
ei = lower_schedule_item(sched[-2])
|
||||
daccs = [u for u in ei.prg.p.uops if u.op is Ops.DEFINE_ACC]
|
||||
assert all(u.dtype.scalar() is dtypes.int for u in daccs)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT != "DSP", "only tests for DSP")
|
||||
class TestQuantizeOnnx(unittest.TestCase):
|
||||
def test_quant_128(self): self.test_quant(128)
|
||||
def test_quant(self, sz=512):
|
||||
from onnxruntime.quantization import quantize_static, QuantFormat, QuantType, CalibrationDataReader
|
||||
from examples.benchmark_onnx import load_onnx_model
|
||||
class FakeDataReader(CalibrationDataReader):
|
||||
def __init__(self): self.cnt = 0
|
||||
def get_next(self) -> dict:
|
||||
self.cnt += 1
|
||||
if self.cnt == 100: return None
|
||||
return {"input": np.random.uniform(size=(sz, sz)).astype(np.float32)}
|
||||
out_file = "/tmp/test_out.onnx"
|
||||
# divide is ~1500-2000 without reduce_range, 750-900 with it
|
||||
quantize_static(create_gemm_model("/tmp/test_in.onnx", sz, sz, sz), out_file,
|
||||
FakeDataReader(), quant_format=QuantFormat.QDQ, per_channel=False, reduce_range=False,
|
||||
activation_type=QuantType.QUInt8, weight_type=QuantType.QInt8,
|
||||
extra_options={"ActivationSymmetric": False})
|
||||
out_file = get_quantized_model(sz)
|
||||
run_onnx_jit, _ = load_onnx_model(out_file)
|
||||
with Context(DONT_REALIZE_EXPAND=1):
|
||||
run_onnx_jit(input=Tensor(np.random.uniform(size=(sz, sz)).astype(np.float32)))
|
||||
|
||||
51
test/test_rewrite_tracked_childen.py
Normal file
51
test/test_rewrite_tracked_childen.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.ops import PatternMatcher, Ops, UPat, graph_rewrite, RewriteContext, UOp
|
||||
|
||||
class TestRewriteTrackedChildren(unittest.TestCase):
|
||||
def test_children_in_context(self):
|
||||
def print_children(ctx:RewriteContext, sink:UOp):
|
||||
view_w_child = sink.src[0].src[0].src[0]
|
||||
assert view_w_child.op is Ops.VIEW
|
||||
assert set([x.arg for x in ctx.children[view_w_child]]) == set((2,3))
|
||||
ctx.update_children()
|
||||
assert set([x.arg for x in ctx.children[view_w_child]]) == set((3,4))
|
||||
# this is the 3
|
||||
assert len(ctx.children[sink.src[0].src[1]]) == 1
|
||||
assert next(iter(ctx.children[sink.src[0].src[1]])).op is Ops.ADD
|
||||
# this is the 4
|
||||
assert len(ctx.children[sink.src[0].src[0]]) == 1
|
||||
assert next(iter(ctx.children[sink.src[0].src[0]])).op is Ops.ADD
|
||||
rewrite = PatternMatcher([
|
||||
(UPat(Ops.CONST, arg=2, name="x"), lambda x: x.replace(arg=4)),
|
||||
(UPat(Ops.SINK, name="sink"), print_children)
|
||||
])
|
||||
a = Tensor(2)
|
||||
b = Tensor(3)
|
||||
c = a + b
|
||||
sink = c.lazydata.sink()
|
||||
sink = graph_rewrite(sink, rewrite, track_children=True)
|
||||
|
||||
def test_simple_child(self):
|
||||
rewrite = PatternMatcher([
|
||||
(UPat(Ops.CONST, arg=2, name="x"), lambda x: x.replace(arg=4)),
|
||||
])
|
||||
a = Tensor(2)
|
||||
b = Tensor(3)
|
||||
c = a + b
|
||||
sink = c.lazydata
|
||||
view_w_child = a.lazydata.src[0]
|
||||
print([x().arg for x in view_w_child.children])
|
||||
print([x.arg for x in sink.get_children_map()[view_w_child]])
|
||||
self.assertSetEqual(set([x.arg for x in sink.get_children_map()[view_w_child]]), set((2,3)))
|
||||
# children can either be added to or removed from the map with graph_rewrite
|
||||
# added to is easy to detect, just hook the UOp constructor
|
||||
# when are children removed?
|
||||
# * if a rewrite rule returns a UOp, the matched node is removed from the graph
|
||||
sink = graph_rewrite(sink, rewrite)
|
||||
print([x().arg for x in view_w_child.children])
|
||||
print([x.arg for x in sink.get_children_map()[view_w_child]])
|
||||
self.assertSetEqual(set([x.arg for x in sink.get_children_map()[view_w_child]]), set((3,4)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
@@ -14,7 +14,7 @@ from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, merge_views, GroupOp
|
||||
from tinygrad.codegen.symbolic import symbolic_simple
|
||||
from tinygrad.spec import type_verify, shape_spec
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, all_same, temp
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
|
||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, view_right, view_left, sym
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
|
||||
from extra.models.llama import precompute_freqs_cis
|
||||
@@ -98,6 +98,12 @@ class TestSchedule(unittest.TestCase):
|
||||
a.realize()
|
||||
assert not a.lazydata.is_realized
|
||||
|
||||
def test_simplify_padded_const(self):
|
||||
a = Tensor.empty(1022).cummax(axis=0)
|
||||
sched = check_schedule(a, 5)
|
||||
ast = sched[0].ast
|
||||
self.assertLessEqual(len([u for u in ast.toposort if u.op is Ops.WHERE]), 6)
|
||||
|
||||
def test_basic_binop_fusion(self):
|
||||
a = Tensor.empty(10)
|
||||
b = Tensor.empty(10)
|
||||
@@ -1851,44 +1857,31 @@ class TestIndexing(unittest.TestCase):
|
||||
def test_recursive_swizzle(self):
|
||||
a = Tensor([1,2,3,4]).realize()
|
||||
for _ in range(24): a = a + a
|
||||
ast = a.schedule()[0].ast
|
||||
swizzle = ast.src[0].src[2].reshape((4, 1))
|
||||
new_uop = swizzle_rewrite(swizzle)
|
||||
new_uop = swizzle_rewrite(a.lazydata.reshape((4, 1)))
|
||||
self.assertEqual(new_uop.st, ShapeTracker.from_shape((4,)).reshape((4, 1)))
|
||||
self.assertEqual(swizzle_cnt(new_uop), 0)
|
||||
|
||||
def test_no_rewrite_elementwise(self):
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]
|
||||
ld1 = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
ld2 = UOp(Ops.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), ld1+ld2)),))
|
||||
rsink = graph_rewrite(sink, view_right)
|
||||
self.assertEqual(rsink.key, sink.key)
|
||||
a = Tensor.empty(32, 32)
|
||||
b = Tensor.empty(32, 32)
|
||||
sink = (a+b).schedule()[0].ast
|
||||
self.assertEqual(swizzle_cnt(sink), 0)
|
||||
|
||||
def test_simple_store_reshape(self):
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0, 1)))
|
||||
r = UOp(Ops.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
|
||||
r = r + r.const_like(2).replace(src=(unwrap(r.st).to_uop(),))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
|
||||
rsink = graph_rewrite(sink, view_right)
|
||||
# this AST first needs to swizzle, but it doesn't have implicit movementops
|
||||
self.assertEqual(swizzle_cnt(sink), 1)
|
||||
verify_ast(rsink)
|
||||
a = Tensor.empty(32, 32).sum(axis=1)+Tensor.empty(1,32)
|
||||
ast = a.schedule()[0].ast
|
||||
self.assertEqual(ast.shape, (32, 1))
|
||||
self.assertEqual(a.lazydata.shape, (1, 32))
|
||||
|
||||
def test_no_reshape_reduceop(self):
|
||||
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
|
||||
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
|
||||
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0, 1)))
|
||||
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((1, 1)).to_uop(), r)),))
|
||||
rsink = graph_rewrite(sink, view_right)
|
||||
verify_ast(sink)
|
||||
self.assertEqual(sink.key, rsink.key)
|
||||
a = Tensor.empty(32, 32).sum(axis=(1,)).contiguous()
|
||||
ast = a.schedule()[0].ast
|
||||
self.assertEqual(ast.shape, (32, 1))
|
||||
self.assertEqual(a.lazydata.shape, (32,))
|
||||
|
||||
@track_rewrites(named=True)
|
||||
def swizzle_rewrite(u:UOp) -> UOp: return graph_rewrite(graph_rewrite(u, view_left), view_right)
|
||||
def swizzle_cnt(u:UOp) -> int: return len([x for x in u.toposort if x.op is Ops.VIEW and len(x.src) != 0])
|
||||
def swizzle_cnt(u:UOp) -> int: return len([x for x in u.toposort if x.op is Ops.VIEW and len(x.src) != 0 and x.src[0].op is not Ops.BUFFER])
|
||||
|
||||
class TestSwizzle(unittest.TestCase):
|
||||
def test_swizzle_simple(self):
|
||||
@@ -1965,6 +1958,16 @@ class TestSwizzle(unittest.TestCase):
|
||||
t = a_reduce+b_reduce
|
||||
with Context(DONT_GROUP_REDUCES=1, DONT_REALIZE_EXPAND=1): run_schedule(check_schedule(t, 1))
|
||||
|
||||
def test_unsafe_pad(self):
|
||||
x = Tensor.full((2,2), 1.0).contiguous()
|
||||
y = x*x.sum((1,)).reciprocal()
|
||||
t = y.pad(((0,1),None)).contiguous()
|
||||
swizzled = swizzle_rewrite(t.lazydata)
|
||||
sched = check_schedule(swizzled.sink(), 3)
|
||||
output_buffer = sched[-1].bufs[0]
|
||||
run_schedule(sched)
|
||||
self.assertListEqual(output_buffer.as_buffer().cast("f").tolist(), [0.5, 0.5, 0.5, 0.5, 0., 0.])
|
||||
|
||||
def store_val(si:ScheduleItem): return si.ast.src[0].src[2]
|
||||
zero_pm = UPat(Ops.CONST, arg=0)
|
||||
class TestView(unittest.TestCase):
|
||||
|
||||
@@ -12,6 +12,7 @@ from tinygrad.renderer import Renderer
|
||||
# ***** load/store grouping *****
|
||||
|
||||
def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
|
||||
if getenv("UNSAFE_DISABLE_MASK", 0): mask = None
|
||||
# first, extract all the relevant offsets
|
||||
offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict)
|
||||
for i in range(vec.dtype.count):
|
||||
@@ -45,7 +46,8 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
|
||||
global_offset += len(grp)
|
||||
assert None not in idxs, f"some idxs are missing {idxs}"
|
||||
# this base thing is for image, we want the CAT to be a normal pointer
|
||||
return UOp(Ops.CAT, ptrdtype.base.ptr(size=ptrdtype.size, local=ptrdtype.local).vec(vec.dtype.count), tuple(ret)).gep(tuple(cast(list[int], idxs)))
|
||||
post_cat = UOp(Ops.PTRCAT, ptrdtype.base.ptr(size=ptrdtype.size, local=ptrdtype.local).vec(vec.dtype.count), tuple(ret))
|
||||
return post_cat.gep(tuple(cast(list[int], idxs)))
|
||||
|
||||
def cat_after_store(cat:UOp, data:UOp):
|
||||
# TODO: this is written in many places
|
||||
@@ -73,11 +75,11 @@ load_store_folding = PatternMatcher([
|
||||
lambda gep, ld: ld.replace(dtype=ld.dtype.scalar().vec(gep.dtype.count), src=(gep.src[0],)+ld.src[1:]).gep(gep.arg)),
|
||||
# GEP on data of STORE
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.GEP, name="gep"), UPat.var("st"))), gep_on_store),
|
||||
# put CAT after LOAD
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.CAT, name="cat"),), name="ld", allow_any_len=True),
|
||||
# put PTRCAT after LOAD
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.PTRCAT, name="cat"),), name="ld", allow_any_len=True),
|
||||
lambda cat,ld: UOp(Ops.CAT, ld.dtype, tuple(ld.replace(dtype=x.dtype.base, src=(x,)+ld.src[1:]) for x in cat.src))),
|
||||
# put CAT after STORE
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.CAT, name="cat"), UPat(name="data"))), cat_after_store),
|
||||
# put PTRCAT after STORE
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.PTRCAT, name="cat"), UPat(name="data"))), cat_after_store),
|
||||
])
|
||||
|
||||
# ***** image load valid simplification *****
|
||||
@@ -143,7 +145,11 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
|
||||
if (sz:=ls.src[0].dtype.count) == 1: return None
|
||||
lengths = []
|
||||
buf = idx.src[0]
|
||||
if buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType):
|
||||
must_divide = True
|
||||
if ctx is not None and ctx.device == "DSP":
|
||||
lengths = [128,64,32,16,8,4]
|
||||
must_divide = False
|
||||
elif buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType):
|
||||
pass
|
||||
elif isinstance(buf.dtype, ImageDType):
|
||||
lengths = [4]
|
||||
@@ -158,7 +164,7 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
|
||||
for fold_length in lengths:
|
||||
if global_offset+fold_length > sz: continue
|
||||
oidx = idx.src[1] + global_offset
|
||||
if oidx.simplify().divides(fold_length) is None: continue
|
||||
if must_divide and oidx.simplify().divides(fold_length) is None: continue
|
||||
lidx = buf.index(oidx, idx.src[2] if len(idx.src) > 2 else None)
|
||||
if fold_length > 1: lidx = lidx.cast(ptrdtype.base.vec(fold_length).ptr(size=ptrdtype.size, local=ptrdtype.local))
|
||||
if ls.op is Ops.STORE: ret.append(ls.replace(src=(lidx,ls.src[1].gep(tuple(range(global_offset, global_offset+fold_length))))+ls.src[2:]))
|
||||
|
||||
@@ -102,9 +102,6 @@ class Kernel:
|
||||
@property
|
||||
def membufs(self) -> list[UOp]: return dedup([x.src[0] for x in self.bufs if x.op in {Ops.LOAD, Ops.STORE}])
|
||||
|
||||
# TODO: these need more tests or it might silently be no-op
|
||||
def float4_axis(self, i:int): return [x-self.first_upcast for x in self.sts[i].unit_stride_axes() if x >= self.first_upcast and self.sts[i].shape[x]%4 == 0] # noqa: E501
|
||||
|
||||
def upcasted_axis(self, i:int) -> list[tuple[int, Optional[sint], bool]]:
|
||||
upcasted_shape, upcasted_stride = self.sts[i].shape[self.first_upcast:], self.sts[i].real_strides()[self.first_upcast:]
|
||||
assert all_int(upcasted_shape), f"cannot upcast a symbolic amount {upcasted_shape=}"
|
||||
@@ -461,7 +458,8 @@ class Kernel:
|
||||
|
||||
if self.opts.has_local and self.opts.has_shared and all_int(self.sts[0].shape[:self.first_reduce]):
|
||||
# are we grouping? (requires local shape support)
|
||||
if not self.float4_axis(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048: # noqa: E501
|
||||
if not [x for x in self.sts[0].unit_stride_axes() if x >= self.first_upcast and self.sts[0].shape[x]%4 == 0] and \
|
||||
self.first_reduce <= 2 and self.first_reduce < self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048:
|
||||
# TODO: use 1024 if it's allowed in a smarter way
|
||||
for sz in ([256, 16] if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
|
||||
if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts):
|
||||
@@ -503,10 +501,12 @@ class Kernel:
|
||||
for axis in to_upcast[::-1]: self.apply_opt(Opt(OptOps.UPCAST, axis, 0))
|
||||
|
||||
# potentially do more upcasts of non reduce axes based on a heuristic
|
||||
is_dsp = self.opts is not None and self.opts.device == "DSP"
|
||||
upcasted_axis: set[int] = set()
|
||||
while resolve(prod(self.sts[0].shape[:self.first_reduce]) >= 1024):
|
||||
xb_choices = []
|
||||
for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
|
||||
# consider all the non reduce axes, and a 3 or 4 reduce. (128 on the DSP)
|
||||
for axis, upcast_amount in itertools.product(range(self.first_reduce), ([128] if not len(upcasted_axis) else []) if is_dsp else [3,4]):
|
||||
# if we haven't upcasted it, it's not symbolic, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
|
||||
if axis not in upcasted_axis and isinstance(self.full_shape[axis], int) and self.full_shape[axis]%upcast_amount == 0 and any(st.views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index, st in enumerate(self.sts)): # noqa: E501
|
||||
xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in self.sts), sum(st.views[-1].strides[axis] for st in self.sts), axis, upcast_amount)) # noqa: E501
|
||||
|
||||
@@ -2,11 +2,12 @@
|
||||
import functools, itertools, operator, math
|
||||
from dataclasses import dataclass
|
||||
from typing import cast
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop
|
||||
from tinygrad.dtype import dtypes, PtrDType, least_upper_dtype
|
||||
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop, GroupOp
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.helpers import all_int, prod, partition, flatten, unwrap
|
||||
from tinygrad.helpers import all_int, prod, partition, flatten, unwrap, QUANTIZE
|
||||
from tinygrad.codegen.expander import expand_rewrite
|
||||
from tinygrad.codegen.symbolic import symbolic
|
||||
|
||||
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
||||
def get_contraction(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) -> list[list[int]]|None:
|
||||
@@ -156,9 +157,65 @@ pm_lowerer = PatternMatcher([
|
||||
# rewrite LOAD/STORE VIEW to LOAD/STORE with indexed
|
||||
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(), UPat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store),
|
||||
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)),
|
||||
(UPat(Ops.IGNORE, name="x"), lambda x: x.src[0]),
|
||||
])
|
||||
|
||||
# **** this is the "quantization preprocessor", it makes ONNX quantized models, and probably also others, actually use ints ****
|
||||
|
||||
def view_to_mask(x:UOp):
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
st = cast(ShapeTracker, x.st)
|
||||
if len(st.views) > 1: return None
|
||||
if st.views[-1].mask is None: return None
|
||||
return ShapeTracker((View(st.shape, (0,)*len(st.shape), 0, st.views[-1].mask, False),))
|
||||
|
||||
FP = (1 << 16)
|
||||
pm_quant = symbolic+PatternMatcher([
|
||||
# cast after add/mul
|
||||
(UPat.var("x").cast(dtypes.float32) + UPat.var("y").cast(dtypes.float32),
|
||||
lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))+y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)),
|
||||
(UPat.var("x").cast(dtypes.float32) * UPat.var("y").cast(dtypes.float32),
|
||||
lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))*y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)),
|
||||
# MUL after reduce
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("x") * UPat.cvar("c"),), name="r"), lambda x,c,r: r.replace(src=(x,))*c),
|
||||
# CAST after reduce (doesn't work if it's a size change)
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, src=(UPat.var("x"),)),), name="r"),
|
||||
lambda x,r: r.replace(dtype=x.dtype, src=(x,)).cast(r.dtype) if dtypes.is_float(r.dtype) else None),
|
||||
# x*c1 + y*c2 -> (x+y)*c1 (if c1 and c2 are close floats)
|
||||
(UPat.var("x")*UPat.cvar("c1", dtype=dtypes.floats) + UPat.var("y")*UPat.cvar("c2", dtype=dtypes.floats),
|
||||
lambda x,y,c1,c2: (x+y)*c1 if abs(c1.arg-c2.arg) < 1e-9 else None),
|
||||
# mul 0 * c1 is 0
|
||||
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
|
||||
UPat(Ops.LOAD, src=(UPat(), UPat(Ops.VIEW, name="v"))).cast(dtypes.int).cast(dtypes.float).named("ld"), lambda ld,v,c1: ld*c1),
|
||||
# mul (with plus) 0 * c1 is 0
|
||||
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
|
||||
(UPat(Ops.LOAD, src=(UPat(), UPat(Ops.VIEW, name="v"))).cast(dtypes.int) + \
|
||||
UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar(), UPat(Ops.CONST, arg=0))).cast(dtypes.float).named("ld"),
|
||||
lambda ld,v,c1: ld*c1),
|
||||
# fixed point mult, replace (x.float()*c1+c2).int() with an int expression
|
||||
((UPat.var("x").cast(dtypes.float)*UPat.cvar("c1")+UPat.cvar("c2")).cast(dtypes.int),
|
||||
lambda x,c1,c2: (x * (c1 * FP).cast(dtypes.int) + (c2 * FP).cast(dtypes.int)) // FP),
|
||||
# where move
|
||||
(UPat.var("valid").where(UPat.var("yes"), UPat(Ops.CONST, arg=0))*UPat.var("mul"), lambda valid, yes, mul:
|
||||
(yes*mul*valid.where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))) if yes.op is not Ops.CONST or yes.arg != 1 else None),
|
||||
((UPat.var("x")*UPat.cvar("c"))*(UPat.var().where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)).named("v")), lambda x,c,v: (x*v)*c),
|
||||
(UPat.var("x").cast().named('c') * UPat.var('valid').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)), lambda x,c,valid:
|
||||
(x*valid.where(UOp.const(x.dtype, 1), UOp.const(x.dtype, 0))).cast(c.dtype)),
|
||||
((UPat.var('x') * UPat.var('v1').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)) *
|
||||
UPat.var('v2').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0))).named("mul"), lambda x, mul, v1, v2:
|
||||
x * (v1&v2).where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))),
|
||||
# don't care
|
||||
(UPat(Ops.STORE, name="x"), lambda x:
|
||||
x.replace(src=(x.src[0], UOp(Ops.IGNORE, src=(x.src[1],), arg=mm), UOp(Ops.IGNORE, x.src[2].dtype, src=(x.src[2],), arg=mm),)) \
|
||||
if x.src[1].op is not Ops.IGNORE and (mm:=view_to_mask(x.src[1])) is not None else None),
|
||||
(UPat(Ops.IGNORE, src=(UPat((*GroupOp.ALU, Ops.CAST), name="alu"),), name="ig"),
|
||||
lambda ig,alu: alu.replace(src=tuple(UOp(Ops.IGNORE, x.dtype, (x,), ig.arg) for x in alu.src))),
|
||||
(UPat(Ops.IGNORE, src=(UPat.cvar("c"),), name="ig"), lambda ig, c: c),
|
||||
(UPat(Ops.IGNORE, src=(UPat(Ops.VALID, name="v"),), name="ig"), lambda ig, v: UOp.const(dtypes.bool, True) if v.src[0].arg == ig.arg else None),
|
||||
])
|
||||
|
||||
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp:
|
||||
if QUANTIZE and opts.device in {"CPU", "DSP"}: ast = graph_rewrite(ast, pm_quant, name="quantize")
|
||||
sink = graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))
|
||||
# expand_rewrite turns this into a vectorized program
|
||||
return expand_rewrite(sink)
|
||||
|
||||
@@ -113,6 +113,8 @@ sym = symbolic_simple+PatternMatcher([
|
||||
|
||||
# **** UOp realization
|
||||
|
||||
DONT_PUSH_VIEWS = {Ops.BUFFER, Ops.CONST, Ops.DEFINE_VAR, Ops.DEVICE, Ops.ASSIGN, Ops.SINK, Ops.CONTIGUOUS, Ops.COPY}
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class GrouperContext:
|
||||
assigns: dict[UOp, UOp] # maps realized buffers to assigns
|
||||
@@ -133,11 +135,11 @@ def realize_before_view(ctx:GrouperContext, view:UOp, src:UOp) -> None:
|
||||
|
||||
do_realize = PatternMatcher([
|
||||
# always realize SINK parents
|
||||
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.realizes.update((x.base, None) for x in s.src if x.base.op not in {Ops.CONST, Ops.BIND, Ops.BUFFER})),
|
||||
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.realizes.update((x, None) for x in s.src if x.op not in DONT_PUSH_VIEWS)),
|
||||
# always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
|
||||
(UPat({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}, name="tr"), realize),
|
||||
# realize before expand or unsafe pad ops
|
||||
(UPat(Ops.VIEW, name="view", src=(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE}, name="src"),)), realize_before_view),
|
||||
(UPat(Ops.VIEW, name="view", src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="src"),)), realize_before_view),
|
||||
# realize before COPY
|
||||
(UPat(Ops.COPY, src=(UPat(), UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW}, name="tr"))), realize),
|
||||
])
|
||||
@@ -221,7 +223,7 @@ def group_realizes(sink:UOp) -> dict[UOp, None]:
|
||||
if len(ctx.children[top_reduce]) == 1: del ctx.realizes[top_reduce]
|
||||
return ctx.realizes
|
||||
|
||||
# break the SINK into kernels
|
||||
# **** create kernels
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Kernel:
|
||||
@@ -241,6 +243,7 @@ def create_kernel(ctx:KernelContext, x:UOp, b:UOp):
|
||||
return UOp(Ops.ASSIGN, x.dtype, (buffer, kernel)).reshape(x.shape)
|
||||
|
||||
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER}
|
||||
|
||||
def append_to_kernel(ctx:KernelContext, x:UOp):
|
||||
new_srcs: list[UOp] = []
|
||||
metadata = dict.fromkeys(x.arg.metadata)
|
||||
@@ -266,32 +269,7 @@ create_kernels = merge_views+PatternMatcher([
|
||||
(UPat(Ops.SINK, name="x"), lambda x:x.replace(src=tuple(s.base for s in x.src)) if any(s.op is Ops.VIEW for s in x.src) else None),
|
||||
])
|
||||
|
||||
DONT_PUSH_VIEWS = {Ops.BUFFER, *GroupOp.Buffer, Ops.ASSIGN, Ops.SINK}
|
||||
|
||||
# **** fix kernel AST
|
||||
|
||||
# ** create buffer ops + enumerate buffers
|
||||
|
||||
add_buffer_ops = PatternMatcher([
|
||||
# LOAD
|
||||
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)), x.st.to_uop()))),
|
||||
# STORE (except for COPY/BUFFER_VIEW)
|
||||
(UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x),
|
||||
# partial assign can store to a non-contiguous ShapeTracker
|
||||
(UPat(Ops.SINK, src=(UPat(Ops.ASSIGN, name="x"),)),
|
||||
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), x.src[0].st.to_uop(), x.src[1]).sink()),
|
||||
# otherwise the store is contiguous
|
||||
(UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)),
|
||||
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()),
|
||||
# if the last child is a VIEW we merge the ShapeTrackers and store the base
|
||||
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="x"),)))),
|
||||
lambda x,b,st: UOp.store(b, (st.arg+x.st).to_uop(), x)),
|
||||
# remove CONTIGUOUS/DEVICE from kernel AST
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
|
||||
])
|
||||
|
||||
# ** push views to buffer ops
|
||||
# **** swizzler
|
||||
|
||||
def apply_swizzle(u:UOp) -> UOp:
|
||||
with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left)
|
||||
@@ -314,7 +292,7 @@ def reduceop_view_right(src:UOp, v:UOp, r:UOp):
|
||||
assert unwrap(v.st).contiguous and v.size == src.size, f"can't compute new axis for {src.shape} -> {r.shape}"
|
||||
return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, r.shape)) if s != u)).view(ShapeTracker.from_shape(r.shape))
|
||||
|
||||
def elementwise_view_right(root:UOp) -> UOp|None:
|
||||
def elementwise_view_right(root:UOp):
|
||||
if not (swizzles:=[x for x in root.src if x.op is Ops.VIEW and x.base.op not in DONT_PUSH_VIEWS]): return None
|
||||
assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}"
|
||||
# place view after applying the elementwise op
|
||||
@@ -323,7 +301,7 @@ def elementwise_view_right(root:UOp) -> UOp|None:
|
||||
# reshape to match downstream shapes
|
||||
return root.replace(src=tuple(new_src)).reshape(root.shape)
|
||||
|
||||
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
||||
def merge_double_reduce(root:UOp, first_reduce:UOp):
|
||||
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
|
||||
assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.src[0].toposort), "can't merge more than two reduceops at a time"
|
||||
return first_reduce.replace(arg=(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg))
|
||||
@@ -340,21 +318,42 @@ view_right = merge_views+PatternMatcher([
|
||||
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
|
||||
])
|
||||
|
||||
# ** unbind variables
|
||||
# **** unbind variables
|
||||
|
||||
def unbind_shapetracker(ctx:dict[Variable, int], x:UOp) -> UOp|None:
|
||||
def unbind_shapetracker(ctx:tuple[dict[Variable, int], tuple[UOp, ...]], x:UOp):
|
||||
st = unwrap(x.st).simplify()
|
||||
if any(x.op is Ops.BIND for x in st.vars()):
|
||||
st, var_vals = st.unbind()
|
||||
ctx.update(var_vals)
|
||||
return st.to_uop() if st != x.st else None
|
||||
ctx[0].update(var_vals)
|
||||
return x.replace(arg=st) if st != x.st else None
|
||||
|
||||
def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp):
|
||||
ctx[var.replace(src=())] = val.arg
|
||||
return var
|
||||
unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),])
|
||||
|
||||
# ** fix_kernel_ops
|
||||
# **** fix kernel AST
|
||||
|
||||
add_buffer_ops = PatternMatcher([
|
||||
# LOAD
|
||||
(UPat(Ops.BUFFER, name="x"), lambda ctx,x:UOp.load(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx[1].index(x)), x.st.to_uop(), dtype=x.dtype)),
|
||||
# STORE (except for COPY/BUFFER_VIEW)
|
||||
(UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x),
|
||||
# partial assign can store to a non-contiguous ShapeTracker
|
||||
(UPat(Ops.SINK, src=(UPat(Ops.ASSIGN, name="x"),)),
|
||||
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), x.src[0].st.to_uop(), x.src[1]).sink()),
|
||||
# otherwise the store is contiguous
|
||||
(UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)),
|
||||
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()),
|
||||
# VALID
|
||||
(UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),), name="view"), lambda x,view: x.valid(view.arg)),
|
||||
# if the last child is a VIEW we merge the ShapeTrackers and store the base
|
||||
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="x"),)))),
|
||||
lambda x,b,st: UOp.store(b, (st.arg+x.st).to_uop(), x)),
|
||||
# remove CONTIGUOUS/DEVICE from kernel AST
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
|
||||
])
|
||||
|
||||
def check_load_st(glbl:UOp, view:UOp):
|
||||
if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return
|
||||
@@ -369,8 +368,6 @@ def check_load_st(glbl:UOp, view:UOp):
|
||||
fix_kernel_ops = PatternMatcher([
|
||||
# BIND in shapetracker becomes DEFINE_VAR
|
||||
(UPat(Ops.VIEW, name="x"), unbind_shapetracker),
|
||||
# remove unmasked valid
|
||||
(UPat.where(UPat(Ops.VALID, name="valid"), UPat.cvar("x"), UPat()), lambda valid,x: x if all(v.mask is None for v in valid.st.views) else None),
|
||||
# no ImageDType after load
|
||||
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
|
||||
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
|
||||
@@ -387,13 +384,11 @@ def fix_kernel_ast(k:UOp, var_vals:dict[Variable, int]) -> UOp:
|
||||
ast = k.arg.ast.substitute(parents_rep)
|
||||
# unbind_vars + push views to edges
|
||||
ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=var_vals), view_right)
|
||||
# add buffer ops
|
||||
ast = graph_rewrite(ast, view_left+add_buffer_ops, bufs:=tuple(s.buf_uop for s in k.src), bottom_up=True)
|
||||
# add buffer ops + fix_kernel_ops
|
||||
ast = graph_rewrite(ast, merge_views+add_buffer_ops+fix_kernel_ops, ctx=(var_vals, bufs:=tuple(s.buf_uop for s in k.src)), bottom_up=True)
|
||||
if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}")
|
||||
# create subbuffer (TODO: this does not belong here)
|
||||
if ast.op is Ops.BUFFER_VIEW: buffers[bufs[0]] = (base:=bufs[1].buffer).view(ast.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
|
||||
# fix_kernel_ops
|
||||
ast = graph_rewrite(ast, fix_kernel_ops, var_vals)
|
||||
return k.replace(arg=Kernel(ast, k.arg.metadata))
|
||||
|
||||
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}
|
||||
|
||||
@@ -113,6 +113,7 @@ SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), Conte
|
||||
PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1)
|
||||
CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
|
||||
DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES = ContextVar("DONT_REALIZE_EXPAND", 0), ContextVar("DONT_GROUP_REDUCES", 0)
|
||||
QUANTIZE = ContextVar("QUANTIZE", 0)
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Metadata:
|
||||
|
||||
@@ -117,7 +117,7 @@ class Ops(FastEnum):
|
||||
REDUCE_AXIS = auto()
|
||||
|
||||
# helper ops
|
||||
GEP = auto(); VECTORIZE = auto(); CAT = auto() # noqa: E702
|
||||
GEP = auto(); VECTORIZE = auto(); CAT = auto(); PTRCAT = auto() # noqa: E702
|
||||
|
||||
# UnaryOps
|
||||
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
|
||||
@@ -154,6 +154,7 @@ class Ops(FastEnum):
|
||||
|
||||
# CUSTOMI is inline
|
||||
CUSTOM = auto(); CUSTOMI = auto() # noqa: E702
|
||||
IGNORE = auto()
|
||||
|
||||
class GroupOp:
|
||||
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG}
|
||||
@@ -280,6 +281,13 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
return nodes
|
||||
return _toposort(self, cache=set())
|
||||
|
||||
# returns map of UOps to their children in the graph rooted by self
|
||||
def get_children_map(self) -> dict[UOp, dict[UOp, None]]:
|
||||
ret: dict[UOp, dict[UOp, None]] = {}
|
||||
for u in self.toposort:
|
||||
for s in u.src: ret.setdefault(s, {})[u] = None
|
||||
return ret
|
||||
|
||||
@functools.cached_property
|
||||
def tuplize(self:UOp) -> tuple[int, Any, Optional[DType], tuple]: return (self.op.value, self.arg, self.dtype, tuple(x.tuplize for x in self.src))
|
||||
|
||||
@@ -895,10 +903,23 @@ def launch_viz(env_str:str, data:str):
|
||||
# *** simple graph rewrite engine ***
|
||||
|
||||
class RewriteContext:
|
||||
def __init__(self, pm, ctx=None):
|
||||
def __init__(self, pm, ctx=None, children=None):
|
||||
self.pm: PatternMatcher = pm
|
||||
self.ctx = ctx
|
||||
self.ctx = self if children is not None else ctx
|
||||
self.replace: dict[UOp, UOp] = {}
|
||||
self.children = children
|
||||
# TODO: is this function always right?
|
||||
def update_children(self):
|
||||
# add any new children from UOps that were replaced
|
||||
for u in self.replace.values():
|
||||
for s in u.src: self.children.setdefault(s, {})[u] = None
|
||||
# find any children that were replaced and replace them
|
||||
for k,v in self.children.items():
|
||||
new_child: dict[UOp, None] = {}
|
||||
for tv in v:
|
||||
while (nv:=self.replace.get(tv, None)) is not None and nv is not tv: tv = nv
|
||||
new_child[tv] = None
|
||||
self.children[k] = new_child
|
||||
def top_down_rewrite(self, n:UOp) -> UOp:
|
||||
if (rn := self.replace.get(n)) is not None: return rn
|
||||
new_src = tuple([self.top_down_rewrite(x) for x in n.src])
|
||||
@@ -913,15 +934,16 @@ class RewriteContext:
|
||||
self.replace[n] = ret = last_n if new_src == last_n.src else self.bottom_up_rewrite(UOp(last_n.op, last_n.dtype, new_src, last_n.arg))
|
||||
return ret
|
||||
|
||||
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None) -> UOp:
|
||||
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, track_children=False) -> UOp:
|
||||
if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0:
|
||||
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up, name=name))
|
||||
return RewriteContext(pm, ctx).bottom_up_rewrite(sink) if bottom_up else RewriteContext(pm, ctx).top_down_rewrite(sink)
|
||||
rewrite_ctx = RewriteContext(pm, ctx, children=sink.get_children_map() if track_children else None)
|
||||
return rewrite_ctx.bottom_up_rewrite(sink) if bottom_up else rewrite_ctx.top_down_rewrite(sink)
|
||||
|
||||
def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None) -> dict[UOp, UOp]:
|
||||
def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, track_children=False) -> dict[UOp, UOp]:
|
||||
if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0:
|
||||
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up, name=name))
|
||||
rewrite_ctx = RewriteContext(pm, ctx)
|
||||
rewrite_ctx = RewriteContext(pm, ctx, children=sink.get_children_map() if track_children else None)
|
||||
return {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k)) for k in list(sink.toposort)[::-1]}
|
||||
|
||||
def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x
|
||||
@@ -958,6 +980,9 @@ merge_views = PatternMatcher([
|
||||
# merge unmasked const views
|
||||
(UPat(Ops.VIEW, name="v", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="const"),)),
|
||||
lambda v,const: const.replace(src=(const.src[0].replace(arg=const.st+v.st),)) if all(x.mask is None for x in (const.st+v.st).views) else None),
|
||||
# merge view on load/store/valid
|
||||
(UPat(Ops.VIEW, name="v", src=(UPat((Ops.LOAD, Ops.STORE, Ops.VALID), name="b"),)),
|
||||
lambda b,v: b.replace(src=tuple((s.st+v.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
|
||||
# remove view if it's a contiguous and the shapes match
|
||||
(UPat(Ops.VIEW, name="v", src=(UPat(GroupOp.All-{Ops.DEVICE}, name="x"),)), lambda v,x: x if v.arg.contiguous and x.shape == v.shape else None),
|
||||
# remove mask if there's a zero in the masked dim
|
||||
@@ -967,13 +992,11 @@ merge_views = PatternMatcher([
|
||||
(UPat(GroupOp.Movement, src=(UPat.var("x"),), name="mop"), lambda mop,x: x.view(mop.st)),
|
||||
])
|
||||
|
||||
# push VIEW to parents
|
||||
view_left = merge_views+PatternMatcher([
|
||||
# VIEW(CONST) becomes VALID
|
||||
(UPat(Ops.VIEW, name="vm", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),)), lambda vm,x: x.valid(vm.st)),
|
||||
# VIEW before elementwise/buffer ops
|
||||
# do not push masked view before unsafe pad ops
|
||||
(UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.UnsafePad, name="e"),)),
|
||||
lambda e,vm: e.contiguous().view(vm.st) if any(v.mask is not None for v in vm.st.views) else None),
|
||||
# view before elementwise ops
|
||||
(UPat(Ops.VIEW, name="vm", src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e"),)),
|
||||
lambda e,vm: e.replace(src=tuple(s if s.st is None else s.view(vm.st) if s is s.base else s.base.view(s.st+vm.st) for s in e.src))),
|
||||
(UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.Buffer, name="b"),)),
|
||||
lambda b,vm: b.replace(src=tuple((s.st+vm.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
|
||||
])
|
||||
|
||||
@@ -197,8 +197,8 @@ class ClangRenderer(CStyleLanguage):
|
||||
if sys.platform == 'win32':
|
||||
kernel_prefix = "__attribute__((ms_abi)) "
|
||||
def render_vector_prefix(self, dt:DType) -> str:
|
||||
# round (down) to power of two
|
||||
alignment = 2**int(math.log2(dt.itemsize))
|
||||
# round (down) to power of two (this is actually the default clang behavior)
|
||||
alignment = 2**int(math.log2(dt.itemsize)) if getenv("ALIGNED", 1) else 1
|
||||
return f"typedef {self.render_dtype(dt.scalar())} {self.render_dtype(dt)} __attribute__((aligned({alignment}),vector_size({dt.itemsize})));"
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
||||
@@ -397,11 +397,18 @@ def cast_float_to_bf16(x: UOp) -> UOp:
|
||||
class AMDRenderer(CStyleLanguage):
|
||||
device = "AMD"
|
||||
shared_max = 65536
|
||||
# NOTE: this is only really needed on gfx12, even though gfx11 reports the same limitation
|
||||
global_max = (2147483647, 65535, 65535)
|
||||
# https://gpuopen.com/learn/wmma_on_rdna3/
|
||||
tensor_cores = [TensorCore(dims=(16,16,16), threads=32, elements_per_thread=(16,16,8), dtype_in=di, dtype_out=do,
|
||||
opts=("l0","l0","l0","l0","l1","u1","u1","u1"), swizzle=(((4,9,10,11,0),(1,2,3,5,6,7,8)), ((0,1,2,3,4),(9,10,11,5,6,7,8))))
|
||||
for di,do in [(dtypes.half,dtypes.float),(dtypes.half,dtypes.half)]]
|
||||
|
||||
def __init__(self, arch:str): # gfx942 => MI300, gfx1100 => RX 7900
|
||||
# TODO: fix tensor cores for gfx1201
|
||||
self.tensor_cores, self.arch = AMDRenderer.tensor_cores if arch != "gfx1201" else [], arch
|
||||
def __reduce__(self): return self.__class__, (self.arch,)
|
||||
|
||||
# language options
|
||||
ockl = [(f"__ockl_get_{name}", "unsigned int", "size_t", "const") for name in ["local_id", "group_id", "local_size"]]
|
||||
ocml = [(f"__ocml_{name}_f{n}", f"{dt}, {dt}" if "fmax" == name else dt, dt, atr)
|
||||
|
||||
@@ -685,7 +685,8 @@ class AMDDevice(HCQCompiled):
|
||||
self.dev_iface = PCIIface(self, self.device_id) if AMDDevice.driverless else KFDIface(self, self.device_id)
|
||||
self.target = int(self.dev_iface.props['gfx_target_version'])
|
||||
self.arch = "gfx%d%x%x" % (self.target // 10000, (self.target // 100) % 100, self.target % 100)
|
||||
if self.target < 100300 or self.target >= 120000: raise RuntimeError(f"Unsupported arch: {self.arch}")
|
||||
if self.target < 100300 or self.target >= 130000: raise RuntimeError(f"Unsupported arch: {self.arch}")
|
||||
if DEBUG >= 1: print(f"AMDDevice: opening {self.device_id} with target {self.target} arch {self.arch}")
|
||||
|
||||
self.max_cu_id = self.dev_iface.props['simd_count'] // self.dev_iface.props['simd_per_cu'] - 1
|
||||
self.max_wave_id = self.dev_iface.props['max_waves_per_simd'] * self.dev_iface.props['simd_per_cu'] - 1
|
||||
@@ -704,7 +705,7 @@ class AMDDevice(HCQCompiled):
|
||||
|
||||
self.sdma_queue = self.create_queue(kfd.KFD_IOC_QUEUE_TYPE_SDMA, 0x800000)
|
||||
|
||||
super().__init__(device, AMDAllocator(self), AMDRenderer(), AMDCompiler(self.arch), functools.partial(AMDProgram, self),
|
||||
super().__init__(device, AMDAllocator(self), AMDRenderer(self.arch), AMDCompiler(self.arch), functools.partial(AMDProgram, self),
|
||||
AMDSignal, AMDComputeQueue, AMDCopyQueue)
|
||||
|
||||
# Scratch setup
|
||||
|
||||
@@ -16,17 +16,22 @@ dsp_pm = PatternMatcher([
|
||||
lambda x: UOp(Ops.CUSTOM, dtypes.uchar.vec(128), src=tuple(x.gep(tuple(range(i, i+32))) for i in range(0, 128, 32)),
|
||||
arg="__builtin_HEXAGON_V6_vpackhub_sat_128B(__builtin_HEXAGON_V6_vpackwh_sat_128B({3}, {2}), __builtin_HEXAGON_V6_vpackwh_sat_128B({1}, {0}))")),
|
||||
(UPat(Ops.GEP, name="x"), lambda x: UOp(Ops.CUSTOM, x.dtype, x.src+x.src,
|
||||
"__builtin_shufflevector({0}, {1}, "+','.join([str(y) for y in x.arg])+")") if len(x.arg) > 1 else None),
|
||||
"__builtin_shufflevector({0}, {1}, "+','.join([str(y) for y in x.arg])+")") if len(x.arg) > 1 and x.src[0].dtype.count > 1 else None),
|
||||
])
|
||||
|
||||
dsp_pm_late = PatternMatcher([
|
||||
(UPat.var("x")+UPat(Ops.VECTORIZE, src=UPat.var("y")), lambda x,y: x+UOp(Ops.CUSTOMI, x.dtype, (y,), arg="{0}")),
|
||||
(UPat.var("x")*UPat(Ops.VECTORIZE, src=UPat.var("y")), lambda x,y: x*UOp(Ops.CUSTOMI, x.dtype, (y,), arg="{0}")),
|
||||
(UPat.var("x")//UPat(Ops.VECTORIZE, src=UPat.var("y")), lambda x,y: x//UOp(Ops.CUSTOMI, x.dtype, (y,), arg="{0}")),
|
||||
(UPat.var("x")+UPat(Ops.VECTORIZE,src=UPat.var("y")), lambda x,y: x+UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None),
|
||||
(UPat.var("x")*UPat(Ops.VECTORIZE,src=UPat.var("y")), lambda x,y: x*UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None),
|
||||
(UPat.var("x")//UPat(Ops.VECTORIZE,src=UPat.var("y")), lambda x,y: x//UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None),
|
||||
(UPat(Ops.DEFINE_ACC, src=(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST, arg=0)),), dtype=dtypes.uchar.vec(128), name="d", allow_any_len=True),
|
||||
lambda d: d.replace(src=(UOp(Ops.CUSTOMI, d.dtype, arg="__builtin_HEXAGON_V6_vd0_128B()"),)+d.src[1:])),
|
||||
])
|
||||
|
||||
# NOTE: this just increases readability of the generated code
|
||||
dsp_string = PatternMatcher([
|
||||
(UPat(Ops.CONST, (dtypes.int8, dtypes.uint8), name="x"), lambda ctx,x: str(x.arg)),
|
||||
])
|
||||
|
||||
class DSPRenderer(ClangRenderer):
|
||||
device = "DSP"
|
||||
supports_float4 = True
|
||||
@@ -34,6 +39,7 @@ class DSPRenderer(ClangRenderer):
|
||||
kernel_prefix = "__attribute__((noinline)) "
|
||||
pre_matcher = dsp_pm
|
||||
extra_matcher = dsp_pm_late+ClangRenderer.extra_matcher
|
||||
string_rewrite = dsp_string+ClangRenderer.string_rewrite
|
||||
type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" }
|
||||
code_for_op = {**ClangRenderer.code_for_op, Ops.SIN: lambda x,dtype: f"__builtin_sin({x})",
|
||||
Ops.LOG2: lambda x,dtype: f"__builtin_log2l({x})" if dtype == dtypes.float64 else f"__builtin_log2f({x})",
|
||||
|
||||
@@ -14,7 +14,7 @@ class HIPDevice(Compiled):
|
||||
self.device_id = int(device.split(":")[1]) if ":" in device else 0
|
||||
self.arch = init_c_var(hip.hipDeviceProp_t(), lambda x: check(hip.hipGetDeviceProperties(x, self.device_id))).gcnArchName.decode()
|
||||
self.time_event_st, self.time_event_en = [init_c_var(hip.hipEvent_t(), lambda x: hip.hipEventCreate(ctypes.byref(x), 0)) for _ in range(2)]
|
||||
super().__init__(device, HIPAllocator(self), HIPRenderer(), AMDCompiler(self.arch), functools.partial(HIPProgram, self))
|
||||
super().__init__(device, HIPAllocator(self), HIPRenderer(self.arch), AMDCompiler(self.arch), functools.partial(HIPProgram, self))
|
||||
def synchronize(self):
|
||||
check(hip.hipSetDevice(self.device_id))
|
||||
check(hip.hipDeviceSynchronize())
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import ctypes, collections, time, dataclasses, pathlib, fcntl, os
|
||||
import ctypes, collections, time, dataclasses, pathlib, fcntl, os, importlib
|
||||
from tinygrad.helpers import to_mv, mv_address, getenv, round_up, DEBUG, temp
|
||||
from tinygrad.runtime.autogen.am import am, mp_11_0
|
||||
from tinygrad.runtime.support.allocator import TLSFAllocator
|
||||
@@ -391,10 +391,10 @@ class AMDev:
|
||||
gc_info = am.struct_gc_info_v1_0.from_address(gc_addr:=ctypes.addressof(bhdr) + bhdr.table_list[am.GC].offset)
|
||||
self.gc_info = getattr(am, f"struct_gc_info_v{gc_info.header.version_major}_{gc_info.header.version_minor}").from_address(gc_addr)
|
||||
|
||||
def _ip_module(self, prefix:str, hwip):
|
||||
def _ip_module(self, prefix:str, hwip, prever_prefix:str=""):
|
||||
version = [self.ip_versions[hwip]//10000, (self.ip_versions[hwip]//100)%100, self.ip_versions[hwip]%100]
|
||||
for ver in [version, version[:2]+[0], version[:1]+[0, 0]]:
|
||||
try: return __import__(f"tinygrad.runtime.autogen.am.{prefix}_{ver[0]}_{ver[1]}_{ver[2]}", fromlist=[f"{prefix}_{ver[0]}_{ver[1]}_{ver[2]}"])
|
||||
try: return importlib.import_module(f"tinygrad.runtime.autogen.am.{prefix}_{prever_prefix}{ver[0]}_{ver[1]}_{ver[2]}")
|
||||
except ImportError: pass
|
||||
raise ImportError(f"am {self.devfmt}: failed to load {prefix} module with version {version}")
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import ctypes, time, contextlib
|
||||
from typing import Literal
|
||||
from tinygrad.runtime.autogen.am import am, smu_v13_0_0
|
||||
from tinygrad.runtime.autogen.am import am
|
||||
from tinygrad.helpers import to_mv, data64, lo32, hi32, DEBUG
|
||||
|
||||
class AM_IP:
|
||||
@@ -61,7 +61,14 @@ class AM_GMC(AM_IP):
|
||||
self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_START_ADDR", "_LO32", "_HI32", self.vm_base >> 12)
|
||||
self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_END_ADDR", "_LO32", "_HI32", self.vm_end >> 12)
|
||||
self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_BASE_ADDR", "_LO32", "_HI32", page_table.paddr | 1)
|
||||
self.adev.reg(f"reg{ip}VM_CONTEXT{vmid}_CNTL").write(0x1fffe00, enable_context=1, page_table_depth=(3 - page_table.lv))
|
||||
self.adev.reg(f"reg{ip}VM_CONTEXT{vmid}_CNTL").write(0x1800000, pde0_protection_fault_enable_interrupt=1, pde0_protection_fault_enable_default=1,
|
||||
dummy_page_protection_fault_enable_interrupt=1, dummy_page_protection_fault_enable_default=1,
|
||||
range_protection_fault_enable_interrupt=1, range_protection_fault_enable_default=1,
|
||||
valid_protection_fault_enable_interrupt=1, valid_protection_fault_enable_default=1,
|
||||
read_protection_fault_enable_interrupt=1, read_protection_fault_enable_default=1,
|
||||
write_protection_fault_enable_interrupt=1, write_protection_fault_enable_default=1,
|
||||
execute_protection_fault_enable_interrupt=1, execute_protection_fault_enable_default=1,
|
||||
enable_context=1, page_table_depth=(3 - page_table.lv))
|
||||
|
||||
def init_hub(self, ip:Literal["MM", "GC"]):
|
||||
# Init system apertures
|
||||
@@ -106,37 +113,38 @@ class AM_GMC(AM_IP):
|
||||
class AM_SMU(AM_IP):
|
||||
def __init__(self, adev):
|
||||
super().__init__(adev)
|
||||
self.smu_mod = self.adev._ip_module("smu", am.MP1_HWIP, prever_prefix='v')
|
||||
self.driver_table_paddr = self.adev.mm.palloc(0x4000, zero=not self.adev.partial_boot, boot=True)
|
||||
|
||||
def init(self):
|
||||
self._send_msg(smu_v13_0_0.PPSMC_MSG_SetDriverDramAddrHigh, hi32(self.adev.paddr2mc(self.driver_table_paddr)), poll=True)
|
||||
self._send_msg(smu_v13_0_0.PPSMC_MSG_SetDriverDramAddrLow, lo32(self.adev.paddr2mc(self.driver_table_paddr)), poll=True)
|
||||
self._send_msg(smu_v13_0_0.PPSMC_MSG_EnableAllSmuFeatures, 0, poll=True)
|
||||
self._send_msg(self.smu_mod.PPSMC_MSG_SetDriverDramAddrHigh, hi32(self.adev.paddr2mc(self.driver_table_paddr)), poll=True)
|
||||
self._send_msg(self.smu_mod.PPSMC_MSG_SetDriverDramAddrLow, lo32(self.adev.paddr2mc(self.driver_table_paddr)), poll=True)
|
||||
self._send_msg(self.smu_mod.PPSMC_MSG_EnableAllSmuFeatures, 0, poll=True)
|
||||
|
||||
def is_smu_alive(self):
|
||||
with contextlib.suppress(RuntimeError): self._send_msg(smu_v13_0_0.PPSMC_MSG_GetSmuVersion, 0, timeout=100)
|
||||
with contextlib.suppress(RuntimeError): self._send_msg(self.smu_mod.PPSMC_MSG_GetSmuVersion, 0, timeout=100)
|
||||
return self.adev.mmMP1_SMN_C2PMSG_90.read() != 0
|
||||
|
||||
def mode1_reset(self):
|
||||
if DEBUG >= 2: print(f"am {self.adev.devfmt}: mode1 reset")
|
||||
self._send_msg(smu_v13_0_0.PPSMC_MSG_Mode1Reset, 0, poll=True)
|
||||
self._send_msg(self.smu_mod.PPSMC_MSG_Mode1Reset, 0, poll=True)
|
||||
time.sleep(0.5) # 500ms
|
||||
|
||||
def read_table(self, table_t, cmd):
|
||||
self._send_msg(smu_v13_0_0.PPSMC_MSG_TransferTableSmu2Dram, cmd, poll=True)
|
||||
self._send_msg(self.smu_mod.PPSMC_MSG_TransferTableSmu2Dram, cmd, poll=True)
|
||||
return table_t.from_buffer(to_mv(self.adev.paddr2cpu(self.driver_table_paddr), ctypes.sizeof(table_t)))
|
||||
def read_metrics(self): return self.read_table(smu_v13_0_0.SmuMetricsExternal_t, smu_v13_0_0.TABLE_SMU_METRICS)
|
||||
def read_metrics(self): return self.read_table(self.smu_mod.SmuMetricsExternal_t, self.smu_mod.TABLE_SMU_METRICS)
|
||||
|
||||
def set_clocks(self, level):
|
||||
if not hasattr(self, 'clcks'):
|
||||
self.clcks = {}
|
||||
for clck in [smu_v13_0_0.PPCLK_GFXCLK, smu_v13_0_0.PPCLK_UCLK, smu_v13_0_0.PPCLK_FCLK, smu_v13_0_0.PPCLK_SOCCLK]:
|
||||
cnt = self._send_msg(smu_v13_0_0.PPSMC_MSG_GetDpmFreqByIndex, (clck<<16)|0xff, read_back_arg=True)&0x7fffffff
|
||||
self.clcks[clck] = [self._send_msg(smu_v13_0_0.PPSMC_MSG_GetDpmFreqByIndex, (clck<<16)|i, read_back_arg=True)&0x7fffffff for i in range(cnt)]
|
||||
for clck in [self.smu_mod.PPCLK_GFXCLK, self.smu_mod.PPCLK_UCLK, self.smu_mod.PPCLK_FCLK, self.smu_mod.PPCLK_SOCCLK]:
|
||||
cnt = self._send_msg(self.smu_mod.PPSMC_MSG_GetDpmFreqByIndex, (clck<<16)|0xff, read_back_arg=True)&0x7fffffff
|
||||
self.clcks[clck] = [self._send_msg(self.smu_mod.PPSMC_MSG_GetDpmFreqByIndex, (clck<<16)|i, read_back_arg=True)&0x7fffffff for i in range(cnt)]
|
||||
|
||||
for clck, vals in self.clcks.items():
|
||||
self._send_msg(smu_v13_0_0.PPSMC_MSG_SetSoftMinByFreq, clck << 16 | (vals[level]), poll=True)
|
||||
self._send_msg(smu_v13_0_0.PPSMC_MSG_SetSoftMaxByFreq, clck << 16 | (vals[level]), poll=True)
|
||||
self._send_msg(self.smu_mod.PPSMC_MSG_SetSoftMinByFreq, clck << 16 | (vals[level]), poll=True)
|
||||
self._send_msg(self.smu_mod.PPSMC_MSG_SetSoftMaxByFreq, clck << 16 | (vals[level]), poll=True)
|
||||
|
||||
def _smu_cmn_poll_stat(self, timeout=10000): self.adev.wait_reg(self.adev.mmMP1_SMN_C2PMSG_90, mask=0xFFFFFFFF, value=1, timeout=timeout)
|
||||
def _smu_cmn_send_msg(self, msg, param=0):
|
||||
@@ -206,7 +214,7 @@ class AM_GFX(AM_IP):
|
||||
cp_hqd_pq_rptr_report_addr_lo=lo32(rptr_addr), cp_hqd_pq_rptr_report_addr_hi=hi32(rptr_addr),
|
||||
cp_hqd_pq_wptr_poll_addr_lo=lo32(wptr_addr), cp_hqd_pq_wptr_poll_addr_hi=hi32(wptr_addr),
|
||||
cp_hqd_pq_doorbell_control=self.adev.regCP_HQD_PQ_DOORBELL_CONTROL.build(doorbell_offset=doorbell*2, doorbell_en=1),
|
||||
cp_hqd_pq_control=self.adev.regCP_HQD_PQ_CONTROL.build(rptr_block_size=5, unord_dispatch=1, queue_size=(ring_size//4).bit_length()-2),
|
||||
cp_hqd_pq_control=self.adev.regCP_HQD_PQ_CONTROL.build(rptr_block_size=5, unord_dispatch=0, queue_size=(ring_size//4).bit_length()-2),
|
||||
cp_hqd_ib_control=self.adev.regCP_HQD_IB_CONTROL.build(min_ib_avail_size=0x3), cp_hqd_hq_status0=0x20004000,
|
||||
cp_mqd_control=self.adev.regCP_MQD_CONTROL.build(priv_state=1), cp_hqd_vmid=0,
|
||||
cp_hqd_eop_base_addr_lo=lo32(eop_addr>>8), cp_hqd_eop_base_addr_hi=hi32(eop_addr>>8),
|
||||
@@ -290,7 +298,7 @@ class AM_IH(AM_IP):
|
||||
self.adev.reg(f"regIH_RB_WPTR{suf}").write(0)
|
||||
self.adev.reg(f"regIH_RB_RPTR{suf}").write(0)
|
||||
|
||||
self.adev.reg(f"regIH_DOORBELL_RPTR{suf}").write(((am.AMDGPU_NAVI10_DOORBELL_IH + ring_id) * 2), enable=1)
|
||||
self.adev.reg(f"regIH_DOORBELL_RPTR{suf}").write(offset=(am.AMDGPU_NAVI10_DOORBELL_IH + ring_id) * 2, enable=1)
|
||||
|
||||
self.adev.regIH_STORM_CLIENT_LIST_CNTL.update(client18_is_storm_client=1)
|
||||
self.adev.regIH_INT_FLOOD_CNTL.update(flood_cntl_enable=1)
|
||||
|
||||
Reference in New Issue
Block a user