Merge branch 'master' into retinanet_mlperf

This commit is contained in:
Francis Lata
2025-01-15 02:45:21 -08:00
35 changed files with 304 additions and 264 deletions

View File

@@ -91,7 +91,7 @@ jobs:
- name: Run GPT2 w HALF
run: HALF=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
- name: Run GPT2 w HALF/BEAM
run: HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAST_BEFORE_VIEW=0 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
run: HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
- name: Train MNIST
run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 python3.11 examples/beautiful_mnist.py | tee beautiful_mnist.txt
- name: Run 10 CIFAR training steps
@@ -202,12 +202,12 @@ jobs:
# run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_six_gpu.txt
- name: Run LLaMA-3 8B BEAM
run: NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_beam.txt
- name: Run LLaMA-3 8B on 4 GPUs
run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_four_gpu.txt
- name: Run LLaMA-3 8B on 6 GPUs
run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt
- name: Run LLaMA-2 70B
run: NV=1 CAPTURE_PROCESS_REPLAY=0 MAX_CONTEXT=256 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_2_70B.txt
- name: Run LLaMA-3 8B on 4 GPUs with BEAM
run: NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_four_gpu.txt
# - name: Run LLaMA-3 8B on 6 GPUs
# run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt
# - name: Run LLaMA-2 70B
# run: NV=1 CAPTURE_PROCESS_REPLAY=0 MAX_CONTEXT=256 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_2_70B.txt
- name: Run Mixtral 8x7B
run: time NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/mixtral.py --temperature 0 --count 10 --timing | tee mixtral.txt
- name: Run GPT2
@@ -217,7 +217,7 @@ jobs:
- name: Run GPT2 w HALF
run: NV=1 HALF=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
- name: Run GPT2 w HALF/BEAM
run: NV=1 HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAST_BEFORE_VIEW=0 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
run: NV=1 HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
- uses: actions/upload-artifact@v4
with:
name: Speed (NVIDIA)
@@ -372,7 +372,7 @@ jobs:
#- name: Fuzz Padded Tensor Core GEMM
# run: HSA=1 M_START=12 M_STOP=20 M_STEP=1 N_START=12 N_STOP=20 N_STEP=1 K_START=28 K_STOP=36 K_STEP=1 HALF=1 TC_OPT=2 DEBUG=2 python3 ./extra/gemm/fuzz_matmul.py
- name: Remove amdgpu
run: sudo rmmod amdgpu
run: sleep 5 && sudo rmmod amdgpu # sleep a bit to let the driver unload the prev pid.
- name: Run Stable Diffusion
run: AMD=1 python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing | tee sd.txt
- name: Run SDXL
@@ -389,14 +389,14 @@ jobs:
# run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_six_gpu.txt
- name: Run LLaMA-3 8B BEAM
run: AMD=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_beam.txt
- name: Run LLaMA-3 8B on 4 GPUs
run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_four_gpu.txt
- name: Run LLaMA-3 8B on 6 GPUs
run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt
- name: Run LLaMA-3 8B on 4 GPUs with BEAM
run: AMD=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_four_gpu.txt
# - name: Run LLaMA-3 8B on 6 GPUs
# run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt
- name: Restore amdgpu
run: sudo modprobe amdgpu
- name: Run LLaMA-2 70B
run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_2_70B.txt
# - name: Run LLaMA-2 70B
# run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_2_70B.txt
- name: Run Mixtral 8x7B
run: time AMD=1 python3 examples/mixtral.py --temperature 0 --count 10 --timing | tee mixtral.txt
- name: Run GPT2
@@ -406,7 +406,7 @@ jobs:
- name: Run GPT2 w HALF
run: AMD=1 HALF=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
- name: Run GPT2 w HALF/BEAM
run: AMD=1 HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAST_BEFORE_VIEW=0 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
run: AMD=1 HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
- uses: actions/upload-artifact@v4
with:
name: Speed (AMD)

View File

@@ -166,6 +166,7 @@ jobs:
- name: Test emulated CUDA tensor cores
run: |
DEBUG=2 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm_fp16
DEBUG=2 EMULATE_CUDA_SM75=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm_fp16
PYTHONPATH="." DEBUG=2 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
- name: Test emulated INTEL OpenCL tensor cores
run: DEBUG=2 EMULATE_INTEL=1 FORWARD_ONLY=1 PYTHON=1 HALF=1 N=64 python3 ./extra/gemm/simple_matmul.py
@@ -296,7 +297,7 @@ jobs:
- if: ${{ matrix.task == 'optimage' }}
name: Test openpilot model kernel count and gate usage
run: |
PYTHONPATH="." ALLOWED_KERNEL_COUNT=208 ALLOWED_READ_IMAGE=2105 ALLOWED_GATED_READ_IMAGE=29 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
PYTHONPATH="." ALLOWED_KERNEL_COUNT=208 ALLOWED_READ_IMAGE=2104 ALLOWED_GATED_READ_IMAGE=29 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
- if: ${{ matrix.task == 'optimage' }}
name: Test openpilot alt model correctness (float32)
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx

View File

@@ -2,7 +2,7 @@
export PYTHONPATH="."
export MODEL="bert"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=72 EVAL_BS=36
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36
export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=64 BEAM_LOCAL_MAX=512
export IGNORE_JIT_FIRST_BEAM=1

View File

@@ -2,7 +2,7 @@
export PYTHONPATH="."
export MODEL="bert"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=72 EVAL_BS=36
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36
export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=64 BEAM_LOCAL_MAX=512
export IGNORE_JIT_FIRST_BEAM=1

View File

@@ -3,7 +3,7 @@
export PYTHONPATH="."
export MODEL="bert"
export SUBMISSION_PLATFORM="tinybox_green"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=72 EVAL_BS=36
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36
export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=64 BEAM_LOCAL_MAX=512
export IGNORE_JIT_FIRST_BEAM=1

View File

@@ -2,7 +2,7 @@
export PYTHONPATH="."
export MODEL="bert"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=72 EVAL_BS=36
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36
export BEAM=3
export IGNORE_JIT_FIRST_BEAM=1

View File

@@ -2,7 +2,7 @@
export PYTHONPATH="."
export MODEL="bert"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=72 EVAL_BS=36
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36
export BEAM=3
export IGNORE_JIT_FIRST_BEAM=1

View File

@@ -3,7 +3,7 @@
export PYTHONPATH="."
export MODEL="bert"
export SUBMISSION_PLATFORM="tinybox_red"
export DEFAULT_FLOAT="HALF" GPUS=6 BS=72 EVAL_BS=36
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36
export BEAM=3
export IGNORE_JIT_FIRST_BEAM=1

View File

@@ -9,6 +9,7 @@ from onnx2torch import convert
from extra.onnx import get_run_onnx
from tinygrad.helpers import OSX, DEBUG, fetch
from tinygrad import Tensor, Device
from tinygrad.device import CompileError
MODELS = {
"resnet50": "https://github.com/onnx/models/raw/main/validated/vision/classification/resnet/model/resnet50-caffe2-v1-9.onnx",
@@ -72,10 +73,10 @@ def benchmark_model(m, devices, validate_outs=False):
for _ in range(3): {k:v.numpy() for k,v in tinygrad_jitted_model(**inputs).items()}
benchmark(m, f"tinygrad_{device.lower()}_jit", lambda: {k:v.numpy() for k,v in tinygrad_jitted_model(**inputs).items()}) # noqa: F821
del inputs, tinygrad_model, tinygrad_jitted_model
except RuntimeError as e:
except CompileError as e:
# TODO: we don't run the dm model on METAL for now
if Device.DEFAULT == "METAL":
assert "buffer count limit" in str(e)
assert "no 'buffer' resource location available" in str(e)
return
else: raise e

View File

@@ -2,7 +2,7 @@
# compare kernels created by HEAD against master
from collections import defaultdict
import os, multiprocessing, logging, pickle, sqlite3, difflib, functools, warnings
from typing import Callable, List, Set, Tuple, Union, cast
from typing import Callable, List, Tuple, Union, cast
from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm
from tinygrad.engine.schedule import ScheduleContext, schedule_uop
from tinygrad.codegen.kernel import Kernel, Opt
@@ -30,9 +30,9 @@ class ProcessReplayWarning(Warning): pass
# *** recreators
def recreate_sched(ast:UOp, assigns:Set[UOp]) -> UOp:
def recreate_sched(ast:UOp) -> UOp:
# NOTE: process replay isn't meant to actually schedule anything
return schedule_uop(ast, ScheduleContext(assigns=assigns, tensor_uops=defaultdict(list))).ast
return schedule_uop(ast, ScheduleContext(tensor_uops=defaultdict(list))).ast
def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:List[Opt], name:str) -> str:
k = Kernel(ast, opts=opts)
for opt in applied_opts: k.apply_opt(opt)

View File

