Merge branch 'master' into retinanet_mlperf

This commit is contained in:
Francis Lata
2025-03-17 08:24:35 -07:00
23 changed files with 409 additions and 164 deletions

View File

@@ -423,6 +423,8 @@ jobs:
run: LLVM=1 python -m pytest -n=auto test/external/external_test_onnx_backend.py --durations=20
- name: Test Additional ONNX Ops (CPU)
run: CPU=1 PYTHONPATH=. python3 test/external/external_test_onnx_ops.py
- name: Test Quantize ONNX
run: CPU=1 PYTHONPATH=. python3 test/test_quantize_onnx.py
- name: Run CLOUD=1 Test
run: |
CLOUDDEV=CPU CLOUD=1 python3 test/test_tiny.py
@@ -467,7 +469,7 @@ jobs:
testdsp:
name: Linux (DSP)
runs-on: ubuntu-24.04
timeout-minutes: 10
timeout-minutes: 15
steps:
- name: Checkout Code
uses: actions/checkout@v4

View File

@@ -928,9 +928,9 @@ def train_bert():
# ** hyperparameters **
BS = config["GLOBAL_BATCH_SIZE"] = getenv("BS", 11 * len(GPUS) if dtypes.default_float in (dtypes.float16, dtypes.bfloat16) else 8 * len(GPUS))
EVAL_BS = config["EVAL_BS"] = getenv("EVAL_BS", 1 * len(GPUS))
max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.0002 * math.sqrt(BS/96))
max_lr = config["OPT_BASE_LEARNING_RATE"] = getenv("OPT_BASE_LEARNING_RATE", 0.000175 * math.sqrt(BS/96))
train_steps = config["TRAIN_STEPS"] = getenv("TRAIN_STEPS", 3630000 // BS)
train_steps = config["TRAIN_STEPS"] = getenv("TRAIN_STEPS", 3300000 // BS)
warmup_steps = config["NUM_WARMUP_STEPS"] = getenv("NUM_WARMUP_STEPS", 1)
max_eval_steps = config["MAX_EVAL_STEPS"] = getenv("MAX_EVAL_STEPS", (10000 + EVAL_BS - 1) // EVAL_BS) # EVAL_BS * MAX_EVAL_STEPS >= 10000
eval_step_freq = config["EVAL_STEP_FREQ"] = getenv("EVAL_STEP_FREQ", int((math.floor(0.05 * (230.23 * BS + 3000000) / 25000) * 25000) / BS)) # Round down

View File

@@ -1,5 +1,6 @@
import numpy as np
from tinygrad.helpers import getenv
from tinygrad.dtype import _to_np_dtype
from tinygrad import dtypes, Tensor
dtype_in = dtypes.half if getenv("HALF") else dtypes.bfloat16 if getenv("BFLOAT16") else dtypes.float
@@ -13,12 +14,17 @@ K = getenv("K", N)
CNT = getenv("CNT", 10)
ATOL = getenv("ATOL", 1e-4)
RTOL = getenv("RTOL", 3e-2)
INT_LOW = getenv("INT_LOW", 0)
INT_HIGH = getenv("INT_HIGH", 10)
if __name__ == "__main__":
def init_matrix(rows, cols):
rng = np.random.default_rng()
# NOTE: numpy does not support bfloat16
if (np_dtype := _to_np_dtype(dtype_in)) is None: np_dtype = np.float32
if dtype_in in dtypes.ints:
return Tensor.randint((rows, cols), dtype=dtype_in).realize()
return Tensor.rand(rows, cols, dtype=dtype_in).realize()
return Tensor(rng.integers(INT_LOW, INT_HIGH, (rows, cols), dtype=np_dtype)).realize()
return Tensor(rng.random((rows, cols), dtype=np.float32).astype(np_dtype)).cast(dtype_in).realize()
a, b = init_matrix(M, K), init_matrix(K, N)
for i in range(CNT):

25
extra/hip_large_kernel.py Normal file
View File

@@ -0,0 +1,25 @@
from tinygrad.device import Device, Buffer
from tinygrad.dtype import dtypes, _to_np_dtype
dev = Device.default
mbin = dev.compiler.compile("""
typedef long unsigned int size_t;
extern "C" __attribute__((device, const)) size_t __ockl_get_group_id(unsigned int);
extern "C" __attribute__((global)) void __attribute__((amdgpu_flat_work_group_size(1, 1))) write_ones(signed char* data0) {
int gidx0 = __ockl_get_group_id(0); /* 16 */
int gidx1 = __ockl_get_group_id(1); /* 1026048 */
*(data0+(gidx0+gidx1*1)) = 1;
}
""")
dev.compiler.disassemble(mbin)
buf0 = Buffer(Device.DEFAULT, 1*65537, dtypes.uint8).ensure_allocated()
prg = dev.runtime("write_ones", mbin)
prg(buf0._buf, global_size=(1,65537,1), local_size=(1,1,1), wait=True)
import numpy as np
def to_np(buf): return np.frombuffer(buf.as_buffer().cast(buf.dtype.base.fmt), dtype=_to_np_dtype(buf.dtype.base))
big = to_np(buf0)
print(big)
print((big-1).nonzero())

View File

@@ -728,7 +728,11 @@ def get_onnx_ops():
def QuantizeLinear(x:Tensor, y_scale:Tensor, y_zero_point:Tensor|int=0, axis:int=1, block_size:int=0, output_dtype:int=0, saturate=1):
out_dtype = y_zero_point.dtype if isinstance(y_zero_point, Tensor) else dtype_parse(output_dtype) if output_dtype else dtypes.uint8
y_scale, y_zero_point = _prepare_quantize(x, y_scale, y_zero_point, axis, block_size)
return _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype).contiguous()
if out_dtype == dtypes.uchar:
# this appears to work in practice, at least for uchar out_dtype. it folds with the quantize stuff
return _clamp_cast((x / y_scale + 0.4999999 + y_zero_point).int(), out_dtype).contiguous()
else:
return _clamp_cast(((x / y_scale).round() + y_zero_point), out_dtype).contiguous()
def DynamicQuantizeLinear(x: Tensor):
# only support uint8

View File

@@ -1,6 +1,7 @@
import pickle, sys
from dataclasses import replace
from tinygrad import Device
from tinygrad import Device, Context
from tinygrad.device import Buffer
from tinygrad.helpers import getenv
from tinygrad.engine.jit import TinyJit
from tinygrad.engine.realize import CompiledRunner
@@ -8,10 +9,11 @@ from tinygrad.renderer import ProgramSpec
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
if __name__ == "__main__":
with open(sys.argv[1], "rb") as f:
fxn: TinyJit = pickle.load(f)
print(f"{f.tell()/1e6:.2f}M loaded")
print(type(fxn))
with Context(DEBUG=0):
with open(sys.argv[1], "rb") as f:
fxn: TinyJit = pickle.load(f)
print(f"{f.tell()/1e6:.2f}M loaded")
print(type(fxn))
knum = 1
for ei in fxn.captured.jit_cache:
@@ -21,15 +23,33 @@ if __name__ == "__main__":
p: ProgramSpec = ei.prg.p
k = Kernel(p.ast, Device["DSP"].renderer)
if not getenv("NOOPT"):
if knum == 2:
if knum in [6,7,9,11]:
k.apply_opt(Opt(OptOps.PADTO, 1, 128))
k.apply_opt(Opt(OptOps.UPCAST, 1, 128))
elif knum in [5,8]:
k.apply_opt(Opt(op=OptOps.UNROLL, axis=1, arg=0))
k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=0))
k.apply_opt(Opt(OptOps.PADTO, 2, 128))
k.apply_opt(Opt(OptOps.UPCAST, 2, 128))
elif knum == 2:
k.apply_opt(Opt(op=OptOps.UNROLL, axis=1, arg=0))
k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=0))
k.apply_opt(Opt(OptOps.PADTO, 2, 128))
k.apply_opt(Opt(OptOps.UPCAST, 2, 128))
#k.apply_opt(Opt(op=OptOps.UPCAST, axis=1, arg=4))
elif knum == 1:
k.apply_opt(Opt(op=OptOps.UNROLL, axis=2, arg=0))
k.apply_opt(Opt(op=OptOps.UNROLL, axis=1, arg=0))
#k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=0))
k.apply_opt(Opt(OptOps.PADTO, 2, 128))
k.apply_opt(Opt(OptOps.UPCAST, 2, 128))
elif knum == 3:
k.apply_opt(Opt(op=OptOps.UNROLL, axis=0, arg=4))
k.apply_opt(Opt(OptOps.UPCAST, 1, 128))
else:
k.hand_coded_optimizations()
#if knum in [5]: k.apply_opt(Opt(OptOps.UPCAST, 1, 2))
p2 = k.to_program()
new_ei = replace(ei, prg=CompiledRunner(p2))
new_ei = replace(ei, prg=CompiledRunner(p2), bufs=[Buffer("DSP", 1024+b.size*2, b.dtype).view(b.size, b.dtype, 512) for b in ei.bufs])
new_ei.run()
knum += 1

