mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
only the ops changes
This commit is contained in:
4
.github/workflows/test.yml
vendored
4
.github/workflows/test.yml
vendored
@@ -240,8 +240,8 @@ jobs:
|
||||
run: |
|
||||
PYTHONPATH="." python test/external/fuzz_shapetracker.py
|
||||
PYTHONPATH="." python test/external/fuzz_shapetracker_math.py
|
||||
- name: Repo line count <= 9999 lines
|
||||
run: MAX_LINE_COUNT=10136 python sz.py
|
||||
- name: Repo line count < 11000 lines
|
||||
run: MAX_LINE_COUNT=11000 python sz.py
|
||||
|
||||
testopencl:
|
||||
strategy:
|
||||
|
||||
@@ -77,7 +77,7 @@ assert out.as_buffer().cast('I')[0] == 5
|
||||
print("******** third, the LazyBuffer ***********")
|
||||
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.schedule import create_schedule_with_vars
|
||||
|
||||
# allocate some values + load in values
|
||||
a = UOp.metaop(Ops.EMPTY, (1,), dtypes.int32, DEVICE)
|
||||
@@ -91,7 +91,7 @@ b = b.buf_uop_view()
|
||||
out = a.alu(Ops.ADD, b)
|
||||
|
||||
# schedule the computation as a list of kernels
|
||||
sched = create_schedule([out])
|
||||
sched, _ = create_schedule_with_vars([out])
|
||||
for si in sched: print(si.ast.op) # NOTE: the first two convert it to CLANG
|
||||
|
||||
# DEBUGGING: print the compute ast
|
||||
|
||||
@@ -6,7 +6,6 @@ from tinygrad import Tensor, Device, dtypes, nn
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.ops import Ops, sym_infer
|
||||
from tinygrad.device import Compiled
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin
|
||||
from tinygrad.helpers import DEBUG, ansilen, getenv, colored, TRACEMETA
|
||||
|
||||
@@ -18,12 +17,12 @@ def get_sched_resnet():
|
||||
# run model twice to get only what changes, these are the kernels of the model
|
||||
for _ in range(2):
|
||||
out = mdl(Tensor.empty(BS, 3, 224, 224))
|
||||
targets = [out.lazydata]
|
||||
targets = [out]
|
||||
if getenv("BACKWARD"):
|
||||
optim.zero_grad()
|
||||
out.sparse_categorical_crossentropy(Tensor.empty(BS, dtype=dtypes.int)).backward()
|
||||
targets += [x.lazydata for x in optim.schedule_step()]
|
||||
sched = create_schedule(targets)
|
||||
targets += [x for x in optim.schedule_step()]
|
||||
sched = Tensor.schedule(*targets)
|
||||
print(f"schedule length {len(sched)}")
|
||||
return sched
|
||||
|
||||
@@ -42,17 +41,16 @@ def get_sched_bert():
|
||||
next_sentence_labels = Tensor.empty((BS, 1), dtype=dtypes.float32)
|
||||
|
||||
# run model twice to get only what changes, these are the kernels of the model
|
||||
seen = set()
|
||||
for _ in range(2):
|
||||
lm_logits, seq_relationship_logits = mdl(input_ids, attention_mask, masked_positions, segment_ids)
|
||||
targets = [lm_logits.lazydata, seq_relationship_logits.lazydata]
|
||||
targets = [lm_logits, seq_relationship_logits]
|
||||
if getenv("BACKWARD"):
|
||||
optim.zero_grad()
|
||||
loss = mdl.loss(lm_logits, seq_relationship_logits, masked_lm_ids, masked_lm_weights, next_sentence_labels)
|
||||
# ignore grad norm and loss scaler for now
|
||||
loss.backward()
|
||||
targets += [x.lazydata for x in optim.schedule_step()]
|
||||
sched = create_schedule(targets)
|
||||
targets += [x for x in optim.schedule_step()]
|
||||
sched = Tensor.schedule(targets)
|
||||
print(f"schedule length {len(sched)}")
|
||||
return sched
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ from tinygrad import Device, nn, Tensor, dtypes, Variable
|
||||
Device.DEFAULT = "CLANG"
|
||||
from train_gpt2 import GPT, GPTConfig
|
||||
from tinygrad.helpers import dedup, to_function_name, flatten, getenv, GlobalCounters, ansilen, to_function_name
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import get_kernel, run_schedule
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.ops import Ops
|
||||
@@ -37,7 +36,7 @@ if __name__ == "__main__":
|
||||
tensors = optimizer.schedule_step()
|
||||
else:
|
||||
tensors = []
|
||||
sched = create_schedule([loss.lazydata] + [x.lazydata for x in tensors])
|
||||
sched = loss.schedule(*tensors)
|
||||
print(f"calls {i}:", len(sched))
|
||||
#run_schedule(sched[:])
|
||||
sched = memory_planner(sched)
|
||||
|
||||
@@ -30,14 +30,13 @@ except ImportError:
|
||||
|
||||
import os
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
|
||||
# define the compute
|
||||
A = Tensor.rand(M, K, device="clang")
|
||||
B = Tensor.rand(K, N, device="clang")
|
||||
C = (A.reshape(M, 1, K) * B.permute(1,0).reshape(1, N, K)).sum(axis=2)
|
||||
|
||||
sched = create_schedule([C.lazydata])
|
||||
sched = C.schedule()
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.device import CompilerOptions
|
||||
lin = Kernel(sched[-1].ast, CompilerOptions(has_local=False, supports_float4=False))
|
||||
|
||||
@@ -4,7 +4,6 @@ from tinygrad import Tensor, Device
|
||||
import tinygrad.runtime.autogen.amd_gpu as amd_gpu
|
||||
import tinygrad.runtime.autogen.kfd as kfd
|
||||
import tinygrad.runtime.autogen.hsa as hsa
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.runtime.ops_amd import kio, AMDProgram
|
||||
from tinygrad.helpers import to_mv
|
||||
|
||||
@@ -49,7 +48,7 @@ if __name__ == "__main__":
|
||||
a = Tensor([0.,1.,2.], device="KFD").realize()
|
||||
b = a + 7
|
||||
b.lazydata.buffer.allocate()
|
||||
si = create_schedule([b.lazydata])[-1]
|
||||
si = b.schedule()[-1]
|
||||
runner = dev.get_runner(*si.ast)
|
||||
prg: AMDProgram = runner.clprg
|
||||
print("device initted")
|
||||
|
||||
3
test/external/external_test_hcq.py
vendored
3
test/external/external_test_hcq.py
vendored
@@ -2,7 +2,6 @@ import unittest, ctypes, struct, time, array
|
||||
from tinygrad import Device, Tensor, dtypes
|
||||
from tinygrad.helpers import to_mv, CI
|
||||
from tinygrad.device import Buffer, BufferSpec
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import get_runner
|
||||
|
||||
def _time_queue(q, d):
|
||||
@@ -21,7 +20,7 @@ class TestHCQ(unittest.TestCase):
|
||||
#TestHCQ.d1: AMDDevice = Device["AMD:1"]
|
||||
TestHCQ.a = Tensor([0.,1.], device=Device.DEFAULT).realize()
|
||||
TestHCQ.b = self.a + 1
|
||||
si = create_schedule([self.b.lazydata])[-1]
|
||||
si = self.b.schedule()[-1]
|
||||
TestHCQ.runner = get_runner(TestHCQ.d0.device, si.ast)
|
||||
TestHCQ.b.lazydata.buffer.allocate()
|
||||
# wow that's a lot of abstraction layers
|
||||
|
||||
4
test/external/external_test_nv.py
vendored
4
test/external/external_test_nv.py
vendored
@@ -1,7 +1,6 @@
|
||||
import unittest, struct, array, ctypes
|
||||
from tinygrad import Device, dtypes, Tensor
|
||||
from tinygrad.helpers import to_mv
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.runtime.ops_nv import NVDevice, HWQueue
|
||||
from tinygrad.engine.search import Opt, OptOps
|
||||
from test.test_linearizer_failures import helper_test_lin
|
||||
@@ -20,7 +19,7 @@ class TestNV(unittest.TestCase):
|
||||
TestNV.d0: NVDevice = Device["NV"]
|
||||
TestNV.a = Tensor([0.,1.], device="NV").realize()
|
||||
TestNV.b = self.a + 1
|
||||
si = create_schedule([self.b.lazydata])[-1]
|
||||
si = self.b.schedule()[-1]
|
||||
TestNV.d0_runner = get_runner(TestNV.d0.device, si.ast)
|
||||
TestNV.b.lazydata.buffer.allocate()
|
||||
TestNV.addr = struct.pack("QQ", TestNV.b.lazydata.buffer._buf.va_addr, TestNV.a.lazydata.buffer._buf.va_addr)
|
||||
@@ -65,4 +64,3 @@ class TestNV(unittest.TestCase):
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
|
||||
3
test/external/fuzz_graph.py
vendored
3
test/external/fuzz_graph.py
vendored
@@ -4,7 +4,6 @@ from tinygrad.device import Buffer, Device
|
||||
from tinygrad.helpers import Context, getenv, from_mv
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import ExecItem, BufferXfer, get_runner
|
||||
from tinygrad.engine.jit import apply_graph_to_jit
|
||||
|
||||
@@ -19,7 +18,7 @@ def gen_prg(device, inputs_cnt):
|
||||
s = fst[0]
|
||||
for i in range(1, inputs_cnt): s = s.xor(fst[i])
|
||||
|
||||
si = create_schedule([s.lazydata])[-1]
|
||||
si = s.schedule()[-1]
|
||||
prg = get_runner(device, si.ast)
|
||||
cached_prgs[(device, inputs_cnt)] = prg
|
||||
return prg
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
import unittest, math
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.ops import Ops
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.helpers import CI
|
||||
import numpy as np
|
||||
from tinygrad.device import is_dtype_supported
|
||||
|
||||
def _check_ast_count(desired_count:int, t:Tensor):
|
||||
# NOTE: this has side effect because everything can be scheduled only once
|
||||
schedule = create_schedule(t.lazydata.lbs)
|
||||
schedule = t.schedule()
|
||||
asts = [s for s in schedule if s.ast.op is Ops.SINK]
|
||||
assert len(asts) == desired_count, f"{len(asts)} != {desired_count}"
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ import unittest
|
||||
from tinygrad.ops import Ops
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.nn import Conv2d
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from tinygrad.helpers import prod
|
||||
from test.unit.test_shapetracker import shapetracker_getitem
|
||||
@@ -11,11 +10,10 @@ from test.unit.test_shapetracker import shapetracker_getitem
|
||||
class TestConvShapetracker(unittest.TestCase):
|
||||
def test_conv_3x3_one_view(self):
|
||||
conv = Conv2d(16, 32, (3, 3))
|
||||
|
||||
# first run to init the weights, they are scheduled.
|
||||
create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata])
|
||||
conv(Tensor.empty(1, 16, 10, 10)).schedule()
|
||||
# run it again to get the kernels
|
||||
sched = [si for si in create_schedule([conv(Tensor.empty(1, 16, 10, 10)).lazydata]) if si.ast.op is Ops.SINK]
|
||||
sched = [si for si in conv(Tensor.empty(1, 16, 10, 10)).schedule() if si.ast.op is Ops.SINK]
|
||||
assert len(sched) == 1, f"conv should only have one kernel, getting {len(sched)}"
|
||||
for st in [x.st_arg for x in sched[0].ast.toposort if x.op is Ops.LOAD]:
|
||||
assert len(st.views) == 1
|
||||
|
||||
@@ -6,7 +6,6 @@ import numpy as np
|
||||
from hypothesis import given, strategies as strat, settings, HealthCheck
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.helpers import CI, getenv
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.ops import GroupOp
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
@@ -72,7 +71,7 @@ def universal_test(a, b, dtype, op):
|
||||
def universal_test_unary(a, dtype, op):
|
||||
if not isinstance(op, tuple): op = (op, op)
|
||||
out: Tensor = op[0](Tensor([a], dtype=dtype))
|
||||
sched = create_schedule([out.lazydata])
|
||||
sched = out.schedule()
|
||||
ast = sched[-1].ast
|
||||
run_schedule(sched)
|
||||
tensor_value = out.numpy()
|
||||
|
||||
@@ -2,7 +2,6 @@ import unittest
|
||||
import time
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import lower_schedule_item, run_schedule
|
||||
|
||||
class TestFusionOp(unittest.TestCase):
|
||||
@@ -17,7 +16,7 @@ class TestFusionOp(unittest.TestCase):
|
||||
def test_expand_fuse(self):
|
||||
bt = Tensor(np.ones((10, 1)), dtype=dtypes.float32)
|
||||
out = (bt*2).expand(10,10).sum(1)
|
||||
sched = create_schedule([out.lazydata])
|
||||
sched = out.schedule()
|
||||
run_schedule(sched)
|
||||
outd = out.tolist()
|
||||
assert all(x == 20.0 for x in outd)
|
||||
@@ -26,7 +25,7 @@ class TestFusionOp(unittest.TestCase):
|
||||
st = time.perf_counter()
|
||||
a = Tensor([1,2,3,4])
|
||||
for _ in range(24): a = a + a
|
||||
sched = create_schedule([a.lazydata])
|
||||
sched = a.schedule()
|
||||
ei = lower_schedule_item(sched[-1])
|
||||
self.assertLess(time.perf_counter()-st, 2.0)
|
||||
assert len(ei.prg.p.src.splitlines()) < 250
|
||||
@@ -35,13 +34,13 @@ class TestFusionOp(unittest.TestCase):
|
||||
st = time.perf_counter()
|
||||
a = Tensor([1,2,3,4])
|
||||
for _ in range(24): a = a + a
|
||||
sched1 = create_schedule([a.lazydata])
|
||||
sched1 = a.schedule()
|
||||
b = Tensor([1,2,3,4])
|
||||
for _ in range(24): b = b + b
|
||||
sched2 = create_schedule([b.lazydata])
|
||||
sched2 = b.schedule()
|
||||
c = Tensor([1,2,3,4])
|
||||
for _ in range(23): c = c + c
|
||||
sched3 = create_schedule([c.lazydata])
|
||||
sched3 = c.schedule()
|
||||
self.assertEqual(sched1[-1].ast, sched2[-1].ast)
|
||||
with self.assertRaises(AssertionError): self.assertEqual(sched1[-1].ast, sched3[-1].ast)
|
||||
self.assertLess(time.perf_counter()-st, 2.0)
|
||||
|
||||
@@ -3,7 +3,6 @@ import unittest, ctypes
|
||||
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.helpers import Context, CI, dedup, from_mv
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.engine.realize import ExecItem, BufferXfer, get_runner, CompiledRunner
|
||||
@@ -21,7 +20,7 @@ def helper_exec_op(device, outbuf, inbufs):
|
||||
s = fst[0]
|
||||
for i in range(1, len(inbufs)): s = s.xor(fst[i])
|
||||
|
||||
si = create_schedule([s.lazydata])[-1]
|
||||
si = s.schedule()[-1]
|
||||
prg = get_runner(device, si.ast)
|
||||
cached_prgs[(device, len(inbufs))] = prg
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ from tinygrad import Device, Tensor, dtypes
|
||||
from tinygrad.helpers import CI, getenv
|
||||
from tinygrad.device import Buffer, BufferSpec
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import get_runner, CompiledRunner
|
||||
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
|
||||
from tinygrad import Variable
|
||||
@@ -17,7 +16,7 @@ class TestHCQ(unittest.TestCase):
|
||||
TestHCQ.d0 = Device[Device.DEFAULT]
|
||||
TestHCQ.a = Tensor([0.,1.], device=Device.DEFAULT).realize()
|
||||
TestHCQ.b = self.a + 1
|
||||
si = create_schedule([self.b.lazydata])[-1]
|
||||
si = self.b.schedule()[-1]
|
||||
|
||||
TestHCQ.runner = get_runner(TestHCQ.d0.device, si.ast)
|
||||
TestHCQ.b.lazydata.buffer.allocate()
|
||||
@@ -159,7 +158,7 @@ class TestHCQ(unittest.TestCase):
|
||||
|
||||
a = Tensor.randint((3, 3, 3), dtype=dtypes.int, device=Device.DEFAULT).realize()
|
||||
b = a + 1
|
||||
si = create_schedule([b.lazydata])[-1]
|
||||
si = b.schedule()[-1]
|
||||
k = Kernel(si.ast, opts=TestHCQ.d0.renderer)
|
||||
for i in range(3): k.apply_opt(Opt(op=OptOps.LOCAL, axis=0, amt=3))
|
||||
|
||||
@@ -442,7 +441,7 @@ class TestHCQ(unittest.TestCase):
|
||||
def test_memory_barrier(self):
|
||||
a = Tensor([0, 1], device=Device.DEFAULT, dtype=dtypes.int8).realize()
|
||||
b = a + 1
|
||||
runner = get_runner(TestHCQ.d0.device, create_schedule([b.lazydata])[-1].ast)
|
||||
runner = get_runner(TestHCQ.d0.device, b.schedule()[-1].ast)
|
||||
|
||||
buf1 = Buffer(Device.DEFAULT, 2, dtypes.int8, options=BufferSpec(nolru=True)).ensure_allocated()
|
||||
buf2 = Buffer(Device.DEFAULT, 2, dtypes.int8, options=BufferSpec(cpu_access=True, nolru=True)).ensure_allocated()
|
||||
|
||||
@@ -4,7 +4,6 @@ import unittest
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.ops import Ops, UOp
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
|
||||
class TestLazyBuffer(unittest.TestCase):
|
||||
def test_fromcpu_shape_tracker(self):
|
||||
@@ -74,14 +73,14 @@ class TestLazyBuffer(unittest.TestCase):
|
||||
b = Tensor.randn(2, 2).realize()
|
||||
add = (a+b).contiguous()
|
||||
out = add+2
|
||||
sched = create_schedule([out.lazydata])
|
||||
sched = out.schedule()
|
||||
self.assertEqual(len(sched), 2)
|
||||
run_schedule(sched)
|
||||
np.testing.assert_allclose(out.numpy(), a.numpy()+b.numpy()+2)
|
||||
|
||||
def test_forced_realized_metaop(self):
|
||||
empty = Tensor.empty(1).contiguous()
|
||||
sched = create_schedule([empty.lazydata])
|
||||
sched = empty.schedule()
|
||||
self.assertEqual(len(sched), 1)
|
||||
self.assertIs(sched[0].ast.op, Ops.EMPTY)
|
||||
run_schedule(sched)
|
||||
@@ -90,14 +89,14 @@ class TestReduceOp(unittest.TestCase):
|
||||
def test_no_split_reduce_kernel(self):
|
||||
a = Tensor.rand(4, 4).realize()
|
||||
a = a.sum()
|
||||
sched = create_schedule([a.lazydata])
|
||||
sched = a.schedule()
|
||||
assert len(sched) == 1
|
||||
self.assertIs(sched[0].ast.src[0].src[2].op, Ops.REDUCE_AXIS)
|
||||
|
||||
def test_split_reduce_kernel_dim0(self):
|
||||
a = Tensor.rand(256, 255).realize()
|
||||
a = a.sum()
|
||||
sched = create_schedule([a.lazydata])
|
||||
sched = a.schedule()
|
||||
assert len(sched) == 2
|
||||
for s in sched:
|
||||
self.assertIs(s.ast.src[0].src[2].op, Ops.REDUCE_AXIS)
|
||||
@@ -105,7 +104,7 @@ class TestReduceOp(unittest.TestCase):
|
||||
def test_split_reduce_kernel_dim1(self):
|
||||
a = Tensor.rand(255, 256).realize()
|
||||
a = a.sum()
|
||||
sched = create_schedule([a.lazydata])
|
||||
sched = a.schedule()
|
||||
assert len(sched) == 2
|
||||
for s in sched:
|
||||
self.assertIs(s.ast.src[0].src[2].op, Ops.REDUCE_AXIS)
|
||||
|
||||
@@ -12,14 +12,14 @@ from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
# from tinygrad.ops import Variable
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule
|
||||
from tinygrad.engine.schedule import BUF_LIMIT
|
||||
from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner
|
||||
from tinygrad.helpers import prod, Context, getenv, CI, flatten, dedup, AMX
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
|
||||
def helper_realized_ast(r:Union[Tensor, List[Tensor]]) -> Tuple[UOp, List[Buffer]]:
|
||||
if isinstance(r, Tensor): r = [r]
|
||||
s = create_schedule([x.lazydata for x in r])
|
||||
s = Tensor.schedule(*r)
|
||||
run_schedule(s[:-1]) # run all kernels except the last one
|
||||
# now all input LazyBuffers buffers in s[-1] should be realized
|
||||
# create fresh buffers for the output buffer
|
||||
@@ -30,7 +30,7 @@ def helper_tc_allclose(n:int, m:int, k:int, dtype_in:DType, dtype_out:DType, axi
|
||||
a, b = Tensor.rand(m, k, dtype=dtype_in), Tensor.rand(k, n, dtype=dtype_in)
|
||||
np_a, np_b = a.numpy(), b.numpy()
|
||||
r = a.matmul(b, acc_dtype=dtype_out)
|
||||
sched = create_schedule([r.lazydata])
|
||||
sched = r.schedule()
|
||||
realized_ast = sched[-1].ast
|
||||
run_schedule(sched)
|
||||
out = r.numpy()
|
||||
@@ -48,7 +48,7 @@ def helper_tc_allclose(n:int, m:int, k:int, dtype_in:DType, dtype_out:DType, axi
|
||||
def helper_tc_ensure_uops_and_opts_count(n: int, m:int, k:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_opt:int=0, ensure_triggered:bool=True):
|
||||
a, b = Tensor.rand(m, k, dtype=dtype_in), Tensor.rand(k, n, dtype=dtype_in)
|
||||
r = a.matmul(b, acc_dtype=dtype_out)
|
||||
sched = create_schedule([r.lazydata])
|
||||
sched = r.schedule()
|
||||
realized_ast = sched[-1].ast
|
||||
k = Kernel(realized_ast)
|
||||
k.apply_tensor_cores(1, axis=axis, tc_opt=tc_opt)
|
||||
@@ -67,7 +67,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
a, b = Tensor.randn(4), Tensor.randn(4)
|
||||
np_a, np_b = a.numpy(), b.numpy()
|
||||
c = ((a.shrink(((0, 2),)) - a.shrink(((2, 4),))) - (b.shrink(((0, 2),)) - b.shrink(((2, 4),))))
|
||||
lowered = list(lower_schedule(create_schedule([c.lazydata])))
|
||||
lowered = list(lower_schedule(c.schedule()))
|
||||
for ei in lowered: ei.run()
|
||||
rawbufs = lowered[-1].bufs
|
||||
assert len(rawbufs) == 3 and set(rawbufs[1:]) == {a.lazydata.base.realized, b.lazydata.base.realized}
|
||||
@@ -924,7 +924,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
# these are of size 3 to avoid float4 coalesce
|
||||
r = a[:-1] + a[1:]
|
||||
|
||||
k = Kernel(create_schedule([r.lazydata])[-1].ast)
|
||||
k = Kernel(r.schedule()[-1].ast)
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
num_loads = len([uop for uop in k.uops if uop.op is Ops.LOAD])
|
||||
@@ -955,7 +955,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
|
||||
r = a.expand([2]) + b.expand([2])
|
||||
|
||||
k = Kernel(create_schedule([r.lazydata])[-1].ast)
|
||||
k = Kernel(r.schedule()[-1].ast)
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
num_ops = len([uop for uop in k.uops if uop.op in GroupOp.ALU])
|
||||
@@ -966,7 +966,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
x, w = Tensor.randn((1,1,3)).realize(), Tensor.randn((1,1,2)).realize()
|
||||
r = Tensor.conv2d(x,w,padding=1).relu()
|
||||
|
||||
k = Kernel(create_schedule([r.lazydata])[-1].ast)
|
||||
k = Kernel(r.schedule()[-1].ast)
|
||||
k.upcast()
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
@@ -983,7 +983,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
def test_upcast_with_locals(self):
|
||||
x, y = Tensor.rand(1,128), Tensor.rand(128, 128)
|
||||
r = (x@y).relu()
|
||||
k = Kernel(create_schedule([r.lazydata])[-1].ast)
|
||||
k = Kernel(r.schedule()[-1].ast)
|
||||
k.hand_coded_optimizations()
|
||||
k.linearize()
|
||||
|
||||
@@ -1000,7 +1000,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
a, b = Tensor.randn(1).realize(), Tensor.randn(1).realize()
|
||||
r = Tensor.stack(a, b)
|
||||
|
||||
k = Kernel(create_schedule([r.lazydata])[-1].ast)
|
||||
k = Kernel(r.schedule()[-1].ast)
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
num_ops = len([uop for uop in k.uops if uop.op in GroupOp.ALU])
|
||||
@@ -1011,14 +1011,14 @@ class TestLinearizer(unittest.TestCase):
|
||||
(dtypes.bool, dtypes.int), (dtypes.int16, dtypes.int), (dtypes.float16, dtypes.float), (dtypes.bfloat16, dtypes.float)):
|
||||
if is_dtype_supported(tensor_dtype) and is_dtype_supported(acc_dtype):
|
||||
a = Tensor([1, 2, 3], dtype=tensor_dtype).sum()
|
||||
k = Kernel(create_schedule([a.lazydata])[-1].ast)
|
||||
k = Kernel(a.schedule()[-1].ast)
|
||||
k.linearize()
|
||||
local = [uop for uop in k.uops if uop.op is Ops.DEFINE_ACC]
|
||||
assert local[0].dtype == acc_dtype
|
||||
|
||||
def test_arg_acc_dtype(self):
|
||||
def helper_arg_acc_dtype(c: Tensor, expected_dtype:DType):
|
||||
k = Kernel(create_schedule([c.lazydata])[-1].ast)
|
||||
k = Kernel(c.schedule()[-1].ast)
|
||||
k.linearize()
|
||||
local = [uop for uop in k.uops if uop.op is Ops.DEFINE_ACC]
|
||||
assert local[0].dtype == expected_dtype
|
||||
@@ -1225,7 +1225,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
|
||||
def test_div_collapse(self):
|
||||
def helper(t, msg, max_ops=0):
|
||||
sched = [si for si in create_schedule([t.lazydata]) if si.ast.op is Ops.SINK]
|
||||
sched = [si for si in t.schedule() if si.ast.op is Ops.SINK]
|
||||
assert len(sched) == 1
|
||||
|
||||
lin = Kernel(sched[0].ast)
|
||||
@@ -1246,7 +1246,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
|
||||
def test_sum_collapse(self):
|
||||
t = Tensor([2]).reshape(1, 1).expand(256, 256).sum()
|
||||
sched = [si for si in create_schedule([t.lazydata]) if si.ast.op is Ops.SINK]
|
||||
sched = [si for si in t.schedule() if si.ast.op is Ops.SINK]
|
||||
assert len(sched) == 1
|
||||
lin = Kernel(sched[0].ast)
|
||||
assert not any(u.op is Ops.RANGE for u in lin.linearize().uops), "found loop in sum collapse"
|
||||
@@ -1262,7 +1262,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
a = Tensor.ones(4, 4).contiguous().realize()
|
||||
b = a.shrink(((1, 2), None)).pad(((1, 2), None))
|
||||
a.assign(b.where(2, a))
|
||||
sched = create_schedule([a.lazydata])
|
||||
sched = a.schedule()
|
||||
assert len(sched) == 1
|
||||
sched_copy = sched[:]
|
||||
run_schedule(sched)
|
||||
@@ -1424,7 +1424,7 @@ class TestFloat4(unittest.TestCase):
|
||||
b = Tensor.rand(2, 8).realize()
|
||||
c = a + b
|
||||
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
s = c.schedule()[0]
|
||||
k = Kernel(s.ast)
|
||||
k.hand_coded_optimizations()
|
||||
k.linearize()
|
||||
@@ -1437,7 +1437,7 @@ class TestFloat4(unittest.TestCase):
|
||||
b = Tensor.rand(2, 8).realize()
|
||||
c = a + b
|
||||
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
s = c.schedule()[0]
|
||||
k = Kernel(s.ast)
|
||||
k.shift_to(0, 4) # float4 dimension
|
||||
k.shift_to(0, 2, insert_before=k.shape_len-1)
|
||||
@@ -1455,7 +1455,7 @@ class TestFloat4(unittest.TestCase):
|
||||
b = Tensor.rand(2, size).realize()
|
||||
c = a + b
|
||||
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
s = c.schedule()[0]
|
||||
k = Kernel(s.ast)
|
||||
k.shift_to(0, 4)
|
||||
k.shift_to(0, shift, insert_before=k.shape_len-1)
|
||||
@@ -1479,7 +1479,7 @@ class TestFloat4(unittest.TestCase):
|
||||
b = Tensor.rand(9).realize().shrink(((1, 9),))
|
||||
c = a + b
|
||||
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
s = c.schedule()[0]
|
||||
k = Kernel(s.ast)
|
||||
k.hand_coded_optimizations() # implicit trigger float4 dim
|
||||
k.linearize()
|
||||
@@ -1492,7 +1492,7 @@ class TestFloat4(unittest.TestCase):
|
||||
b = Tensor.rand(2, 9).realize().shrink(((0, 2), (1, 9),))
|
||||
c = a + b
|
||||
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
s = c.schedule()[0]
|
||||
k = Kernel(s.ast)
|
||||
k.shift_to(len(k.full_unupcasted_shape)-1, 4) # manual trigger float4 dim
|
||||
k.upcast()
|
||||
@@ -1510,7 +1510,7 @@ class TestFloat4(unittest.TestCase):
|
||||
b = Tensor.rand(2, size).realize().shrink(((0, 2), (1, size),))
|
||||
c = a + b
|
||||
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
s = c.schedule()[0]
|
||||
k = Kernel(s.ast)
|
||||
k.shift_to(len(k.full_unupcasted_shape)-1, 4) # manual trigger float4 dim
|
||||
k.upcast()
|
||||
@@ -1535,7 +1535,7 @@ class TestFloat4(unittest.TestCase):
|
||||
# only the first and last conv dot products are aligned in a, and b is never aligned, so no
|
||||
# float4 should be emitted (the reduce axis of size 4 is the float4 axis here)
|
||||
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
s = c.schedule()[0]
|
||||
k = Kernel(s.ast)
|
||||
k.upcast()
|
||||
k.linearize()
|
||||
@@ -1551,7 +1551,7 @@ class TestFloat4(unittest.TestCase):
|
||||
# don't.
|
||||
# UPDATE: now we do this fusion
|
||||
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
s = c.schedule()[0]
|
||||
k = Kernel(s.ast)
|
||||
k.upcast()
|
||||
k.upcast()
|
||||
@@ -1567,7 +1567,7 @@ class TestFloat4(unittest.TestCase):
|
||||
# we will upcast the top axis of sz 4. they should not be coalesced into float4,
|
||||
# since the top axis is not contiguous.
|
||||
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
s = c.schedule()[0]
|
||||
k = Kernel(s.ast)
|
||||
k.shift_to(0, 4, top=True) # top axes are float4 axes
|
||||
k.upcast()
|
||||
@@ -1583,7 +1583,7 @@ class TestFloat4(unittest.TestCase):
|
||||
# we will upcast the top axis of sz 4. they should not be coalesced into float4,
|
||||
# since the top axis is not contiguous.
|
||||
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
s = c.schedule()[0]
|
||||
k = Kernel(s.ast)
|
||||
k.shift_to(0, 4) # float4 axis
|
||||
k.upcast()
|
||||
@@ -1598,7 +1598,7 @@ class TestFloat4(unittest.TestCase):
|
||||
|
||||
# should float4 b but not a
|
||||
|
||||
s = create_schedule([c.lazydata])[0]
|
||||
s = c.schedule()[0]
|
||||
k = Kernel(s.ast)
|
||||
k.shift_to(0, 4) # float4 axis
|
||||
k.upcast()
|
||||
@@ -1692,7 +1692,7 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
layer_1 = Tensor.cat(*[Tensor.rand(5) for _ in range(4)])
|
||||
layer_2 = Tensor.cat(layer_1.unsqueeze(0), Tensor.rand(6, 20))
|
||||
|
||||
s = create_schedule([layer_2.lazydata])[-1]
|
||||
s = layer_2.schedule()[-1]
|
||||
k = Kernel(s.ast)
|
||||
k.hand_coded_optimizations()
|
||||
assert len(k.bufs) == 6 # make sure all ops are done in one kernel
|
||||
@@ -1705,7 +1705,7 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
def test_masked_upcast_wino(self):
|
||||
monster = Tensor.stack(*[Tensor.stack(*[Tensor.rand(16) for _ in range(6)]) for _ in range(6)])
|
||||
|
||||
s = create_schedule([monster.lazydata])[-1]
|
||||
s = monster.schedule()[-1]
|
||||
k = Kernel(s.ast)
|
||||
k.hand_coded_optimizations()
|
||||
assert len(k.bufs) == 37 # make sure all ops are done in one kernel
|
||||
@@ -1719,7 +1719,7 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
out.mean().backward()
|
||||
|
||||
upcasts = []
|
||||
wino_schedule = create_schedule([out.lazydata])
|
||||
wino_schedule = out.schedule()
|
||||
# collect upcasts of tile transform kernels
|
||||
for i, si in enumerate(wino_schedule):
|
||||
k = Kernel(si.ast)
|
||||
@@ -1732,7 +1732,7 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
# this test case's inputs are too small, so one of the 4-stacks became a local, which is fine i guess
|
||||
assert upcasts.count((6, 6)) == 2 #and upcasts.count((4, 4)) == 1
|
||||
|
||||
backward_schedule = create_schedule([x.grad.lazydata, w.grad.lazydata])
|
||||
backward_schedule = Tensor.schedule(x.grad, w.grad)
|
||||
for si in backward_schedule:
|
||||
k = Kernel(si.ast)
|
||||
k.hand_coded_optimizations()
|
||||
@@ -2060,7 +2060,7 @@ class TestKernelOpts(unittest.TestCase):
|
||||
def test_padto_sum_ok(self):
|
||||
N = 18 * 18
|
||||
# NOTE: this setup prevents 17 * 17 contiguous merged into one dimension
|
||||
a = Tensor.rand(N, N).shrink(((0, 17), (0, 17))) * 100
|
||||
a = Tensor.rand(N, N).realize().shrink(((0, 17), (0, 17))) * 100
|
||||
b = (Tensor.rand(N, N) < 0.5).realize().shrink(((0, 17), (0, 17)))
|
||||
|
||||
helper_linearizer_opt(a.sum(0), [
|
||||
|
||||
@@ -4,7 +4,6 @@ from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes
|
||||
from tinygrad.ops import Ops
|
||||
from tinygrad.helpers import CI, getenv, prod, Context
|
||||
from tinygrad.nn.state import get_parameters, get_state_dict
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import lower_schedule, BufferCopy, CompiledRunner, run_schedule
|
||||
from tinygrad.multi import all_reduce, MultiLazyBuffer
|
||||
import numpy as np
|
||||
@@ -69,7 +68,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
X = Tensor.ones(256).contiguous().realize()
|
||||
X.shard_(devices_2, 0)
|
||||
out = (X + X)
|
||||
sched = create_schedule(out.lazydata.lbs)
|
||||
sched = out.schedule()
|
||||
names = []
|
||||
for si, ei in zip(sched[:], lower_schedule(sched)):
|
||||
if isinstance(ei.prg, CompiledRunner): names.append(ei.prg.p.name)
|
||||
@@ -492,7 +491,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
for p in get_parameters(bn): p.shard_(devices_4).realize()
|
||||
|
||||
out = bn(t)
|
||||
scheds = [sched for sched in create_schedule(out.lazydata.lbs) if sched.outputs[0].device in devices_4 and sched.ast.op is not Ops.COPY]
|
||||
scheds = [sched for sched in out.schedule() if sched.outputs[0].device in devices_4 and sched.ast.op is not Ops.COPY]
|
||||
assert set(out.device for sched in scheds for out in sched.outputs) == set(devices_4), "should have ast on each shard device"
|
||||
asts = [sched.ast for sched in scheds]
|
||||
assert len(asts)
|
||||
@@ -723,7 +722,7 @@ class TestHandleData(unittest.TestCase):
|
||||
device = (d0, d1, d2, d3)
|
||||
t = Tensor([1, 2, 3, 4]).shard(device).realize()
|
||||
not_covered = t.to(d5)
|
||||
sched = create_schedule([not_covered.lazydata])
|
||||
sched = not_covered.schedule()
|
||||
assert len(sched) == 1
|
||||
# setup again because create_schedule has side effect
|
||||
t = Tensor([1, 2, 3, 4]).shard(device).realize()
|
||||
@@ -733,7 +732,7 @@ class TestHandleData(unittest.TestCase):
|
||||
for d in device:
|
||||
t = Tensor([1, 2, 3, 4]).shard(device).realize()
|
||||
covered = t.to(d)
|
||||
sched = create_schedule([covered.lazydata])
|
||||
sched = covered.schedule()
|
||||
assert len(sched) == 0
|
||||
# setup again because create_schedule has side effect
|
||||
t = Tensor([1, 2, 3, 4]).shard(device).realize()
|
||||
@@ -1001,9 +1000,9 @@ class TestBatchNorm(unittest.TestCase):
|
||||
p.to_(devices)
|
||||
|
||||
synced_out = synced_bn(x)
|
||||
synced_si = list(create_schedule(synced_out.lazydata.lbs))
|
||||
synced_si = list(synced_out.schedule())
|
||||
unsynced_out = unsynced_bn(x)
|
||||
unsynced_si = list(create_schedule(unsynced_out.lazydata.lbs))
|
||||
unsynced_si = list(unsynced_out.schedule())
|
||||
|
||||
# TODO: test synced / unsynced batchnorm cross device kernel and copies
|
||||
assert synced_si
|
||||
|
||||
@@ -8,7 +8,6 @@ from tinygrad.helpers import CI, Context
|
||||
from tinygrad.nn import Conv1d, ConvTranspose1d, Conv2d, ConvTranspose2d, Linear, Embedding
|
||||
from tinygrad.nn import BatchNorm, LayerNorm, LayerNorm2d, GroupNorm, InstanceNorm, RMSNorm, LSTMCell
|
||||
from tinygrad.nn.state import load_state_dict
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.device import is_dtype_supported
|
||||
|
||||
@@ -517,7 +516,7 @@ class TestNN(unittest.TestCase):
|
||||
a = Tensor([[1, 5, 9, 11],
|
||||
[12, 19, 8, 1]])
|
||||
result = layer(a)
|
||||
schedule = create_schedule([result.lazydata])
|
||||
schedule = result.schedule()
|
||||
self.assertEqual(3, len([item for item in schedule if item.ast.op is Ops.SINK]), "first run realizes arange, weight, and embedding")
|
||||
run_schedule(schedule)
|
||||
|
||||
@@ -525,7 +524,7 @@ class TestNN(unittest.TestCase):
|
||||
[4, 5, 6],
|
||||
[7, 8, 9]])
|
||||
result = layer(b)
|
||||
schedule = create_schedule([result.lazydata])
|
||||
schedule = result.schedule()
|
||||
self.assertEqual(1, len([item for item in schedule if item.ast.op is Ops.SINK]), "second run realizes embedding only")
|
||||
run_schedule(schedule)
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import unittest, pickle, types
|
||||
import numpy as np
|
||||
from tinygrad import Tensor, TinyJit, Variable, dtypes
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.helpers import GlobalCounters
|
||||
from tinygrad.ops import PatternMatcher, UPat, UOp
|
||||
|
||||
@@ -99,7 +98,7 @@ class TestPickle(unittest.TestCase):
|
||||
def test_pickle_schedule(self):
|
||||
a = Tensor([1,2])
|
||||
out = a + 2
|
||||
sched = create_schedule([out.lazydata])
|
||||
sched = out.schedule()
|
||||
pk = pickle.dumps(sched)
|
||||
sched_pk = pickle.loads(pk)
|
||||
self.assertEqual(sched_pk[-1].ast, sched[-1].ast)
|
||||
|
||||
@@ -3,7 +3,6 @@ from tinygrad import Device, Tensor, dtypes, TinyJit
|
||||
from tinygrad.helpers import CI, getenv, Context
|
||||
from tinygrad.device import Buffer, BufferSpec, Compiled, ProfileRangeEvent, ProfileDeviceEvent, ProfileGraphEvent
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import get_runner
|
||||
|
||||
MOCKGPU = getenv("MOCKGPU")
|
||||
@@ -34,7 +33,7 @@ class TestProfiler(unittest.TestCase):
|
||||
|
||||
TestProfiler.a = Tensor([0.,1.], device=Device.DEFAULT).realize()
|
||||
TestProfiler.b = self.a + 1
|
||||
si = create_schedule([self.b.lazydata])[-1]
|
||||
si = self.b.schedule()[-1]
|
||||
|
||||
TestProfiler.runner = get_runner(TestProfiler.d0.device, si.ast)
|
||||
TestProfiler.b.lazydata.buffer.allocate()
|
||||
|
||||
@@ -14,20 +14,21 @@ from tinygrad.dtype import DType, ImageDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, view_supported_devices, symbolic
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
|
||||
from tinygrad.codegen.kernel import Kernel, verify_ast
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleContext, ScheduleItem, create_schedule, view_right, view_left, remove_movement_ops, to_uop
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops
|
||||
from tinygrad.engine.realize import CompiledRunner, get_runner, run_schedule
|
||||
from extra.models.llama import precompute_freqs_cis
|
||||
|
||||
class KernelCountException(Exception): pass
|
||||
def check_schedule(t:Union[Tensor, List[Tensor], UOp], allowed:int, to_prerealize:Optional[List[Tensor]]=None, filter_sink=True):
|
||||
if isinstance(t, Tensor): outs = t.lazydata.lbs
|
||||
elif isinstance(t, List): outs = flatten([r.lazydata.lbs for r in t])
|
||||
else: outs = [t]
|
||||
if to_prerealize:
|
||||
for pre in to_prerealize: pre.schedule()
|
||||
sched = create_schedule(outs)
|
||||
if isinstance(t, Tensor): sched = t.schedule()
|
||||
elif isinstance(t, List) and isinstance(t[0], Tensor): sched = Tensor.schedule(*t)
|
||||
else:
|
||||
assert isinstance(t, UOp), f"can't schedule {t}"
|
||||
sched, _ = create_schedule_with_vars([t])
|
||||
if filter_sink: sched = [s for s in sched if s.ast.op is Ops.SINK]
|
||||
if len(sched) != allowed:
|
||||
print(f"SCHEDULE ISSUE, expecting {allowed} got {len(sched)}")
|
||||
@@ -54,7 +55,7 @@ def _test_conv2d(allowed:int, dtype:DType=dtypes.float, **kwargs):
|
||||
w = Tensor.uniform(16, CIN, 3, 3, requires_grad=True).realize()
|
||||
ret = Tensor.conv2d(img, w).relu().mean().backward()
|
||||
dtypes.default_float = old_default_float
|
||||
with Context(**kwargs): s = create_schedule([ret.lazydata, img.grad.lazydata, w.grad.lazydata])
|
||||
with Context(**kwargs): s = Tensor.schedule(ret, img.grad, w.grad)
|
||||
run_schedule(s.copy())
|
||||
cnt = len([si for si in s if si.ast.op is Ops.SINK])
|
||||
assert cnt == allowed, f"expected {allowed} kernels, got {cnt}"
|
||||
@@ -1393,11 +1394,11 @@ class TestSchedule(unittest.TestCase):
|
||||
|
||||
def test_const_schedule(self):
|
||||
constv = Tensor.empty(2, 2).lazydata.const_like(10)
|
||||
self.assertEqual(len(create_schedule([constv])), 0)
|
||||
check_schedule(constv, 0)
|
||||
|
||||
def test_const_schedule_contig(self):
|
||||
constv = Tensor.empty(2, 2).lazydata.const_like(10).contiguous()
|
||||
self.assertEqual(len(create_schedule([constv])), 1)
|
||||
check_schedule(constv, 1)
|
||||
|
||||
@unittest.skipIf(Device.DEFAULT != "GPU", "image only supported on GPU")
|
||||
def test_image_matmul(self):
|
||||
@@ -2047,15 +2048,6 @@ class TestBigGraph(unittest.TestCase):
|
||||
sink = tensor_rewrite(a)
|
||||
assert UPat.cvar(dtype=dtypes.int).match(sink, {})
|
||||
|
||||
# failure: the scheduler must not change image to its base dtype before const folding
|
||||
@unittest.skipIf(Device.DEFAULT not in ("QCOM", "GPU"), "only images on GPU")
|
||||
@unittest.expectedFailure
|
||||
def test_float_to_image_cast_stays(self):
|
||||
a = Tensor.empty(4).cast(dtypes.imagef((1,1,4)))
|
||||
sink = to_uop(a.lazydata, ScheduleContext(), {})
|
||||
sink = graph_rewrite(sink, symbolic)
|
||||
assert UPat(Ops.VIEW, src=(UPat(Ops.BUFFER, dtypes.imagef((1, 1, 4))), UPat(Ops.CAST, src=(UPat.var(dtype=dtypes.float),)))).match(sink, {})
|
||||
|
||||
tensor_const_pm = PatternMatcher([
|
||||
(UPat(Ops.CONST, src=(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),)),)), lambda: True),
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE), UPat(Ops.BIND, src=(UPat(Ops.DEFINE_VAR), UPat(Ops.CONST))))), lambda: True),
|
||||
|
||||
@@ -4,7 +4,6 @@ from test.helpers import ast_const
|
||||
from tinygrad.codegen.kernel import Opt, OptOps
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.ops import UOp, Ops
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.search import time_linearizer, bufs_from_lin, actions, beam_search
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.tensor import Tensor
|
||||
@@ -16,7 +15,8 @@ from tinygrad.shape.view import View
|
||||
|
||||
class TestTimeLinearizer(unittest.TestCase):
|
||||
def test_reasonable_time(self):
|
||||
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op is Ops.SINK][0]
|
||||
a = Tensor([1,2,3,4]).realize()
|
||||
si = (a+1).schedule()[0]
|
||||
out = Buffer(Device.DEFAULT, si.outputs[0].size, si.outputs[0].dtype).allocate()
|
||||
memops = {x.src[0].arg:x.src[-1].arg.real_size() for x in si.ast.toposort if x.op is Ops.LOAD}
|
||||
rawbufs = [out] + [Buffer(Device.DEFAULT, memops[i], x.dtype).allocate() for i,x in enumerate(si.inputs, start=len(si.outputs))]
|
||||
@@ -24,7 +24,8 @@ class TestTimeLinearizer(unittest.TestCase):
|
||||
assert tm > 0 and tm != float('inf')
|
||||
|
||||
def test_bufs_from_lin(self):
|
||||
si = [i for i in create_schedule([Tensor([1,2,3,4]).add(1).lazydata]) if i.ast.op is Ops.SINK][0]
|
||||
a = Tensor([1,2,3,4]).realize()
|
||||
si = (a+1).schedule()[0]
|
||||
rawbufs = bufs_from_lin(lin:=Kernel(si.ast))
|
||||
assert len(rawbufs) == len(lin.membufs) == 2
|
||||
assert all(r is not None for r in rawbufs)
|
||||
@@ -34,7 +35,7 @@ class TestTimeLinearizer(unittest.TestCase):
|
||||
def test_bufs_from_lin_alt(self):
|
||||
a = Tensor.randn(4, 4).realize()
|
||||
b = a+a[0]
|
||||
si = [si for si in b.schedule() if si.ast.op is Ops.SINK][0]
|
||||
si = b.schedule()[0]
|
||||
rawbufs = bufs_from_lin(k:=Kernel(si.ast))
|
||||
assert len(rawbufs) == len(k.membufs) == 2
|
||||
assert all(r is not None for r in rawbufs)
|
||||
|
||||
@@ -3,7 +3,6 @@ import numpy as np
|
||||
import torch
|
||||
import unittest, copy, mmap, random, math, array
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.helpers import getenv, temp, _METADATA, mv_address
|
||||
from extra.gradcheck import numerical_jacobian, jacobian, gradcheck
|
||||
from hypothesis import given, settings, strategies as strat
|
||||
@@ -725,7 +724,7 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
W = Tensor.rand(3, 3, requires_grad=True)
|
||||
out = x.matmul(W)
|
||||
self.assertEqual(out.lazydata.metadata.name, "matmul")
|
||||
si = create_schedule([out.lazydata])[-1]
|
||||
si = out.schedule()[-1]
|
||||
self.assertEqual(len(si.metadata), 1)
|
||||
self.assertEqual(si.metadata[0].name, "matmul")
|
||||
|
||||
@@ -733,7 +732,7 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
x = Tensor.rand(3, requires_grad=True)
|
||||
out = x.relu()
|
||||
self.assertEqual(out.lazydata.metadata.name, "relu")
|
||||
si = create_schedule([out.lazydata])[-1]
|
||||
si = out.schedule()[-1]
|
||||
self.assertEqual(len(si.metadata), 1)
|
||||
self.assertEqual(si.metadata[0].name, "relu")
|
||||
|
||||
@@ -744,7 +743,7 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
self.assertEqual(out.lazydata.metadata.name, "__mul__")
|
||||
self.assertEqual(out.lazydata.src[0].metadata.name, "relu")
|
||||
self.assertEqual(out.lazydata.src[1].metadata.name, "sigmoid")
|
||||
si = create_schedule([out.lazydata])[-1]
|
||||
si = out.schedule()[-1]
|
||||
self.assertEqual(len(si.metadata), 3)
|
||||
self.assertEqual(set(m.name for m in si.metadata), {"relu", "sigmoid", "__mul__"})
|
||||
|
||||
@@ -758,7 +757,7 @@ class TestTensorMetadata(unittest.TestCase):
|
||||
self.assertTrue(x.grad.lazydata.metadata.backward)
|
||||
self.assertEqual(y.grad.lazydata.metadata.name, "sigmoid")
|
||||
self.assertTrue(y.grad.lazydata.metadata.backward)
|
||||
si = create_schedule([out.lazydata, x.grad.lazydata, y.grad.lazydata])[-1]
|
||||
si = Tensor.schedule(out, x.grad, y.grad)[-1]
|
||||
self.assertEqual(len(si.metadata), 3, f"failed with {si.metadata}")
|
||||
self.assertEqual(set(m.name for m in si.metadata), {"sigmoid", "sigmoid", "relu"})
|
||||
bw = [m for m in si.metadata if m.backward]
|
||||
|
||||
@@ -9,7 +9,7 @@ from tinygrad.dtype import dtypes, DType
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.ops import Ops, UOp, UPat, KernelInfo, exec_alu, spec # noqa F401
|
||||
from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.engine.schedule import create_schedule, to_si
|
||||
from tinygrad.engine.schedule import to_si
|
||||
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
|
||||
from tinygrad.codegen.linearize import linearize_uop
|
||||
from tinygrad.codegen.uopgraph import full_graph_rewrite, sym
|
||||
@@ -237,12 +237,12 @@ class TestExecALU(TestUOps):
|
||||
class TestConstantFolding(unittest.TestCase):
|
||||
def test_cast_const(self):
|
||||
t = Tensor(1, dtype=dtypes.float).cast(dtypes.int)
|
||||
si = create_schedule([t.lazydata])
|
||||
si = t.schedule()
|
||||
assert len(si) == 0
|
||||
|
||||
def test_bitcast_const(self):
|
||||
t = Tensor(1, dtype=dtypes.float).bitcast(dtypes.int)
|
||||
si = create_schedule([t.lazydata])
|
||||
si = t.schedule()
|
||||
assert len(si) == 1
|
||||
ji = lower_schedule_item(si[-1])
|
||||
assert any(uop.op is Ops.BITCAST for uop in ji.prg.p.uops), f"{[uop.op for uop in ji.prg.p.uops]} does not contain bitcast"
|
||||
@@ -592,5 +592,29 @@ class TestShapeSpec(unittest.TestCase):
|
||||
r = Tensor.empty(4, 4).sum(axis=1)
|
||||
self.assertEqual(r.lazydata.st, ShapeTracker.from_shape((4,)))
|
||||
|
||||
class TestUOpChildren(unittest.TestCase):
|
||||
def test_children_exist(self):
|
||||
a = UOp.variable("weird_name_234", 0, 10)
|
||||
b = a*a
|
||||
self.assertEqual(len(a.children), 1)
|
||||
self.assertIs(list(a.children)[0](), b)
|
||||
|
||||
def test_children_cleaned_up(self):
|
||||
a = UOp.variable("weird_name_235", 0, 10)
|
||||
b = a*a
|
||||
self.assertEqual(len(a.children), 1)
|
||||
del b
|
||||
self.assertEqual(len(a.children), 0)
|
||||
|
||||
def test_children_cleaned_up_two(self):
|
||||
a = UOp.variable("weird_name_236", 0, 10)
|
||||
b = a*a
|
||||
c = a*2
|
||||
self.assertEqual(len(a.children), 2)
|
||||
del b
|
||||
self.assertEqual(len(a.children), 1)
|
||||
del c
|
||||
self.assertEqual(len(a.children), 0)
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main(verbosity=2)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import unittest
|
||||
from tinygrad import Tensor
|
||||
from tinygrad.helpers import getenv, GlobalCounters
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.realize import lower_schedule_item, ProgramSpec
|
||||
from tinygrad.renderer import Estimates
|
||||
from tinygrad.codegen.linearize import linearize_uop
|
||||
@@ -16,7 +15,7 @@ def flops_mem(uops, ignore_indexing=False):
|
||||
# **************** new FlopCounter ****************
|
||||
|
||||
def get_stats(x:Tensor):
|
||||
si = create_schedule([x.lazydata])[-1]
|
||||
si = x.schedule()[-1]
|
||||
ei = lower_schedule_item(si)
|
||||
return ei.prg.estimates.ops, ei.prg.estimates.mem
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ from tinygrad import Tensor, GlobalCounters, dtypes
|
||||
from tinygrad.ops import Ops
|
||||
from tinygrad.helpers import Timing, CI, Profiling, WINO, DEBUG, getenv
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
|
||||
class TestWinograd(unittest.TestCase):
|
||||
def setUp(self):
|
||||
@@ -20,7 +19,7 @@ class TestWinograd(unittest.TestCase):
|
||||
out = Tensor.conv2d(x, w)
|
||||
|
||||
with Timing("scheduling: "):
|
||||
sched = create_schedule([out.lazydata])
|
||||
sched = out.schedule()
|
||||
|
||||
for i,s in enumerate(sched):
|
||||
if s.ast.op is not Ops.SINK: continue
|
||||
|
||||
@@ -4,6 +4,7 @@ from tinygrad.ops import UOp, symbolic, graph_rewrite_map, _substitute
|
||||
from test.unit.test_tensor_uop_representation import is_pattern, realized_pattern, is_pattern_uop
|
||||
|
||||
class TestTensorMutates(unittest.TestCase):
|
||||
@unittest.expectedFailure
|
||||
def test_mutate_add(self):
|
||||
a = Tensor([1,2,3])
|
||||
b = Tensor([4,5,6])
|
||||
@@ -214,6 +215,5 @@ class TestRewriteMap(unittest.TestCase):
|
||||
self.assertEqual(node_map[zero_node], zero_node)
|
||||
self.assertEqual(node_map[one_node], one_node)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
@@ -7,6 +7,40 @@ const_pattern = UPat(Ops.CONST, src=(UPat(Ops.VIEW, src=(UPat(Ops.DEVICE),),)))
|
||||
def is_pattern_uop(u:UOp, pat:UPat): assert pat.match(u, {}), f"{u}\nis not\n{pat}"
|
||||
def is_pattern(ten:Tensor, pat:UPat): is_pattern_uop(ten.lazydata, pat)
|
||||
|
||||
class TestTensorMutates(unittest.TestCase):
|
||||
# this fails because uops are mutating
|
||||
@unittest.expectedFailure
|
||||
def test_mutate_add(self):
|
||||
a = Tensor([1,2,3])
|
||||
b = Tensor([4,5,6])
|
||||
ret = a+b
|
||||
pa = a.lazydata
|
||||
pb = b.lazydata
|
||||
pr = ret.lazydata
|
||||
ret.schedule()
|
||||
self.assertIsNot(pa, a.lazydata)
|
||||
self.assertIsNot(pb, b.lazydata)
|
||||
self.assertIsNot(pr, ret.lazydata)
|
||||
for t in [a,b,ret]: is_pattern(t, realized_pattern)
|
||||
|
||||
def test_reshape_is_same_parent(self):
|
||||
a = Tensor([1,2,3])
|
||||
b = Tensor([4,5,6])
|
||||
c = a+b
|
||||
d = (a+b).reshape(3,1)
|
||||
d.realize()
|
||||
is_pattern_uop(d.lazydata.base, realized_pattern)
|
||||
is_pattern_uop(c.lazydata.base, realized_pattern)
|
||||
|
||||
def test_reshape_is_same_child(self):
|
||||
a = Tensor([1,2,3])
|
||||
b = Tensor([4,5,6])
|
||||
c = a+b
|
||||
d = (a+b).reshape(3,1)
|
||||
c.realize()
|
||||
is_pattern_uop(c.lazydata.base, realized_pattern)
|
||||
is_pattern_uop(d.lazydata.base, realized_pattern)
|
||||
|
||||
class TestTensorUopRepresentation(unittest.TestCase):
|
||||
def test_realized(self):
|
||||
a = Tensor([1.,2,3]).realize()
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from __future__ import annotations
|
||||
import itertools, functools
|
||||
import itertools, functools, math
|
||||
from dataclasses import dataclass
|
||||
from collections import defaultdict
|
||||
from typing import Optional, cast, Final, Callable, Sequence
|
||||
@@ -298,19 +298,15 @@ class Kernel:
|
||||
# can only fuse reduces with the same tc options
|
||||
assert all_same(tensor_core_opts)
|
||||
if tensor_core_opts[0] is None: continue
|
||||
# tensor core -- unroll the reduce dim, upcast input and local the correct thread pattern
|
||||
self.tensor_core_opts = tc_opts = tensor_core_opts[0]
|
||||
|
||||
# attempt to pad the tensor axes that require it
|
||||
try:
|
||||
for axis, dim in tc_opts.axis_pads: self.apply_opt(Opt(OptOps.PADTO, axis, dim), append_opt=False) # PADTO might fail
|
||||
except KernelOptError: continue
|
||||
for tc_dim, amt in tc.reduce_axes: self.apply_opt(Opt(OptOps.UNROLL,tc_opts.axes[2]-self.first_reduce,amt), append_opt=False)
|
||||
for opt in tc.opts_seq:
|
||||
if opt == "UP":
|
||||
for tc_dim, amt in tc.early_upcast_axes: self.apply_opt(Opt(OptOps.UPCAST,tc_opts.axes[tc_dim],amt), append_opt=False)
|
||||
elif opt == "LC":
|
||||
for tc_dim, amt in tc.threads: self.apply_opt(Opt(OptOps.LOCAL,tc_opts.axes[tc_dim],amt), append_opt=False)
|
||||
# tensor core -- unroll the reduce dim (K), upcast and local the inner and outer dims (N, M)
|
||||
for dim, amt in tc.get_reduce_axes(): self.apply_opt(Opt(OptOps.UNROLL, tc_opts.axes[2]-self.first_reduce, amt), append_opt=False)
|
||||
for opt in tc.opts: self.apply_opt(Opt({"u":OptOps.UPCAST, "l":OptOps.LOCAL}[opt[0]], tc_opts.axes[int(opt[1])], 2), append_opt=False)
|
||||
self.tensor_core = tc
|
||||
self.use_tensor_cores = use_tensor_cores # TC=2 will do the shape ops without the WMMA
|
||||
return True
|
||||
@@ -405,7 +401,7 @@ class Kernel:
|
||||
self.upcast()
|
||||
elif opt.op is OptOps.UPCAST: # yellow
|
||||
check(axis < self.first_reduce, "upcast is for non-reduce")
|
||||
check(not (self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.threads)), "can't upcast TC locals")
|
||||
check(not (self.tensor_core and self.global_dims <= axis < self.global_dims+len(self.tensor_core.get_local_axes())), "can't upcast TC locals")
|
||||
check(amt <= 16, "don't upcast more than 16")
|
||||
self.shift_to(axis, amt, insert_before=None)
|
||||
self.upcast()
|
||||
@@ -596,48 +592,43 @@ class Kernel:
|
||||
grouped_axes = reduced_axes(self.first_reduce, self.first_reduce + self.group_for_reduces)
|
||||
|
||||
if (tc := self.tensor_core) and (self.use_tensor_cores == 1 or self.use_tensor_cores == 3):
|
||||
def fix_st(st: ShapeTracker, wd_pattern, tcd_pattern):
|
||||
st = ShapeTracker.from_shape(st.shape) # st needs to be contiguous
|
||||
wd, warp_dims = self.global_dims, tuple(sz for _, sz in tc.threads)
|
||||
tcd, tcd_dims = self.first_upcast, tuple(sz for _, sz in tc.reduce_axes + tc.early_upcast_axes)
|
||||
|
||||
assert st.shape[wd:wd+len(warp_dims)] == warp_dims, f"warp dims wrong: {st.shape[wd:wd+len(warp_dims)]=} != {warp_dims=}"
|
||||
assert st.shape[tcd:tcd+len(tcd_dims)] == tcd_dims, f"tcd dims wrong: {st.shape[tcd:tcd+len(tcd_dims)]=} != {tcd_dims=}"
|
||||
assert tc.expanded_shape is not None
|
||||
|
||||
new_shape = st.shape[:tcd] + tc.expanded_shape + st.shape[tcd+len(tcd_dims):] # expand the tcd
|
||||
permaxis = list(range(wd)) + [y + (wd if x == 0 else tcd) for x,y in wd_pattern] + list(range(wd+len(warp_dims),tcd)) + \
|
||||
[y + (wd if x == 0 else tcd) for x,y in tcd_pattern] + list(range(tcd+len(tc.expanded_shape),len(new_shape)))
|
||||
return st.reshape(new_shape).permute(tuple(permaxis)).reshape(st.shape).simplify()
|
||||
wd, tcd = self.global_dims, self.first_upcast
|
||||
def get_upcast_axes(buf): # upcast along non-zero dimensions of (tc_reduce + tc_upcast)
|
||||
upcast_axes = int(math.log2(tc.elements_per_thread[buf]))
|
||||
return tuple((tcd + len(tc.get_reduce_axes()) + len(tc.get_upcast_axes()) - (i+1), 2) for i in range(upcast_axes))
|
||||
def get_tc_swizzle_st(shape, local_perm, upcast_perm):
|
||||
offset = (tcd - (wd + len(local_perm)))
|
||||
permaxis = list(range(wd)) \
|
||||
+ [wd + x + (offset if x >= len(local_perm) else 0) for x in local_perm] + list(range(wd + len(local_perm), tcd)) \
|
||||
+ [wd + x + (offset if x >= len(local_perm) else 0) for x in upcast_perm] + list(range(tcd + len(upcast_perm), len(shape)))
|
||||
return ShapeTracker.from_shape(shape).permute(tuple(permaxis))
|
||||
|
||||
srcs = list((ret.src[0] if ret.src[0].op is not Ops.CAST else ret.src[0].src[0]).src)
|
||||
for i, tc_pattern in enumerate([tc.st1_pattern, tc.st2_pattern]):
|
||||
if tc_pattern: srcs[i] = srcs[i].view(fix_st(srcs[i].st_arg if srcs[i].op is Ops.LOAD else srcs[i].src[0].st_arg, *tc_pattern))
|
||||
for i, (src, swizzle) in enumerate(zip(srcs, tc.swizzle)):
|
||||
if swizzle: srcs[i] = src.view(get_tc_swizzle_st((src if src.op is Ops.LOAD else src.src[0]).st_arg.shape, *swizzle))
|
||||
|
||||
if self.use_tensor_cores == 3: # for TC=3, emulate the warp addressing with locals
|
||||
local_shape = tuple(1 if i >= self.first_reduce and i < self.first_upcast else s for i, s in enumerate(self.full_shape))
|
||||
st = store_st = ShapeTracker.from_shape(local_shape)
|
||||
local_buffer = UOp(Ops.DEFINE_LOCAL, tc.dtype_in.ptr(size=st.real_size(), local=True), (), (f"temp{i + 1}", st.real_size()))
|
||||
if tc_pattern: store_st = fix_st(store_st, *tc_pattern)
|
||||
if swizzle: store_st = get_tc_swizzle_st(store_st.shape, *swizzle)
|
||||
local_store = UOp.store(local_buffer, store_st.to_uop(), srcs[i])
|
||||
srcs[i] = UOp(Ops.LOAD, tc.dtype_in, (local_buffer, st.to_uop(), local_store))
|
||||
|
||||
tc_reduce_axes = tuple(self.first_upcast + ax for ax, _ in tc.reduce_axes)
|
||||
tc_reduce_axes = tuple(tcd + ax for ax, _ in tc.get_reduce_axes())
|
||||
if self.use_tensor_cores == 1: # real WMMA, use CONTRACT/UNROLL to get the vectorization right
|
||||
upcast_axes = tuple(tuple((self.first_upcast + ax, sz) for ax, sz in up) for up in tc.upcast_axes)
|
||||
wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, prod(sz for _, sz in tc.threads), upcast_axes, tc_reduce_axes)
|
||||
wmma_sz = [prod(x[1] for x in l) for l in upcast_axes]
|
||||
wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(wmma_sz[2]), src=(
|
||||
UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(wmma_sz[0]), src=(srcs[0],), arg=upcast_axes[0]),
|
||||
UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(wmma_sz[1]), src=(srcs[1],), arg=upcast_axes[1]),
|
||||
UOp.const(tc.dtype_out.vec(wmma_sz[2]), 0.0)), arg=wmma_arg)
|
||||
tc_uop = UOp(Ops.UNROLL, tc.dtype_out, (wmma,), arg=upcast_axes[2])
|
||||
tc_upcast_axes = (get_upcast_axes(0), get_upcast_axes(1), get_upcast_axes(2))
|
||||
wmma_arg = (str(tc), tc.dims, tc.dtype_in, tc.dtype_out, self.opts.device, tc.threads, tc_upcast_axes, tc_reduce_axes)
|
||||
wmma = UOp(Ops.WMMA, dtype=tc.dtype_out.vec(tc.elements_per_thread[2]), src=(
|
||||
UOp(Ops.CONTRACT, dtype=srcs[0].dtype.vec(tc.elements_per_thread[0]), src=(srcs[0],), arg=tc_upcast_axes[0]),
|
||||
UOp(Ops.CONTRACT, dtype=srcs[1].dtype.vec(tc.elements_per_thread[1]), src=(srcs[1],), arg=tc_upcast_axes[1]),
|
||||
UOp.const(tc.dtype_out.vec(tc.elements_per_thread[2]), 0.0)), arg=wmma_arg)
|
||||
tc_uop = UOp(Ops.UNROLL, tc.dtype_out, (wmma,), arg=tc_upcast_axes[2])
|
||||
|
||||
else: # for TC=3 MUL/SUM instead of WMMA
|
||||
tc_uop = UOp(Ops.REDUCE_AXIS, tc.dtype_out, ((srcs[0] * srcs[1]).cast(tc.dtype_out),), (Ops.ADD, tc_reduce_axes))
|
||||
|
||||
new_reduce_axes = tuple(i for i in axes if i not in tc_reduce_axes)
|
||||
return ret.replace(src=(tc_uop,), arg=(Ops.ADD, new_reduce_axes)) if new_reduce_axes else tc_uop
|
||||
return ret.replace(src=(tc_uop,), arg=(Ops.ADD, new_axes)) if (new_axes := tuple(i for i in axes if i not in tc_reduce_axes)) else tc_uop
|
||||
|
||||
ret = ret.replace(arg = (op.arg[0], axes))
|
||||
if self.group_for_reduces and grouped_axes:
|
||||
@@ -694,8 +685,9 @@ class Kernel:
|
||||
# the living definition of intermediate UOps
|
||||
|
||||
def _assert_valid_uop(uop:UOp, st:ShapeTracker, sts:dict[UOp, ShapeTracker]) -> None:
|
||||
if not uop.has_st or uop in sts: return
|
||||
if uop in sts: return
|
||||
# restore globals from the two stage reduce
|
||||
# this is because this LOAD has an implicit movement op
|
||||
if uop.op is Ops.LOAD and uop.src[0].op is Ops.DEFINE_LOCAL:
|
||||
_assert_valid_uop(local_reduce:=uop.src[2].src[2], uop.st_arg, sts)
|
||||
sts[uop] = sts[local_reduce]
|
||||
|
||||
@@ -32,8 +32,8 @@ tensor_uop_spec = PatternMatcher([
|
||||
(UPat(GroupOp.Movement, name="mv", src=(UPat.var("x"),)), lambda mv,x:
|
||||
# naturally correct
|
||||
(isinstance(mv.arg, tuple) and mv.dtype == x.dtype) or
|
||||
# TODO: "make things that can't be images not images" can override the source dtype
|
||||
# is there a clean way to update its _mop children?
|
||||
# "make things that can't be images not images" can change the buffer dtype
|
||||
# this is fine as long as it's a realized buffer and base dtypes match.
|
||||
((isinstance(mv.dtype, ImageDType) or isinstance(x.dtype, ImageDType)) and x.dtype.base == mv.dtype.base and x.is_realized)),
|
||||
|
||||
# Tensor variable bindings
|
||||
@@ -124,20 +124,21 @@ class ScheduleContext:
|
||||
contiguous: dict[UOp, UOp] = field(default_factory=dict) # this maps roots to places they are made contiguous
|
||||
children: defaultdict[UOp, dict[UOp, None]] = field(default_factory=lambda: defaultdict(dict))
|
||||
|
||||
# TODO: delete this once CONST has a VIEW source
|
||||
# currently tensor uop is VIEW(DEVICE, CONST)
|
||||
# TODO: delete this once BIND has a VIEW source
|
||||
def is_constant(u:UOp): return u.op is Ops.CONST or (u.op is Ops.VIEW and len(u.src) == 2 and u.src[1].op is Ops.BIND)
|
||||
|
||||
def to_uop(buf:UOp, ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp:
|
||||
# wrap tensor uops around a VIEW(BUFFER, <uop>)
|
||||
# this BUFFER preserves a link back to the uop on the tensor after the scheduler rewrites it.
|
||||
def add_buffers(buf:UOp, ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp:
|
||||
if (r:=cache.get(buf)) is not None: return r
|
||||
if buf.op is Ops.SINK: return UOp.sink(*[to_uop(x, ctx, cache) for x in buf.src])
|
||||
if buf.op is Ops.SINK: return UOp.sink(*[add_buffers(x, ctx, cache) for x in buf.src])
|
||||
# shapeless op is passthrough
|
||||
# realized is passthrough
|
||||
# constants are passthrough
|
||||
if buf.st is None or buf.base.is_realized or is_constant(buf.base): return buf
|
||||
# view is passthrough
|
||||
if buf is not buf.base:
|
||||
cache[buf] = ret = to_uop(buf.base, ctx, cache).view(buf.st)
|
||||
cache[buf] = ret = add_buffers(buf.base, ctx, cache).view(buf.st)
|
||||
return ret
|
||||
# make things that can't be images not images
|
||||
dtype = buf.dtype
|
||||
@@ -146,8 +147,9 @@ def to_uop(buf:UOp, ctx:ScheduleContext, cache:dict[UOp, UOp]) -> UOp:
|
||||
dtype = buf.dtype.base
|
||||
# meta ops and assign already have a target buffer, otherwise we create a new one
|
||||
buf_uop = buf.buf_uop if buf.op in {Ops.ASSIGN, Ops.VIEW} else UOp.new_buffer(buf.device, buf.size, dtype)
|
||||
if buf.op is Ops.VIEW: op = buf.src[1].replace(src=tuple(to_uop(x, ctx, cache) for x in buf.src[1].src))
|
||||
else: op = buf.replace(dtype=dtype.base, src=tuple(to_uop(x, ctx, cache) for x in buf.src))
|
||||
# TODO: we need to rethink meta ops having buffers at creation time
|
||||
if buf.op is Ops.VIEW: op = buf.src[1].replace(src=tuple(add_buffers(x, ctx, cache) for x in buf.src[1].src))
|
||||
else: op = buf.replace(dtype=dtype.base, src=tuple(add_buffers(x, ctx, cache) for x in buf.src))
|
||||
# track the underlying tensor uop for this op
|
||||
ctx.tensor_uops[buf_uop] = [buf]
|
||||
# (early) bufferize
|
||||
@@ -185,15 +187,14 @@ def elementwise_view_right(root:UOp) -> UOp|None:
|
||||
# push the swizzle from src to root
|
||||
output_swizzle = swizzles[0]
|
||||
new_input_st = ShapeTracker.from_shape(output_swizzle.base.shape)
|
||||
ret = root.replace(src=tuple(x if not x.has_st else x.src[0] if x in swizzles else apply_swizzle(x.view(new_input_st)) for x in root.src))
|
||||
# update the ASSIGN offset to match the new shape
|
||||
if ret.op is Ops.ASSIGN and ret.arg is not None: ret = ret.replace(arg=ret.arg+new_input_st,)
|
||||
ret = root.replace(src=tuple(x if x.st is None else x.base if x in swizzles else apply_swizzle(x.view(new_input_st)) for x in root.src))
|
||||
# NOTE: swizzle resolves once we hit STORE
|
||||
return ret if ret.op is Ops.STORE else ret.view(ShapeTracker.from_shape(output_swizzle.shape))
|
||||
|
||||
def merge_double_reduce(root:UOp, first_reduce:UOp) -> UOp:
|
||||
assert root.arg[0] == first_reduce.arg[0], "can't merge reduceops with different alu"
|
||||
assert not any(x.op is Ops.REDUCE_AXIS for x in first_reduce.src[0].toposort), "can't merge more than two reduceops at a time"
|
||||
return first_reduce.src[0].r(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg)
|
||||
return first_reduce.replace(arg=(first_reduce.arg[0], root.axis_arg+first_reduce.axis_arg))
|
||||
|
||||
# push VIEW to stores
|
||||
view_right = merge_views+PatternMatcher([
|
||||
@@ -233,7 +234,7 @@ def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> UOp|None:
|
||||
|
||||
def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
|
||||
ctx.bufs.append(x)
|
||||
return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(size=x.arg[1]), (), len(ctx.bufs)-1)
|
||||
return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(size=x.size), (), len(ctx.bufs)-1)
|
||||
append_bufs = PatternMatcher([(UPat(Ops.BUFFER, name="x"), _append_buf)])
|
||||
|
||||
def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp:
|
||||
@@ -394,11 +395,9 @@ def group_realizes(ctx:ScheduleContext) -> list[list[UOp]]:
|
||||
|
||||
# **** Schedule creation and BFS toposort
|
||||
|
||||
# ** ops in the big graph can either be pre-realized or scheduled (fused/realized)
|
||||
|
||||
class UPatScheduled(UPat):
|
||||
def __init__(self, *args, **kwargs): super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"),
|
||||
UPat(*args, **{"name":"to_store",**kwargs})))
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(Ops.VIEW, name="base", src=(UPat(Ops.BUFFER, name="b"), UPat(*args, **{"name":"to_store",**kwargs})))
|
||||
|
||||
# ** this is schedule level const folding
|
||||
|
||||
@@ -412,7 +411,7 @@ def simplify_reduceop(reduce:UOp, x:UOp) -> UOp|None:
|
||||
case Ops.MUL: ret **= prshape
|
||||
case Ops.MAX: pass # NOTE: Ops.MAX is passthrough
|
||||
case _: return None
|
||||
return UOp.const(reduce.dtype, ret)
|
||||
return reduce.const_like(ret)
|
||||
|
||||
def found_contiguous(ctx:ScheduleContext, contig:UOp, base:UOp, b:UOp):
|
||||
if contig.src[0].op is Ops.VIEW and len(contig.src[0].src):
|
||||
@@ -433,7 +432,7 @@ ops_folding = symbolic_simple+PatternMatcher([
|
||||
(UPat(Ops.DETACH, name="detach"), lambda detach: detach.src[0]),
|
||||
# reduce of size 0 is the identity element
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.var("x"),)),
|
||||
lambda reduce,x:UOp.const(reduce.dtype, identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
|
||||
lambda reduce,x: reduce.const_like(identity_element(reduce.arg[0], reduce.dtype)) if x.size == 0 and reduce.size != 0 else None),
|
||||
# reduce of const is collapsed (TODO: make this a generic rule for stride0)
|
||||
(UPat(Ops.REDUCE_AXIS, name="reduce", src=(UPat.cvar("x"),)), simplify_reduceop),
|
||||
# CONST doesn't need COPY
|
||||
@@ -496,14 +495,14 @@ def fold_img_cast(ctx:ScheduleContext, xb:UOp, view:UOp, b:UOp, to_cast:UOp, **k
|
||||
del ctx.realizes[b]
|
||||
return to_cast.view(unwrap(view.st))
|
||||
|
||||
def init_big_graph(ctx:ScheduleContext, sink:UOp) -> UOp|None:
|
||||
def sink_outputs(ctx:ScheduleContext, sink:UOp) -> UOp|None:
|
||||
new_src = tuple(x.base for x in sink.src if x.base.realized is None and not is_constant(x.base))
|
||||
for x in new_src: realize(ctx, x.buf_uop, x)
|
||||
return None if new_src == sink.src else UOp(Ops.NOOP) if len(new_src) == 0 else UOp.sink(*new_src)
|
||||
|
||||
do_realize = PatternMatcher([
|
||||
# always realize sinked ops
|
||||
(UPat(Ops.SINK, name="sink"), init_big_graph),
|
||||
(UPat(Ops.SINK, name="sink"), sink_outputs),
|
||||
# always realize meta ops
|
||||
(UPatScheduled({Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta}), realize),
|
||||
# realize before expand or unsafe pad ops
|
||||
@@ -521,7 +520,6 @@ def unbind_variable(ctx:ScheduleContext, bind:UOp, st:UOp):
|
||||
return UOp.const(bind.dtype, bind).valid(unwrap(st.st))
|
||||
|
||||
def load_realized(ctx:ScheduleContext, b:UOp, st:UOp):
|
||||
assert st.size == b.size and unwrap(st.st).contiguous, f"ShapeTracker of realized {b} BUFFER must match the BUFFER size {st}"
|
||||
# NOTE: if we're assigning to the BUFFER too, PRELOAD tells toposort to place this load before the ASSIGN
|
||||
return UOp(Ops.PRELOAD if b in ctx.assigns else Ops.LOAD, b.dtype.base, (b, unwrap(st.st).to_uop()))
|
||||
|
||||
@@ -548,6 +546,7 @@ def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:
|
||||
for x in op.src:
|
||||
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None
|
||||
# BUFFER_VIEW overrides the underlying buffer
|
||||
# TODO: this should be a shrink on the buffer
|
||||
if op.op is Ops.BUFFER_VIEW:
|
||||
buffers[buf_uop] = (x:=op.src[0]).base.buffer.view(view.size, view.dtype, unwrap(x.st).views[0].offset*x.dtype.itemsize)
|
||||
buf_uop.buffer.ref(1)
|
||||
@@ -556,14 +555,13 @@ create_ctx = PatternMatcher([(UPat(Ops.VIEW, name="view", src=(UPat(Ops.BUFFER,
|
||||
# **** movement ops
|
||||
|
||||
remove_movement_ops = PatternMatcher([
|
||||
(UPat(GroupOp.Movement, name="x"), lambda x: x.base.view(unwrap(x.st))),
|
||||
# NOTE: movement ops are always applied to base
|
||||
(UPat(GroupOp.Movement, name="mov", src=(UPat.any(UPat.var("x").view(), UPat.var("x")))), lambda x,mov: x.view(unwrap(mov.st))),
|
||||
# some masked views can collapse to 0, VIEW(x) -> CONST(VIEW)
|
||||
(UPat(Ops.VIEW, name="view"),
|
||||
lambda view: view.const_like(0) if (vm:=view.st.views[-1].mask) is not None and any((x[1]-x[0]) == 0 for x in vm) else None),
|
||||
# merge one src (unrealized) views
|
||||
# NOTE: we can't merge realized buffer views here, because the buffer is realized before the view
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, src=(UPat.var("x"),), name="v1")), name="v2"),
|
||||
lambda x,v1,v2: v1.replace(arg=v1.arg+v2.arg) if x.op is not Ops.BUFFER else None),
|
||||
# merge one src views.
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.VIEW, src=(UPat(),), name="v1")), name="v2"), lambda v1,v2: v1.replace(arg=v1.arg+v2.arg)),
|
||||
# merge unmasked const views
|
||||
(UPat(Ops.VIEW, name="view", src=(UPat(Ops.CONST, name="const", src=(UPat(Ops.VIEW, name="st"),) ),)),
|
||||
lambda st,const,view: const.replace(src=(st.replace(arg=st.st+view.st),)) if all(v.mask is None for v in (st.st+view.st).views) else None),
|
||||
@@ -573,7 +571,7 @@ remove_movement_ops = PatternMatcher([
|
||||
def create_schedule_with_vars(outs:list[UOp], skip_check:bool=not __debug__) -> tuple[list[ScheduleItem], dict[Variable, int]]:
|
||||
if not skip_check: type_verify(list(UOp.sink(*outs).toposort), extra_spec=tensor_uop_spec)
|
||||
# to_uop is removing (many) of the movement ops
|
||||
sink = to_uop(UOp.sink(*outs), ctx:=ScheduleContext(), cache={})
|
||||
sink = add_buffers(UOp.sink(*outs), ctx:=ScheduleContext(), cache={})
|
||||
# const folding and fusion
|
||||
sink = graph_rewrite(sink, remove_movement_ops+ops_folding+do_realize, ctx)
|
||||
sink = graph_rewrite(sink, merge_bufs, ctx)
|
||||
@@ -613,8 +611,3 @@ def create_schedule_with_vars(outs:list[UOp], skip_check:bool=not __debug__) ->
|
||||
if len(schedule) != (groups:=len(prescheduled)): raise RuntimeError(f"cycle detected in graph, grouped {groups} but only scheduled {len(schedule)}")
|
||||
if DEBUG >= 1 and len(schedule) >= 10: print(f"scheduled {len(schedule)} kernels")
|
||||
return schedule, ctx.var_vals
|
||||
|
||||
def create_schedule(outs:list[UOp]) -> list[ScheduleItem]:
|
||||
schedule, var_vals = create_schedule_with_vars(outs)
|
||||
assert len(var_vals) == 0
|
||||
return schedule
|
||||
|
||||
@@ -108,7 +108,7 @@ def get_kernel_actions(lin:Kernel, include_0=True) -> dict[int, Kernel]:
|
||||
lin2 = lin.copy()
|
||||
try:
|
||||
lin2.apply_opt(a)
|
||||
up, lcl, tc_up = 1, 1, prod(tc.dims)//prod([x[1] for x in tc.threads]) if (tc:=lin2.tensor_core) else 1
|
||||
up, lcl, tc_up = 1, 1, prod(tc.dims)//tc.threads if (tc:=lin2.tensor_core) else 1
|
||||
for s,c in zip(lin2.full_shape, lin2.colors()):
|
||||
if c in {"magenta", "yellow"}: up *= s
|
||||
elif c in {"cyan", "green", "white"}: lcl *= s
|
||||
|
||||
@@ -220,7 +220,8 @@ class UOpMetaClass(type):
|
||||
ucache:dict[tuple, weakref.ReferenceType[UOp]] = {}
|
||||
def __call__(cls, op:Ops, dtype:DType=dtypes.void, src:tuple[UOp,...]=tuple(), arg:Any=None, _buffer=None):
|
||||
if (wret:=UOpMetaClass.ucache.get(key:=(op, dtype, src, arg), None)) is not None and (ret:=wret()) is not None: return ret
|
||||
UOpMetaClass.ucache[key] = weakref.ref(created:=super().__call__(*key))
|
||||
UOpMetaClass.ucache[key] = ref = weakref.ref(created:=super().__call__(*key))
|
||||
for s in src: s.children.add(ref)
|
||||
# NOTE: this will soon be set by Tensor once we remove function.py
|
||||
if (metadata:=_METADATA.get()) is not None: all_metadata[created] = metadata
|
||||
return created
|
||||
@@ -230,8 +231,6 @@ buffers:weakref.WeakKeyDictionary[UOp, Buffer] = weakref.WeakKeyDictionary() # t
|
||||
all_metadata:weakref.WeakKeyDictionary[UOp, Metadata] = weakref.WeakKeyDictionary()
|
||||
forced_realize:weakref.WeakSet[UOp] = weakref.WeakSet()
|
||||
|
||||
becomes_map: weakref.WeakKeyDictionary[UOp, UOp] = weakref.WeakKeyDictionary()
|
||||
|
||||
# NOTE: this should be frozen, but frozen is slower
|
||||
@dataclass(eq=False, slots=True)
|
||||
class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
@@ -239,9 +238,11 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
dtype:DType = dtypes.void
|
||||
src:tuple[UOp, ...] = tuple()
|
||||
arg:Any = None
|
||||
children:set[weakref.ref[UOp]] = field(default_factory=set)
|
||||
def __del__(self):
|
||||
if self.op is Ops.BUFFER and (buffer:=buffers.get(self)) is not None: buffer.ref(-1)
|
||||
if (k:=(self.op, self.dtype, self.src, self.arg)) in UOpMetaClass.ucache:
|
||||
if (ref:=UOpMetaClass.ucache.get(k:=(self.op, self.dtype, self.src, self.arg))) is not None:
|
||||
for s in self.src: s.children.discard(ref)
|
||||
del UOpMetaClass.ucache[k]
|
||||
def __reduce__(self):
|
||||
args = [self.op, self.dtype, self.src, self.arg]
|
||||
@@ -275,8 +276,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
|
||||
# *** uop shape stuff ***
|
||||
|
||||
@property
|
||||
def has_st(self) -> bool: return self.op not in {Ops.DEFINE_LOCAL, Ops.DEFINE_GLOBAL, Ops.BUFFER, Ops.CONST, Ops.DEFINE_VAR}
|
||||
@functools.cached_property
|
||||
def st(self) -> ShapeTracker|None:
|
||||
# these ops define a ShapeTracker from the arg
|
||||
@@ -297,7 +296,9 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
return ShapeTracker.from_shape(shape)
|
||||
@functools.cached_property
|
||||
def full_shape(self) -> tuple[sint, ...]:
|
||||
return self.shape if self.op is Ops.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
|
||||
if self.op is Ops.VIEW: return self.shape
|
||||
# TODO: this should check if st is None, it cannot because local reduce has implicit movement ops
|
||||
return tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.op not in {Ops.DEFINE_GLOBAL,Ops.DEFINE_LOCAL,Ops.DEFINE_VAR,Ops.CONST}]))
|
||||
@property
|
||||
def shape(self) -> tuple[sint, ...]: return unwrap(self.st).shape
|
||||
@property
|
||||
@@ -474,9 +475,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
|
||||
# CAUTION: MUTABILITY!
|
||||
def become(self, u:UOp):
|
||||
becomes_map[self] = u
|
||||
#del UOpMetaClass.ucache[(self.op, self.dtype, self.src, self.arg)]
|
||||
#self.op, self.dtype, self.src, self.arg = u.op, u.dtype, u.src, u.arg
|
||||
del UOpMetaClass.ucache[(self.op, self.dtype, self.src, self.arg)]
|
||||
self.op, self.dtype, self.src, self.arg = u.op, u.dtype, u.src, u.arg
|
||||
|
||||
# *** uop movement ops ***
|
||||
|
||||
@@ -1317,7 +1317,7 @@ merge_views = PatternMatcher([(UPat(Ops.VIEW, name="s0").view(name="s1"), lambda
|
||||
view_left = merge_views+PatternMatcher([
|
||||
# VIEW before elementwise ops
|
||||
(UPat({*GroupOp.ALU, Ops.CAST, Ops.BITCAST, Ops.ASSIGN}, name="e").view(name="v"),
|
||||
lambda e,v: e.replace(src=tuple(s if not s.has_st else s.view(v.st) if s is s.base else s.base.view(s.st+v.st) for s in e.src))),
|
||||
lambda e,v: e.replace(src=tuple(s if s.st is None else s.view(v.st) if s is s.base else s.base.view(s.st+v.st) for s in e.src))),
|
||||
# early merge VIEW buffer ops
|
||||
(UPat(GroupOp.Buffer, name="b").view(name="v"), lambda b,v: b.replace(src=tuple((s.st+v.st).to_uop() if s.op is Ops.VIEW else s for s in b.src))),
|
||||
])
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
from typing import Optional, Callable
|
||||
import functools
|
||||
import functools, math
|
||||
from dataclasses import dataclass, field, replace
|
||||
from tinygrad.helpers import to_function_name, dedup, prod
|
||||
from tinygrad.ops import Ops, UOp, sym_infer, sint, Variable, ssimplify, GroupOp, PatternMatcher
|
||||
@@ -9,19 +9,25 @@ from tinygrad.dtype import DType
|
||||
@dataclass(frozen=True)
|
||||
class TensorCore: # D = A * B + C, A is (M x K), B is (K x N), C and D are (M x N)
|
||||
dims: tuple[int,int,int] # N, M, K
|
||||
threads: int # number of threads that construct the warp
|
||||
elements_per_thread: tuple[int, int, int] # elements per-thread to load/store from A/B/C
|
||||
dtype_in: DType # dtype for A and B
|
||||
dtype_out: DType # dtype for C and D
|
||||
threads: list[tuple[int,int]] # list of (TC dim,amt) that construct the warp thread structure
|
||||
reduce_axes: list[tuple[int,int]] # list of (TC dim,amt) that constructs the shape of the reduce dim
|
||||
@property
|
||||
def early_upcast_axes(self) -> list[tuple[int,int]]: # list of (TC dim,amt) that upcasts the threads remainders of dims [0,1]
|
||||
return [(d,self.dims[d]//sz) for d,sz in [(dim,prod(sz for d,sz in self.threads if d==dim)) for dim in range(2)] if self.dims[d]>sz]
|
||||
upcast_axes: tuple[list[tuple[int,int]], list[tuple[int,int]], list[tuple[int,int]]] # list of (TC dim,amt) that upcast A, B and C
|
||||
st1_pattern: Optional[tuple[tuple[tuple[int,int], ...], tuple[tuple[int,int], ...]]] = None # pattern to fix shapetracker for A
|
||||
st2_pattern: Optional[tuple[tuple[tuple[int,int], ...], tuple[tuple[int,int], ...]]] = None # pattern to fix shapetracker for B
|
||||
expanded_shape: Optional[tuple[int, ...]] = None
|
||||
opts_seq: tuple[str,str] = ("UP","LC") # upcast input, local the thread pattern
|
||||
opts: tuple[str, ...] # ordered tuple of "ux" or "lx" specifing kernel opts to perform. "ux" upcasts dim x and "lx" localizes dim x
|
||||
swizzle: tuple[Optional[tuple[tuple[int, ...], tuple[int, ...]]], Optional[tuple[tuple[int, ...], tuple[int, ...]]]] = (None, None)
|
||||
def get_reduce_axes(self): return [(i, 2) for i in range(int(math.log2(self.dims[2])))]
|
||||
def get_upcast_axes(self): return [opt for opt in self.opts if opt[0] == "u"]
|
||||
def get_local_axes(self): return [opt for opt in self.opts if opt[0] == "l"]
|
||||
def __str__(self): return "_".join(["WMMA"] + list(map(str, self.dims)) + [self.dtype_in.name, self.dtype_out.name])
|
||||
def __post_init__(self):
|
||||
local_axes, upcast_axes, reduce_axes = len(self.get_local_axes()), len(self.get_upcast_axes()), len(self.get_reduce_axes())
|
||||
assert self.dims[0] * self.dims[1] == 2**(local_axes + upcast_axes), (
|
||||
f"N({self.dims[0]}) x M({self.dims[1]}) != local({2**local_axes}) x upcast({2**upcast_axes}) with opts({self.opts})")
|
||||
assert 2**local_axes == self.threads, f"{self.threads} threads construct the warp but found {2**local_axes} in {self.opts}"
|
||||
assert 2**upcast_axes == self.elements_per_thread[2], (
|
||||
f"{self.elements_per_thread[2]} elements from C are processed per thread but found {2**upcast_axes} in {self.opts}")
|
||||
assert all(len(perm[0]) == local_axes and len(perm[1]) == reduce_axes + upcast_axes for perm in self.swizzle if perm), (
|
||||
f"swizzle perm should be of len (({local_axes})({reduce_axes + upcast_axes}))")
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Estimates:
|
||||
|
||||
@@ -175,8 +175,9 @@ class ClangRenderer(CStyleLanguage):
|
||||
CStyleLanguage.extra_matcher
|
||||
|
||||
if AMX:
|
||||
tensor_cores = [TensorCore(dims=(sz,sz,1), threads=[], reduce_axes=[], upcast_axes=([(1,sz)],[(0,sz)],[(1,sz),(0,sz)]), dtype_in=dt, dtype_out=dt)
|
||||
for dt, sz in [(dt, 64//dt.itemsize) for dt in [dtypes.float]]]
|
||||
tensor_cores = [TensorCore(dims=(sz,sz,1), threads=1, elements_per_thread=(sz,sz,sz*sz), dtype_in=dt, dtype_out=dt,
|
||||
swizzle=(None, ((),(4,5,6,7,0,1,2,3))), opts=("u0","u0","u0","u0","u1","u1","u1","u1"))
|
||||
for dt,sz in [(dt, 64 // dt.itemsize) for dt in [dtypes.float]]]
|
||||
|
||||
def render_vector_prefix(self, dt:DType) -> str:
|
||||
return f"typedef {self.render_dtype(dt.scalar())} {self.render_dtype(dt)} __attribute__((aligned({(sz:=dt.itemsize)}),vector_size({sz})));"
|
||||
@@ -226,12 +227,12 @@ class OpenCLRenderer(CStyleLanguage):
|
||||
|
||||
class IntelRenderer(OpenCLRenderer):
|
||||
device, suffix, kernel_prefix = "GPU", "INTEL", "__attribute__((intel_reqd_sub_group_size(8)))\n" + "__kernel "
|
||||
tensor_cores = [TensorCore(dims=(8,8,16),threads=[(0,8)],dtype_in=di,dtype_out=do,reduce_axes=[(0,16)],upcast_axes=([(0,16)],[(0,16)],[(1,8)]),
|
||||
st1_pattern=(((1,0),),((1,2),(1,1),(0,0))),expanded_shape=(8,2,8)) for di,do in [(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)]]
|
||||
tensor_cores = [TensorCore(dims=(8,8,16), threads=8, elements_per_thread=(16,16,8), dtype_in=dtypes.half, dtype_out=dtypes.float,
|
||||
opts=("l0","l0","l0","u1","u1","u1"), swizzle=(((4,5,6),(0,1,2,3,7,8,9)), ((0,1,2),(7,8,9,3,4,5,6))))]
|
||||
|
||||
string_rewrite = PatternMatcher([
|
||||
(UPat(Ops.CAST, dtype=dtypes.bfloat16, src=(UPat.var('x', dtype=dtypes.float))), lambda ctx,x: f"intel_convert_bfloat16_as_ushort({ctx[x[0]]})"),
|
||||
(UPat(Ops.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16))), lambda ctx,x: f"intel_convert_as_bfloat16_float({ctx[x[0]]})"),
|
||||
(UPat(Ops.CAST, dtype=dtypes.bfloat16, src=(UPat.var('x', dtype=dtypes.float))), lambda ctx,x: f"intel_convert_bfloat16_as_ushort({ctx[x]})"),
|
||||
(UPat(Ops.CAST, dtype=dtypes.float, src=(UPat.var('x', dtype=dtypes.bfloat16))), lambda ctx,x: f"intel_convert_as_bfloat16_float({ctx[x]})"),
|
||||
]) + OpenCLRenderer.string_rewrite
|
||||
|
||||
def render_kernel(self, function_name, kernel, bufs, uops, prefix=None) -> str:
|
||||
@@ -245,10 +246,9 @@ class IntelRenderer(OpenCLRenderer):
|
||||
class MetalRenderer(CStyleLanguage):
|
||||
device = "METAL"
|
||||
shared_max = 32768
|
||||
tensor_cores = [TensorCore(dims=(8,8,8),threads=[(0,2),(1,4),(0,2),(1,2)],expanded_shape=(2,2,2,2),upcast_axes=([(1,2)],[(1,2)],[(1,2)]),
|
||||
st1_pattern=(((1,1),(0,1),(1,0),(0,3)),((0,0),(0,2),(1,3),(1,2))),st2_pattern=(((0,0),(1,1),(1,2),(0,2),(1,0)),((0,1),(0,3),(1,3))),
|
||||
dtype_in=di,dtype_out=do,reduce_axes=[(0,8)]) for di,do in [(dtypes.float,dtypes.float),(dtypes.half,dtypes.float),(dtypes.half,dtypes.half),
|
||||
(dtypes.bfloat16,dtypes.float),(dtypes.bfloat16,dtypes.bfloat16)]]
|
||||
tensor_cores = [TensorCore(dims=(8,8,8), threads=32, elements_per_thread=(2,2,2), dtype_in=di, dtype_out=do, opts=("u0","l0","l1","l1","l0","l1"),
|
||||
swizzle=(((6,1,2,7,4),(8,0,3,5)), ((0,5,6,3,7),(1,2,4,8)))) for di,do in [(dtypes.float,dtypes.float),(dtypes.half,dtypes.float),
|
||||
(dtypes.half,dtypes.half),(dtypes.bfloat16,dtypes.float),(dtypes.bfloat16,dtypes.bfloat16)]]
|
||||
def __init__(self): self.tensor_cores = MetalRenderer.tensor_cores if hasattr(os, 'uname') and os.uname().machine == "arm64" else []
|
||||
|
||||
# language options
|
||||
@@ -295,10 +295,9 @@ class CUDARenderer(CStyleLanguage):
|
||||
local_max = (1024, 1024, 64)
|
||||
shared_max = 49152
|
||||
# https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float
|
||||
tensor_cores = [TensorCore(dims=(8,16,16), threads=[(0,2),(0,2),(1,2),(1,2),(1,2)], dtype_in=di, dtype_out=do, expanded_shape=(2,2,2,2,2,2),
|
||||
st1_pattern=(((1,1),(1,0),(0,2),(0,3),(0,4)),((1,3),(1,5),(1,2),(0,0),(0,1),(1,4))),
|
||||
st2_pattern=(((1,1),(1,0),(1,4),(0,0),(0,1)),((0,4),(0,2),(1,5),(0,3),(1,3),(1,2))), reduce_axes=[(0,8),(1,2)],
|
||||
upcast_axes=([(0,8)],[(2,2),(3,2)],[(3,2),(2,2)])) for di, do in ([(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)])]
|
||||
tensor_cores = [TensorCore(dims=(8,16,16), threads=32, elements_per_thread=(8,4,4), dtype_in=di, dtype_out=do,
|
||||
opts=("u0","l0","l0","l1","l1","l1","u1"), swizzle=(((6,7,2,3,4),(0,1,9,5,10,8)), ((6,7,9,0,1),(2,3,4,10,5,8))))
|
||||
for di,do in ([(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)])]
|
||||
def __init__(self, arch:str): self.tensor_cores, self.arch = CUDARenderer.tensor_cores if int(arch[3:]) >= 80 else [], arch
|
||||
def __reduce__(self): return self.__class__, (self.arch,)
|
||||
|
||||
@@ -365,9 +364,9 @@ class AMDRenderer(CStyleLanguage):
|
||||
device = "AMD"
|
||||
shared_max = 65536
|
||||
# https://gpuopen.com/learn/wmma_on_rdna3/
|
||||
tensor_cores = [TensorCore(dims=(16,16,16), threads=[(0,8),(0,2),(1,2)], dtype_in=di, dtype_out=do, reduce_axes=[(0,16)], opts_seq=("LC","UP"),
|
||||
upcast_axes = ([(0,16)],[(0,16)],[(1,8)]), st1_pattern=(((1,2),(0,2),(1,1),(0,1)),((1,0),(0,0))), expanded_shape=(16,2,4))
|
||||
for (di, do) in [(dtypes.half, dtypes.float), (dtypes.half, dtypes.half)]]
|
||||
tensor_cores = [TensorCore(dims=(16,16,16), threads=32, elements_per_thread=(16,16,8), dtype_in=di, dtype_out=do,
|
||||
opts=("l0","l0","l0","l0","l1","u1","u1","u1"), swizzle=(((4,9,10,11,0),(1,2,3,5,6,7,8)), ((0,1,2,3,4),(9,10,11,5,6,7,8))))
|
||||
for di,do in [(dtypes.half,dtypes.float),(dtypes.half,dtypes.half)]]
|
||||
|
||||
# language options
|
||||
ockl = [(f"__ockl_get_{name}", "unsigned int", "size_t", "const") for name in ["local_id", "group_id", "local_size"]]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py
|
||||
from __future__ import annotations
|
||||
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, dataclasses, hashlib, weakref
|
||||
import time, math, itertools, functools, struct, sys, inspect, pathlib, string, dataclasses, hashlib
|
||||
from contextlib import ContextDecorator
|
||||
from typing import List, Tuple, Callable, Optional, ClassVar, Type, Union, Sequence, cast, get_args, Literal, TYPE_CHECKING, SupportsIndex
|
||||
from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, least_upper_float, least_upper_dtype, sum_acc_dtype, to_dtype, truncate
|
||||
@@ -8,20 +8,12 @@ from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_u
|
||||
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN, unwrap
|
||||
from tinygrad.multi import MultiLazyBuffer
|
||||
from tinygrad.gradient import compute_gradient
|
||||
from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait, identity_element, becomes_map, graph_rewrite_map, _substitute
|
||||
from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait, identity_element
|
||||
from tinygrad.device import Device, Buffer, BufferSpec
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.engine.memory import memory_planner
|
||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
|
||||
|
||||
# *** Tensors are containers for UOps ***
|
||||
|
||||
# NOTE: this has to be a list as UOps can map to more than one Tensor
|
||||
tensor_map: dict[UOp, list[weakref.ref[Tensor]]] = {}
|
||||
def update_tensor_map(t:Tensor):
|
||||
# TODO: multi
|
||||
tensor_map.setdefault(t.lazydata, []).append(weakref.ref(t))
|
||||
|
||||
# **** start with two base classes, Tensor and Function ****
|
||||
|
||||
class Function:
|
||||
@@ -41,7 +33,6 @@ class Function:
|
||||
ret = Tensor.__new__(Tensor)
|
||||
ret.lazydata, ret.requires_grad, ret.grad = ctx.forward(*[t.lazydata for t in x], **kwargs), ctx.requires_grad, None
|
||||
ret._ctx = ctx if ctx.requires_grad and not Tensor.no_grad else None # used by autograd engine
|
||||
update_tensor_map(ret)
|
||||
return ret
|
||||
|
||||
import tinygrad.function as F
|
||||
@@ -179,7 +170,6 @@ class Tensor(SimpleMathTrait):
|
||||
else:
|
||||
assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}"
|
||||
self.lazydata = data
|
||||
update_tensor_map(self)
|
||||
|
||||
def requires_grad_(self, requires_grad=True) -> Tensor:
|
||||
self.requires_grad = requires_grad
|
||||
@@ -226,24 +216,7 @@ class Tensor(SimpleMathTrait):
|
||||
|
||||
NOTE: A Tensor can only be scheduled once.
|
||||
"""
|
||||
scheduled_uops = flatten([x.lazydata.lbs for x in (self,)+lst])
|
||||
schedule, var_vals = create_schedule_with_vars(scheduled_uops)
|
||||
# TODO: becomes_map should be returned from create_schedule_with_vars
|
||||
|
||||
# NOTE: we put tensor_map in here instead of scheduled_uops. we could be more selective here
|
||||
# all Tensors in a different graph will just rewrite to themselves
|
||||
rewrite_map = graph_rewrite_map(UOp.sink(*tensor_map), _substitute, becomes_map, bottom_up=True)
|
||||
becomes_map.clear()
|
||||
|
||||
# apply becomes_map
|
||||
# TODO: gc tensor_map
|
||||
for k,v in rewrite_map.items():
|
||||
if k is v: continue
|
||||
if (tt:=tensor_map.get(k)) is not None:
|
||||
for t in tt:
|
||||
if (rt:=t()) is not None:
|
||||
rt.lazydata = v
|
||||
update_tensor_map(rt)
|
||||
schedule, var_vals = create_schedule_with_vars(flatten([x.lazydata.lbs for x in (self,)+lst]))
|
||||
return memory_planner(schedule), var_vals
|
||||
|
||||
def schedule(self, *lst:Tensor) -> list[ScheduleItem]:
|
||||
@@ -262,10 +235,9 @@ class Tensor(SimpleMathTrait):
|
||||
Replaces the data of this tensor with the data of another tensor. Only the shape of the tensors must match.
|
||||
"""
|
||||
# used for replacing a Tensor with a new version of it (potentially with a different device and dtype)
|
||||
assert not x.requires_grad and getattr(self, '_ctx', None) is None
|
||||
assert getattr(self, '_ctx', None) is None
|
||||
assert self.shape == x.shape, f"replace shape mismatch {self.shape} != {x.shape}"
|
||||
self.lazydata = x.lazydata
|
||||
update_tensor_map(self)
|
||||
return self
|
||||
|
||||
def assign(self, x) -> Tensor:
|
||||
@@ -285,7 +257,6 @@ class Tensor(SimpleMathTrait):
|
||||
assert not x.requires_grad # self requires_grad is okay?
|
||||
if not self.lazydata.is_realized: return self.replace(x)
|
||||
self.lazydata = self.lazydata.assign(x.lazydata)
|
||||
update_tensor_map(self)
|
||||
return self
|
||||
|
||||
def detach(self) -> Tensor:
|
||||
@@ -384,8 +355,7 @@ class Tensor(SimpleMathTrait):
|
||||
real = self.to(device)
|
||||
# TODO: is this assign?
|
||||
if self.grad is not None and real.grad is not None: self.grad.lazydata = real.grad.lazydata
|
||||
self.lazydata = real.lazydata
|
||||
update_tensor_map(self)
|
||||
return self.replace(real)
|
||||
|
||||
def shard(self, devices:tuple[str, ...], axis:Optional[int]=None, splits:Optional[tuple[int, ...]]=None) -> Tensor:
|
||||
"""
|
||||
@@ -413,9 +383,7 @@ class Tensor(SimpleMathTrait):
|
||||
"""
|
||||
Shards the tensor across the given devices in place.
|
||||
"""
|
||||
self.lazydata = self.shard(devices, axis, splits).lazydata
|
||||
update_tensor_map(self)
|
||||
return self
|
||||
return self.replace(self.shard(devices, axis, splits))
|
||||
|
||||
@staticmethod
|
||||
def from_uop(y:UOp, **kwargs) -> Tensor:
|
||||
|
||||
Reference in New Issue
Block a user