@@ -33,7 +33,7 @@ def helper_test(nm, gen, model, max_memory_allowed, max_kernels_allowed, all_jit
kernels_used = len(model.jit_cache) if hasattr(model, "jit_cache") else None
print(f"{nm}: used {mem_used/1e9:.2f} GB and {kernels_used} kernels in {min(tms)/1e6:.2f} ms")
assert mem_used/1e9 < max_memory_allowed, f"{nm} used more than {max_memory_allowed:.2f} GB - {mem_used/1e9:.2} GB used"
assert not kernels_used or kernels_used <= max_kernels_allowed, f"{nm} used more than {max_kernels_allowed} kernels"
assert not kernels_used or kernels_used <= max_kernels_allowed, f"{nm} used more than {max_kernels_allowed} kernels, it used {kernels_used}"
if all_jitted:
assert kernels_used > 0 and kernels_used == GlobalCounters.kernel_count or (kernels_used <= GlobalCounters.kernel_count and getattr(Device[Device.DEFAULT], "graph", None)), f"only {kernels_used} out of {GlobalCounters.kernel_count} were jitted" # noqa: E501

View File

@@ -66,7 +66,8 @@ class TestArange(unittest.TestCase):
return self.test_all_opts([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], [Opt(op=OptOps.GROUP, axis=0, amt=0)])
class TestIndexing(unittest.TestCase):
@unittest.expectedFailure
# update: passing after CAST_BEFORE_VIEW=1 deletion
# @unittest.expectedFailure
def test_arange_2_reduce(self):
needle = Tensor.zeros(16384, dtype=dtypes.int).contiguous()
needle[1337] = 1

View File

@@ -132,11 +132,12 @@ class TestMovedConstFolding(unittest.TestCase):
def test_cast_padded(self):
# NOTE: this is folded due to CAST_BEFORE_VIEW
# update: CAST_BEFORE_VIEW=1 is no longer supported
if is_dtype_supported(dtypes.int16):
_check_ast_count(0, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16))
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16))
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16).numpy(), [0, 1, 1, 1, 1, 0])
if is_dtype_supported(dtypes.uint16):
_check_ast_count(0, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
_check_ast_count(1, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
np.testing.assert_equal(Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16).numpy(), [0, 65535, 65535, 65535, 65535, 0])
# not folded
if is_dtype_supported(dtypes.int64):

View File

@@ -781,7 +781,8 @@ class TestAutoCastType(unittest.TestCase):
if DEBUG >= 2:
print(f"testing {default_dtype=}, {dtype=}")
a = Tensor([1, 2, 3], dtype=dtype, requires_grad=True)
b = (a * 5).sum()
# NOTE: this is broken without default_dtype because of CAST_BEFORE_VIEW
b = (a * 5).sum(acc_dtype=default_dtype)
b.backward() # if there is dtype mismatch, lazy should assert
assert a.grad.dtype == a.dtype
np.testing.assert_allclose(a.grad.numpy(), [5, 5, 5])

View File

@@ -120,7 +120,7 @@ class TestImageDType(unittest.TestCase):
loss = x.image_dot(w1).image_dot(w2).float().max()
loss.backward()
sched = unwrap(w1.grad).schedule()
self.assertEqual(len(sched), 10)
self.assertEqual(len(sched), 9)
for s,ei in zip(sched, lower_schedule(sched[:])):
ei.run()
if s.outputs[0].dtype == dtypes.float:

View File

@@ -12,7 +12,6 @@ from tinygrad.shape.shapetracker import ShapeTracker
from tinygrad.shape.view import View
# from tinygrad.ops import Variable
from tinygrad.tensor import Tensor, _to_np_dtype
from tinygrad.engine.schedule import BUF_LIMIT
from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner
from tinygrad.helpers import prod, Context, getenv, CI, flatten, dedup, AMX
from tinygrad.dtype import DType, dtypes
@@ -1701,7 +1700,7 @@ class TestHandCodedOpts(unittest.TestCase):
# float4/other hcopt shouldn't upcast last axis, since we already have 7 upcast, and the last axis is not very contiguous
assert k.upcasted == 1 and k.full_shape[-1] == 7
@unittest.skipIf((buf_max:=BUF_LIMIT.get(Device.DEFAULT)) is not None and buf_max <= 37, "this test uses too many bufs")
@unittest.skipIf(Device.DEFAULT == "METAL", "METAL can only run kernels with up to 32 buffers")
def test_masked_upcast_wino(self):
monster = Tensor.stack(*[Tensor.stack(*[Tensor.rand(16) for _ in range(6)]) for _ in range(6)])

View File

@@ -43,6 +43,12 @@ class TestMultiTensor(unittest.TestCase):
assert lb.shape == (256,)
(X + X).realize()
def test_gradient(self):
X = Tensor.ones(256).contiguous().realize()
X.to_(devices_2)
grad = X.sum().gradient(X)[0]
grad.realize()
def test_shard(self):
X = Tensor.ones(256).contiguous().realize()
X.shard_(devices_2, 0)
@@ -75,7 +81,7 @@ class TestMultiTensor(unittest.TestCase):
ei.run()
assert names[-2] == names[-1], "function was relinearized"
@unittest.skip("this doesn't fold because from_sharded calls contiguous on all lbs")
@unittest.skip("this doesn't fold because shard_ calls contiguous on all lbs")
def test_sharded_memory(self):
# Buffer may be stuck in track_cross_buffer
for x in (d0, d1, d2, d3, d4): Device[x].synchronize()
@@ -157,6 +163,7 @@ class TestMultiTensor(unittest.TestCase):
strat.sampled_from((Ops.ADD, Ops.MUL, Ops.MAX)),
strat.sampled_from((None, 0, 1)), strat.sampled_from((None, 0, 1)), strat.sampled_from((1, 0, -1)))
def test_simple_reduce(self, N, devices, rop, shard_axis, reduce_axis, sign):
N = N * len(devices)
X = Tensor.rand(N*N).reshape(N, N).mul(sign)
n = X.numpy()
X.shard_(devices, shard_axis)
@@ -438,6 +445,7 @@ class TestMultiTensor(unittest.TestCase):
assert isinstance(jf.jit_cache[4].prg, BufferCopy)
assert isinstance(jf.jit_cache[5].prg, graph_d1)
@unittest.skip("no longer supports uneven shard")
def test_uneven_shard(self):
for N in range(1, 6):
X = Tensor.rand(4, 1, 257).contiguous().realize()
@@ -450,6 +458,7 @@ class TestMultiTensor(unittest.TestCase):
np.testing.assert_equal(X.expand((4, 4, 257)).numpy(), np.tile(n, (1, 4, 1)))
np.testing.assert_equal(X.permute((0, 2, 1)).numpy(), np.transpose(n, (0, 2, 1)))
@unittest.skip("no longer supports uneven shard")
def test_uneven_multiple_zeros(self):
for data in ([1, 2, 3, 4], [1, 2, 3], [1, 2], [1], []):
for N in (1, 2, 3, 4):
@@ -458,29 +467,28 @@ class TestMultiTensor(unittest.TestCase):
X = ((Tensor(data).shard(devices, axis=0) + 1).realize() - 1).realize()
np.testing.assert_equal(X.numpy(), data)
@unittest.skip("no longer supports uneven shard")
def test_uneven_shard_with_empty(self):
N = 4
X = Tensor.rand(16, 1, 17).contiguous().realize()
X = Tensor.rand(16, 1, 3).contiguous().realize()
np_x = X.numpy()
devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N))
# test empty shard
np.testing.assert_equal(X.shard(devices, 0, (2, 2, 12, 0)).numpy(), np_x)
np.testing.assert_equal(X.shard(devices, 0).numpy(), np_x)
# test reshape with empty shard
np.testing.assert_equal(X.shard(devices, 0, (2, 2, 12, 0)).reshape(8, 1, 34).numpy(), np_x.reshape(8, 1, 34))
# test elementwise with empty shard
np.testing.assert_equal((X.shard(devices, 0, (2, 2, 12, 0)) + X.shard(devices, 0, (0, 0, 1, 15))).numpy(), np_x + np_x)
np.testing.assert_equal(X.shard(devices, 0).reshape(8, 1, 6).numpy(), np_x.reshape(8, 1, 6))
@unittest.skip("no longer supports uneven shard")
def test_multiple_uneven_shard(self):
N = 4
X = Tensor.rand(4, 1, 257).contiguous().realize()
Y = Tensor.rand(4, 1, 257).contiguous().realize()
np_x, np_y = X.numpy(), Y.numpy()
devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N))
X.shard_(devices, 2, (2, 38, 47, 170))
Y.shard_(devices, 2, (34, 53, 51, 119))
X.shard_(devices, 2)
Y.shard_(devices, 2)
np.testing.assert_equal(X.numpy(), np_x)
np.testing.assert_equal(Y.numpy(), np_y)
np.testing.assert_equal((X + Y).numpy(), np_x + np_y)
@@ -534,6 +542,7 @@ class TestMultiTensor(unittest.TestCase):
with self.assertRaises((AssertionError, ValueError)):
t0.reshape((26*15,7))
@unittest.skip("no longer supports uneven shard")
def test_reshape_on_axis_uneven(self):
def reshape_helper(t0, t, t_axis):
np.testing.assert_allclose(t0.reshape(t.shape).numpy(), t.numpy())
@@ -605,8 +614,9 @@ class TestMultiTensor(unittest.TestCase):
self.assertEqual(t.dtype, t2.dtype)
self.assertEqual(t.lazydata.axis, t2.lazydata.axis)
@unittest.skip("no longer supports uneven shard")
def test_rand_like_uneven_shard(self):
t = Tensor.empty((4, 42, 15)).shard(devices_3, axis=1, splits=(14, 7, 21))
t = Tensor.empty((4, 42, 15)).shard(devices_3, axis=1)
t2 = Tensor.rand_like(t)
self.assertEqual(t.shape, t2.shape)
self.assertEqual(t.device, t2.device)
@@ -655,9 +665,10 @@ class TestMultiTensor(unittest.TestCase):
assert set(unique) == {0, 2}, unique
assert 200 < counts[0] < 312, counts[0]
@unittest.skip("no longer supports uneven shard")
def test_dropout_on_uneven_shard_axis(self):
with Tensor.train():
X = Tensor.ones(256).shard(devices_3, axis=0, splits=(100, 50, 106))
X = Tensor.ones(256).shard(devices_3, axis=0)
output = X.dropout(0.5).numpy()
unique, counts = np.unique(output, return_counts=True)
assert set(unique) == {0, 2}, unique
@@ -689,6 +700,7 @@ class TestMultiTensor(unittest.TestCase):
assert ast.src[2].src[0].op is Ops.LOAD
assert ast.src[2].src[1].src[1].op is Ops.CONST and ast.src[2].src[1].src[1].arg == 3
@unittest.skip("TODO: this requires forced_realize to be deleted.")
def test_shard_memory(self):
devices = (d0, d1, d2, d3)
t = Tensor.zeros(16, 16).contiguous()
@@ -816,6 +828,7 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
np.testing.assert_allclose(a.reshape((2, 1, 8)).expand((2, 5, 8)).numpy(), b.reshape((2, 1, 8)).expand((2, 5, 8)).numpy(), rtol=1e-7, atol=1e-3)
np.testing.assert_allclose(a.flip(-1).numpy(), b.flip(-1).numpy(), rtol=1e-7, atol=1e-3)
@unittest.skip("no longer supports uneven shard")
def test_uneven(self):
t = Tensor.arange(24).reshape(3, 8).contiguous().realize()
t.shard_([f"{Device.DEFAULT}:{i}" for i in range(2)], axis=0)