View File

@@ -25,23 +25,29 @@ class TestArange(unittest.TestCase):
return p.estimates.ops
def test_complexity(self, opts=None, limit=None):
# add 1 to avoid divide by 0. arange is 0 flops now!
f1 = self._get_flops(256, opts) + 1
f2 = self._get_flops(2560, opts) + 1
f1 = self._get_flops(256, opts)
f2 = self._get_flops(2560, opts)
print(f"{f1=}, {f2=}")
assert (f1 < 6000 and f2 < 6000) or (f2 / f1 < 16), f"bad complexity, flops {f2/f1:.1f}X while inputs 10X"
# add 1 to avoid divide by 0. arange is 0 flops now!
assert (f1 < 6000 and f2 < 6000) or ((f2+1) / (f1+1) < 16), f"bad complexity, flops {(f2+1) / (f1+1):.1f}X while inputs 10X"
if limit is not None and not getenv("PTX"):
# PTX counts index ALU in flops
assert f1 <= limit, f"{f1=}, {limit=}"
def test_complexity_w_upcast(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4)], limit=1)
def test_complexity_w_unroll2(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 2)], limit=1)
def test_complexity_w_unroll4(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 4)], limit=1)
def test_complexity_w_unroll8(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 8)], limit=1)
def test_complexity_w_upcast_and_unroll(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], limit=1)
def test_complexity_w_upcast(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4)], limit=0)
def test_complexity_w_unroll2(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 2)], limit=0)
def test_complexity_w_unroll4(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 4)], limit=0)
def test_complexity_w_unroll8(self): return self.test_complexity([Opt(OptOps.UNROLL, 0, 8)], limit=0)
def test_complexity_w_upcast_and_unroll(self): return self.test_complexity([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], limit=0)
@unittest.skip("doesn't work yet")
def test_complexity_w_local_and_padto(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(op=OptOps.PADTO, axis=1, arg=32)])
if Device.default.renderer.has_local:
# TODO: fix limit
def test_complexity_w_group(self): return self.test_complexity([Opt(OptOps.GROUP, 0, 16)], limit=81920)
def test_complexity_w_group_top(self): return self.test_complexity([Opt(OptOps.GROUPTOP, 0, 16)], limit=106496)
def test_complexity_w_local(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16)], limit=0)
@unittest.skip("doesn't work yet")
def test_complexity_w_local_and_padto(self): return self.test_complexity([Opt(OptOps.LOCAL, 0, 16), Opt(OptOps.PADTO, axis=1, arg=32)])
def test_all_opts(self, opts=None, exclude=None):
k = Kernel(Tensor.arange(256).schedule()[-1].ast)

View File

@@ -820,6 +820,7 @@ class TestAutoCastType(unittest.TestCase):
np.testing.assert_allclose(t.grad.numpy(), [1, 0])
@unittest.skipIf(Device.DEFAULT == "PYTHON", "very slow")
@unittest.skipIf(CI and Device.DEFAULT == "AMD", "very slow")
@unittest.skipIf(Device.DEFAULT == "WEBGPU", "Binding size is larger than the maximum storage buffer binding size")
@unittest.skipUnless(is_dtype_supported(dtypes.half), "need half")
def test_mean_half_precision_underflow(self):

View File

@@ -2,8 +2,9 @@ import numpy as np
import unittest
from dataclasses import replace
from tinygrad import Tensor, Context, Device, dtypes
from tinygrad.ops import Ops
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
from tinygrad.engine.realize import CompiledRunner, ExecItem
from tinygrad.engine.realize import CompiledRunner, ExecItem, lower_schedule_item
N = 512
@@ -44,24 +45,46 @@ def sexec(out:Tensor, opts:list[Opt], replace_src=None, run_count=3):
ei = ExecItem(CompiledRunner(prg), [x.ensure_allocated() for x in si.bufs], si.metadata)
for _ in range(run_count): ei.run(wait=True)
def get_quantized_model(sz):
from onnxruntime.quantization import quantize_static, QuantFormat, QuantType, CalibrationDataReader
class FakeDataReader(CalibrationDataReader):
def __init__(self): self.cnt = 0
def get_next(self) -> dict:
self.cnt += 1
if self.cnt == 100: return None
return {"input": np.random.uniform(size=(sz, sz)).astype(np.float32)}
out_file = "/tmp/test_out.onnx"
quantize_static(create_gemm_model("/tmp/test_in.onnx", sz, sz, sz), out_file,
FakeDataReader(), quant_format=QuantFormat.QDQ, per_channel=False, reduce_range=False,
activation_type=QuantType.QUInt8, weight_type=QuantType.QInt8,
extra_options={"ActivationSymmetric": False})
return out_file
@unittest.skipIf(Device.DEFAULT != "CPU", "only tests for CPU")
class TestQuantizeOnnxCPU(unittest.TestCase):
def test_quant_128(self, sz=128):
try:
import onnx
except ImportError:
raise unittest.SkipTest()
from extra.onnx import OnnxRunner
out_file = get_quantized_model(sz)
onnx_model = onnx.load(out_file)
run_onnx = OnnxRunner(onnx_model)
inp = Tensor(np.random.uniform(size=(sz, sz)).astype(np.float32))
with Context(DONT_REALIZE_EXPAND=1, QUANTIZE=1):
sched = run_onnx({"input":inp})["output"].schedule()
ei = lower_schedule_item(sched[-2])
daccs = [u for u in ei.prg.p.uops if u.op is Ops.DEFINE_ACC]
assert all(u.dtype.scalar() is dtypes.int for u in daccs)
@unittest.skipIf(Device.DEFAULT != "DSP", "only tests for DSP")
class TestQuantizeOnnx(unittest.TestCase):
def test_quant_128(self): self.test_quant(128)
def test_quant(self, sz=512):
from onnxruntime.quantization import quantize_static, QuantFormat, QuantType, CalibrationDataReader
from examples.benchmark_onnx import load_onnx_model
class FakeDataReader(CalibrationDataReader):
def __init__(self): self.cnt = 0
def get_next(self) -> dict:
self.cnt += 1
if self.cnt == 100: return None
return {"input": np.random.uniform(size=(sz, sz)).astype(np.float32)}
out_file = "/tmp/test_out.onnx"
# divide is ~1500-2000 without reduce_range, 750-900 with it
quantize_static(create_gemm_model("/tmp/test_in.onnx", sz, sz, sz), out_file,
FakeDataReader(), quant_format=QuantFormat.QDQ, per_channel=False, reduce_range=False,
activation_type=QuantType.QUInt8, weight_type=QuantType.QInt8,
extra_options={"ActivationSymmetric": False})
out_file = get_quantized_model(sz)
run_onnx_jit, _ = load_onnx_model(out_file)
with Context(DONT_REALIZE_EXPAND=1):
run_onnx_jit(input=Tensor(np.random.uniform(size=(sz, sz)).astype(np.float32)))

View File

@@ -0,0 +1,51 @@
import unittest
from tinygrad import Tensor
from tinygrad.ops import PatternMatcher, Ops, UPat, graph_rewrite, RewriteContext, UOp
class TestRewriteTrackedChildren(unittest.TestCase):
def test_children_in_context(self):
def print_children(ctx:RewriteContext, sink:UOp):
view_w_child = sink.src[0].src[0].src[0]
assert view_w_child.op is Ops.VIEW
assert set([x.arg for x in ctx.children[view_w_child]]) == set((2,3))
ctx.update_children()
assert set([x.arg for x in ctx.children[view_w_child]]) == set((3,4))
# this is the 3
assert len(ctx.children[sink.src[0].src[1]]) == 1
assert next(iter(ctx.children[sink.src[0].src[1]])).op is Ops.ADD
# this is the 4
assert len(ctx.children[sink.src[0].src[0]]) == 1
assert next(iter(ctx.children[sink.src[0].src[0]])).op is Ops.ADD
rewrite = PatternMatcher([
(UPat(Ops.CONST, arg=2, name="x"), lambda x: x.replace(arg=4)),
(UPat(Ops.SINK, name="sink"), print_children)
])
a = Tensor(2)
b = Tensor(3)
c = a + b
sink = c.lazydata.sink()
sink = graph_rewrite(sink, rewrite, track_children=True)
def test_simple_child(self):
rewrite = PatternMatcher([
(UPat(Ops.CONST, arg=2, name="x"), lambda x: x.replace(arg=4)),
])
a = Tensor(2)
b = Tensor(3)
c = a + b
sink = c.lazydata
view_w_child = a.lazydata.src[0]
print([x().arg for x in view_w_child.children])
print([x.arg for x in sink.get_children_map()[view_w_child]])
self.assertSetEqual(set([x.arg for x in sink.get_children_map()[view_w_child]]), set((2,3)))
# children can either be added to or removed from the map with graph_rewrite
# added to is easy to detect, just hook the UOp constructor
# when are children removed?
# * if a rewrite rule returns a UOp, the matched node is removed from the graph
sink = graph_rewrite(sink, rewrite)
print([x().arg for x in view_w_child.children])
print([x.arg for x in sink.get_children_map()[view_w_child]])
self.assertSetEqual(set([x.arg for x in sink.get_children_map()[view_w_child]]), set((3,4)))
if __name__ == '__main__':
unittest.main()