View File

@@ -551,7 +551,7 @@ class TestNN(unittest.TestCase):
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
def test_load_state_dict_sharded_model(self):
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2")
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3")
layer = Conv2d(3, 5, kernel_size=3)
layer.weight.shard_(devices, 3)
@@ -572,7 +572,7 @@ class TestNN(unittest.TestCase):
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
def test_load_state_dict_sharded_dict(self):
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2")
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3")
layer = Conv2d(3, 5, kernel_size=3)
state_dict = {
@@ -589,7 +589,7 @@ class TestNN(unittest.TestCase):
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
def test_load_state_dict_sharded_model_dict_same_axis(self):
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2")
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3")
layer = Conv2d(3, 5, kernel_size=3)
layer.weight.shard_(devices, 3)
@@ -610,7 +610,8 @@ class TestNN(unittest.TestCase):
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
def test_load_state_dict_sharded_model_dict_different_axis(self):
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2")
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3")
devices5 = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3", f"{Device.DEFAULT}:4", f"{Device.DEFAULT}:5")
layer = Conv2d(3, 5, kernel_size=3)
layer.weight.shard_(devices, 3)
@@ -619,14 +620,14 @@ class TestNN(unittest.TestCase):
# different shard axis
state_dict = {
'weight': Tensor.randn(5, 3, 3, 3).shard(devices, None),
'bias': Tensor.randn(5).shard(devices, 0),
'bias': Tensor.randn(5).shard(devices5, 0),
}
load_state_dict(layer, state_dict)
# NOTE: model and state_dict shard differently, use the state_dict sharding # TODO: revisit this?
self.assertEqual(layer.weight.device, devices)
self.assertEqual(layer.weight.lazydata.axis, None)
self.assertEqual(layer.bias.device, devices)
self.assertEqual(layer.bias.device, devices5)
self.assertEqual(layer.bias.lazydata.axis, 0)
np.testing.assert_allclose(layer.weight.numpy(), state_dict['weight'].numpy())
np.testing.assert_allclose(layer.bias.numpy(), state_dict['bias'].numpy())

View File

@@ -16,7 +16,7 @@ from tinygrad.shape.view import View
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, view_supported_devices, symbolic_simple, merge_views
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
from tinygrad.codegen.kernel import verify_ast
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
from extra.models.llama import precompute_freqs_cis
@@ -585,6 +585,15 @@ class TestSchedule(unittest.TestCase):
run_schedule(check_schedule(out, 2))
np.testing.assert_allclose(out.numpy(), np.ones((64,64)))
def test_example_matmul_same(self):
x = Tensor.eye(64, requires_grad=True)
z = x.matmul(x).sum()
z.backward()
out = x.grad.contiguous()
run_schedule(check_schedule(out, 2))
# NOTE: the gradient flows twice
np.testing.assert_allclose(out.numpy(), 2*np.ones((64,64)))
def test_contiguous_add(self):
x = Tensor.empty(32)
y = Tensor.empty(32)
@@ -1363,8 +1372,9 @@ class TestSchedule(unittest.TestCase):
@unittest.expectedFailure
def test_conv2d_fused_half(self): _test_conv2d(5, dtype=dtypes.half)
@unittest.skip("splitting kernels exceeding device buffer count is not yet supported")
def _test_buf_cnt(self, cnt:int, allowed:int):
if (m:=BUF_LIMIT.get(Device.DEFAULT)) is None or m != 32: self.skipTest(f"test needs a buf_max of 32 {Device.DEFAULT}")
#if (m:=BUF_LIMIT.get(Device.DEFAULT)) is None or m != 32: self.skipTest(f"test needs a buf_max of 32 {Device.DEFAULT}")
alu = functools.reduce(lambda x,y: x+y, [Tensor.ones((1, 1)).contiguous().realize() for _ in range(cnt-1)])
s = alu.schedule()
assert len(s) == allowed
@@ -1435,6 +1445,7 @@ class TestSchedule(unittest.TestCase):
def test_late_fusion_post_expand(self):
self._test_fusion([(32, 32)], lambda a:a-a.sum(1), 2)
@unittest.skip("CAST_BEFORE_VIEW=1 is not supported")
def test_cast_padded_view(self):
a = Tensor.arange(4).reshape(1, 4)
casted_view = a.pad(((0, 1), (0, 0))).cast(dtypes.float)
@@ -1445,6 +1456,7 @@ class TestSchedule(unittest.TestCase):
self.assertListEqual(realized_view.tolist(), [[0.0, 1.0, 2.0, 3.0], [0.0, 0.0, 0.0, 0.0]])
# NOTE: we might want to reconsider pushing this cast before the shrink
@unittest.skip("CAST_BEFORE_VIEW=1 is not supported")
def test_cast_after_shrink(self):
a = Tensor.arange(4).reshape(1, 4)
casted_view = a.shrink(((0, 1), (0, 2))).cast(dtypes.float)
@@ -1454,6 +1466,7 @@ class TestSchedule(unittest.TestCase):
self.assertEqual(realized_view.lazydata.base.realized.size, 2)
self.assertListEqual(realized_view.tolist(), [[0, 1]])
@unittest.skip("CAST_BEFORE_VIEW=1 is not supported")
def test_cast_const_view(self):
a = Tensor.ones((4, 4), dtype=dtypes.float32)
casted_view = a.cast(dtypes.int32)
@@ -1463,6 +1476,7 @@ class TestSchedule(unittest.TestCase):
run_schedule(check_schedule(realized_const_view, 1))
self.assertListEqual(realized_const_view.tolist(), [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]])
@unittest.skip("CAST_BEFORE_VIEW=1 is not supported")
def test_cast_padded_const(self):
a = Tensor(1, dtype=dtypes.int32).reshape(1, 1).pad(((1, 1), None))
casted_view = a.cast(dtypes.float32)
@@ -1565,7 +1579,7 @@ class TestIndexing(unittest.TestCase):
x = Tensor.randn(5, 2).realize()
a = Tensor.arange(10)
out = (x + a[2]).sum()
self.check_schedule(out, 1)
self.check_schedule(out, 2)
np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum(), atol=1e-5, rtol=1e-6)
def test_arange_index_contiguous(self):
@@ -1573,7 +1587,7 @@ class TestIndexing(unittest.TestCase):
x = Tensor.randn(5, 2).realize()
a = Tensor.arange(10).contiguous()
out = (x + a[2]).sum()
self.check_schedule(out, 2)
self.check_schedule(out, 3)
np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum(), atol=1e-5, rtol=1e-6)
def test_arange_index_child(self):
@@ -1581,7 +1595,7 @@ class TestIndexing(unittest.TestCase):
x = Tensor.randn(5, 2).realize()
a = Tensor.arange(10)+1
out = (x + a[2]).sum()
self.check_schedule(out, 1)
self.check_schedule(out, 2)
np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum(), atol=1e-5, rtol=1e-6)
def test_arange_index_contiguous_child(self):
@@ -1589,7 +1603,7 @@ class TestIndexing(unittest.TestCase):
x = Tensor.randn(5, 2).realize()
a = (Tensor.arange(10)+1).contiguous()
out = (x + a[2]).sum()
self.check_schedule(out, 2)
self.check_schedule(out, 3)
np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum(), atol=1e-5, rtol=1e-6)
def test_arange_childless_base(self):

View File

@@ -464,7 +464,7 @@ class TestTinygrad(unittest.TestCase):
def test_repr_with_grad(self):
a = Tensor([1], requires_grad=True)
b = Tensor([1])
c = (a + b).mean().backward()
c = (a + b).sum().backward()
print(a)
print(c)
@@ -646,11 +646,20 @@ class TestZeroShapeTensor(unittest.TestCase):
def test_clone(self):
a = Tensor.rand(16, 16).realize()
self.assertIsNot(a.lazydata, a.clone().lazydata)
np.testing.assert_allclose(a.numpy(), a.clone().numpy())
a = Tensor.rand(16, 16).mul(5.0).add(5.0)
self.assertIsNot(a.lazydata, a.clone().lazydata)
np.testing.assert_allclose(a.numpy(), a.clone().numpy())
def test_clone_with_shrink(self):
a = Tensor.empty(16, 16)
self.assertIsNot(a.lazydata, a.clone().lazydata)
b = a.shrink(((2, 10), None))
self.assertIsNot(b.lazydata, b.clone().lazydata)
def test_clone_with_grad(self):
a = Tensor.rand(16, 16, requires_grad=True)
a.mul(5.0).add(5.0).mean().backward()

View File

@@ -3,9 +3,9 @@ import numpy as np
import unittest
from tinygrad import Tensor, Device, dtypes
from tinygrad.engine.realize import run_schedule
from tinygrad.ops import Ops, UOp
from tinygrad.ops import Ops, UOp, UPat
class TestLazyBuffer(unittest.TestCase):
class TestTensorUOp(unittest.TestCase):
def test_fromcpu_shape_tracker(self):
def helper(a: np.ndarray):
print(a.shape, a.strides, a.flags.c_contiguous)
@@ -68,7 +68,7 @@ class TestLazyBuffer(unittest.TestCase):
assert lb.const_like(1).const_arg == 1.0
assert type(lb.const_like(1).const_arg) is float
def test_forced_realized_alu(self):
def test_contiguous_alu(self):
a = Tensor.randn(2, 2).realize()
b = Tensor.randn(2, 2).realize()
add = (a+b).contiguous()
@@ -84,13 +84,14 @@ class TestLazyBuffer(unittest.TestCase):
sched = empty.schedule()
self.assertEqual(len(sched), 0)
reduce_kernel = UPat(Ops.SINK, src=(UPat(Ops.STORE, src=(UPat(), UPat(), UPat(Ops.REDUCE_AXIS)))))
class TestReduceOp(unittest.TestCase):
def test_no_split_reduce_kernel(self):
a = Tensor.rand(4, 4).realize()
a = a.sum()
sched = a.schedule()
assert len(sched) == 1
self.assertIs(sched[0].ast.src[0].src[2].op, Ops.REDUCE_AXIS)
assert reduce_kernel.match(sched[0].ast, {})
def test_split_reduce_kernel_dim0(self):
a = Tensor.rand(256, 255).realize()
@@ -98,7 +99,7 @@ class TestReduceOp(unittest.TestCase):
sched = a.schedule()
assert len(sched) == 2
for s in sched:
self.assertIs(s.ast.src[0].src[2].op, Ops.REDUCE_AXIS)
assert reduce_kernel.match(s.ast, {})
def test_split_reduce_kernel_dim1(self):
a = Tensor.rand(255, 256).realize()
@@ -106,7 +107,7 @@ class TestReduceOp(unittest.TestCase):
sched = a.schedule()
assert len(sched) == 2
for s in sched:
self.assertIs(s.ast.src[0].src[2].op, Ops.REDUCE_AXIS)
assert reduce_kernel.match(s.ast, {})
if __name__ == "__main__":
unittest.main()

View File