View File

@@ -14,7 +14,7 @@ from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, merge_views, GroupOp
from tinygrad.codegen.symbolic import symbolic_simple
from tinygrad.spec import type_verify, shape_spec
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, unwrap, all_same, temp
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, view_right, view_left, sym
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
from extra.models.llama import precompute_freqs_cis
@@ -98,6 +98,12 @@ class TestSchedule(unittest.TestCase):
a.realize()
assert not a.lazydata.is_realized
def test_simplify_padded_const(self):
a = Tensor.empty(1022).cummax(axis=0)
sched = check_schedule(a, 5)
ast = sched[0].ast
self.assertLessEqual(len([u for u in ast.toposort if u.op is Ops.WHERE]), 6)
def test_basic_binop_fusion(self):
a = Tensor.empty(10)
b = Tensor.empty(10)
@@ -1851,44 +1857,31 @@ class TestIndexing(unittest.TestCase):
def test_recursive_swizzle(self):
a = Tensor([1,2,3,4]).realize()
for _ in range(24): a = a + a
ast = a.schedule()[0].ast
swizzle = ast.src[0].src[2].reshape((4, 1))
new_uop = swizzle_rewrite(swizzle)
new_uop = swizzle_rewrite(a.lazydata.reshape((4, 1)))
self.assertEqual(new_uop.st, ShapeTracker.from_shape((4,)).reshape((4, 1)))
self.assertEqual(swizzle_cnt(new_uop), 0)
def test_no_rewrite_elementwise(self):
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)]
ld1 = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
ld2 = UOp(Ops.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((32, 32)).to_uop()))
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((32, 32)).to_uop(), ld1+ld2)),))
rsink = graph_rewrite(sink, view_right)
self.assertEqual(rsink.key, sink.key)
a = Tensor.empty(32, 32)
b = Tensor.empty(32, 32)
sink = (a+b).schedule()[0].ast
self.assertEqual(swizzle_cnt(sink), 0)
def test_simple_store_reshape(self):
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0, 1)))
r = UOp(Ops.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(()))
r = r + r.const_like(2).replace(src=(unwrap(r.st).to_uop(),))
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),))
rsink = graph_rewrite(sink, view_right)
# this AST first needs to swizzle, but it doesn't have implicit movementops
self.assertEqual(swizzle_cnt(sink), 1)
verify_ast(rsink)
a = Tensor.empty(32, 32).sum(axis=1)+Tensor.empty(1,32)
ast = a.schedule()[0].ast
self.assertEqual(ast.shape, (32, 1))
self.assertEqual(a.lazydata.shape, (1, 32))
def test_no_reshape_reduceop(self):
bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)]
ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop()))
r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0, 1)))
sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((1, 1)).to_uop(), r)),))
rsink = graph_rewrite(sink, view_right)
verify_ast(sink)
self.assertEqual(sink.key, rsink.key)
a = Tensor.empty(32, 32).sum(axis=(1,)).contiguous()
ast = a.schedule()[0].ast
self.assertEqual(ast.shape, (32, 1))
self.assertEqual(a.lazydata.shape, (32,))
@track_rewrites(named=True)
def swizzle_rewrite(u:UOp) -> UOp: return graph_rewrite(graph_rewrite(u, view_left), view_right)
def swizzle_cnt(u:UOp) -> int: return len([x for x in u.toposort if x.op is Ops.VIEW and len(x.src) != 0])
def swizzle_cnt(u:UOp) -> int: return len([x for x in u.toposort if x.op is Ops.VIEW and len(x.src) != 0 and x.src[0].op is not Ops.BUFFER])
class TestSwizzle(unittest.TestCase):
def test_swizzle_simple(self):
@@ -1965,6 +1958,16 @@ class TestSwizzle(unittest.TestCase):
t = a_reduce+b_reduce
with Context(DONT_GROUP_REDUCES=1, DONT_REALIZE_EXPAND=1): run_schedule(check_schedule(t, 1))
def test_unsafe_pad(self):
x = Tensor.full((2,2), 1.0).contiguous()
y = x*x.sum((1,)).reciprocal()
t = y.pad(((0,1),None)).contiguous()
swizzled = swizzle_rewrite(t.lazydata)
sched = check_schedule(swizzled.sink(), 3)
output_buffer = sched[-1].bufs[0]
run_schedule(sched)
self.assertListEqual(output_buffer.as_buffer().cast("f").tolist(), [0.5, 0.5, 0.5, 0.5, 0., 0.])
def store_val(si:ScheduleItem): return si.ast.src[0].src[2]
zero_pm = UPat(Ops.CONST, arg=0)
class TestView(unittest.TestCase):

View File

@@ -12,6 +12,7 @@ from tinygrad.renderer import Renderer
# ***** load/store grouping *****
def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
if getenv("UNSAFE_DISABLE_MASK", 0): mask = None
# first, extract all the relevant offsets
offsets_rootsrc: defaultdict[Any, dict[int, list[int]]] = defaultdict(dict)
for i in range(vec.dtype.count):
@@ -45,7 +46,8 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
global_offset += len(grp)
assert None not in idxs, f"some idxs are missing {idxs}"
# this base thing is for image, we want the CAT to be a normal pointer
return UOp(Ops.CAT, ptrdtype.base.ptr(size=ptrdtype.size, local=ptrdtype.local).vec(vec.dtype.count), tuple(ret)).gep(tuple(cast(list[int], idxs)))
post_cat = UOp(Ops.PTRCAT, ptrdtype.base.ptr(size=ptrdtype.size, local=ptrdtype.local).vec(vec.dtype.count), tuple(ret))
return post_cat.gep(tuple(cast(list[int], idxs)))
def cat_after_store(cat:UOp, data:UOp):
# TODO: this is written in many places
@@ -73,11 +75,11 @@ load_store_folding = PatternMatcher([
lambda gep, ld: ld.replace(dtype=ld.dtype.scalar().vec(gep.dtype.count), src=(gep.src[0],)+ld.src[1:]).gep(gep.arg)),
# GEP on data of STORE
(UPat(Ops.STORE, src=(UPat(Ops.GEP, name="gep"), UPat.var("st"))), gep_on_store),
# put CAT after LOAD
(UPat(Ops.LOAD, src=(UPat(Ops.CAT, name="cat"),), name="ld", allow_any_len=True),
# put PTRCAT after LOAD
(UPat(Ops.LOAD, src=(UPat(Ops.PTRCAT, name="cat"),), name="ld", allow_any_len=True),
lambda cat,ld: UOp(Ops.CAT, ld.dtype, tuple(ld.replace(dtype=x.dtype.base, src=(x,)+ld.src[1:]) for x in cat.src))),
# put CAT after STORE
(UPat(Ops.STORE, src=(UPat(Ops.CAT, name="cat"), UPat(name="data"))), cat_after_store),
# put PTRCAT after STORE
(UPat(Ops.STORE, src=(UPat(Ops.PTRCAT, name="cat"), UPat(name="data"))), cat_after_store),
])
# ***** image load valid simplification *****
@@ -143,7 +145,11 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
if (sz:=ls.src[0].dtype.count) == 1: return None
lengths = []
buf = idx.src[0]
if buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType):
must_divide = True
if ctx is not None and ctx.device == "DSP":
lengths = [128,64,32,16,8,4]
must_divide = False
elif buf.dtype.base != dtypes.float and buf.dtype.base != dtypes.half and not isinstance(buf.dtype, ImageDType):
pass
elif isinstance(buf.dtype, ImageDType):
lengths = [4]
@@ -158,7 +164,7 @@ def split_load_store(ctx:Renderer|None, ls:UOp, idx:UOp):
for fold_length in lengths:
if global_offset+fold_length > sz: continue
oidx = idx.src[1] + global_offset
if oidx.simplify().divides(fold_length) is None: continue
if must_divide and oidx.simplify().divides(fold_length) is None: continue
lidx = buf.index(oidx, idx.src[2] if len(idx.src) > 2 else None)
if fold_length > 1: lidx = lidx.cast(ptrdtype.base.vec(fold_length).ptr(size=ptrdtype.size, local=ptrdtype.local))
if ls.op is Ops.STORE: ret.append(ls.replace(src=(lidx,ls.src[1].gep(tuple(range(global_offset, global_offset+fold_length))))+ls.src[2:]))

View File