@@ -81,6 +81,8 @@ class TestTiny(unittest.TestCase):
# *** a model ***
# TODO: this is failing because of how swizzling rewrites the ShapeTracker of the final STORE
@unittest.skipIf(IMAGE>0, "failing because of make things that can't be images not images")
def test_mnist_model(self):
layers = [
nn.Conv2d(1, 32, 5), Tensor.relu,

View File

@@ -93,6 +93,12 @@ class TestTensorGradient(unittest.TestCase):
dx = z.gradient(x, gradient=dz)[0]
self.assertListEqual(dx.tolist(), [2.0, 4.0, 6.0])
def test_cast_before_view(self):
x = Tensor([1.0, 1, 1, 1])
x_reshaped = x.reshape(2,2)
x_casted = x_reshaped.cast(dtypes.float16)
x_casted.mean().gradient(x_reshaped)
class TestRealizeMeansRealize(unittest.TestCase):
def test_randn_realizes(self):
x = Tensor.randn(2, 3, 64, 64, requires_grad=True).realize()
@@ -104,5 +110,11 @@ class TestRealizeMeansRealize(unittest.TestCase):
print(x.lazydata)
self.assertEqual(x.lazydata.op, Ops.VIEW)
# NOTE: even though it doesn't realize, this seems fine
def test_uniform_gradient(self):
x = Tensor.uniform(16, 3, 3, 3, requires_grad=True).realize()
y = x * 2
y.sum().gradient(x)[0].realize()
if __name__ == '__main__':
unittest.main()

View File

@@ -1,39 +1,6 @@
import unittest
from tinygrad import dtypes, Tensor
from tinygrad import dtypes
from tinygrad.ops import UOp, symbolic, graph_rewrite_map, _substitute
from test.unit.test_tensor_uop_representation import is_pattern, realized_pattern, is_pattern_uop
class TestTensorMutates(unittest.TestCase):
def test_mutate_add(self):
a = Tensor([1,2,3])
b = Tensor([4,5,6])
ret = a+b
pa = a.lazydata
pb = b.lazydata
pr = ret.lazydata
ret.schedule()
self.assertIsNot(pa, a.lazydata)
self.assertIsNot(pb, b.lazydata)
self.assertIsNot(pr, ret.lazydata)
for t in [a,b,ret]: is_pattern(t, realized_pattern)
def test_reshape_is_same_parent(self):
a = Tensor([1,2,3])
b = Tensor([4,5,6])
c = a+b
d = (a+b).reshape(3,1)
d.realize()
is_pattern_uop(d.lazydata.base, realized_pattern)
is_pattern_uop(c.lazydata.base, realized_pattern)
def test_reshape_is_same_child(self):
a = Tensor([1,2,3])
b = Tensor([4,5,6])
c = a+b
d = (a+b).reshape(3,1)
c.realize()
is_pattern_uop(c.lazydata.base, realized_pattern)
is_pattern_uop(d.lazydata.base, realized_pattern)
class TestRewriteMap(unittest.TestCase):
def test_substitute(self):

View File

@@ -19,7 +19,10 @@ class TestTensorMutates(unittest.TestCase):
self.assertIsNot(pa, a.lazydata)
self.assertIsNot(pb, b.lazydata)
self.assertIsNot(pr, ret.lazydata)
for t in [a,b,ret]: is_pattern(t, realized_pattern)
# NOTE: this becomes a VIEW(VIEW(BUFFER)) because UOp.view no longer instantly folds contiguous VIEW of the same shape
# this is fine because realized exists on the base.
# TODO: we can make this always be a VIEW(BUFFER) once BUFFER has a ShapeTracker of shape=(N,)
for t in [a,b,ret]: is_pattern_uop(t.lazydata.base, realized_pattern)
def test_reshape_is_same_parent(self):
a = Tensor([1,2,3])
@@ -43,14 +46,14 @@ class TestTensorUopRepresentation(unittest.TestCase):
def test_realized(self):
a = Tensor([1.,2,3]).realize()
print(a.lazydata)
is_pattern(a, realized_pattern)
is_pattern_uop(a.lazydata.base, realized_pattern)
def test_add_realized(self):
a = Tensor([1.,2,3]).realize()
b = Tensor([4.,5,6]).realize()
c = a+b
print(c.lazydata)
is_pattern(c, UPat(Ops.ADD, src=(realized_pattern, realized_pattern)))
is_pattern(c, UPat(Ops.ADD, src=(UPat(Ops.VIEW, src=(realized_pattern,)), UPat(Ops.VIEW, src=(realized_pattern,)))))
def test_const_pattern(self):
a = Tensor(1)
@@ -107,7 +110,8 @@ class TestTensorUopRepresentation(unittest.TestCase):
a = Tensor([1.,2,3]).realize()
c = a.to("TEST") # NOTE: this isn't checked
print(c.lazydata)
is_pattern(c, UPat(Ops.COPY, src=(UPat(Ops.DEVICE), realized_pattern,)))
# TODO: COPY on a Tensor becomes a VIEW(COPY), this should be done in the scheduler not in ops
is_pattern(c, UPat(Ops.VIEW, src=(UPat(Ops.COPY, src=(UPat(Ops.DEVICE), realized_pattern,)),)))
if __name__ == '__main__':
unittest.main()

View File

@@ -13,8 +13,6 @@ from tinygrad.device import Buffer
# creation can recurse a lot
sys.setrecursionlimit(10000)
BUF_LIMIT = {"METAL":32}
# **** big graph spec
tensor_uop_spec = PatternMatcher([
@@ -185,14 +183,9 @@ view_right = merge_views+PatternMatcher([
@dataclass(frozen=True)
class ScheduleItemContext:
ops_metadata: dict[UOp, Metadata]
assigns: set[UOp]
var_vals: dict[Variable, int]
sinked: dict[UOp, UOp]
sts: set[ShapeTracker] = field(default_factory=set)
bufs: list[UOp] = field(default_factory=list)
metadata: set[Metadata] = field(default_factory=set)
assign_adj: dict[UOp, list[UOp]] = field(default_factory=dict)
def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> UOp|None:
if (st:=unwrap(x.st)) in ctx.sts: return None
@@ -204,53 +197,49 @@ def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> UOp|None:
def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
ctx.bufs.append(x)
return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(size=x.size), (), len(ctx.bufs)-1)
append_bufs = PatternMatcher([(UPat(Ops.BUFFER, name="x"), _append_buf)])
def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp:
(adj_loads:=ctx.assign_adj.setdefault(b, [])).append(x)
if not all_same([x.op for x in adj_loads]): raise RuntimeError(f"Detected cycle when fusing {adj_loads}. Can only fuse PRELOAD or LOAD of {b}")
return x.replace(op=Ops.LOAD)
check_preload = PatternMatcher([(UPat(Ops.PRELOAD, src=(UPat.var("b"), UPat()), name="x"), _append_preload),])
to_si = PatternMatcher([
# BUFFER -> DEFINE_GLOBAL
(UPat(Ops.BUFFER, name="x"), _append_buf),
# simplify and unbind the final VIEWs
(UPat(Ops.VIEW, name="x"), _append_st_vars),
# don't need SINK on COPY or BUFFER_VIEW
(UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat(GroupOp.Meta, name="x")),)), lambda b,x: x.replace(src=(b, *x.src))),
# don't need contiguous or assign anymore
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
(UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda x: x),
# PRELOAD becomes LOAD
(UPat(Ops.PRELOAD, name="root"), lambda root:root.replace(op=Ops.LOAD)),
])
add_metadata = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: None if (m:=ctx.ops_metadata.get(x)) is None else ctx.metadata.add(m)),])
add_assign_adjacents = PatternMatcher([(UPat.load(UPat.var("b"), UPat(), name="x"), lambda ctx,b,x: ctx.assign_adj.setdefault(b, []).append(x)
if b in ctx.assigns else None)])
# late folding for multi output kernels
multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: ctx.sinked.get(b)),])
# LOAD(BUFFER) -> the STORE value if it's we're doing the STORE in the same kernel
multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: ctx.get(b)),])
def schedule_uop(pre:UOp, ctx:ScheduleContext) -> ScheduleItem:
# create the ast context
si_ctx = ScheduleItemContext(ctx.ops_metadata, ctx.assigns, ctx.var_vals, {x.buf_uop:x.src[2] for x in pre.src})
create_ctx = add_metadata if len(si_ctx.assigns) == 0 else add_metadata+add_assign_adjacents
sink = graph_rewrite(pre, create_ctx if len(si_ctx.sinked) == 1 else multioutput+create_ctx, si_ctx)
# do movement ops
sink = graph_rewrite(graph_rewrite(sink, view_left), view_right)
# convert to AST
sink = graph_rewrite(graph_rewrite(sink, to_si+check_preload if len(si_ctx.assigns) != 0 else to_si, si_ctx), append_bufs, si_ctx)
# assert buffer count limit
if (limit:=BUF_LIMIT.get(device:=si_ctx.bufs[0].device)) is not None and len(si_ctx.bufs) >= limit:
if DEBUG >= 3: print(sink)
raise RuntimeError(f"Kernel for {si_ctx.metadata} exceeded the {limit} buffer count limit for {device} with {len(si_ctx.bufs)} buffers.")
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
for ubuf,ops in si_ctx.assign_adj.items():
if si_ctx.sinked.get(ubuf) is not None and not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \
and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in ops):
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
# remove movement ops + substitute LOAD of fused STORE with just the value
sink = graph_rewrite(graph_rewrite(pre, multioutput+view_left, store_bufs:={x.buf_uop:x.src[2] for x in pre.src}), view_right)
# remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL
ast = graph_rewrite(sink, to_si, si_ctx:=ScheduleItemContext(ctx.var_vals))
# deal with ASSIGN
assign_preloads: list[UOp] = []
if len(ctx.assigns) != 0:
for x in list(sink.toposort)[::-1]:
# we only allow a kernel to depend on either the before ASSIGN or after ASSIGN version of a BUFFER
if x.op is Ops.LOAD and x.buf_uop in assign_preloads: raise RuntimeError("cycle detected in graph")
# PRELOAD tells the toposort this kernel should run before ASSIGN
if x.op is Ops.PRELOAD:
assign_preloads.append(x.buf_uop)
# if this kernel also assigns to the buffer, we only allow either contiguous or masked views for the LOAD
if x.buf_uop in store_bufs and not (st:=x.st_arg).contiguous:
# if it has a single view and it's equal when you shrink a contig, it's fine
if len(st.views) != 1 or (mask:=st.views[0].mask) is None or ShapeTracker.from_shape(st.shape).shrink(mask) != st.shrink(mask):
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
# capture process replay
if CAPTURE_PROCESS_REPLAY:
with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, si_ctx.assigns, ContextVar._cache, sink))
return ScheduleItem(sink, tuple(u.buffer for u in si_ctx.bufs if u.size != 0), tuple(si_ctx.metadata),
tuple(ubuf for ubuf,ops in si_ctx.assign_adj.items() if any(x.op is Ops.PRELOAD for x in ops)))
with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, ContextVar._cache, ast))
return ScheduleItem(ast, tuple(u.buffer for u in si_ctx.bufs if u.size != 0),
tuple(dedup(m for x in pre.toposort if (m:=ctx.ops_metadata.get(x)) is not None)), tuple(dedup(assign_preloads)))
PROCESS_REPLAY_CAPTURE: dict[str, bytes] = {}
if CAPTURE_PROCESS_REPLAY:
@@ -511,7 +500,7 @@ break_sched = PatternMatcher([
def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:
ctx.allbufs[buf_uop] = view
if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns.add(buf_uop)
for x in op.src:
for x in op.base.src:
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None
# BUFFER_VIEW overrides the underlying buffer
# TODO: this should be a shrink on the buffer
@@ -551,12 +540,13 @@ def create_schedule_with_vars(outs:list[UOp], skip_check:bool=not __debug__) ->
# preschedule realize groups
prescheduled: list[ScheduleItem] = []
for store_uops in store_groups:
if len(stores:=[ctx.realizes[u] for u in store_uops if ctx.realizes[u].op is Ops.STORE]) != 0:
prescheduled.append(schedule_uop(UOp.sink(*stores), ctx))
# can only schedule once
for buf_uop in store_uops:
for luop in ctx.tensor_uops[buf_uop]: ctx.becomes_map[luop] = buf_uop.view(unwrap(luop.st))
# do BFS
if len(stores:=[ctx.realizes[u] for u in store_uops if ctx.realizes[u].op is Ops.STORE]) == 0: continue
prescheduled.append(schedule_uop(UOp.sink(*stores), ctx))
# can only schedule once
for buf_uop in store_uops:
for luop in ctx.tensor_uops[buf_uop]: ctx.becomes_map[luop] = buf_uop.view(unwrap(luop.st))
# add kernel children
schedule_targets = {out:si for si in prescheduled for out in si.outputs}
graph: defaultdict[ScheduleItem, list[ScheduleItem]] = defaultdict(list)
in_degree: defaultdict[ScheduleItem, int] = defaultdict(int)
@@ -571,6 +561,8 @@ def create_schedule_with_vars(outs:list[UOp], skip_check:bool=not __debug__) ->
for x in scheduled_parents:
graph[x].append(si)
in_degree[si] += 1
# do BFS
queue = deque(si for si in prescheduled if in_degree[si] == 0)
schedule: list[ScheduleItem] = []
while queue:

View File

@@ -1,4 +1,4 @@
from typing import cast
from typing import cast, Iterator
import math, functools
from tinygrad.dtype import dtypes, sum_acc_dtype
from tinygrad.ops import UOp, PatternMatcher, UPat, Ops
@@ -39,13 +39,15 @@ pm_gradient = PatternMatcher([
# there's no gradient for...is this ASSIGN?
(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat(Ops.BUFFER_VIEW))), lambda: (None, None)),
# also no gradient for bitcast
(UPat(Ops.BITCAST), lambda ctx: (None,)),
])
# copied from tensor.py, get relevant toposort of gradients
def _deepwalk(root:UOp, targets:list[UOp]):
def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]:
@functools.lru_cache(None)
def is_in_target_path(x:UOp) -> bool: return any(u in targets or is_in_target_path(u) for u in x.src)
def _walk(node:UOp, visited:set[UOp]):
def _walk(node:UOp, visited:set[UOp]) -> Iterator[UOp]:
visited.add(node)
if node.op is Ops.DETACH: return
if is_in_target_path(node):
@@ -54,7 +56,7 @@ def _deepwalk(root:UOp, targets:list[UOp]):
yield node
return list(_walk(root, set()))
def compute_gradient(root:UOp, root_grad:UOp, targets:list[UOp]) -> dict[UOp, UOp]:
def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp]:
grads = {root: root_grad}
for t0 in reversed(_deepwalk(root, targets)):
if t0 not in grads: continue

View File

@@ -14,11 +14,10 @@ def all_reduce(bop: Ops, lbs: list[UOp]) -> list[UOp]:
if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{numel} | {lbs[0].dtype}")
if not use_ring: return [functools.reduce(lambda x,y: x.alu(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
factor = next(f for f in [32, 16, 8, 4, 2, 1] if numel % f == 0)
factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1)
base, left = (numel // factor) // n_lbs, (numel // factor) % n_lbs
chunk_sizes = [(base + 1) * factor] * left + [base * factor] * (n_lbs - left)
acc = 0
chunks = [(acc, (acc := acc + i)) for i in chunk_sizes if i > 0]
chunks = list(itertools.pairwise(itertools.accumulate(chunk_sizes, initial=0)))
chunked = [[lb.reshape((numel,)).shrink(((s,e),)) for s,e in chunks] for lb in lbs]
# scatter-reduce
@@ -38,7 +37,7 @@ def all_reduce(bop: Ops, lbs: list[UOp]) -> list[UOp]:
return [functools.reduce(operator.add, [c.pad(pad) for pad,c in zip(pads,lb_c)]).reshape(shape) for lb_c in chunked]
def to_sharded(lbs:list[UOp], axis:int, bounds: tuple[tuple[int, int], ...]) -> list[UOp]:
if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}, bounds={bounds}")
if lbs[0].shape[axis] % len(lbs) != 0: raise RuntimeError(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}, bounds={bounds}")
return [lb.shrink(tuple((0,s) if a != axis else bound for a,s in enumerate(lb.shape))) for i, (bound, lb) in enumerate(zip(bounds, lbs))]
class MultiLazyBuffer(MathTrait):
@@ -46,9 +45,6 @@ class MultiLazyBuffer(MathTrait):
assert all(isinstance(x, UOp) for x in lbs) and len(lbs), "all lbs must be LazyBuffers, and we need at least one of them"
assert all_same([x.dtype for x in lbs]), f"all multilazybuffer needs same dtype, getting {[x.dtype for x in lbs]}"
self.lbs, self.axis, self.dtype, self.device, self.real = lbs, axis, lbs[0].dtype, tuple(x.device for x in lbs), real or [True]*len(lbs)
if axis is not None:
splits = list(itertools.accumulate([lb.shape[axis] for lb in lbs], initial=0))
self.bounds = tuple(zip(splits, splits[1:]))
@property
def shape(self): return tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape))
@@ -59,20 +55,16 @@ class MultiLazyBuffer(MathTrait):
@property
def real_lbs(self): return [lb for lb,r in zip(self.lbs, self.real) if r]
@property
def bounds(self):
if self.axis is None: raise RuntimeError("bounds is not defined when axis is None")
return tuple(itertools.pairwise(itertools.accumulate([lb.shape[self.axis] for lb in self.lbs], initial=0)))
def __repr__(self): return f"<MLB {self.axis=} {self.real=} {chr(10)}{chr(10).join([f'{x.device} {x.st}' for x in self.lbs])}>"
@staticmethod
def from_sharded(lb:UOp, devices:tuple[str, ...], axis:int|None, bounds:tuple[tuple[int, int], ...]|None):
assert (axis is None) == (bounds is None), "must specify bounds iff axis is specified"
lbs = [lb] * len(devices)
sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(to_sharded(lbs, axis, bounds) if axis is not None and bounds is not None else lbs, devices)]
# NOTE: this contiguous is making it impossible for the scheduler to do late const folding
return MultiLazyBuffer([lb.contiguous(allow_buffer_view=False) for lb in sharded_lbs], axis)
def copy_to_device(self, device:str) -> UOp:
if self.axis is None:
# if we already have a copy on the device, return that
return next((lb for lb in self.real_lbs if lb.device == device), self.real_lbs[0].copy_to_device(device))
# if we already have a copy on the device, return that
if self.axis is None: return next((lb for lb in self.real_lbs if lb.device == device), self.real_lbs[0].copy_to_device(device))
# copy lbs to device, pad to final shape, and sum
llbs:list[UOp] = []
for lb,real,(start,end) in zip(self.lbs, self.real, self.bounds):
@@ -84,13 +76,14 @@ class MultiLazyBuffer(MathTrait):
# passthroughs
@property
def is_realized(self) -> bool: return all(lb.base.realized is not None for lb in self.real_lbs)
def cast(self, dtype:DType, bitcast:bool=False):
return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis, self.real)
def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis, self.real)
def const_like(self, b) -> MultiLazyBuffer: return MultiLazyBuffer([x.const_like(b) for x in self.lbs], self.axis, self.real)
def assign(self, x:MultiLazyBuffer): return MultiLazyBuffer([s.assign(d) for s,d in zip(self.lbs, x.lbs)], self.axis, self.real)
def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis, self.real)
def clone(self) -> MultiLazyBuffer: return MultiLazyBuffer([lb.clone() for lb in self.lbs], self.axis, self.real)
def detach(self) -> MultiLazyBuffer: return MultiLazyBuffer([lb.detach() for lb in self.lbs], self.axis, self.real)
@property
def toposort(self) -> dict[UOp, None]: return {l:None for x in self.lbs for l in x.toposort}
# elementwise is simple
def alu(self, op:Ops, *in_srcs:MultiLazyBuffer) -> MultiLazyBuffer:
@@ -106,12 +99,10 @@ class MultiLazyBuffer(MathTrait):
assert any(new_real), "output contains no real lb"
for mlb in msrcs:
if (mlb.axis == axis and (mlb.axis is None or mlb.bounds == bounds)) or not_all_real: srcs.append(mlb.lbs)
elif mlb.axis is None and axis is not None:
assert bounds is not None
srcs.append(to_sharded(mlb.lbs, axis, bounds))
else:
assert axis is not None and bounds is not None
srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis, bounds))
if mlb.axis is None: srcs.append(to_sharded(mlb.lbs, axis, bounds))
else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis, bounds))
new_real_lbs:dict[int,UOp] = {i:lsrcs[0].alu(op, *lsrcs[1:]) for i,(lsrcs,r) in enumerate(zip(zip(*srcs), new_real)) if r}
# NOTE: const dtype should match real
new_dtype = next(iter(new_real_lbs.values())).dtype

View File