@@ -102,9 +102,6 @@ class Kernel:
@property
def membufs(self) -> list[UOp]: return dedup([x.src[0] for x in self.bufs if x.op in {Ops.LOAD, Ops.STORE}])
# TODO: these need more tests or it might silently be no-op
def float4_axis(self, i:int): return [x-self.first_upcast for x in self.sts[i].unit_stride_axes() if x >= self.first_upcast and self.sts[i].shape[x]%4 == 0] # noqa: E501
def upcasted_axis(self, i:int) -> list[tuple[int, Optional[sint], bool]]:
upcasted_shape, upcasted_stride = self.sts[i].shape[self.first_upcast:], self.sts[i].real_strides()[self.first_upcast:]
assert all_int(upcasted_shape), f"cannot upcast a symbolic amount {upcasted_shape=}"
@@ -461,7 +458,8 @@ class Kernel:
if self.opts.has_local and self.opts.has_shared and all_int(self.sts[0].shape[:self.first_reduce]):
# are we grouping? (requires local shape support)
if not self.float4_axis(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048: # noqa: E501
if not [x for x in self.sts[0].unit_stride_axes() if x >= self.first_upcast and self.sts[0].shape[x]%4 == 0] and \
self.first_reduce <= 2 and self.first_reduce < self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048:
# TODO: use 1024 if it's allowed in a smarter way
for sz in ([256, 16] if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
if all(st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts):
@@ -503,10 +501,12 @@ class Kernel:
for axis in to_upcast[::-1]: self.apply_opt(Opt(OptOps.UPCAST, axis, 0))
# potentially do more upcasts of non reduce axes based on a heuristic
is_dsp = self.opts is not None and self.opts.device == "DSP"
upcasted_axis: set[int] = set()
while resolve(prod(self.sts[0].shape[:self.first_reduce]) >= 1024):
xb_choices = []
for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
# consider all the non reduce axes, and a 3 or 4 reduce. (128 on the DSP)
for axis, upcast_amount in itertools.product(range(self.first_reduce), ([128] if not len(upcasted_axis) else []) if is_dsp else [3,4]):
# if we haven't upcasted it, it's not symbolic, it mods, and buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
if axis not in upcasted_axis and isinstance(self.full_shape[axis], int) and self.full_shape[axis]%upcast_amount == 0 and any(st.views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index, st in enumerate(self.sts)): # noqa: E501
xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in self.sts), sum(st.views[-1].strides[axis] for st in self.sts), axis, upcast_amount)) # noqa: E501

View File

@@ -2,11 +2,12 @@
import functools, itertools, operator, math
from dataclasses import dataclass
from typing import cast
from tinygrad.dtype import dtypes, PtrDType
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop
from tinygrad.dtype import dtypes, PtrDType, least_upper_dtype
from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop, GroupOp
from tinygrad.renderer import Renderer
from tinygrad.helpers import all_int, prod, partition, flatten, unwrap
from tinygrad.helpers import all_int, prod, partition, flatten, unwrap, QUANTIZE
from tinygrad.codegen.expander import expand_rewrite
from tinygrad.codegen.symbolic import symbolic
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
def get_contraction(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) -> list[list[int]]|None:
@@ -156,9 +157,65 @@ pm_lowerer = PatternMatcher([
# rewrite LOAD/STORE VIEW to LOAD/STORE with indexed
(UPat((Ops.LOAD, Ops.STORE), src=(UPat(), UPat(Ops.VIEW)), allow_any_len=True, name="x"), lower_load_store),
(UPat(Ops.INDEX, src=(UPat.var("b"), UPat.var("idx"), UPat.const(dtypes.bool, True))), lambda b, idx: b.index(idx)),
(UPat(Ops.IGNORE, name="x"), lambda x: x.src[0]),
])
# **** this is the "quantization preprocessor", it makes ONNX quantized models, and probably also others, actually use ints ****
def view_to_mask(x:UOp):
from tinygrad.shape.shapetracker import ShapeTracker, View
st = cast(ShapeTracker, x.st)
if len(st.views) > 1: return None
if st.views[-1].mask is None: return None
return ShapeTracker((View(st.shape, (0,)*len(st.shape), 0, st.views[-1].mask, False),))
FP = (1 << 16)
pm_quant = symbolic+PatternMatcher([
# cast after add/mul
(UPat.var("x").cast(dtypes.float32) + UPat.var("y").cast(dtypes.float32),
lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))+y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)),
(UPat.var("x").cast(dtypes.float32) * UPat.var("y").cast(dtypes.float32),
lambda x,y: (x.cast(least_upper_dtype(x.dtype, y.dtype))*y.cast(least_upper_dtype(x.dtype, y.dtype))).cast(dtypes.float32)),
# MUL after reduce
(UPat(Ops.REDUCE_AXIS, src=(UPat.var("x") * UPat.cvar("c"),), name="r"), lambda x,c,r: r.replace(src=(x,))*c),
# CAST after reduce (doesn't work if it's a size change)
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.CAST, src=(UPat.var("x"),)),), name="r"),
lambda x,r: r.replace(dtype=x.dtype, src=(x,)).cast(r.dtype) if dtypes.is_float(r.dtype) else None),
# x*c1 + y*c2 -> (x+y)*c1 (if c1 and c2 are close floats)
(UPat.var("x")*UPat.cvar("c1", dtype=dtypes.floats) + UPat.var("y")*UPat.cvar("c2", dtype=dtypes.floats),
lambda x,y,c1,c2: (x+y)*c1 if abs(c1.arg-c2.arg) < 1e-9 else None),
# mul 0 * c1 is 0
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
UPat(Ops.LOAD, src=(UPat(), UPat(Ops.VIEW, name="v"))).cast(dtypes.int).cast(dtypes.float).named("ld"), lambda ld,v,c1: ld*c1),
# mul (with plus) 0 * c1 is 0
(UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar("c1"), UPat(Ops.CONST, arg=0)) *
(UPat(Ops.LOAD, src=(UPat(), UPat(Ops.VIEW, name="v"))).cast(dtypes.int) + \
UPat(Ops.VALID, src=(UPat(Ops.VIEW, name="v"),)).where(UPat.cvar(), UPat(Ops.CONST, arg=0))).cast(dtypes.float).named("ld"),
lambda ld,v,c1: ld*c1),
# fixed point mult, replace (x.float()*c1+c2).int() with an int expression
((UPat.var("x").cast(dtypes.float)*UPat.cvar("c1")+UPat.cvar("c2")).cast(dtypes.int),
lambda x,c1,c2: (x * (c1 * FP).cast(dtypes.int) + (c2 * FP).cast(dtypes.int)) // FP),
# where move
(UPat.var("valid").where(UPat.var("yes"), UPat(Ops.CONST, arg=0))*UPat.var("mul"), lambda valid, yes, mul:
(yes*mul*valid.where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))) if yes.op is not Ops.CONST or yes.arg != 1 else None),
((UPat.var("x")*UPat.cvar("c"))*(UPat.var().where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)).named("v")), lambda x,c,v: (x*v)*c),
(UPat.var("x").cast().named('c') * UPat.var('valid').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)), lambda x,c,valid:
(x*valid.where(UOp.const(x.dtype, 1), UOp.const(x.dtype, 0))).cast(c.dtype)),
((UPat.var('x') * UPat.var('v1').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0)) *
UPat.var('v2').where(UPat(Ops.CONST, arg=1), UPat(Ops.CONST, arg=0))).named("mul"), lambda x, mul, v1, v2:
x * (v1&v2).where(UOp.const(mul.dtype, 1), UOp.const(mul.dtype, 0))),
# don't care
(UPat(Ops.STORE, name="x"), lambda x:
x.replace(src=(x.src[0], UOp(Ops.IGNORE, src=(x.src[1],), arg=mm), UOp(Ops.IGNORE, x.src[2].dtype, src=(x.src[2],), arg=mm),)) \
if x.src[1].op is not Ops.IGNORE and (mm:=view_to_mask(x.src[1])) is not None else None),
(UPat(Ops.IGNORE, src=(UPat((*GroupOp.ALU, Ops.CAST), name="alu"),), name="ig"),
lambda ig,alu: alu.replace(src=tuple(UOp(Ops.IGNORE, x.dtype, (x,), ig.arg) for x in alu.src))),
(UPat(Ops.IGNORE, src=(UPat.cvar("c"),), name="ig"), lambda ig, c: c),
(UPat(Ops.IGNORE, src=(UPat(Ops.VALID, name="v"),), name="ig"), lambda ig, v: UOp.const(dtypes.bool, True) if v.src[0].arg == ig.arg else None),
])
def rewrite_shapetracker_with_index(ast:UOp, opts:Renderer) -> UOp:
if QUANTIZE and opts.device in {"CPU", "DSP"}: ast = graph_rewrite(ast, pm_quant, name="quantize")
sink = graph_rewrite(ast, pm_lowerer, ctx=get_index(ast, opts))
# expand_rewrite turns this into a vectorized program
return expand_rewrite(sink)

View File

@@ -113,6 +113,8 @@ sym = symbolic_simple+PatternMatcher([
# **** UOp realization
DONT_PUSH_VIEWS = {Ops.BUFFER, Ops.CONST, Ops.DEFINE_VAR, Ops.DEVICE, Ops.ASSIGN, Ops.SINK, Ops.CONTIGUOUS, Ops.COPY}
@dataclass(frozen=True)
class GrouperContext:
assigns: dict[UOp, UOp] # maps realized buffers to assigns
@@ -133,11 +135,11 @@ def realize_before_view(ctx:GrouperContext, view:UOp, src:UOp) -> None:
do_realize = PatternMatcher([
# always realize SINK parents
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.realizes.update((x.base, None) for x in s.src if x.base.op not in {Ops.CONST, Ops.BIND, Ops.BUFFER})),
(UPat(Ops.SINK, name="s"), lambda ctx,s: ctx.realizes.update((x, None) for x in s.src if x.op not in DONT_PUSH_VIEWS)),
# always realize ASSIGN/CONTIGUOUS/COPY/BUFFER_VIEW
(UPat({Ops.ASSIGN, Ops.CONTIGUOUS, Ops.COPY, Ops.BUFFER_VIEW}, name="tr"), realize),
# realize before expand or unsafe pad ops
(UPat(Ops.VIEW, name="view", src=(UPat(GroupOp.All-{Ops.BUFFER, Ops.CONST, Ops.DEVICE}, name="src"),)), realize_before_view),
(UPat(Ops.VIEW, name="view", src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="src"),)), realize_before_view),
# realize before COPY
(UPat(Ops.COPY, src=(UPat(), UPat(GroupOp.All-{Ops.BUFFER, Ops.VIEW}, name="tr"))), realize),
])
@@ -221,7 +223,7 @@ def group_realizes(sink:UOp) -> dict[UOp, None]:
if len(ctx.children[top_reduce]) == 1: del ctx.realizes[top_reduce]
return ctx.realizes
# break the SINK into kernels
# **** create kernels
@dataclass(frozen=True)
class Kernel:
@@ -241,6 +243,7 @@ def create_kernel(ctx:KernelContext, x:UOp, b:UOp):
return UOp(Ops.ASSIGN, x.dtype, (buffer, kernel)).reshape(x.shape)
DONT_PLACE_IN_KERNEL = {Ops.KERNEL, Ops.ASSIGN, Ops.BUFFER}
def append_to_kernel(ctx:KernelContext, x:UOp):
new_srcs: list[UOp] = []
metadata = dict.fromkeys(x.arg.metadata)
@@ -266,32 +269,7 @@ create_kernels = merge_views+PatternMatcher([
(UPat(Ops.SINK, name="x"), lambda x:x.replace(src=tuple(s.base for s in x.src)) if any(s.op is Ops.VIEW for s in x.src) else None),
])
DONT_PUSH_VIEWS = {Ops.BUFFER, *GroupOp.Buffer, Ops.ASSIGN, Ops.SINK}
# **** fix kernel AST
# ** create buffer ops + enumerate buffers
add_buffer_ops = PatternMatcher([
# LOAD
(UPat(Ops.BUFFER, name="x"), lambda ctx,x: UOp(Ops.LOAD, x.dtype, (UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx.index(x)), x.st.to_uop()))),
# STORE (except for COPY/BUFFER_VIEW)
(UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x),
# partial assign can store to a non-contiguous ShapeTracker
(UPat(Ops.SINK, src=(UPat(Ops.ASSIGN, name="x"),)),
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), x.src[0].st.to_uop(), x.src[1]).sink()),
# otherwise the store is contiguous
(UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)),
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()),
# if the last child is a VIEW we merge the ShapeTrackers and store the base
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="x"),)))),
lambda x,b,st: UOp.store(b, (st.arg+x.st).to_uop(), x)),
# remove CONTIGUOUS/DEVICE from kernel AST
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
])
# ** push views to buffer ops
# **** swizzler
def apply_swizzle(u:UOp) -> UOp:
with Context(TRACK_MATCH_STATS=0): return graph_rewrite(u, view_left)
@@ -314,7 +292,7 @@ def reduceop_view_right(src:UOp, v:UOp, r:UOp):
assert unwrap(v.st).contiguous and v.size == src.size, f"can't compute new axis for {src.shape} -> {r.shape}"
return src.r(r.arg[0], tuple(i for i,(s,u) in enumerate(zip(src.shape, r.shape)) if s != u)).view(ShapeTracker.from_shape(r.shape))
def elementwise_view_right(root:UOp) -> UOp|None:
def elementwise_view_right(root:UOp):
if not (swizzles:=[x for x in root.src if x.op is Ops.VIEW and x.base.op not in DONT_PUSH_VIEWS]): return None
assert all_same([x.base.size for x in swizzles]), f"swizzle inputs must have the same size {swizzles}"
# place view after applying the elementwise op
@@ -323,7 +301,7 @@ def elementwise_view_right(root:UOp) -> UOp|None:
# reshape to match downstream shapes
return root.replace(src=tuple(new_src)).reshape(root.shape)
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
def merge_double_reduce(root:UOp, first_reduce:UOp):
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.src[0].toposort), "can't merge more than two reduceops at a time"
return first_reduce.replace(arg=(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg))
@@ -340,21 +318,42 @@ view_right = merge_views+PatternMatcher([
(UPat(Ops.REDUCE_AXIS, src=(UPat(Ops.REDUCE_AXIS, name="first_reduce"),), name="root"), merge_double_reduce),
])
# ** unbind variables
# **** unbind variables
def unbind_shapetracker(ctx:dict[Variable, int], x:UOp) -> UOp|None:
def unbind_shapetracker(ctx:tuple[dict[Variable, int], tuple[UOp, ...]], x:UOp):
st = unwrap(x.st).simplify()
if any(x.op is Ops.BIND for x in st.vars()):
st, var_vals = st.unbind()
ctx.update(var_vals)
return st.to_uop() if st != x.st else None
ctx[0].update(var_vals)
return x.replace(arg=st) if st != x.st else None
def unbind_variable(ctx:dict[Variable, int], bind:UOp, var:UOp, val:UOp):
ctx[var.replace(src=())] = val.arg
return var
unbind_vars = PatternMatcher([(UPat(Ops.BIND, name="bind", src=(UPat.var("var"), UPat.cvar("val"))), unbind_variable),])
# ** fix_kernel_ops
# **** fix kernel AST
add_buffer_ops = PatternMatcher([
# LOAD
(UPat(Ops.BUFFER, name="x"), lambda ctx,x:UOp.load(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), ctx[1].index(x)), x.st.to_uop(), dtype=x.dtype)),
# STORE (except for COPY/BUFFER_VIEW)
(UPat(Ops.SINK, src=(UPat((Ops.COPY, Ops.BUFFER_VIEW), name="x"),)), lambda x:x),
# partial assign can store to a non-contiguous ShapeTracker
(UPat(Ops.SINK, src=(UPat(Ops.ASSIGN, name="x"),)),
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), x.src[0].st.to_uop(), x.src[1]).sink()),
# otherwise the store is contiguous
(UPat(Ops.SINK, src=(UPat(GroupOp.All-{Ops.STORE}, name="x"),)),
lambda x: UOp.store(UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(x.size), (), 0), ShapeTracker.from_shape(x.shape).to_uop(), x).sink()),
# VALID
(UPat(Ops.VIEW, src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),), name="view"), lambda x,view: x.valid(view.arg)),
# if the last child is a VIEW we merge the ShapeTrackers and store the base
(UPat(Ops.STORE, src=(UPat.var("b"), UPat.var("st"), UPat(Ops.VIEW, src=(UPat(GroupOp.All-DONT_PUSH_VIEWS, name="x"),)))),
lambda x,b,st: UOp.store(b, (st.arg+x.st).to_uop(), x)),
# remove CONTIGUOUS/DEVICE from kernel AST
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),), name="view"), lambda view: view.replace(src=())),
])
def check_load_st(glbl:UOp, view:UOp):
if glbl.arg != 0 or (st:=unwrap(view.st)).contiguous: return
@@ -369,8 +368,6 @@ def check_load_st(glbl:UOp, view:UOp):
fix_kernel_ops = PatternMatcher([
# BIND in shapetracker becomes DEFINE_VAR
(UPat(Ops.VIEW, name="x"), unbind_shapetracker),
# remove unmasked valid
(UPat.where(UPat(Ops.VALID, name="valid"), UPat.cvar("x"), UPat()), lambda valid,x: x if all(v.mask is None for v in valid.st.views) else None),
# no ImageDType after load
(UPat(GroupOp.All-{Ops.DEFINE_GLOBAL}, name="x"), lambda x: x.replace(dtype=x.dtype.base) if isinstance(x.dtype, ImageDType) else None),
# if this kernel also assigns to the loaded buffer, ensure we can index it correctly
@@ -387,13 +384,11 @@ def fix_kernel_ast(k:UOp, var_vals:dict[Variable, int]) -> UOp:
ast = k.arg.ast.substitute(parents_rep)
# unbind_vars + push views to edges
ast = graph_rewrite(graph_rewrite(ast, unbind_vars+view_left, ctx=var_vals), view_right)
# add buffer ops
ast = graph_rewrite(ast, view_left+add_buffer_ops, bufs:=tuple(s.buf_uop for s in k.src), bottom_up=True)
# add buffer ops + fix_kernel_ops
ast = graph_rewrite(ast, merge_views+add_buffer_ops+fix_kernel_ops, ctx=(var_vals, bufs:=tuple(s.buf_uop for s in k.src)), bottom_up=True)
if ast.op is Ops.SINK and not all_same(dev:=[x.device for x in bufs]): raise RuntimeError(f"all buffers must be on the same device: {dev}")
# create subbuffer (TODO: this does not belong here)
if ast.op is Ops.BUFFER_VIEW: buffers[bufs[0]] = (base:=bufs[1].buffer).view(ast.size, ast.dtype, ast.arg[1]*base.dtype.itemsize)
# fix_kernel_ops
ast = graph_rewrite(ast, fix_kernel_ops, var_vals)
return k.replace(arg=Kernel(ast, k.arg.metadata))
PROCESS_REPLAY_CAPTURE:dict[str, bytes] = {}

View File

@@ -113,6 +113,7 @@ SPLIT_REDUCEOP, NO_MEMORY_PLANNER, RING = ContextVar("SPLIT_REDUCEOP", 1), Conte
PICKLE_BUFFERS, PROFILE, LRU = ContextVar("PICKLE_BUFFERS", 1), ContextVar("PROFILE", getenv("VIZ")), ContextVar("LRU", 1)
CACHELEVEL, IGNORE_BEAM_CACHE, DEVECTORIZE = ContextVar("CACHELEVEL", 2), ContextVar("IGNORE_BEAM_CACHE", 0), ContextVar("DEVECTORIZE", 1)
DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES = ContextVar("DONT_REALIZE_EXPAND", 0), ContextVar("DONT_GROUP_REDUCES", 0)
QUANTIZE = ContextVar("QUANTIZE", 0)
@dataclass(frozen=True)
class Metadata:

View File

@@ -117,7 +117,7 @@ class Ops(FastEnum):
REDUCE_AXIS = auto()
# helper ops
GEP = auto(); VECTORIZE = auto(); CAT = auto() # noqa: E702
GEP = auto(); VECTORIZE = auto(); CAT = auto(); PTRCAT = auto() # noqa: E702
# UnaryOps
CAST = auto(); BITCAST = auto(); EXP2 = auto(); LOG2 = auto(); SIN = auto(); SQRT = auto(); RECIP = auto(); NEG = auto() # noqa: E702
@@ -154,6 +154,7 @@ class Ops(FastEnum):
# CUSTOMI is inline
CUSTOM = auto(); CUSTOMI = auto() # noqa: E702
IGNORE = auto()
class GroupOp:
Unary = {Ops.EXP2, Ops.LOG2, Ops.SIN, Ops.SQRT, Ops.RECIP, Ops.NEG}
@@ -280,6 +281,13 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
return nodes
return _toposort(self, cache=set())
# returns map of UOps to their children in the graph rooted by self
def get_children_map(self) -> dict[UOp, dict[UOp, None]]:
ret: dict[UOp, dict[UOp, None]] = {}
for u in self.toposort:
for s in u.src: ret.setdefault(s, {})[u] = None
return ret
@functools.cached_property
def tuplize(self:UOp) -> tuple[int, Any, Optional[DType], tuple]: return (self.op.value, self.arg, self.dtype, tuple(x.tuplize for x in self.src))
@@ -895,10 +903,23 @@ def launch_viz(env_str:str, data:str):
# *** simple graph rewrite engine ***
class RewriteContext:
def __init__(self, pm, ctx=None):
def __init__(self, pm, ctx=None, children=None):
self.pm: PatternMatcher = pm
self.ctx = ctx
self.ctx = self if children is not None else ctx
self.replace: dict[UOp, UOp] = {}
self.children = children
# TODO: is this function always right?
def update_children(self):
# add any new children from UOps that were replaced
for u in self.replace.values():
for s in u.src: self.children.setdefault(s, {})[u] = None
# find any children that were replaced and replace them
for k,v in self.children.items():
new_child: dict[UOp, None] = {}
for tv in v:
while (nv:=self.replace.get(tv, None)) is not None and nv is not tv: tv = nv
new_child[tv] = None
self.children[k] = new_child
def top_down_rewrite(self, n:UOp) -> UOp:
if (rn := self.replace.get(n)) is not None: return rn
new_src = tuple([self.top_down_rewrite(x) for x in n.src])
@@ -913,15 +934,16 @@ class RewriteContext:
self.replace[n] = ret = last_n if new_src == last_n.src else self.bottom_up_rewrite(UOp(last_n.op, last_n.dtype, new_src, last_n.arg))
return ret
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None) -> UOp:
def graph_rewrite(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, track_children=False) -> UOp:
if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0:
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up, name=name))
return RewriteContext(pm, ctx).bottom_up_rewrite(sink) if bottom_up else RewriteContext(pm, ctx).top_down_rewrite(sink)
rewrite_ctx = RewriteContext(pm, ctx, children=sink.get_children_map() if track_children else None)
return rewrite_ctx.bottom_up_rewrite(sink) if bottom_up else rewrite_ctx.top_down_rewrite(sink)
def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None) -> dict[UOp, UOp]:
def graph_rewrite_map(sink:UOp, pm:PatternMatcher, ctx=None, bottom_up=False, name=None, track_children=False) -> dict[UOp, UOp]:
if TRACK_MATCH_STATS >= 2 and len(tracked_ctxs) != 0:
tracked_ctxs[-1].append(TrackedGraphRewrite(((frm:=sys._getframe(1)).f_code.co_filename, frm.f_lineno), sink, bottom_up, name=name))
rewrite_ctx = RewriteContext(pm, ctx)
rewrite_ctx = RewriteContext(pm, ctx, children=sink.get_children_map() if track_children else None)
return {k:(rewrite_ctx.bottom_up_rewrite(k) if bottom_up else rewrite_ctx.top_down_rewrite(k)) for k in list(sink.toposort)[::-1]}
def sint_to_uop(x:sint, dtype:DType=dtypes.int) -> UOp: return UOp.const(dtype, x) if isinstance(x, int) else x
@@ -958,6 +980,9 @@ merge_views = PatternMatcher([
# merge unmasked const views
(UPat(Ops.VIEW, name="v", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="const"),)),
lambda v,const: const.replace(src=(const.src[0].replace(arg=const.st+v.st),)) if all(x.mask is None for x in (const.st+v.st).views) else None),
# merge view on load/store/valid
(UPat(Ops.VIEW, name="v", src=(UPat((Ops.LOAD, Ops.STORE, Ops.VALID), name="b"),)),
lambda b,v: b.replace(src=tuple((s.st+v.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
# remove view if it's a contiguous and the shapes match
(UPat(Ops.VIEW, name="v", src=(UPat(GroupOp.All-{Ops.DEVICE}, name="x"),)), lambda v,x: x if v.arg.contiguous and x.shape == v.shape else None),
# remove mask if there's a zero in the masked dim
@@ -967,13 +992,11 @@ merge_views = PatternMatcher([
(UPat(GroupOp.Movement, src=(UPat.var("x"),), name="mop"), lambda mop,x: x.view(mop.st)),
])
# push VIEW to parents
view_left = merge_views+PatternMatcher([
# VIEW(CONST) becomes VALID
(UPat(Ops.VIEW, name="vm", src=(UPat((Ops.CONST, Ops.DEFINE_VAR), name="x"),)), lambda vm,x: x.valid(vm.st)),
# VIEW before elementwise/buffer ops
# do not push masked view before unsafe pad ops
(UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.UnsafePad, name="e"),)),
lambda e,vm: e.contiguous().view(vm.st) if any(v.mask is not None for v in vm.st.views) else None),
# view before elementwise ops
(UPat(Ops.VIEW, name="vm", src=(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e"),)),
lambda e,vm: e.replace(src=tuple(s if s.st is None else s.view(vm.st) if s is s.base else s.base.view(s.st+vm.st) for s in e.src))),
(UPat(Ops.VIEW, name="vm", src=(UPat(GroupOp.Buffer, name="b"),)),
lambda b,vm: b.replace(src=tuple((s.st+vm.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
])

View File

@@ -197,8 +197,8 @@ class ClangRenderer(CStyleLanguage):
if sys.platform == 'win32':
kernel_prefix = "__attribute__((ms_abi)) "
def render_vector_prefix(self, dt:DType) -> str:
# round (down) to power of two
alignment = 2**int(math.log2(dt.itemsize))
# round (down) to power of two (this is actually the default clang behavior)
alignment = 2**int(math.log2(dt.itemsize)) if getenv("ALIGNED", 1) else 1
return f"typedef {self.render_dtype(dt.scalar())} {self.render_dtype(dt)} __attribute__((aligned({alignment}),vector_size({dt.itemsize})));"
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
@@ -397,11 +397,18 @@ def cast_float_to_bf16(x: UOp) -> UOp:
class AMDRenderer(CStyleLanguage):
device = "AMD"
shared_max = 65536
# NOTE: this is only really needed on gfx12, even though gfx11 reports the same limitation
global_max = (2147483647, 65535, 65535)
# https://gpuopen.com/learn/wmma_on_rdna3/
tensor_cores = [TensorCore(dims=(16,16,16), threads=32, elements_per_thread=(16,16,8), dtype_in=di, dtype_out=do,
opts=("l0","l0","l0","l0","l1","u1","u1","u1"), swizzle=(((4,9,10,11,0),(1,2,3,5,6,7,8)), ((0,1,2,3,4),(9,10,11,5,6,7,8))))
for di,do in [(dtypes.half,dtypes.float),(dtypes.half,dtypes.half)]]
def __init__(self, arch:str): # gfx942 => MI300, gfx1100 => RX 7900
# TODO: fix tensor cores for gfx1201
self.tensor_cores, self.arch = AMDRenderer.tensor_cores if arch != "gfx1201" else [], arch
def __reduce__(self): return self.__class__, (self.arch,)
# language options
ockl = [(f"__ockl_get_{name}", "unsigned int", "size_t", "const") for name in ["local_id", "group_id", "local_size"]]
ocml = [(f"__ocml_{name}_f{n}", f"{dt}, {dt}" if "fmax" == name else dt, dt, atr)

View File

@@ -685,7 +685,8 @@ class AMDDevice(HCQCompiled):
self.dev_iface = PCIIface(self, self.device_id) if AMDDevice.driverless else KFDIface(self, self.device_id)
self.target = int(self.dev_iface.props['gfx_target_version'])
self.arch = "gfx%d%x%x" % (self.target // 10000, (self.target // 100) % 100, self.target % 100)
if self.target < 100300 or self.target >= 120000: raise RuntimeError(f"Unsupported arch: {self.arch}")
if self.target < 100300 or self.target >= 130000: raise RuntimeError(f"Unsupported arch: {self.arch}")
if DEBUG >= 1: print(f"AMDDevice: opening {self.device_id} with target {self.target} arch {self.arch}")
self.max_cu_id = self.dev_iface.props['simd_count'] // self.dev_iface.props['simd_per_cu'] - 1
self.max_wave_id = self.dev_iface.props['max_waves_per_simd'] * self.dev_iface.props['simd_per_cu'] - 1
@@ -704,7 +705,7 @@ class AMDDevice(HCQCompiled):
self.sdma_queue = self.create_queue(kfd.KFD_IOC_QUEUE_TYPE_SDMA, 0x800000)
super().__init__(device, AMDAllocator(self), AMDRenderer(), AMDCompiler(self.arch), functools.partial(AMDProgram, self),
super().__init__(device, AMDAllocator(self), AMDRenderer(self.arch), AMDCompiler(self.arch), functools.partial(AMDProgram, self),
AMDSignal, AMDComputeQueue, AMDCopyQueue)
# Scratch setup

View File

@@ -16,17 +16,22 @@ dsp_pm = PatternMatcher([
lambda x: UOp(Ops.CUSTOM, dtypes.uchar.vec(128), src=tuple(x.gep(tuple(range(i, i+32))) for i in range(0, 128, 32)),
arg="__builtin_HEXAGON_V6_vpackhub_sat_128B(__builtin_HEXAGON_V6_vpackwh_sat_128B({3}, {2}), __builtin_HEXAGON_V6_vpackwh_sat_128B({1}, {0}))")),
(UPat(Ops.GEP, name="x"), lambda x: UOp(Ops.CUSTOM, x.dtype, x.src+x.src,
"__builtin_shufflevector({0}, {1}, "+','.join([str(y) for y in x.arg])+")") if len(x.arg) > 1 else None),
"__builtin_shufflevector({0}, {1}, "+','.join([str(y) for y in x.arg])+")") if len(x.arg) > 1 and x.src[0].dtype.count > 1 else None),
])
dsp_pm_late = PatternMatcher([
(UPat.var("x")+UPat(Ops.VECTORIZE, src=UPat.var("y")), lambda x,y: x+UOp(Ops.CUSTOMI, x.dtype, (y,), arg="{0}")),
(UPat.var("x")*UPat(Ops.VECTORIZE, src=UPat.var("y")), lambda x,y: x*UOp(Ops.CUSTOMI, x.dtype, (y,), arg="{0}")),
(UPat.var("x")//UPat(Ops.VECTORIZE, src=UPat.var("y")), lambda x,y: x//UOp(Ops.CUSTOMI, x.dtype, (y,), arg="{0}")),
(UPat.var("x")+UPat(Ops.VECTORIZE,src=UPat.var("y")), lambda x,y: x+UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None),
(UPat.var("x")*UPat(Ops.VECTORIZE,src=UPat.var("y")), lambda x,y: x*UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None),
(UPat.var("x")//UPat(Ops.VECTORIZE,src=UPat.var("y")), lambda x,y: x//UOp(Ops.CUSTOMI,x.dtype,(y,),arg="{0}") if x.op is not Ops.CUSTOMI else None),
(UPat(Ops.DEFINE_ACC, src=(UPat(Ops.VECTORIZE, src=UPat(Ops.CONST, arg=0)),), dtype=dtypes.uchar.vec(128), name="d", allow_any_len=True),
lambda d: d.replace(src=(UOp(Ops.CUSTOMI, d.dtype, arg="__builtin_HEXAGON_V6_vd0_128B()"),)+d.src[1:])),
])
# NOTE: this just increases readability of the generated code
dsp_string = PatternMatcher([
(UPat(Ops.CONST, (dtypes.int8, dtypes.uint8), name="x"), lambda ctx,x: str(x.arg)),
])
class DSPRenderer(ClangRenderer):
device = "DSP"
supports_float4 = True
@@ -34,6 +39,7 @@ class DSPRenderer(ClangRenderer):
kernel_prefix = "__attribute__((noinline)) "
pre_matcher = dsp_pm
extra_matcher = dsp_pm_late+ClangRenderer.extra_matcher
string_rewrite = dsp_string+ClangRenderer.string_rewrite
type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" }
code_for_op = {**ClangRenderer.code_for_op, Ops.SIN: lambda x,dtype: f"__builtin_sin({x})",
Ops.LOG2: lambda x,dtype: f"__builtin_log2l({x})" if dtype == dtypes.float64 else f"__builtin_log2f({x})",

View File

@@ -14,7 +14,7 @@ class HIPDevice(Compiled):
self.device_id = int(device.split(":")[1]) if ":" in device else 0
self.arch = init_c_var(hip.hipDeviceProp_t(), lambda x: check(hip.hipGetDeviceProperties(x, self.device_id))).gcnArchName.decode()
self.time_event_st, self.time_event_en = [init_c_var(hip.hipEvent_t(), lambda x: hip.hipEventCreate(ctypes.byref(x), 0)) for _ in range(2)]
super().__init__(device, HIPAllocator(self), HIPRenderer(), AMDCompiler(self.arch), functools.partial(HIPProgram, self))
super().__init__(device, HIPAllocator(self), HIPRenderer(self.arch), AMDCompiler(self.arch), functools.partial(HIPProgram, self))
def synchronize(self):
check(hip.hipSetDevice(self.device_id))
check(hip.hipDeviceSynchronize())

View File

@@ -1,5 +1,5 @@
from __future__ import annotations
import ctypes, collections, time, dataclasses, pathlib, fcntl, os
import ctypes, collections, time, dataclasses, pathlib, fcntl, os, importlib
from tinygrad.helpers import to_mv, mv_address, getenv, round_up, DEBUG, temp
from tinygrad.runtime.autogen.am import am, mp_11_0
from tinygrad.runtime.support.allocator import TLSFAllocator
@@ -391,10 +391,10 @@ class AMDev:
gc_info = am.struct_gc_info_v1_0.from_address(gc_addr:=ctypes.addressof(bhdr) + bhdr.table_list[am.GC].offset)
self.gc_info = getattr(am, f"struct_gc_info_v{gc_info.header.version_major}_{gc_info.header.version_minor}").from_address(gc_addr)
def _ip_module(self, prefix:str, hwip):
def _ip_module(self, prefix:str, hwip, prever_prefix:str=""):
version = [self.ip_versions[hwip]//10000, (self.ip_versions[hwip]//100)%100, self.ip_versions[hwip]%100]
for ver in [version, version[:2]+[0], version[:1]+[0, 0]]:
try: return __import__(f"tinygrad.runtime.autogen.am.{prefix}_{ver[0]}_{ver[1]}_{ver[2]}", fromlist=[f"{prefix}_{ver[0]}_{ver[1]}_{ver[2]}"])
try: return importlib.import_module(f"tinygrad.runtime.autogen.am.{prefix}_{prever_prefix}{ver[0]}_{ver[1]}_{ver[2]}")
except ImportError: pass
raise ImportError(f"am {self.devfmt}: failed to load {prefix} module with version {version}")

View File

@@ -1,6 +1,6 @@
import ctypes, time, contextlib
from typing import Literal
from tinygrad.runtime.autogen.am import am, smu_v13_0_0
from tinygrad.runtime.autogen.am import am
from tinygrad.helpers import to_mv, data64, lo32, hi32, DEBUG
class AM_IP:
@@ -61,7 +61,14 @@ class AM_GMC(AM_IP):
self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_START_ADDR", "_LO32", "_HI32", self.vm_base >> 12)
self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_END_ADDR", "_LO32", "_HI32", self.vm_end >> 12)
self.adev.wreg_pair(f"reg{ip}VM_CONTEXT{vmid}_PAGE_TABLE_BASE_ADDR", "_LO32", "_HI32", page_table.paddr | 1)
self.adev.reg(f"reg{ip}VM_CONTEXT{vmid}_CNTL").write(0x1fffe00, enable_context=1, page_table_depth=(3 - page_table.lv))
self.adev.reg(f"reg{ip}VM_CONTEXT{vmid}_CNTL").write(0x1800000, pde0_protection_fault_enable_interrupt=1, pde0_protection_fault_enable_default=1,
dummy_page_protection_fault_enable_interrupt=1, dummy_page_protection_fault_enable_default=1,
range_protection_fault_enable_interrupt=1, range_protection_fault_enable_default=1,
valid_protection_fault_enable_interrupt=1, valid_protection_fault_enable_default=1,
read_protection_fault_enable_interrupt=1, read_protection_fault_enable_default=1,
write_protection_fault_enable_interrupt=1, write_protection_fault_enable_default=1,
execute_protection_fault_enable_interrupt=1, execute_protection_fault_enable_default=1,
enable_context=1, page_table_depth=(3 - page_table.lv))
def init_hub(self, ip:Literal["MM", "GC"]):
# Init system apertures
@@ -106,37 +113,38 @@ class AM_GMC(AM_IP):
class AM_SMU(AM_IP):
def __init__(self, adev):
super().__init__(adev)
self.smu_mod = self.adev._ip_module("smu", am.MP1_HWIP, prever_prefix='v')
self.driver_table_paddr = self.adev.mm.palloc(0x4000, zero=not self.adev.partial_boot, boot=True)
def init(self):
self._send_msg(smu_v13_0_0.PPSMC_MSG_SetDriverDramAddrHigh, hi32(self.adev.paddr2mc(self.driver_table_paddr)), poll=True)
self._send_msg(smu_v13_0_0.PPSMC_MSG_SetDriverDramAddrLow, lo32(self.adev.paddr2mc(self.driver_table_paddr)), poll=True)
self._send_msg(smu_v13_0_0.PPSMC_MSG_EnableAllSmuFeatures, 0, poll=True)
self._send_msg(self.smu_mod.PPSMC_MSG_SetDriverDramAddrHigh, hi32(self.adev.paddr2mc(self.driver_table_paddr)), poll=True)
self._send_msg(self.smu_mod.PPSMC_MSG_SetDriverDramAddrLow, lo32(self.adev.paddr2mc(self.driver_table_paddr)), poll=True)
self._send_msg(self.smu_mod.PPSMC_MSG_EnableAllSmuFeatures, 0, poll=True)
def is_smu_alive(self):
with contextlib.suppress(RuntimeError): self._send_msg(smu_v13_0_0.PPSMC_MSG_GetSmuVersion, 0, timeout=100)
with contextlib.suppress(RuntimeError): self._send_msg(self.smu_mod.PPSMC_MSG_GetSmuVersion, 0, timeout=100)
return self.adev.mmMP1_SMN_C2PMSG_90.read() != 0
def mode1_reset(self):
if DEBUG >= 2: print(f"am {self.adev.devfmt}: mode1 reset")
self._send_msg(smu_v13_0_0.PPSMC_MSG_Mode1Reset, 0, poll=True)
self._send_msg(self.smu_mod.PPSMC_MSG_Mode1Reset, 0, poll=True)
time.sleep(0.5) # 500ms
def read_table(self, table_t, cmd):
self._send_msg(smu_v13_0_0.PPSMC_MSG_TransferTableSmu2Dram, cmd, poll=True)
self._send_msg(self.smu_mod.PPSMC_MSG_TransferTableSmu2Dram, cmd, poll=True)
return table_t.from_buffer(to_mv(self.adev.paddr2cpu(self.driver_table_paddr), ctypes.sizeof(table_t)))
def read_metrics(self): return self.read_table(smu_v13_0_0.SmuMetricsExternal_t, smu_v13_0_0.TABLE_SMU_METRICS)
def read_metrics(self): return self.read_table(self.smu_mod.SmuMetricsExternal_t, self.smu_mod.TABLE_SMU_METRICS)
def set_clocks(self, level):
if not hasattr(self, 'clcks'):
self.clcks = {}
for clck in [smu_v13_0_0.PPCLK_GFXCLK, smu_v13_0_0.PPCLK_UCLK, smu_v13_0_0.PPCLK_FCLK, smu_v13_0_0.PPCLK_SOCCLK]:
cnt = self._send_msg(smu_v13_0_0.PPSMC_MSG_GetDpmFreqByIndex, (clck<<16)|0xff, read_back_arg=True)&0x7fffffff
self.clcks[clck] = [self._send_msg(smu_v13_0_0.PPSMC_MSG_GetDpmFreqByIndex, (clck<<16)|i, read_back_arg=True)&0x7fffffff for i in range(cnt)]
for clck in [self.smu_mod.PPCLK_GFXCLK, self.smu_mod.PPCLK_UCLK, self.smu_mod.PPCLK_FCLK, self.smu_mod.PPCLK_SOCCLK]:
cnt = self._send_msg(self.smu_mod.PPSMC_MSG_GetDpmFreqByIndex, (clck<<16)|0xff, read_back_arg=True)&0x7fffffff
self.clcks[clck] = [self._send_msg(self.smu_mod.PPSMC_MSG_GetDpmFreqByIndex, (clck<<16)|i, read_back_arg=True)&0x7fffffff for i in range(cnt)]
for clck, vals in self.clcks.items():
self._send_msg(smu_v13_0_0.PPSMC_MSG_SetSoftMinByFreq, clck << 16 | (vals[level]), poll=True)
self._send_msg(smu_v13_0_0.PPSMC_MSG_SetSoftMaxByFreq, clck << 16 | (vals[level]), poll=True)
self._send_msg(self.smu_mod.PPSMC_MSG_SetSoftMinByFreq, clck << 16 | (vals[level]), poll=True)
self._send_msg(self.smu_mod.PPSMC_MSG_SetSoftMaxByFreq, clck << 16 | (vals[level]), poll=True)
def _smu_cmn_poll_stat(self, timeout=10000): self.adev.wait_reg(self.adev.mmMP1_SMN_C2PMSG_90, mask=0xFFFFFFFF, value=1, timeout=timeout)
def _smu_cmn_send_msg(self, msg, param=0):
@@ -206,7 +214,7 @@ class AM_GFX(AM_IP):
cp_hqd_pq_rptr_report_addr_lo=lo32(rptr_addr), cp_hqd_pq_rptr_report_addr_hi=hi32(rptr_addr),
cp_hqd_pq_wptr_poll_addr_lo=lo32(wptr_addr), cp_hqd_pq_wptr_poll_addr_hi=hi32(wptr_addr),
cp_hqd_pq_doorbell_control=self.adev.regCP_HQD_PQ_DOORBELL_CONTROL.build(doorbell_offset=doorbell*2, doorbell_en=1),
cp_hqd_pq_control=self.adev.regCP_HQD_PQ_CONTROL.build(rptr_block_size=5, unord_dispatch=1, queue_size=(ring_size//4).bit_length()-2),
cp_hqd_pq_control=self.adev.regCP_HQD_PQ_CONTROL.build(rptr_block_size=5, unord_dispatch=0, queue_size=(ring_size//4).bit_length()-2),
cp_hqd_ib_control=self.adev.regCP_HQD_IB_CONTROL.build(min_ib_avail_size=0x3), cp_hqd_hq_status0=0x20004000,
cp_mqd_control=self.adev.regCP_MQD_CONTROL.build(priv_state=1), cp_hqd_vmid=0,
cp_hqd_eop_base_addr_lo=lo32(eop_addr>>8), cp_hqd_eop_base_addr_hi=hi32(eop_addr>>8),
@@ -290,7 +298,7 @@ class AM_IH(AM_IP):
self.adev.reg(f"regIH_RB_WPTR{suf}").write(0)
self.adev.reg(f"regIH_RB_RPTR{suf}").write(0)
self.adev.reg(f"regIH_DOORBELL_RPTR{suf}").write(((am.AMDGPU_NAVI10_DOORBELL_IH + ring_id) * 2), enable=1)
self.adev.reg(f"regIH_DOORBELL_RPTR{suf}").write(offset=(am.AMDGPU_NAVI10_DOORBELL_IH + ring_id) * 2, enable=1)
self.adev.regIH_STORM_CLIENT_LIST_CNTL.update(client18_is_storm_client=1)
self.adev.regIH_INT_FLOOD_CNTL.update(flood_cntl_enable=1)