@@ -361,17 +361,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
def cast(self, dtype:DType, bitcast=False):
if bitcast: return self.bitcast(dtype)
if self._device is not None and self._device.startswith("DISK"): raise RuntimeError("CAST isn't supported on DISK")
if getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self is not self.base:
# NOTE: we have to apply the movementops here, we can't use VIEW (yet)
# TODO: move this to the scheduler
ret = self.base.cast(dtype, bitcast)
op_arg = []
mop = self
while mop is not self.base:
op_arg.append((mop.op, mop.arg))
mop = mop.src[0]
for op,arg in reversed(op_arg): ret = UOp(op, ret.dtype, (ret,), arg)
return ret
return UOp(Ops.CAST, dtype, (self,))
def bitcast(self, dtype:DType):
if self.st is not None and self.shape and ((self.shape[-1]*self.dtype.itemsize)%dtype.itemsize != 0):
@@ -440,7 +429,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
# *** from LazyBuffer ***
@staticmethod
def metaop(op:Ops, shape:tuple[sint, ...], dtype:DType, device:str, arg=None, src:tuple[UOp, ...]=()) -> UOp:
def metaop(op:Ops, shape:tuple[sint, ...], dtype:DType, device:str, arg=None) -> UOp:
from tinygrad.shape.shapetracker import ShapeTracker
# Tensor const is CONST(VIEW(DEVICE)) -> RESHAPE -> EXPAND
if op is Ops.CONST:
@@ -477,13 +466,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
@property
def base(self) -> UOp:
if self.op in GroupOp.Movement: return self.src[0].base
return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is not Ops.BUFFER else self
def view(self, new_st:ShapeTracker) -> UOp:
if self.st is None: return UOp(Ops.VIEW, self.dtype.base if not isinstance(self.dtype, ImageDType) else self.dtype, (self,), new_st)
ret = UOp(Ops.VIEW, self.dtype, (self.base,), new_st)
# instant folding rules
if new_st.contiguous and self.base.shape == new_st.shape: return self.base
return ret
return self.src[0].base if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is not Ops.BUFFER else self
def view(self, new_st:ShapeTracker) -> UOp: return UOp(Ops.VIEW, self.dtype, (self.base,), new_st)
def _mop(self, op:Ops, arg):
ret = UOp(op, self.dtype, (self,), arg)
@@ -1304,7 +1288,10 @@ ConstLike = Union[ConstType, Variable, tuple[ConstType, ...]]
# *** uop swizzling ***
merge_views = PatternMatcher([(UPat(Ops.VIEW, name="s0").view(name="s1"), lambda s0,s1: s0.replace(arg=s0.st+s1.st))])
merge_views = PatternMatcher([
(UPat(Ops.VIEW, name="s0").view(name="s1"), lambda s0,s1: s0.replace(arg=s0.st+s1.st)),
(UPat(Ops.VIEW, name="mv", src=(UPat.var("x"),)), lambda mv,x: x if mv.st.contiguous and x.st is not None and x.shape == mv.shape else None),
])
# push VIEW to loads
view_left = merge_views+PatternMatcher([

View File

@@ -291,17 +291,23 @@ class MetalRenderer(CStyleLanguage):
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
_nms = "xyzwabcdefghijkl"
cuda_tc_opts = ("u0","l0","l0","l1","l1","l1","u1") # shared by all shapes with M=16 N=8
class CUDARenderer(CStyleLanguage):
device = "CUDA"
global_max = (2147483647, 65535, 65535)
local_max = (1024, 1024, 64)
shared_max = 49152
# https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float
tensor_cores = [TensorCore(dims=(8,16,16), threads=32, elements_per_thread=(8,4,4), dtype_in=di, dtype_out=do,
opts=("u0","l0","l0","l1","l1","l1","u1"), swizzle=(((6,7,2,3,4),(0,1,9,5,10,8)), ((6,7,9,0,1),(2,3,4,10,5,8))))
for di,do in ([(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)])]
def __init__(self, arch:str): self.tensor_cores, self.arch = CUDARenderer.tensor_cores if int(arch[3:]) >= 80 else [], arch
# https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-multiply-accumulate-instructions
tc_81616 = [TensorCore(dims=(8,16,16), threads=32, elements_per_thread=(8,4,4), dtype_in=di,dtype_out=do, opts=cuda_tc_opts,
swizzle=(((6,7,2,3,4),(0,1,9,5,10,8)), ((6,7,9,0,1),(2,3,4,10,5,8)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.bfloat16,dtypes.float)]]
tc_8168_f16 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=dtypes.half, dtype_out=dtypes.float, opts=cuda_tc_opts,
swizzle=(((6,7,2,3,4),(0,1,8,5,9)), ((6,7,8,0,1),(2,3,4,9,5))))]
tc_sm80 = tc_81616 + tc_8168_f16
tc_sm75 = tc_8168_f16
def __init__(self, arch:str):
self.tensor_cores, self.arch = CUDARenderer.tc_sm80 if int(arch[3:]) >= 80 else CUDARenderer.tc_sm75 if int(arch[3:]) >= 75 else [], arch
def __reduce__(self): return self.__class__, (self.arch,)
# language options

View File

@@ -124,11 +124,12 @@ class PTXRenderer(Renderer):
device = "CUDA"
suffix = "PTX"
global_max, local_max, shared_max = CUDARenderer.global_max, CUDARenderer.local_max, CUDARenderer.shared_max
tensor_cores = [tc for tc in CUDARenderer.tensor_cores if tc.dtype_in == dtypes.half]
tc_sm80 = [tc for tc in CUDARenderer.tc_sm80 if tc.dtype_in == dtypes.half]
code_for_op = asm_for_op
extra_matcher = ptx_matcher
def __init__(self, arch:str, device="CUDA"):
self.device, self.tensor_cores, self.arch = device, PTXRenderer.tensor_cores if int(arch[3:]) >= 80 else [], arch
self.device, self.arch = device, arch
self.tensor_cores = PTXRenderer.tc_sm80 if int(arch[3:]) >= 80 else CUDARenderer.tc_sm75 if int(arch[3:]) >= 75 else []
def __reduce__(self): return self.__class__, (self.arch, self.device)
# language options

View File

@@ -138,31 +138,38 @@ class PythonProgram:
ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
elif arg[4] == "AMD":
# A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
def a_elem(x, i, j, goff):
assert x[i][goff+j] == x[i][goff+j+16], "warp elements not duplicated properly across lanes"
return x[i][goff+j]
def a_elem(x, k, row, goff):
assert x[k][goff+row] == x[k][goff+row+16], "warp elements not duplicated properly across lanes"
return x[k][goff+row]
# B (16 elements on 32 threads): row major, lane 16-32 == lane 0-15
def b_elem(x, i, j, goff): return a_elem(x, j, i, goff) # pylint: disable=arguments-out-of-order
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order
def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
elif arg[4] == "CUDA":
# A (8 elements on 32 threads)
def a_elem(x, i, j, goff): return x[(i%2)+(j//8)*2+(i//8)*4][goff+((i//2)%4)+(j%8)*4]
# B (4 elements on 32 threads)
def b_elem(x, i, j, goff): return x[(j%2)+(j//8)*2][goff+(j//2)%4+(i)*4]
# (i, j), C, D (4 elements on 32 threads)
def c_map(lane, elem): return ((elem%2)+(lane%4)*2, (lane//4)+(elem//2)*8)
ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
# (col, row) given (lane, elem) for C & D (4 elements on 32 threads); shared by all tc shapes with M=16 N=8
def c_map(lane, elem): return (elem%2 + (lane%4)*2, lane//4 + (elem//2)*8)
if arg[1] == (8,16,16):
def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2 + (k//8)*4][goff + (k//2)%4 + (row%8)*4]
def b_elem(x, col, k, goff): return x[k%2 + (k//8)*2][goff + (k//2)%4 + col*4]
ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
elif arg[1] == (8,16,8):
def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2][goff + k//2 + (row%8)*4]
def b_elem(x, col, k, goff): return x[k%2][goff + k//2 + col*4]
ul[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
elif arg[4] == "INTEL":
# A (16 elements on 8 threads)
def a_elem(x, i, j, goff): return x[i%2+j*2][goff+i//2]
def a_elem(x, k, row, goff): return x[k%2+row*2][goff+k//2]
# B (16 elements on 8 threads)
def b_elem(x, i, j, goff): return x[j][goff+i]
def b_elem(x, col, k, goff): return x[k][goff+col]
# C, D (8 elements on 8 threads)
def c_map(lane, elem): return (lane, elem)
ul[i] = wmma_helper(8, 16, 16, 16, 8, a_elem, b_elem, c_map)
elif arg[4] == "CLANG":
def elem(x, i, j, _): return x[i+j][0]
def elem(x, col, row, _): return x[col+row][0] # k is always 0
def c_map(_, elem): return (elem%16, elem//16)
ul[i] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map)
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
@@ -179,7 +186,8 @@ class PythonRenderer(Renderer):
def __init__(self):
if getenv("EMULATE_METAL"): self.device, self.tensor_cores = "METAL", MetalRenderer.tensor_cores
if getenv("EMULATE_AMD"): self.device, self.tensor_cores = "AMD", AMDRenderer.tensor_cores
if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tensor_cores
if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tc_sm80
if getenv("EMULATE_CUDA_SM75"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tc_sm75
if getenv("EMULATE_INTEL"): self.device, self.suffix, self.tensor_cores = "INTEL", "INTEL", IntelRenderer.tensor_cores
if getenv("EMULATE_AMX"): self.device, self.tensor_cores = "CLANG", ClangRenderer.tensor_cores

View File

@@ -14,6 +14,7 @@ class TLSFAllocator:
def __init__(self, size:int, base:int=0, block_size:int=16, lv2_cnt:int=16):
self.size, self.base, self.block_size, self.l2_cnt = size, base, block_size, lv2_cnt.bit_length()
self.storage:list = [collections.defaultdict(list) for _ in range(size.bit_length() + 1)]
self.lv1_entries:list[int] = [0] * len(self.storage)
# self.blocks is more like a linked list, where each entry is a contigous block.
self.blocks:dict[int, tuple[int, int|None, int|None, bool]] = {0: (size, None, None, True)} # size, next, prev, is_free
@@ -25,12 +26,14 @@ class TLSFAllocator:
def _insert_block(self, start:int, size:int, prev:int|None=None):
if prev is None: prev = self.blocks[start][2]
self.storage[self.lv1(size)][self.lv2(size)].append(start)
self.lv1_entries[self.lv1(size)] += 1
self.blocks[start] = (size, start + size, prev, True)
return self
def _remove_block(self, start:int, size:int, prev:int|None=None):
if prev is None: prev = self.blocks[start][2]
self.storage[self.lv1(size)][self.lv2(size)].remove(start)
self.lv1_entries[self.lv1(size)] -= 1
self.blocks[start] = (size, start + size, prev, False)
return self
@@ -67,6 +70,7 @@ class TLSFAllocator:
# Search for the smallest block that can fit the requested size. Start with the it's bucket and go up until any block is found.
for l1 in range(self.lv1(size), len(self.storage)):
if self.lv1_entries[l1] == 0: continue
for l2 in range(self.lv2(size) if l1 == size.bit_length() else 0, (1 << self.l2_cnt)):
if len(self.storage[l1][l2]) > 0:
nsize = self.blocks[self.storage[l1][l2][0]][0]

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
import ctypes, collections, time, dataclasses, pathlib
from tinygrad.helpers import to_mv, mv_address, getenv, round_up, DEBUG
import ctypes, collections, time, dataclasses, pathlib, fcntl, os, signal
from tinygrad.helpers import to_mv, mv_address, getenv, round_up, DEBUG, temp
from tinygrad.runtime.autogen.am import am, mp_11_0, mp_13_0_0, nbio_4_3_0, mmhub_3_0_0, gc_11_0_0, osssys_6_0_0
from tinygrad.runtime.support.allocator import TLSFAllocator
from tinygrad.runtime.support.am.ip import AM_SOC21, AM_GMC, AM_IH, AM_PSP, AM_SMU, AM_GFX, AM_SDMA
@@ -255,6 +255,15 @@ class AMDev:
self.pcidev, self.devfmt = pcidev, devfmt
self.vram, self.doorbell64, self.mmio = vram_bar, doorbell_bar, mmio_bar
os.umask(0) # Set umask to 0 to allow creating files with 0666 permissions
# Avoid O_CREAT because we dont want to re-create/replace an existing file (triggers extra perms checks) when opening as non-owner.
if os.path.exists(lock_name:=temp(f"am_{self.devfmt}.lock")): self.lock_fd = os.open(lock_name, os.O_RDWR)
else: self.lock_fd = os.open(lock_name, os.O_RDWR | os.O_CREAT, 0o666)
try: fcntl.flock(self.lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
except OSError: raise RuntimeError(f"Failed to open AM device {self.devfmt}. It's already in use.")
self._run_discovery()
self._build_regs()
@@ -284,14 +293,17 @@ class AMDev:
self.sdma:AM_SDMA = AM_SDMA(self)
if self.partial_boot and (self.reg("regCP_MEC_RS64_CNTL").read() & gc_11_0_0.CP_MEC_RS64_CNTL__MEC_HALT_MASK == 0):
print(f"am {self.devfmt}: MEC is active. Someone might be using the GPU? Issue a full reset.")
if DEBUG >= 2: print(f"am {self.devfmt}: MEC is active. Issue a full reset.")
self.partial_boot = False
if not self.partial_boot:
if self.psp.is_sos_alive(): self.smu.mode1_reset()
for ip in [self.soc21, self.gmc, self.ih, self.psp, self.smu]:
ip.init()
if DEBUG >= 2: print(f"am {self.devfmt}: {ip.__class__.__name__} initialized")
try: # do not interrupt the boot process
signal.signal(signal.SIGINT, signal.SIG_IGN)
if self.psp.is_sos_alive(): self.smu.mode1_reset()
for ip in [self.soc21, self.gmc, self.ih, self.psp, self.smu]:
ip.init()
if DEBUG >= 2: print(f"am {self.devfmt}: {ip.__class__.__name__} initialized")
finally: signal.signal(signal.SIGINT, signal.default_int_handler)
# Booting done
self.is_booting = False

View File

@@ -41,9 +41,9 @@ class Function:
import tinygrad.function as F
def _metaop(op, shape:tuple[sint,...], dtype:DType, device:Union[str, tuple[str, ...]], arg=None, src:tuple[UOp, ...]=()):
if isinstance(device, str): return UOp.metaop(op, shape, dtype, device, arg, src)
return MultiLazyBuffer([UOp.metaop(op, shape, dtype, d, arg, src) for d in device], None)
def _metaop(op, shape:tuple[sint,...], dtype:DType, device:Union[str, tuple[str, ...]], arg=None):
if isinstance(device, str): return UOp.metaop(op, shape, dtype, device, arg)
return MultiLazyBuffer([UOp.metaop(op, shape, dtype, d, arg) for d in device], None)
def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
import numpy as np
@@ -179,7 +179,7 @@ class Tensor(SimpleMathTrait):
# data might be on a different device
if isinstance(device, str): self.lazydata:Union[UOp, MultiLazyBuffer] = data if data.device == device else data.copy_to_device(device)
# if device is a tuple, we should have/construct a MultiLazyBuffer
elif isinstance(data, UOp): self.lazydata = MultiLazyBuffer.from_sharded(data, device, None, None)
elif isinstance(data, UOp): self.lazydata = Tensor(data).shard(device).lazydata
else:
assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}"
self.lazydata = data
@@ -394,33 +394,33 @@ class Tensor(SimpleMathTrait):
if self.grad is not None and real.grad is not None: self.grad.replace(real.grad)
return self.replace(real)
def shard(self, devices:tuple[str, ...], axis:Optional[int]=None, splits:Optional[tuple[int, ...]]=None) -> Tensor:
def shard(self, devices:tuple[str, ...], axis:Optional[int]=None) -> Tensor:
"""
Shards the tensor across the given devices. Optionally specify which axis to shard on, and how to split it across devices.
Shards the tensor across the given devices. Optionally specify which axis to shard on.
```python exec="true" source="above" session="tensor" result="python"
t = Tensor.empty(2, 3)
print(t.shard((t.device, t.device), axis=1, splits=(2, 1)).lazydata)
t = Tensor.empty(2, 4)
print(t.shard((t.device, t.device), axis=1).lazydata)
```
"""
assert isinstance(self.lazydata, UOp), "can't shard a MultiLazyBuffer"
devices, bounds = tuple(Device.canonicalize(x) for x in devices), None
if axis is not None:
devices = tuple(Device.canonicalize(x) for x in devices)
if axis is None: lbs = [self.lazydata] * len(devices)
else:
axis = self._resolve_dim(axis)
if splits is None:
if not isinstance(total:=self.shape[axis], int): raise RuntimeError(f"cannot shard symbolic shape {self.shape=}, {axis=}")
sz = ceildiv(total, len(devices))
splits = tuple([max(0, min(sz, total - sz*i)) for i in range(len(devices))])
assert sum(splits) == self.shape[axis], "specified splits do not sum up to axis shape"
bounds = tuple(itertools.pairwise(itertools.accumulate(splits, initial=0)))
return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, devices, axis, bounds), device=devices, requires_grad=self.requires_grad)
sz = ceildiv(self.shape[axis], len(devices))
sizes = [max(0, min(sz, self.shape[axis] - sz*i)) for i in range(len(devices))]
lbs = [cast(UOp, t.lazydata) for t in self.split(sizes, axis)]
sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(lbs, devices)]
# NOTE: this contiguous is making it impossible for the scheduler to do late const folding
mlb = MultiLazyBuffer([lb.contiguous(allow_buffer_view=False) for lb in sharded_lbs], axis)
return Tensor(mlb, device=devices, requires_grad=self.requires_grad)
def shard_(self, devices:tuple[str, ...], axis:Optional[int]=None, splits:Optional[tuple[int, ...]]=None):
def shard_(self, devices:tuple[str, ...], axis:Optional[int]=None):
"""
Shards the tensor across the given devices in place.
"""
return self.replace(self.shard(devices, axis, splits))
return self.replace(self.shard(devices, axis))
@staticmethod
def from_uop(y:UOp, **kwargs) -> Tensor:
@@ -915,23 +915,29 @@ class Tensor(SimpleMathTrait):
print(dy.tolist()) # dz/dy
```
"""
assert isinstance(self.lazydata, UOp), "multi isn't supported yet"
target_uops: list[UOp] = [x.lazydata for x in targets if isinstance(x.lazydata, UOp)]
assert gradient is not None or self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
grads = compute_gradient(self.lazydata, self.lazydata.const_like(1) if gradient is None else cast(UOp, gradient.lazydata), target_uops)
ret = []
for x in target_uops:
if (y:=grads.get(x)) is None: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.lazydata}")
ret.append(Tensor(y, device=x.device))
return ret
if gradient is None: gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
rets = []
for i,(uop,grad) in enumerate(zip(self.lazydata.lbs, gradient.lazydata.lbs)):
target_uops = [x.lazydata.lbs[i] for x in targets]
grads = compute_gradient(uop, grad, set(target_uops))
ret = []
for x in target_uops:
if (y:=grads.get(x)) is None: raise RuntimeError(f"{x}\n\nnot found in\n\n{uop}")
ret.append(y)
rets.append(ret)
# create returned Tensors
if isinstance(self.lazydata, UOp): return [Tensor(u, device=t.device) for t,u in zip(targets, rets[0])]
return [Tensor(MultiLazyBuffer(list(u), cast(MultiLazyBuffer, t.lazydata).axis, cast(MultiLazyBuffer, t.lazydata).real),
device=t.device) for t,u in zip(targets, zip(*rets))]
def _deepwalk(self):
def _walk(node, visited):
def _deepwalk(self) -> list[Tensor]:
def _walk(node:Tensor, visited:set[Tensor]):
visited.add(node)
# if tensor is not leaf, reset grad
if (ctx := getattr(node, "_ctx", None)) is not None and len(ctx.parents) != 0: node.grad = None
if ctx:
for i in node._ctx.parents:
for i in cast(Function, node._ctx).parents:
if i not in visited: yield from _walk(i, visited)
yield node
return list(_walk(self, set()))
@@ -954,18 +960,22 @@ class Tensor(SimpleMathTrait):
# this is "implicit gradient creation"
gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
toposort_uop = self.lazydata.toposort
assert self.shape == gradient.shape, f"grad shape must match tensor shape, {gradient.shape!r} != {self.shape!r}"
self.grad = gradient
for t0 in reversed(toposorted):
if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad")
token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := t0._ctx.metadata) is not None else None)
grads = t0._ctx.backward(t0.grad.lazydata)
ctx = cast(Function, t0._ctx)
token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := ctx.metadata) is not None else None)
grads = ctx.backward(t0.grad.lazydata)
_METADATA.reset(token)
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
for t, g in zip(t0._ctx.parents, grads):
for g in ([grads] if len(ctx.parents) == 1 else grads)]
for t, g in zip(ctx.parents, grads):
if g is not None and t.requires_grad:
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
assert t.lazydata in toposort_uop or (isinstance(t.lazydata, MultiLazyBuffer) and any(x in toposort_uop for x in t.lazydata.lbs)), \
f"grad uop must have a path from self\ngrad uop: {t.lazydata}"
t.grad = g if t.grad is None else (t.grad + g)
if not retain_graph: del t0._ctx
return self