mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Merge branch 'master' into retinanet_mlperf
This commit is contained in:
32
.github/workflows/benchmark.yml
vendored
32
.github/workflows/benchmark.yml
vendored
@@ -91,7 +91,7 @@ jobs:
|
||||
- name: Run GPT2 w HALF
|
||||
run: HALF=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
|
||||
- name: Run GPT2 w HALF/BEAM
|
||||
run: HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAST_BEFORE_VIEW=0 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
|
||||
run: HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3.11 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
|
||||
- name: Train MNIST
|
||||
run: time PYTHONPATH=. TARGET_EVAL_ACC_PCT=96.0 python3.11 examples/beautiful_mnist.py | tee beautiful_mnist.txt
|
||||
- name: Run 10 CIFAR training steps
|
||||
@@ -202,12 +202,12 @@ jobs:
|
||||
# run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_six_gpu.txt
|
||||
- name: Run LLaMA-3 8B BEAM
|
||||
run: NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_beam.txt
|
||||
- name: Run LLaMA-3 8B on 4 GPUs
|
||||
run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_four_gpu.txt
|
||||
- name: Run LLaMA-3 8B on 6 GPUs
|
||||
run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt
|
||||
- name: Run LLaMA-2 70B
|
||||
run: NV=1 CAPTURE_PROCESS_REPLAY=0 MAX_CONTEXT=256 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_2_70B.txt
|
||||
- name: Run LLaMA-3 8B on 4 GPUs with BEAM
|
||||
run: NV=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_four_gpu.txt
|
||||
# - name: Run LLaMA-3 8B on 6 GPUs
|
||||
# run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt
|
||||
# - name: Run LLaMA-2 70B
|
||||
# run: NV=1 CAPTURE_PROCESS_REPLAY=0 MAX_CONTEXT=256 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_2_70B.txt
|
||||
- name: Run Mixtral 8x7B
|
||||
run: time NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/mixtral.py --temperature 0 --count 10 --timing | tee mixtral.txt
|
||||
- name: Run GPT2
|
||||
@@ -217,7 +217,7 @@ jobs:
|
||||
- name: Run GPT2 w HALF
|
||||
run: NV=1 HALF=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
|
||||
- name: Run GPT2 w HALF/BEAM
|
||||
run: NV=1 HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAST_BEFORE_VIEW=0 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
|
||||
run: NV=1 HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: Speed (NVIDIA)
|
||||
@@ -372,7 +372,7 @@ jobs:
|
||||
#- name: Fuzz Padded Tensor Core GEMM
|
||||
# run: HSA=1 M_START=12 M_STOP=20 M_STEP=1 N_START=12 N_STOP=20 N_STEP=1 K_START=28 K_STOP=36 K_STEP=1 HALF=1 TC_OPT=2 DEBUG=2 python3 ./extra/gemm/fuzz_matmul.py
|
||||
- name: Remove amdgpu
|
||||
run: sudo rmmod amdgpu
|
||||
run: sleep 5 && sudo rmmod amdgpu # sleep a bit to let the driver unload the prev pid.
|
||||
- name: Run Stable Diffusion
|
||||
run: AMD=1 python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing | tee sd.txt
|
||||
- name: Run SDXL
|
||||
@@ -389,14 +389,14 @@ jobs:
|
||||
# run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 1 --size 7B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_six_gpu.txt
|
||||
- name: Run LLaMA-3 8B BEAM
|
||||
run: AMD=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/llama3.py --size 8B --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_beam.txt
|
||||
- name: Run LLaMA-3 8B on 4 GPUs
|
||||
run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_four_gpu.txt
|
||||
- name: Run LLaMA-3 8B on 6 GPUs
|
||||
run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt
|
||||
- name: Run LLaMA-3 8B on 4 GPUs with BEAM
|
||||
run: AMD=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 4 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_four_gpu.txt
|
||||
# - name: Run LLaMA-3 8B on 6 GPUs
|
||||
# run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama3.py --size 8B --shard 6 --model weights/LLaMA-3/8B-SF-DPO/ --benchmark --temperature 0 | tee llama3_six_gpu.txt
|
||||
- name: Restore amdgpu
|
||||
run: sudo modprobe amdgpu
|
||||
- name: Run LLaMA-2 70B
|
||||
run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_2_70B.txt
|
||||
# - name: Run LLaMA-2 70B
|
||||
# run: AMD=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/llama.py --gen 2 --size 70B --shard 6 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_2_70B.txt
|
||||
- name: Run Mixtral 8x7B
|
||||
run: time AMD=1 python3 examples/mixtral.py --temperature 0 --count 10 --timing | tee mixtral.txt
|
||||
- name: Run GPT2
|
||||
@@ -406,7 +406,7 @@ jobs:
|
||||
- name: Run GPT2 w HALF
|
||||
run: AMD=1 HALF=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half.txt
|
||||
- name: Run GPT2 w HALF/BEAM
|
||||
run: AMD=1 HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 CAST_BEFORE_VIEW=0 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
|
||||
run: AMD=1 HALF=1 JITBEAM=2 IGNORE_BEAM_CACHE=1 python3 examples/gpt2.py --count 10 --temperature 0 --timing | tee gpt2_half_beam.txt
|
||||
- uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: Speed (AMD)
|
||||
|
||||
3
.github/workflows/test.yml
vendored
3
.github/workflows/test.yml
vendored
@@ -166,6 +166,7 @@ jobs:
|
||||
- name: Test emulated CUDA tensor cores
|
||||
run: |
|
||||
DEBUG=2 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm_fp16
|
||||
DEBUG=2 EMULATE_CUDA_SM75=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_ops.py TestOps.test_gemm_fp16
|
||||
PYTHONPATH="." DEBUG=2 EMULATE_CUDA=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores TestLinearizer.test_tensor_cores_padded
|
||||
- name: Test emulated INTEL OpenCL tensor cores
|
||||
run: DEBUG=2 EMULATE_INTEL=1 FORWARD_ONLY=1 PYTHON=1 HALF=1 N=64 python3 ./extra/gemm/simple_matmul.py
|
||||
@@ -296,7 +297,7 @@ jobs:
|
||||
- if: ${{ matrix.task == 'optimage' }}
|
||||
name: Test openpilot model kernel count and gate usage
|
||||
run: |
|
||||
PYTHONPATH="." ALLOWED_KERNEL_COUNT=208 ALLOWED_READ_IMAGE=2105 ALLOWED_GATED_READ_IMAGE=29 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
|
||||
PYTHONPATH="." ALLOWED_KERNEL_COUNT=208 ALLOWED_READ_IMAGE=2104 ALLOWED_GATED_READ_IMAGE=29 FLOAT16=0 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/v0.9.4/selfdrive/modeld/models/supercombo.onnx
|
||||
- if: ${{ matrix.task == 'optimage' }}
|
||||
name: Test openpilot alt model correctness (float32)
|
||||
run: PYTHONPATH="." FLOAT16=0 DEBUGCL=1 GPU=1 IMAGE=2 python examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/3799fe46b3a629e491d4b8498b8ae83e4c88c304/selfdrive/modeld/models/supercombo.onnx
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
export PYTHONPATH="."
|
||||
export MODEL="bert"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=72 EVAL_BS=36
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36
|
||||
|
||||
export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=64 BEAM_LOCAL_MAX=512
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
export PYTHONPATH="."
|
||||
export MODEL="bert"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=72 EVAL_BS=36
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36
|
||||
|
||||
export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=64 BEAM_LOCAL_MAX=512
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
export PYTHONPATH="."
|
||||
export MODEL="bert"
|
||||
export SUBMISSION_PLATFORM="tinybox_green"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=72 EVAL_BS=36
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36
|
||||
|
||||
export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=64 BEAM_LOCAL_MAX=512
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
export PYTHONPATH="."
|
||||
export MODEL="bert"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=72 EVAL_BS=36
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36
|
||||
|
||||
export BEAM=3
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
export PYTHONPATH="."
|
||||
export MODEL="bert"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=72 EVAL_BS=36
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36
|
||||
|
||||
export BEAM=3
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
export PYTHONPATH="."
|
||||
export MODEL="bert"
|
||||
export SUBMISSION_PLATFORM="tinybox_red"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=72 EVAL_BS=36
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=36
|
||||
|
||||
export BEAM=3
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
|
||||
5
test/external/external_model_benchmark.py
vendored
5
test/external/external_model_benchmark.py
vendored
@@ -9,6 +9,7 @@ from onnx2torch import convert
|
||||
from extra.onnx import get_run_onnx
|
||||
from tinygrad.helpers import OSX, DEBUG, fetch
|
||||
from tinygrad import Tensor, Device
|
||||
from tinygrad.device import CompileError
|
||||
|
||||
MODELS = {
|
||||
"resnet50": "https://github.com/onnx/models/raw/main/validated/vision/classification/resnet/model/resnet50-caffe2-v1-9.onnx",
|
||||
@@ -72,10 +73,10 @@ def benchmark_model(m, devices, validate_outs=False):
|
||||
for _ in range(3): {k:v.numpy() for k,v in tinygrad_jitted_model(**inputs).items()}
|
||||
benchmark(m, f"tinygrad_{device.lower()}_jit", lambda: {k:v.numpy() for k,v in tinygrad_jitted_model(**inputs).items()}) # noqa: F821
|
||||
del inputs, tinygrad_model, tinygrad_jitted_model
|
||||
except RuntimeError as e:
|
||||
except CompileError as e:
|
||||
# TODO: we don't run the dm model on METAL for now
|
||||
if Device.DEFAULT == "METAL":
|
||||
assert "buffer count limit" in str(e)
|
||||
assert "no 'buffer' resource location available" in str(e)
|
||||
return
|
||||
else: raise e
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# compare kernels created by HEAD against master
|
||||
from collections import defaultdict
|
||||
import os, multiprocessing, logging, pickle, sqlite3, difflib, functools, warnings
|
||||
from typing import Callable, List, Set, Tuple, Union, cast
|
||||
from typing import Callable, List, Tuple, Union, cast
|
||||
from tinygrad.helpers import VERSION, Context, ContextVar, colored, db_connection, getenv, tqdm
|
||||
from tinygrad.engine.schedule import ScheduleContext, schedule_uop
|
||||
from tinygrad.codegen.kernel import Kernel, Opt
|
||||
@@ -30,9 +30,9 @@ class ProcessReplayWarning(Warning): pass
|
||||
|
||||
# *** recreators
|
||||
|
||||
def recreate_sched(ast:UOp, assigns:Set[UOp]) -> UOp:
|
||||
def recreate_sched(ast:UOp) -> UOp:
|
||||
# NOTE: process replay isn't meant to actually schedule anything
|
||||
return schedule_uop(ast, ScheduleContext(assigns=assigns, tensor_uops=defaultdict(list))).ast
|
||||
return schedule_uop(ast, ScheduleContext(tensor_uops=defaultdict(list))).ast
|
||||
def recreate_kernel(ast:UOp, opts:Renderer, applied_opts:List[Opt], name:str) -> str:
|
||||
k = Kernel(ast, opts=opts)
|
||||
for opt in applied_opts: k.apply_opt(opt)
|
||||
|
||||
@@ -33,7 +33,7 @@ def helper_test(nm, gen, model, max_memory_allowed, max_kernels_allowed, all_jit
|
||||
kernels_used = len(model.jit_cache) if hasattr(model, "jit_cache") else None
|
||||
print(f"{nm}: used {mem_used/1e9:.2f} GB and {kernels_used} kernels in {min(tms)/1e6:.2f} ms")
|
||||
assert mem_used/1e9 < max_memory_allowed, f"{nm} used more than {max_memory_allowed:.2f} GB - {mem_used/1e9:.2} GB used"
|
||||
assert not kernels_used or kernels_used <= max_kernels_allowed, f"{nm} used more than {max_kernels_allowed} kernels"
|
||||
assert not kernels_used or kernels_used <= max_kernels_allowed, f"{nm} used more than {max_kernels_allowed} kernels, it used {kernels_used}"
|
||||
if all_jitted:
|
||||
assert kernels_used > 0 and kernels_used == GlobalCounters.kernel_count or (kernels_used <= GlobalCounters.kernel_count and getattr(Device[Device.DEFAULT], "graph", None)), f"only {kernels_used} out of {GlobalCounters.kernel_count} were jitted" # noqa: E501
|
||||
|
||||
|
||||
@@ -66,7 +66,8 @@ class TestArange(unittest.TestCase):
|
||||
return self.test_all_opts([Opt(OptOps.UPCAST, 0, 4), Opt(OptOps.UNROLL, 0, 4)], [Opt(op=OptOps.GROUP, axis=0, amt=0)])
|
||||
|
||||
class TestIndexing(unittest.TestCase):
|
||||
@unittest.expectedFailure
|
||||
# update: passing after CAST_BEFORE_VIEW=1 deletion
|
||||
# @unittest.expectedFailure
|
||||
def test_arange_2_reduce(self):
|
||||
needle = Tensor.zeros(16384, dtype=dtypes.int).contiguous()
|
||||
needle[1337] = 1
|
||||
|
||||
@@ -132,11 +132,12 @@ class TestMovedConstFolding(unittest.TestCase):
|
||||
|
||||
def test_cast_padded(self):
|
||||
# NOTE: this is folded due to CAST_BEFORE_VIEW
|
||||
# update: CAST_BEFORE_VIEW=1 is no longer supported
|
||||
if is_dtype_supported(dtypes.int16):
|
||||
_check_ast_count(0, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16))
|
||||
_check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16))
|
||||
np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).cast(dtypes.int16).numpy(), [0, 1, 1, 1, 1, 0])
|
||||
if is_dtype_supported(dtypes.uint16):
|
||||
_check_ast_count(0, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
|
||||
_check_ast_count(1, Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16))
|
||||
np.testing.assert_equal(Tensor.full(4, fill_value=-1).pad(((1, 1),)).cast(dtypes.uint16).numpy(), [0, 65535, 65535, 65535, 65535, 0])
|
||||
# not folded
|
||||
if is_dtype_supported(dtypes.int64):
|
||||
|
||||
@@ -781,7 +781,8 @@ class TestAutoCastType(unittest.TestCase):
|
||||
if DEBUG >= 2:
|
||||
print(f"testing {default_dtype=}, {dtype=}")
|
||||
a = Tensor([1, 2, 3], dtype=dtype, requires_grad=True)
|
||||
b = (a * 5).sum()
|
||||
# NOTE: this is broken without default_dtype because of CAST_BEFORE_VIEW
|
||||
b = (a * 5).sum(acc_dtype=default_dtype)
|
||||
b.backward() # if there is dtype mismatch, lazy should assert
|
||||
assert a.grad.dtype == a.dtype
|
||||
np.testing.assert_allclose(a.grad.numpy(), [5, 5, 5])
|
||||
|
||||
@@ -120,7 +120,7 @@ class TestImageDType(unittest.TestCase):
|
||||
loss = x.image_dot(w1).image_dot(w2).float().max()
|
||||
loss.backward()
|
||||
sched = unwrap(w1.grad).schedule()
|
||||
self.assertEqual(len(sched), 10)
|
||||
self.assertEqual(len(sched), 9)
|
||||
for s,ei in zip(sched, lower_schedule(sched[:])):
|
||||
ei.run()
|
||||
if s.outputs[0].dtype == dtypes.float:
|
||||
|
||||
@@ -12,7 +12,6 @@ from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
# from tinygrad.ops import Variable
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.engine.schedule import BUF_LIMIT
|
||||
from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner
|
||||
from tinygrad.helpers import prod, Context, getenv, CI, flatten, dedup, AMX
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
@@ -1701,7 +1700,7 @@ class TestHandCodedOpts(unittest.TestCase):
|
||||
# float4/other hcopt shouldn't upcast last axis, since we already have 7 upcast, and the last axis is not very contiguous
|
||||
assert k.upcasted == 1 and k.full_shape[-1] == 7
|
||||
|
||||
@unittest.skipIf((buf_max:=BUF_LIMIT.get(Device.DEFAULT)) is not None and buf_max <= 37, "this test uses too many bufs")
|
||||
@unittest.skipIf(Device.DEFAULT == "METAL", "METAL can only run kernels with up to 32 buffers")
|
||||
def test_masked_upcast_wino(self):
|
||||
monster = Tensor.stack(*[Tensor.stack(*[Tensor.rand(16) for _ in range(6)]) for _ in range(6)])
|
||||
|
||||
|
||||
@@ -43,6 +43,12 @@ class TestMultiTensor(unittest.TestCase):
|
||||
assert lb.shape == (256,)
|
||||
(X + X).realize()
|
||||
|
||||
def test_gradient(self):
|
||||
X = Tensor.ones(256).contiguous().realize()
|
||||
X.to_(devices_2)
|
||||
grad = X.sum().gradient(X)[0]
|
||||
grad.realize()
|
||||
|
||||
def test_shard(self):
|
||||
X = Tensor.ones(256).contiguous().realize()
|
||||
X.shard_(devices_2, 0)
|
||||
@@ -75,7 +81,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
ei.run()
|
||||
assert names[-2] == names[-1], "function was relinearized"
|
||||
|
||||
@unittest.skip("this doesn't fold because from_sharded calls contiguous on all lbs")
|
||||
@unittest.skip("this doesn't fold because shard_ calls contiguous on all lbs")
|
||||
def test_sharded_memory(self):
|
||||
# Buffer may be stuck in track_cross_buffer
|
||||
for x in (d0, d1, d2, d3, d4): Device[x].synchronize()
|
||||
@@ -157,6 +163,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
strat.sampled_from((Ops.ADD, Ops.MUL, Ops.MAX)),
|
||||
strat.sampled_from((None, 0, 1)), strat.sampled_from((None, 0, 1)), strat.sampled_from((1, 0, -1)))
|
||||
def test_simple_reduce(self, N, devices, rop, shard_axis, reduce_axis, sign):
|
||||
N = N * len(devices)
|
||||
X = Tensor.rand(N*N).reshape(N, N).mul(sign)
|
||||
n = X.numpy()
|
||||
X.shard_(devices, shard_axis)
|
||||
@@ -438,6 +445,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
assert isinstance(jf.jit_cache[4].prg, BufferCopy)
|
||||
assert isinstance(jf.jit_cache[5].prg, graph_d1)
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_uneven_shard(self):
|
||||
for N in range(1, 6):
|
||||
X = Tensor.rand(4, 1, 257).contiguous().realize()
|
||||
@@ -450,6 +458,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
np.testing.assert_equal(X.expand((4, 4, 257)).numpy(), np.tile(n, (1, 4, 1)))
|
||||
np.testing.assert_equal(X.permute((0, 2, 1)).numpy(), np.transpose(n, (0, 2, 1)))
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_uneven_multiple_zeros(self):
|
||||
for data in ([1, 2, 3, 4], [1, 2, 3], [1, 2], [1], []):
|
||||
for N in (1, 2, 3, 4):
|
||||
@@ -458,29 +467,28 @@ class TestMultiTensor(unittest.TestCase):
|
||||
X = ((Tensor(data).shard(devices, axis=0) + 1).realize() - 1).realize()
|
||||
np.testing.assert_equal(X.numpy(), data)
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_uneven_shard_with_empty(self):
|
||||
N = 4
|
||||
X = Tensor.rand(16, 1, 17).contiguous().realize()
|
||||
X = Tensor.rand(16, 1, 3).contiguous().realize()
|
||||
np_x = X.numpy()
|
||||
devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N))
|
||||
|
||||
# test empty shard
|
||||
np.testing.assert_equal(X.shard(devices, 0, (2, 2, 12, 0)).numpy(), np_x)
|
||||
np.testing.assert_equal(X.shard(devices, 0).numpy(), np_x)
|
||||
|
||||
# test reshape with empty shard
|
||||
np.testing.assert_equal(X.shard(devices, 0, (2, 2, 12, 0)).reshape(8, 1, 34).numpy(), np_x.reshape(8, 1, 34))
|
||||
|
||||
# test elementwise with empty shard
|
||||
np.testing.assert_equal((X.shard(devices, 0, (2, 2, 12, 0)) + X.shard(devices, 0, (0, 0, 1, 15))).numpy(), np_x + np_x)
|
||||
np.testing.assert_equal(X.shard(devices, 0).reshape(8, 1, 6).numpy(), np_x.reshape(8, 1, 6))
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_multiple_uneven_shard(self):
|
||||
N = 4
|
||||
X = Tensor.rand(4, 1, 257).contiguous().realize()
|
||||
Y = Tensor.rand(4, 1, 257).contiguous().realize()
|
||||
np_x, np_y = X.numpy(), Y.numpy()
|
||||
devices = tuple(f"{Device.DEFAULT}:{i}" for i in range(N))
|
||||
X.shard_(devices, 2, (2, 38, 47, 170))
|
||||
Y.shard_(devices, 2, (34, 53, 51, 119))
|
||||
X.shard_(devices, 2)
|
||||
Y.shard_(devices, 2)
|
||||
np.testing.assert_equal(X.numpy(), np_x)
|
||||
np.testing.assert_equal(Y.numpy(), np_y)
|
||||
np.testing.assert_equal((X + Y).numpy(), np_x + np_y)
|
||||
@@ -534,6 +542,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
with self.assertRaises((AssertionError, ValueError)):
|
||||
t0.reshape((26*15,7))
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_reshape_on_axis_uneven(self):
|
||||
def reshape_helper(t0, t, t_axis):
|
||||
np.testing.assert_allclose(t0.reshape(t.shape).numpy(), t.numpy())
|
||||
@@ -605,8 +614,9 @@ class TestMultiTensor(unittest.TestCase):
|
||||
self.assertEqual(t.dtype, t2.dtype)
|
||||
self.assertEqual(t.lazydata.axis, t2.lazydata.axis)
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_rand_like_uneven_shard(self):
|
||||
t = Tensor.empty((4, 42, 15)).shard(devices_3, axis=1, splits=(14, 7, 21))
|
||||
t = Tensor.empty((4, 42, 15)).shard(devices_3, axis=1)
|
||||
t2 = Tensor.rand_like(t)
|
||||
self.assertEqual(t.shape, t2.shape)
|
||||
self.assertEqual(t.device, t2.device)
|
||||
@@ -655,9 +665,10 @@ class TestMultiTensor(unittest.TestCase):
|
||||
assert set(unique) == {0, 2}, unique
|
||||
assert 200 < counts[0] < 312, counts[0]
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_dropout_on_uneven_shard_axis(self):
|
||||
with Tensor.train():
|
||||
X = Tensor.ones(256).shard(devices_3, axis=0, splits=(100, 50, 106))
|
||||
X = Tensor.ones(256).shard(devices_3, axis=0)
|
||||
output = X.dropout(0.5).numpy()
|
||||
unique, counts = np.unique(output, return_counts=True)
|
||||
assert set(unique) == {0, 2}, unique
|
||||
@@ -689,6 +700,7 @@ class TestMultiTensor(unittest.TestCase):
|
||||
assert ast.src[2].src[0].op is Ops.LOAD
|
||||
assert ast.src[2].src[1].src[1].op is Ops.CONST and ast.src[2].src[1].src[1].arg == 3
|
||||
|
||||
@unittest.skip("TODO: this requires forced_realize to be deleted.")
|
||||
def test_shard_memory(self):
|
||||
devices = (d0, d1, d2, d3)
|
||||
t = Tensor.zeros(16, 16).contiguous()
|
||||
@@ -816,6 +828,7 @@ class TestShrinkMultiTensorShardedAxis(unittest.TestCase):
|
||||
np.testing.assert_allclose(a.reshape((2, 1, 8)).expand((2, 5, 8)).numpy(), b.reshape((2, 1, 8)).expand((2, 5, 8)).numpy(), rtol=1e-7, atol=1e-3)
|
||||
np.testing.assert_allclose(a.flip(-1).numpy(), b.flip(-1).numpy(), rtol=1e-7, atol=1e-3)
|
||||
|
||||
@unittest.skip("no longer supports uneven shard")
|
||||
def test_uneven(self):
|
||||
t = Tensor.arange(24).reshape(3, 8).contiguous().realize()
|
||||
t.shard_([f"{Device.DEFAULT}:{i}" for i in range(2)], axis=0)
|
||||
|
||||
@@ -551,7 +551,7 @@ class TestNN(unittest.TestCase):
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
|
||||
def test_load_state_dict_sharded_model(self):
|
||||
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2")
|
||||
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3")
|
||||
|
||||
layer = Conv2d(3, 5, kernel_size=3)
|
||||
layer.weight.shard_(devices, 3)
|
||||
@@ -572,7 +572,7 @@ class TestNN(unittest.TestCase):
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
|
||||
def test_load_state_dict_sharded_dict(self):
|
||||
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2")
|
||||
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3")
|
||||
|
||||
layer = Conv2d(3, 5, kernel_size=3)
|
||||
state_dict = {
|
||||
@@ -589,7 +589,7 @@ class TestNN(unittest.TestCase):
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
|
||||
def test_load_state_dict_sharded_model_dict_same_axis(self):
|
||||
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2")
|
||||
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3")
|
||||
|
||||
layer = Conv2d(3, 5, kernel_size=3)
|
||||
layer.weight.shard_(devices, 3)
|
||||
@@ -610,7 +610,8 @@ class TestNN(unittest.TestCase):
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"GPU", "CUDA", "METAL"}, "no GPU CI")
|
||||
def test_load_state_dict_sharded_model_dict_different_axis(self):
|
||||
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2")
|
||||
devices = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3")
|
||||
devices5 = (f"{Device.DEFAULT}:1", f"{Device.DEFAULT}:2", f"{Device.DEFAULT}:3", f"{Device.DEFAULT}:4", f"{Device.DEFAULT}:5")
|
||||
|
||||
layer = Conv2d(3, 5, kernel_size=3)
|
||||
layer.weight.shard_(devices, 3)
|
||||
@@ -619,14 +620,14 @@ class TestNN(unittest.TestCase):
|
||||
# different shard axis
|
||||
state_dict = {
|
||||
'weight': Tensor.randn(5, 3, 3, 3).shard(devices, None),
|
||||
'bias': Tensor.randn(5).shard(devices, 0),
|
||||
'bias': Tensor.randn(5).shard(devices5, 0),
|
||||
}
|
||||
load_state_dict(layer, state_dict)
|
||||
|
||||
# NOTE: model and state_dict shard differently, use the state_dict sharding # TODO: revisit this?
|
||||
self.assertEqual(layer.weight.device, devices)
|
||||
self.assertEqual(layer.weight.lazydata.axis, None)
|
||||
self.assertEqual(layer.bias.device, devices)
|
||||
self.assertEqual(layer.bias.device, devices5)
|
||||
self.assertEqual(layer.bias.lazydata.axis, 0)
|
||||
np.testing.assert_allclose(layer.weight.numpy(), state_dict['weight'].numpy())
|
||||
np.testing.assert_allclose(layer.bias.numpy(), state_dict['bias'].numpy())
|
||||
|
||||
@@ -16,7 +16,7 @@ from tinygrad.shape.view import View
|
||||
from tinygrad.ops import PatternMatcher, UOp, Ops, UPat, graph_rewrite, track_rewrites, view_supported_devices, symbolic_simple, merge_views
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, getenv, SPLIT_REDUCEOP, unwrap, prod, Context
|
||||
from tinygrad.codegen.kernel import verify_ast
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops
|
||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars, view_right, view_left, remove_movement_ops
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule, lower_schedule
|
||||
from extra.models.llama import precompute_freqs_cis
|
||||
|
||||
@@ -585,6 +585,15 @@ class TestSchedule(unittest.TestCase):
|
||||
run_schedule(check_schedule(out, 2))
|
||||
np.testing.assert_allclose(out.numpy(), np.ones((64,64)))
|
||||
|
||||
def test_example_matmul_same(self):
|
||||
x = Tensor.eye(64, requires_grad=True)
|
||||
z = x.matmul(x).sum()
|
||||
z.backward()
|
||||
out = x.grad.contiguous()
|
||||
run_schedule(check_schedule(out, 2))
|
||||
# NOTE: the gradient flows twice
|
||||
np.testing.assert_allclose(out.numpy(), 2*np.ones((64,64)))
|
||||
|
||||
def test_contiguous_add(self):
|
||||
x = Tensor.empty(32)
|
||||
y = Tensor.empty(32)
|
||||
@@ -1363,8 +1372,9 @@ class TestSchedule(unittest.TestCase):
|
||||
@unittest.expectedFailure
|
||||
def test_conv2d_fused_half(self): _test_conv2d(5, dtype=dtypes.half)
|
||||
|
||||
@unittest.skip("splitting kernels exceeding device buffer count is not yet supported")
|
||||
def _test_buf_cnt(self, cnt:int, allowed:int):
|
||||
if (m:=BUF_LIMIT.get(Device.DEFAULT)) is None or m != 32: self.skipTest(f"test needs a buf_max of 32 {Device.DEFAULT}")
|
||||
#if (m:=BUF_LIMIT.get(Device.DEFAULT)) is None or m != 32: self.skipTest(f"test needs a buf_max of 32 {Device.DEFAULT}")
|
||||
alu = functools.reduce(lambda x,y: x+y, [Tensor.ones((1, 1)).contiguous().realize() for _ in range(cnt-1)])
|
||||
s = alu.schedule()
|
||||
assert len(s) == allowed
|
||||
@@ -1435,6 +1445,7 @@ class TestSchedule(unittest.TestCase):
|
||||
def test_late_fusion_post_expand(self):
|
||||
self._test_fusion([(32, 32)], lambda a:a-a.sum(1), 2)
|
||||
|
||||
@unittest.skip("CAST_BEFORE_VIEW=1 is not supported")
|
||||
def test_cast_padded_view(self):
|
||||
a = Tensor.arange(4).reshape(1, 4)
|
||||
casted_view = a.pad(((0, 1), (0, 0))).cast(dtypes.float)
|
||||
@@ -1445,6 +1456,7 @@ class TestSchedule(unittest.TestCase):
|
||||
self.assertListEqual(realized_view.tolist(), [[0.0, 1.0, 2.0, 3.0], [0.0, 0.0, 0.0, 0.0]])
|
||||
|
||||
# NOTE: we might want to reconsider pushing this cast before the shrink
|
||||
@unittest.skip("CAST_BEFORE_VIEW=1 is not supported")
|
||||
def test_cast_after_shrink(self):
|
||||
a = Tensor.arange(4).reshape(1, 4)
|
||||
casted_view = a.shrink(((0, 1), (0, 2))).cast(dtypes.float)
|
||||
@@ -1454,6 +1466,7 @@ class TestSchedule(unittest.TestCase):
|
||||
self.assertEqual(realized_view.lazydata.base.realized.size, 2)
|
||||
self.assertListEqual(realized_view.tolist(), [[0, 1]])
|
||||
|
||||
@unittest.skip("CAST_BEFORE_VIEW=1 is not supported")
|
||||
def test_cast_const_view(self):
|
||||
a = Tensor.ones((4, 4), dtype=dtypes.float32)
|
||||
casted_view = a.cast(dtypes.int32)
|
||||
@@ -1463,6 +1476,7 @@ class TestSchedule(unittest.TestCase):
|
||||
run_schedule(check_schedule(realized_const_view, 1))
|
||||
self.assertListEqual(realized_const_view.tolist(), [[1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1], [1, 1, 1, 1]])
|
||||
|
||||
@unittest.skip("CAST_BEFORE_VIEW=1 is not supported")
|
||||
def test_cast_padded_const(self):
|
||||
a = Tensor(1, dtype=dtypes.int32).reshape(1, 1).pad(((1, 1), None))
|
||||
casted_view = a.cast(dtypes.float32)
|
||||
@@ -1565,7 +1579,7 @@ class TestIndexing(unittest.TestCase):
|
||||
x = Tensor.randn(5, 2).realize()
|
||||
a = Tensor.arange(10)
|
||||
out = (x + a[2]).sum()
|
||||
self.check_schedule(out, 1)
|
||||
self.check_schedule(out, 2)
|
||||
np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum(), atol=1e-5, rtol=1e-6)
|
||||
|
||||
def test_arange_index_contiguous(self):
|
||||
@@ -1573,7 +1587,7 @@ class TestIndexing(unittest.TestCase):
|
||||
x = Tensor.randn(5, 2).realize()
|
||||
a = Tensor.arange(10).contiguous()
|
||||
out = (x + a[2]).sum()
|
||||
self.check_schedule(out, 2)
|
||||
self.check_schedule(out, 3)
|
||||
np.testing.assert_allclose(out.numpy(), (x.numpy()+np.arange(10)[2]).sum(), atol=1e-5, rtol=1e-6)
|
||||
|
||||
def test_arange_index_child(self):
|
||||
@@ -1581,7 +1595,7 @@ class TestIndexing(unittest.TestCase):
|
||||
x = Tensor.randn(5, 2).realize()
|
||||
a = Tensor.arange(10)+1
|
||||
out = (x + a[2]).sum()
|
||||
self.check_schedule(out, 1)
|
||||
self.check_schedule(out, 2)
|
||||
np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum(), atol=1e-5, rtol=1e-6)
|
||||
|
||||
def test_arange_index_contiguous_child(self):
|
||||
@@ -1589,7 +1603,7 @@ class TestIndexing(unittest.TestCase):
|
||||
x = Tensor.randn(5, 2).realize()
|
||||
a = (Tensor.arange(10)+1).contiguous()
|
||||
out = (x + a[2]).sum()
|
||||
self.check_schedule(out, 2)
|
||||
self.check_schedule(out, 3)
|
||||
np.testing.assert_allclose(out.numpy(), (x.numpy()+(np.arange(10)+1)[2]).sum(), atol=1e-5, rtol=1e-6)
|
||||
|
||||
def test_arange_childless_base(self):
|
||||
|
||||
@@ -464,7 +464,7 @@ class TestTinygrad(unittest.TestCase):
|
||||
def test_repr_with_grad(self):
|
||||
a = Tensor([1], requires_grad=True)
|
||||
b = Tensor([1])
|
||||
c = (a + b).mean().backward()
|
||||
c = (a + b).sum().backward()
|
||||
print(a)
|
||||
print(c)
|
||||
|
||||
@@ -646,11 +646,20 @@ class TestZeroShapeTensor(unittest.TestCase):
|
||||
|
||||
def test_clone(self):
|
||||
a = Tensor.rand(16, 16).realize()
|
||||
self.assertIsNot(a.lazydata, a.clone().lazydata)
|
||||
np.testing.assert_allclose(a.numpy(), a.clone().numpy())
|
||||
|
||||
a = Tensor.rand(16, 16).mul(5.0).add(5.0)
|
||||
self.assertIsNot(a.lazydata, a.clone().lazydata)
|
||||
np.testing.assert_allclose(a.numpy(), a.clone().numpy())
|
||||
|
||||
def test_clone_with_shrink(self):
|
||||
a = Tensor.empty(16, 16)
|
||||
self.assertIsNot(a.lazydata, a.clone().lazydata)
|
||||
|
||||
b = a.shrink(((2, 10), None))
|
||||
self.assertIsNot(b.lazydata, b.clone().lazydata)
|
||||
|
||||
def test_clone_with_grad(self):
|
||||
a = Tensor.rand(16, 16, requires_grad=True)
|
||||
a.mul(5.0).add(5.0).mean().backward()
|
||||
|
||||
@@ -3,9 +3,9 @@ import numpy as np
|
||||
import unittest
|
||||
from tinygrad import Tensor, Device, dtypes
|
||||
from tinygrad.engine.realize import run_schedule
|
||||
from tinygrad.ops import Ops, UOp
|
||||
from tinygrad.ops import Ops, UOp, UPat
|
||||
|
||||
class TestLazyBuffer(unittest.TestCase):
|
||||
class TestTensorUOp(unittest.TestCase):
|
||||
def test_fromcpu_shape_tracker(self):
|
||||
def helper(a: np.ndarray):
|
||||
print(a.shape, a.strides, a.flags.c_contiguous)
|
||||
@@ -68,7 +68,7 @@ class TestLazyBuffer(unittest.TestCase):
|
||||
assert lb.const_like(1).const_arg == 1.0
|
||||
assert type(lb.const_like(1).const_arg) is float
|
||||
|
||||
def test_forced_realized_alu(self):
|
||||
def test_contiguous_alu(self):
|
||||
a = Tensor.randn(2, 2).realize()
|
||||
b = Tensor.randn(2, 2).realize()
|
||||
add = (a+b).contiguous()
|
||||
@@ -84,13 +84,14 @@ class TestLazyBuffer(unittest.TestCase):
|
||||
sched = empty.schedule()
|
||||
self.assertEqual(len(sched), 0)
|
||||
|
||||
reduce_kernel = UPat(Ops.SINK, src=(UPat(Ops.STORE, src=(UPat(), UPat(), UPat(Ops.REDUCE_AXIS)))))
|
||||
class TestReduceOp(unittest.TestCase):
|
||||
def test_no_split_reduce_kernel(self):
|
||||
a = Tensor.rand(4, 4).realize()
|
||||
a = a.sum()
|
||||
sched = a.schedule()
|
||||
assert len(sched) == 1
|
||||
self.assertIs(sched[0].ast.src[0].src[2].op, Ops.REDUCE_AXIS)
|
||||
assert reduce_kernel.match(sched[0].ast, {})
|
||||
|
||||
def test_split_reduce_kernel_dim0(self):
|
||||
a = Tensor.rand(256, 255).realize()
|
||||
@@ -98,7 +99,7 @@ class TestReduceOp(unittest.TestCase):
|
||||
sched = a.schedule()
|
||||
assert len(sched) == 2
|
||||
for s in sched:
|
||||
self.assertIs(s.ast.src[0].src[2].op, Ops.REDUCE_AXIS)
|
||||
assert reduce_kernel.match(s.ast, {})
|
||||
|
||||
def test_split_reduce_kernel_dim1(self):
|
||||
a = Tensor.rand(255, 256).realize()
|
||||
@@ -106,7 +107,7 @@ class TestReduceOp(unittest.TestCase):
|
||||
sched = a.schedule()
|
||||
assert len(sched) == 2
|
||||
for s in sched:
|
||||
self.assertIs(s.ast.src[0].src[2].op, Ops.REDUCE_AXIS)
|
||||
assert reduce_kernel.match(s.ast, {})
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -81,6 +81,8 @@ class TestTiny(unittest.TestCase):
|
||||
|
||||
# *** a model ***
|
||||
|
||||
# TODO: this is failing because of how swizzling rewrites the ShapeTracker of the final STORE
|
||||
@unittest.skipIf(IMAGE>0, "failing because of make things that can't be images not images")
|
||||
def test_mnist_model(self):
|
||||
layers = [
|
||||
nn.Conv2d(1, 32, 5), Tensor.relu,
|
||||
|
||||
@@ -93,6 +93,12 @@ class TestTensorGradient(unittest.TestCase):
|
||||
dx = z.gradient(x, gradient=dz)[0]
|
||||
self.assertListEqual(dx.tolist(), [2.0, 4.0, 6.0])
|
||||
|
||||
def test_cast_before_view(self):
|
||||
x = Tensor([1.0, 1, 1, 1])
|
||||
x_reshaped = x.reshape(2,2)
|
||||
x_casted = x_reshaped.cast(dtypes.float16)
|
||||
x_casted.mean().gradient(x_reshaped)
|
||||
|
||||
class TestRealizeMeansRealize(unittest.TestCase):
|
||||
def test_randn_realizes(self):
|
||||
x = Tensor.randn(2, 3, 64, 64, requires_grad=True).realize()
|
||||
@@ -104,5 +110,11 @@ class TestRealizeMeansRealize(unittest.TestCase):
|
||||
print(x.lazydata)
|
||||
self.assertEqual(x.lazydata.op, Ops.VIEW)
|
||||
|
||||
# NOTE: even though it doesn't realize, this seems fine
|
||||
def test_uniform_gradient(self):
|
||||
x = Tensor.uniform(16, 3, 3, 3, requires_grad=True).realize()
|
||||
y = x * 2
|
||||
y.sum().gradient(x)[0].realize()
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -1,39 +1,6 @@
|
||||
import unittest
|
||||
from tinygrad import dtypes, Tensor
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.ops import UOp, symbolic, graph_rewrite_map, _substitute
|
||||
from test.unit.test_tensor_uop_representation import is_pattern, realized_pattern, is_pattern_uop
|
||||
|
||||
class TestTensorMutates(unittest.TestCase):
|
||||
def test_mutate_add(self):
|
||||
a = Tensor([1,2,3])
|
||||
b = Tensor([4,5,6])
|
||||
ret = a+b
|
||||
pa = a.lazydata
|
||||
pb = b.lazydata
|
||||
pr = ret.lazydata
|
||||
ret.schedule()
|
||||
self.assertIsNot(pa, a.lazydata)
|
||||
self.assertIsNot(pb, b.lazydata)
|
||||
self.assertIsNot(pr, ret.lazydata)
|
||||
for t in [a,b,ret]: is_pattern(t, realized_pattern)
|
||||
|
||||
def test_reshape_is_same_parent(self):
|
||||
a = Tensor([1,2,3])
|
||||
b = Tensor([4,5,6])
|
||||
c = a+b
|
||||
d = (a+b).reshape(3,1)
|
||||
d.realize()
|
||||
is_pattern_uop(d.lazydata.base, realized_pattern)
|
||||
is_pattern_uop(c.lazydata.base, realized_pattern)
|
||||
|
||||
def test_reshape_is_same_child(self):
|
||||
a = Tensor([1,2,3])
|
||||
b = Tensor([4,5,6])
|
||||
c = a+b
|
||||
d = (a+b).reshape(3,1)
|
||||
c.realize()
|
||||
is_pattern_uop(c.lazydata.base, realized_pattern)
|
||||
is_pattern_uop(d.lazydata.base, realized_pattern)
|
||||
|
||||
class TestRewriteMap(unittest.TestCase):
|
||||
def test_substitute(self):
|
||||
|
||||
@@ -19,7 +19,10 @@ class TestTensorMutates(unittest.TestCase):
|
||||
self.assertIsNot(pa, a.lazydata)
|
||||
self.assertIsNot(pb, b.lazydata)
|
||||
self.assertIsNot(pr, ret.lazydata)
|
||||
for t in [a,b,ret]: is_pattern(t, realized_pattern)
|
||||
# NOTE: this becomes a VIEW(VIEW(BUFFER)) because UOp.view no longer instantly folds contiguous VIEW of the same shape
|
||||
# this is fine because realized exists on the base.
|
||||
# TODO: we can make this always be a VIEW(BUFFER) once BUFFER has a ShapeTracker of shape=(N,)
|
||||
for t in [a,b,ret]: is_pattern_uop(t.lazydata.base, realized_pattern)
|
||||
|
||||
def test_reshape_is_same_parent(self):
|
||||
a = Tensor([1,2,3])
|
||||
@@ -43,14 +46,14 @@ class TestTensorUopRepresentation(unittest.TestCase):
|
||||
def test_realized(self):
|
||||
a = Tensor([1.,2,3]).realize()
|
||||
print(a.lazydata)
|
||||
is_pattern(a, realized_pattern)
|
||||
is_pattern_uop(a.lazydata.base, realized_pattern)
|
||||
|
||||
def test_add_realized(self):
|
||||
a = Tensor([1.,2,3]).realize()
|
||||
b = Tensor([4.,5,6]).realize()
|
||||
c = a+b
|
||||
print(c.lazydata)
|
||||
is_pattern(c, UPat(Ops.ADD, src=(realized_pattern, realized_pattern)))
|
||||
is_pattern(c, UPat(Ops.ADD, src=(UPat(Ops.VIEW, src=(realized_pattern,)), UPat(Ops.VIEW, src=(realized_pattern,)))))
|
||||
|
||||
def test_const_pattern(self):
|
||||
a = Tensor(1)
|
||||
@@ -107,7 +110,8 @@ class TestTensorUopRepresentation(unittest.TestCase):
|
||||
a = Tensor([1.,2,3]).realize()
|
||||
c = a.to("TEST") # NOTE: this isn't checked
|
||||
print(c.lazydata)
|
||||
is_pattern(c, UPat(Ops.COPY, src=(UPat(Ops.DEVICE), realized_pattern,)))
|
||||
# TODO: COPY on a Tensor becomes a VIEW(COPY), this should be done in the scheduler not in ops
|
||||
is_pattern(c, UPat(Ops.VIEW, src=(UPat(Ops.COPY, src=(UPat(Ops.DEVICE), realized_pattern,)),)))
|
||||
|
||||
if __name__ == '__main__':
|
||||
unittest.main()
|
||||
|
||||
@@ -13,8 +13,6 @@ from tinygrad.device import Buffer
|
||||
# creation can recurse a lot
|
||||
sys.setrecursionlimit(10000)
|
||||
|
||||
BUF_LIMIT = {"METAL":32}
|
||||
|
||||
# **** big graph spec
|
||||
|
||||
tensor_uop_spec = PatternMatcher([
|
||||
@@ -185,14 +183,9 @@ view_right = merge_views+PatternMatcher([
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ScheduleItemContext:
|
||||
ops_metadata: dict[UOp, Metadata]
|
||||
assigns: set[UOp]
|
||||
var_vals: dict[Variable, int]
|
||||
sinked: dict[UOp, UOp]
|
||||
sts: set[ShapeTracker] = field(default_factory=set)
|
||||
bufs: list[UOp] = field(default_factory=list)
|
||||
metadata: set[Metadata] = field(default_factory=set)
|
||||
assign_adj: dict[UOp, list[UOp]] = field(default_factory=dict)
|
||||
|
||||
def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> UOp|None:
|
||||
if (st:=unwrap(x.st)) in ctx.sts: return None
|
||||
@@ -204,53 +197,49 @@ def _append_st_vars(ctx:ScheduleItemContext, x:UOp) -> UOp|None:
|
||||
def _append_buf(ctx:ScheduleItemContext, x:UOp) -> UOp:
|
||||
ctx.bufs.append(x)
|
||||
return UOp(Ops.DEFINE_GLOBAL, x.dtype.ptr(size=x.size), (), len(ctx.bufs)-1)
|
||||
append_bufs = PatternMatcher([(UPat(Ops.BUFFER, name="x"), _append_buf)])
|
||||
|
||||
def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp:
|
||||
(adj_loads:=ctx.assign_adj.setdefault(b, [])).append(x)
|
||||
if not all_same([x.op for x in adj_loads]): raise RuntimeError(f"Detected cycle when fusing {adj_loads}. Can only fuse PRELOAD or LOAD of {b}")
|
||||
return x.replace(op=Ops.LOAD)
|
||||
check_preload = PatternMatcher([(UPat(Ops.PRELOAD, src=(UPat.var("b"), UPat()), name="x"), _append_preload),])
|
||||
|
||||
to_si = PatternMatcher([
|
||||
# BUFFER -> DEFINE_GLOBAL
|
||||
(UPat(Ops.BUFFER, name="x"), _append_buf),
|
||||
# simplify and unbind the final VIEWs
|
||||
(UPat(Ops.VIEW, name="x"), _append_st_vars),
|
||||
# don't need SINK on COPY or BUFFER_VIEW
|
||||
(UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat(GroupOp.Meta, name="x")),)), lambda b,x: x.replace(src=(b, *x.src))),
|
||||
# don't need contiguous or assign anymore
|
||||
(UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda x: x),
|
||||
(UPat(Ops.ASSIGN, src=(UPat(), UPat.var("x"),)), lambda x: x),
|
||||
# PRELOAD becomes LOAD
|
||||
(UPat(Ops.PRELOAD, name="root"), lambda root:root.replace(op=Ops.LOAD)),
|
||||
])
|
||||
|
||||
add_metadata = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: None if (m:=ctx.ops_metadata.get(x)) is None else ctx.metadata.add(m)),])
|
||||
add_assign_adjacents = PatternMatcher([(UPat.load(UPat.var("b"), UPat(), name="x"), lambda ctx,b,x: ctx.assign_adj.setdefault(b, []).append(x)
|
||||
if b in ctx.assigns else None)])
|
||||
|
||||
# late folding for multi output kernels
|
||||
multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: ctx.sinked.get(b)),])
|
||||
# LOAD(BUFFER) -> the STORE value if it's we're doing the STORE in the same kernel
|
||||
multioutput = PatternMatcher([(UPat.load(UPat.var("b"), UPat()), lambda ctx,b: ctx.get(b)),])
|
||||
|
||||
def schedule_uop(pre:UOp, ctx:ScheduleContext) -> ScheduleItem:
|
||||
# create the ast context
|
||||
si_ctx = ScheduleItemContext(ctx.ops_metadata, ctx.assigns, ctx.var_vals, {x.buf_uop:x.src[2] for x in pre.src})
|
||||
create_ctx = add_metadata if len(si_ctx.assigns) == 0 else add_metadata+add_assign_adjacents
|
||||
sink = graph_rewrite(pre, create_ctx if len(si_ctx.sinked) == 1 else multioutput+create_ctx, si_ctx)
|
||||
# do movement ops
|
||||
sink = graph_rewrite(graph_rewrite(sink, view_left), view_right)
|
||||
# convert to AST
|
||||
sink = graph_rewrite(graph_rewrite(sink, to_si+check_preload if len(si_ctx.assigns) != 0 else to_si, si_ctx), append_bufs, si_ctx)
|
||||
# assert buffer count limit
|
||||
if (limit:=BUF_LIMIT.get(device:=si_ctx.bufs[0].device)) is not None and len(si_ctx.bufs) >= limit:
|
||||
if DEBUG >= 3: print(sink)
|
||||
raise RuntimeError(f"Kernel for {si_ctx.metadata} exceeded the {limit} buffer count limit for {device} with {len(si_ctx.bufs)} buffers.")
|
||||
# we also allow masked views. if it has a single view and it's equal when you shrink a contig, it's fine
|
||||
for ubuf,ops in si_ctx.assign_adj.items():
|
||||
if si_ctx.sinked.get(ubuf) is not None and not all((s:=x.st_arg).contiguous or (len(s.views) == 1 and (m:=s.views[0].mask) is not None \
|
||||
and ShapeTracker.from_shape(s.shape).shrink(m) == s.shrink(m)) for x in ops):
|
||||
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
# remove movement ops + substitute LOAD of fused STORE with just the value
|
||||
sink = graph_rewrite(graph_rewrite(pre, multioutput+view_left, store_bufs:={x.buf_uop:x.src[2] for x in pre.src}), view_right)
|
||||
# remove extra uops from SINK + substitue BUFFER with DEFINE_GLOBAL
|
||||
ast = graph_rewrite(sink, to_si, si_ctx:=ScheduleItemContext(ctx.var_vals))
|
||||
# deal with ASSIGN
|
||||
assign_preloads: list[UOp] = []
|
||||
if len(ctx.assigns) != 0:
|
||||
for x in list(sink.toposort)[::-1]:
|
||||
# we only allow a kernel to depend on either the before ASSIGN or after ASSIGN version of a BUFFER
|
||||
if x.op is Ops.LOAD and x.buf_uop in assign_preloads: raise RuntimeError("cycle detected in graph")
|
||||
# PRELOAD tells the toposort this kernel should run before ASSIGN
|
||||
if x.op is Ops.PRELOAD:
|
||||
assign_preloads.append(x.buf_uop)
|
||||
# if this kernel also assigns to the buffer, we only allow either contiguous or masked views for the LOAD
|
||||
if x.buf_uop in store_bufs and not (st:=x.st_arg).contiguous:
|
||||
# if it has a single view and it's equal when you shrink a contig, it's fine
|
||||
if len(st.views) != 1 or (mask:=st.views[0].mask) is None or ShapeTracker.from_shape(st.shape).shrink(mask) != st.shrink(mask):
|
||||
raise RuntimeError("self operand of augmented assign must be contiguous.\nhelp: consider using .contiguous():\n"
|
||||
+colored(" - a += a.T\n", "red")+colored(" + a += a.T.contiguous()", "green"))
|
||||
# capture process replay
|
||||
if CAPTURE_PROCESS_REPLAY:
|
||||
with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, si_ctx.assigns, ContextVar._cache, sink))
|
||||
return ScheduleItem(sink, tuple(u.buffer for u in si_ctx.bufs if u.size != 0), tuple(si_ctx.metadata),
|
||||
tuple(ubuf for ubuf,ops in si_ctx.assign_adj.items() if any(x.op is Ops.PRELOAD for x in ops)))
|
||||
with Context(PICKLE_BUFFERS=0): PROCESS_REPLAY_CAPTURE[str(pre.key)] = pickle.dumps((pre, ContextVar._cache, ast))
|
||||
return ScheduleItem(ast, tuple(u.buffer for u in si_ctx.bufs if u.size != 0),
|
||||
tuple(dedup(m for x in pre.toposort if (m:=ctx.ops_metadata.get(x)) is not None)), tuple(dedup(assign_preloads)))
|
||||
|
||||
PROCESS_REPLAY_CAPTURE: dict[str, bytes] = {}
|
||||
if CAPTURE_PROCESS_REPLAY:
|
||||
@@ -511,7 +500,7 @@ break_sched = PatternMatcher([
|
||||
def append_uop(ctx:ScheduleContext, view:UOp, buf_uop:UOp) -> None:
|
||||
ctx.allbufs[buf_uop] = view
|
||||
if (op:=uval(view)).op is Ops.ASSIGN: ctx.assigns.add(buf_uop)
|
||||
for x in op.src:
|
||||
for x in op.base.src:
|
||||
if is_scheduled(x.base): ctx.children.setdefault(x.base.buf_uop, {})[buf_uop] = None
|
||||
# BUFFER_VIEW overrides the underlying buffer
|
||||
# TODO: this should be a shrink on the buffer
|
||||
@@ -551,12 +540,13 @@ def create_schedule_with_vars(outs:list[UOp], skip_check:bool=not __debug__) ->
|
||||
# preschedule realize groups
|
||||
prescheduled: list[ScheduleItem] = []
|
||||
for store_uops in store_groups:
|
||||
if len(stores:=[ctx.realizes[u] for u in store_uops if ctx.realizes[u].op is Ops.STORE]) != 0:
|
||||
prescheduled.append(schedule_uop(UOp.sink(*stores), ctx))
|
||||
# can only schedule once
|
||||
for buf_uop in store_uops:
|
||||
for luop in ctx.tensor_uops[buf_uop]: ctx.becomes_map[luop] = buf_uop.view(unwrap(luop.st))
|
||||
# do BFS
|
||||
if len(stores:=[ctx.realizes[u] for u in store_uops if ctx.realizes[u].op is Ops.STORE]) == 0: continue
|
||||
prescheduled.append(schedule_uop(UOp.sink(*stores), ctx))
|
||||
# can only schedule once
|
||||
for buf_uop in store_uops:
|
||||
for luop in ctx.tensor_uops[buf_uop]: ctx.becomes_map[luop] = buf_uop.view(unwrap(luop.st))
|
||||
|
||||
# add kernel children
|
||||
schedule_targets = {out:si for si in prescheduled for out in si.outputs}
|
||||
graph: defaultdict[ScheduleItem, list[ScheduleItem]] = defaultdict(list)
|
||||
in_degree: defaultdict[ScheduleItem, int] = defaultdict(int)
|
||||
@@ -571,6 +561,8 @@ def create_schedule_with_vars(outs:list[UOp], skip_check:bool=not __debug__) ->
|
||||
for x in scheduled_parents:
|
||||
graph[x].append(si)
|
||||
in_degree[si] += 1
|
||||
|
||||
# do BFS
|
||||
queue = deque(si for si in prescheduled if in_degree[si] == 0)
|
||||
schedule: list[ScheduleItem] = []
|
||||
while queue:
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import cast
|
||||
from typing import cast, Iterator
|
||||
import math, functools
|
||||
from tinygrad.dtype import dtypes, sum_acc_dtype
|
||||
from tinygrad.ops import UOp, PatternMatcher, UPat, Ops
|
||||
@@ -39,13 +39,15 @@ pm_gradient = PatternMatcher([
|
||||
|
||||
# there's no gradient for...is this ASSIGN?
|
||||
(UPat(Ops.VIEW, src=(UPat(Ops.BUFFER), UPat(Ops.BUFFER_VIEW))), lambda: (None, None)),
|
||||
# also no gradient for bitcast
|
||||
(UPat(Ops.BITCAST), lambda ctx: (None,)),
|
||||
])
|
||||
|
||||
# copied from tensor.py, get relevant toposort of gradients
|
||||
def _deepwalk(root:UOp, targets:list[UOp]):
|
||||
def _deepwalk(root:UOp, targets:set[UOp]) -> list[UOp]:
|
||||
@functools.lru_cache(None)
|
||||
def is_in_target_path(x:UOp) -> bool: return any(u in targets or is_in_target_path(u) for u in x.src)
|
||||
def _walk(node:UOp, visited:set[UOp]):
|
||||
def _walk(node:UOp, visited:set[UOp]) -> Iterator[UOp]:
|
||||
visited.add(node)
|
||||
if node.op is Ops.DETACH: return
|
||||
if is_in_target_path(node):
|
||||
@@ -54,7 +56,7 @@ def _deepwalk(root:UOp, targets:list[UOp]):
|
||||
yield node
|
||||
return list(_walk(root, set()))
|
||||
|
||||
def compute_gradient(root:UOp, root_grad:UOp, targets:list[UOp]) -> dict[UOp, UOp]:
|
||||
def compute_gradient(root:UOp, root_grad:UOp, targets:set[UOp]) -> dict[UOp, UOp]:
|
||||
grads = {root: root_grad}
|
||||
for t0 in reversed(_deepwalk(root, targets)):
|
||||
if t0 not in grads: continue
|
||||
|
||||
@@ -14,11 +14,10 @@ def all_reduce(bop: Ops, lbs: list[UOp]) -> list[UOp]:
|
||||
if DEBUG >= 2: print(f"{'RING ALLREDUCE' if use_ring else 'NAIVE ALLREDUCE'} {n_lbs}x{numel} | {lbs[0].dtype}")
|
||||
if not use_ring: return [functools.reduce(lambda x,y: x.alu(bop, y), [x.copy_to_device(lb.device) for x in lbs]) for lb in lbs]
|
||||
|
||||
factor = next(f for f in [32, 16, 8, 4, 2, 1] if numel % f == 0)
|
||||
factor = next((f for f in [32, 16, 8, 4, 2] if numel % f == 0), 1)
|
||||
base, left = (numel // factor) // n_lbs, (numel // factor) % n_lbs
|
||||
chunk_sizes = [(base + 1) * factor] * left + [base * factor] * (n_lbs - left)
|
||||
acc = 0
|
||||
chunks = [(acc, (acc := acc + i)) for i in chunk_sizes if i > 0]
|
||||
chunks = list(itertools.pairwise(itertools.accumulate(chunk_sizes, initial=0)))
|
||||
chunked = [[lb.reshape((numel,)).shrink(((s,e),)) for s,e in chunks] for lb in lbs]
|
||||
|
||||
# scatter-reduce
|
||||
@@ -38,7 +37,7 @@ def all_reduce(bop: Ops, lbs: list[UOp]) -> list[UOp]:
|
||||
return [functools.reduce(operator.add, [c.pad(pad) for pad,c in zip(pads,lb_c)]).reshape(shape) for lb_c in chunked]
|
||||
|
||||
def to_sharded(lbs:list[UOp], axis:int, bounds: tuple[tuple[int, int], ...]) -> list[UOp]:
|
||||
if DEBUG >= 3 and lbs[0].shape[axis] % len(lbs) != 0: print(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}, bounds={bounds}")
|
||||
if lbs[0].shape[axis] % len(lbs) != 0: raise RuntimeError(f"multi axis uneven: {lbs[0].shape=} {axis=} {len(lbs)=}, bounds={bounds}")
|
||||
return [lb.shrink(tuple((0,s) if a != axis else bound for a,s in enumerate(lb.shape))) for i, (bound, lb) in enumerate(zip(bounds, lbs))]
|
||||
|
||||
class MultiLazyBuffer(MathTrait):
|
||||
@@ -46,9 +45,6 @@ class MultiLazyBuffer(MathTrait):
|
||||
assert all(isinstance(x, UOp) for x in lbs) and len(lbs), "all lbs must be LazyBuffers, and we need at least one of them"
|
||||
assert all_same([x.dtype for x in lbs]), f"all multilazybuffer needs same dtype, getting {[x.dtype for x in lbs]}"
|
||||
self.lbs, self.axis, self.dtype, self.device, self.real = lbs, axis, lbs[0].dtype, tuple(x.device for x in lbs), real or [True]*len(lbs)
|
||||
if axis is not None:
|
||||
splits = list(itertools.accumulate([lb.shape[axis] for lb in lbs], initial=0))
|
||||
self.bounds = tuple(zip(splits, splits[1:]))
|
||||
|
||||
@property
|
||||
def shape(self): return tuple(sum(y.shape[a] for y in self.real_lbs) if a == self.axis else s for a,s in enumerate(self.real_lbs[0].shape))
|
||||
@@ -59,20 +55,16 @@ class MultiLazyBuffer(MathTrait):
|
||||
@property
|
||||
def real_lbs(self): return [lb for lb,r in zip(self.lbs, self.real) if r]
|
||||
|
||||
@property
|
||||
def bounds(self):
|
||||
if self.axis is None: raise RuntimeError("bounds is not defined when axis is None")
|
||||
return tuple(itertools.pairwise(itertools.accumulate([lb.shape[self.axis] for lb in self.lbs], initial=0)))
|
||||
|
||||
def __repr__(self): return f"<MLB {self.axis=} {self.real=} {chr(10)}{chr(10).join([f'{x.device} {x.st}' for x in self.lbs])}>"
|
||||
|
||||
@staticmethod
|
||||
def from_sharded(lb:UOp, devices:tuple[str, ...], axis:int|None, bounds:tuple[tuple[int, int], ...]|None):
|
||||
assert (axis is None) == (bounds is None), "must specify bounds iff axis is specified"
|
||||
lbs = [lb] * len(devices)
|
||||
sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(to_sharded(lbs, axis, bounds) if axis is not None and bounds is not None else lbs, devices)]
|
||||
# NOTE: this contiguous is making it impossible for the scheduler to do late const folding
|
||||
return MultiLazyBuffer([lb.contiguous(allow_buffer_view=False) for lb in sharded_lbs], axis)
|
||||
|
||||
def copy_to_device(self, device:str) -> UOp:
|
||||
if self.axis is None:
|
||||
# if we already have a copy on the device, return that
|
||||
return next((lb for lb in self.real_lbs if lb.device == device), self.real_lbs[0].copy_to_device(device))
|
||||
# if we already have a copy on the device, return that
|
||||
if self.axis is None: return next((lb for lb in self.real_lbs if lb.device == device), self.real_lbs[0].copy_to_device(device))
|
||||
# copy lbs to device, pad to final shape, and sum
|
||||
llbs:list[UOp] = []
|
||||
for lb,real,(start,end) in zip(self.lbs, self.real, self.bounds):
|
||||
@@ -84,13 +76,14 @@ class MultiLazyBuffer(MathTrait):
|
||||
# passthroughs
|
||||
@property
|
||||
def is_realized(self) -> bool: return all(lb.base.realized is not None for lb in self.real_lbs)
|
||||
def cast(self, dtype:DType, bitcast:bool=False):
|
||||
return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis, self.real)
|
||||
def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis, self.real)
|
||||
def const_like(self, b) -> MultiLazyBuffer: return MultiLazyBuffer([x.const_like(b) for x in self.lbs], self.axis, self.real)
|
||||
def assign(self, x:MultiLazyBuffer): return MultiLazyBuffer([s.assign(d) for s,d in zip(self.lbs, x.lbs)], self.axis, self.real)
|
||||
def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis, self.real)
|
||||
def clone(self) -> MultiLazyBuffer: return MultiLazyBuffer([lb.clone() for lb in self.lbs], self.axis, self.real)
|
||||
def detach(self) -> MultiLazyBuffer: return MultiLazyBuffer([lb.detach() for lb in self.lbs], self.axis, self.real)
|
||||
@property
|
||||
def toposort(self) -> dict[UOp, None]: return {l:None for x in self.lbs for l in x.toposort}
|
||||
|
||||
# elementwise is simple
|
||||
def alu(self, op:Ops, *in_srcs:MultiLazyBuffer) -> MultiLazyBuffer:
|
||||
@@ -106,12 +99,10 @@ class MultiLazyBuffer(MathTrait):
|
||||
assert any(new_real), "output contains no real lb"
|
||||
for mlb in msrcs:
|
||||
if (mlb.axis == axis and (mlb.axis is None or mlb.bounds == bounds)) or not_all_real: srcs.append(mlb.lbs)
|
||||
elif mlb.axis is None and axis is not None:
|
||||
assert bounds is not None
|
||||
srcs.append(to_sharded(mlb.lbs, axis, bounds))
|
||||
else:
|
||||
assert axis is not None and bounds is not None
|
||||
srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis, bounds))
|
||||
if mlb.axis is None: srcs.append(to_sharded(mlb.lbs, axis, bounds))
|
||||
else: srcs.append(to_sharded([mlb.copy_to_device(lb.device) for lb in mlb.lbs], axis, bounds))
|
||||
new_real_lbs:dict[int,UOp] = {i:lsrcs[0].alu(op, *lsrcs[1:]) for i,(lsrcs,r) in enumerate(zip(zip(*srcs), new_real)) if r}
|
||||
# NOTE: const dtype should match real
|
||||
new_dtype = next(iter(new_real_lbs.values())).dtype
|
||||
|
||||
@@ -361,17 +361,6 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
def cast(self, dtype:DType, bitcast=False):
|
||||
if bitcast: return self.bitcast(dtype)
|
||||
if self._device is not None and self._device.startswith("DISK"): raise RuntimeError("CAST isn't supported on DISK")
|
||||
if getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self is not self.base:
|
||||
# NOTE: we have to apply the movementops here, we can't use VIEW (yet)
|
||||
# TODO: move this to the scheduler
|
||||
ret = self.base.cast(dtype, bitcast)
|
||||
op_arg = []
|
||||
mop = self
|
||||
while mop is not self.base:
|
||||
op_arg.append((mop.op, mop.arg))
|
||||
mop = mop.src[0]
|
||||
for op,arg in reversed(op_arg): ret = UOp(op, ret.dtype, (ret,), arg)
|
||||
return ret
|
||||
return UOp(Ops.CAST, dtype, (self,))
|
||||
def bitcast(self, dtype:DType):
|
||||
if self.st is not None and self.shape and ((self.shape[-1]*self.dtype.itemsize)%dtype.itemsize != 0):
|
||||
@@ -440,7 +429,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
# *** from LazyBuffer ***
|
||||
|
||||
@staticmethod
|
||||
def metaop(op:Ops, shape:tuple[sint, ...], dtype:DType, device:str, arg=None, src:tuple[UOp, ...]=()) -> UOp:
|
||||
def metaop(op:Ops, shape:tuple[sint, ...], dtype:DType, device:str, arg=None) -> UOp:
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
# Tensor const is CONST(VIEW(DEVICE)) -> RESHAPE -> EXPAND
|
||||
if op is Ops.CONST:
|
||||
@@ -477,13 +466,8 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
@property
|
||||
def base(self) -> UOp:
|
||||
if self.op in GroupOp.Movement: return self.src[0].base
|
||||
return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is not Ops.BUFFER else self
|
||||
def view(self, new_st:ShapeTracker) -> UOp:
|
||||
if self.st is None: return UOp(Ops.VIEW, self.dtype.base if not isinstance(self.dtype, ImageDType) else self.dtype, (self,), new_st)
|
||||
ret = UOp(Ops.VIEW, self.dtype, (self.base,), new_st)
|
||||
# instant folding rules
|
||||
if new_st.contiguous and self.base.shape == new_st.shape: return self.base
|
||||
return ret
|
||||
return self.src[0].base if self.op is Ops.VIEW and len(self.src) == 1 and self.src[0].op is not Ops.BUFFER else self
|
||||
def view(self, new_st:ShapeTracker) -> UOp: return UOp(Ops.VIEW, self.dtype, (self.base,), new_st)
|
||||
|
||||
def _mop(self, op:Ops, arg):
|
||||
ret = UOp(op, self.dtype, (self,), arg)
|
||||
@@ -1304,7 +1288,10 @@ ConstLike = Union[ConstType, Variable, tuple[ConstType, ...]]
|
||||
|
||||
# *** uop swizzling ***
|
||||
|
||||
merge_views = PatternMatcher([(UPat(Ops.VIEW, name="s0").view(name="s1"), lambda s0,s1: s0.replace(arg=s0.st+s1.st))])
|
||||
merge_views = PatternMatcher([
|
||||
(UPat(Ops.VIEW, name="s0").view(name="s1"), lambda s0,s1: s0.replace(arg=s0.st+s1.st)),
|
||||
(UPat(Ops.VIEW, name="mv", src=(UPat.var("x"),)), lambda mv,x: x if mv.st.contiguous and x.st is not None and x.shape == mv.shape else None),
|
||||
])
|
||||
|
||||
# push VIEW to loads
|
||||
view_left = merge_views+PatternMatcher([
|
||||
|
||||
@@ -291,17 +291,23 @@ class MetalRenderer(CStyleLanguage):
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
||||
|
||||
_nms = "xyzwabcdefghijkl"
|
||||
cuda_tc_opts = ("u0","l0","l0","l1","l1","l1","u1") # shared by all shapes with M=16 N=8
|
||||
|
||||
class CUDARenderer(CStyleLanguage):
|
||||
device = "CUDA"
|
||||
global_max = (2147483647, 65535, 65535)
|
||||
local_max = (1024, 1024, 64)
|
||||
shared_max = 49152
|
||||
# https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-fragment-mma-16816-float
|
||||
tensor_cores = [TensorCore(dims=(8,16,16), threads=32, elements_per_thread=(8,4,4), dtype_in=di, dtype_out=do,
|
||||
opts=("u0","l0","l0","l1","l1","l1","u1"), swizzle=(((6,7,2,3,4),(0,1,9,5,10,8)), ((6,7,9,0,1),(2,3,4,10,5,8))))
|
||||
for di,do in ([(dtypes.half,dtypes.float),(dtypes.bfloat16,dtypes.float)])]
|
||||
def __init__(self, arch:str): self.tensor_cores, self.arch = CUDARenderer.tensor_cores if int(arch[3:]) >= 80 else [], arch
|
||||
# https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-multiply-accumulate-instructions
|
||||
tc_81616 = [TensorCore(dims=(8,16,16), threads=32, elements_per_thread=(8,4,4), dtype_in=di,dtype_out=do, opts=cuda_tc_opts,
|
||||
swizzle=(((6,7,2,3,4),(0,1,9,5,10,8)), ((6,7,9,0,1),(2,3,4,10,5,8)))) for di,do in [(dtypes.half,dtypes.float), (dtypes.bfloat16,dtypes.float)]]
|
||||
tc_8168_f16 = [TensorCore(dims=(8,16,8), threads=32, elements_per_thread=(4,2,4), dtype_in=dtypes.half, dtype_out=dtypes.float, opts=cuda_tc_opts,
|
||||
swizzle=(((6,7,2,3,4),(0,1,8,5,9)), ((6,7,8,0,1),(2,3,4,9,5))))]
|
||||
|
||||
tc_sm80 = tc_81616 + tc_8168_f16
|
||||
tc_sm75 = tc_8168_f16
|
||||
def __init__(self, arch:str):
|
||||
self.tensor_cores, self.arch = CUDARenderer.tc_sm80 if int(arch[3:]) >= 80 else CUDARenderer.tc_sm75 if int(arch[3:]) >= 75 else [], arch
|
||||
def __reduce__(self): return self.__class__, (self.arch,)
|
||||
|
||||
# language options
|
||||
|
||||
@@ -124,11 +124,12 @@ class PTXRenderer(Renderer):
|
||||
device = "CUDA"
|
||||
suffix = "PTX"
|
||||
global_max, local_max, shared_max = CUDARenderer.global_max, CUDARenderer.local_max, CUDARenderer.shared_max
|
||||
tensor_cores = [tc for tc in CUDARenderer.tensor_cores if tc.dtype_in == dtypes.half]
|
||||
tc_sm80 = [tc for tc in CUDARenderer.tc_sm80 if tc.dtype_in == dtypes.half]
|
||||
code_for_op = asm_for_op
|
||||
extra_matcher = ptx_matcher
|
||||
def __init__(self, arch:str, device="CUDA"):
|
||||
self.device, self.tensor_cores, self.arch = device, PTXRenderer.tensor_cores if int(arch[3:]) >= 80 else [], arch
|
||||
self.device, self.arch = device, arch
|
||||
self.tensor_cores = PTXRenderer.tc_sm80 if int(arch[3:]) >= 80 else CUDARenderer.tc_sm75 if int(arch[3:]) >= 75 else []
|
||||
def __reduce__(self): return self.__class__, (self.arch, self.device)
|
||||
|
||||
# language options
|
||||
|
||||
@@ -138,31 +138,38 @@ class PythonProgram:
|
||||
ul[i] = wmma_helper(32, 8, 2, 2, 2, a_b_elem, a_b_elem, c_map)
|
||||
elif arg[4] == "AMD":
|
||||
# A (16 elements on 32 threads): col major, lane 16-32 == lane 0-15
|
||||
def a_elem(x, i, j, goff):
|
||||
assert x[i][goff+j] == x[i][goff+j+16], "warp elements not duplicated properly across lanes"
|
||||
return x[i][goff+j]
|
||||
def a_elem(x, k, row, goff):
|
||||
assert x[k][goff+row] == x[k][goff+row+16], "warp elements not duplicated properly across lanes"
|
||||
return x[k][goff+row]
|
||||
# B (16 elements on 32 threads): row major, lane 16-32 == lane 0-15
|
||||
def b_elem(x, i, j, goff): return a_elem(x, j, i, goff) # pylint: disable=arguments-out-of-order
|
||||
def b_elem(x, col, k, goff): return a_elem(x, k, col, goff) # pylint: disable=arguments-out-of-order
|
||||
def c_map(lane, elem): return (lane%16, lane//16+elem*2) # (i, j), C, D (8 elements on 32 threads): row major
|
||||
ul[i] = wmma_helper(32, 16, 16, 16, 8, a_elem, b_elem, c_map)
|
||||
elif arg[4] == "CUDA":
|
||||
# A (8 elements on 32 threads)
|
||||
def a_elem(x, i, j, goff): return x[(i%2)+(j//8)*2+(i//8)*4][goff+((i//2)%4)+(j%8)*4]
|
||||
# B (4 elements on 32 threads)
|
||||
def b_elem(x, i, j, goff): return x[(j%2)+(j//8)*2][goff+(j//2)%4+(i)*4]
|
||||
# (i, j), C, D (4 elements on 32 threads)
|
||||
def c_map(lane, elem): return ((elem%2)+(lane%4)*2, (lane//4)+(elem//2)*8)
|
||||
ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
|
||||
# (col, row) given (lane, elem) for C & D (4 elements on 32 threads); shared by all tc shapes with M=16 N=8
|
||||
def c_map(lane, elem): return (elem%2 + (lane%4)*2, lane//4 + (elem//2)*8)
|
||||
|
||||
if arg[1] == (8,16,16):
|
||||
def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2 + (k//8)*4][goff + (k//2)%4 + (row%8)*4]
|
||||
def b_elem(x, col, k, goff): return x[k%2 + (k//8)*2][goff + (k//2)%4 + col*4]
|
||||
ul[i] = wmma_helper(32, 16, 8, 4, 4, a_elem, b_elem, c_map)
|
||||
|
||||
elif arg[1] == (8,16,8):
|
||||
def a_elem(x, k, row, goff): return x[k%2 + (row//8)*2][goff + k//2 + (row%8)*4]
|
||||
def b_elem(x, col, k, goff): return x[k%2][goff + k//2 + col*4]
|
||||
ul[i] = wmma_helper(32, 8, 4, 2, 4, a_elem, b_elem, c_map)
|
||||
|
||||
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
|
||||
elif arg[4] == "INTEL":
|
||||
# A (16 elements on 8 threads)
|
||||
def a_elem(x, i, j, goff): return x[i%2+j*2][goff+i//2]
|
||||
def a_elem(x, k, row, goff): return x[k%2+row*2][goff+k//2]
|
||||
# B (16 elements on 8 threads)
|
||||
def b_elem(x, i, j, goff): return x[j][goff+i]
|
||||
def b_elem(x, col, k, goff): return x[k][goff+col]
|
||||
# C, D (8 elements on 8 threads)
|
||||
def c_map(lane, elem): return (lane, elem)
|
||||
ul[i] = wmma_helper(8, 16, 16, 16, 8, a_elem, b_elem, c_map)
|
||||
elif arg[4] == "CLANG":
|
||||
def elem(x, i, j, _): return x[i+j][0]
|
||||
def elem(x, col, row, _): return x[col+row][0] # k is always 0
|
||||
def c_map(_, elem): return (elem%16, elem//16)
|
||||
ul[i] = wmma_helper(1, 1, 16, 16, 256, elem, elem, c_map)
|
||||
else: raise NotImplementedError(f"unimplemented tensor core {arg}")
|
||||
@@ -179,7 +186,8 @@ class PythonRenderer(Renderer):
|
||||
def __init__(self):
|
||||
if getenv("EMULATE_METAL"): self.device, self.tensor_cores = "METAL", MetalRenderer.tensor_cores
|
||||
if getenv("EMULATE_AMD"): self.device, self.tensor_cores = "AMD", AMDRenderer.tensor_cores
|
||||
if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tensor_cores
|
||||
if getenv("EMULATE_CUDA"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tc_sm80
|
||||
if getenv("EMULATE_CUDA_SM75"): self.device, self.tensor_cores = "CUDA", CUDARenderer.tc_sm75
|
||||
if getenv("EMULATE_INTEL"): self.device, self.suffix, self.tensor_cores = "INTEL", "INTEL", IntelRenderer.tensor_cores
|
||||
if getenv("EMULATE_AMX"): self.device, self.tensor_cores = "CLANG", ClangRenderer.tensor_cores
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ class TLSFAllocator:
|
||||
def __init__(self, size:int, base:int=0, block_size:int=16, lv2_cnt:int=16):
|
||||
self.size, self.base, self.block_size, self.l2_cnt = size, base, block_size, lv2_cnt.bit_length()
|
||||
self.storage:list = [collections.defaultdict(list) for _ in range(size.bit_length() + 1)]
|
||||
self.lv1_entries:list[int] = [0] * len(self.storage)
|
||||
|
||||
# self.blocks is more like a linked list, where each entry is a contigous block.
|
||||
self.blocks:dict[int, tuple[int, int|None, int|None, bool]] = {0: (size, None, None, True)} # size, next, prev, is_free
|
||||
@@ -25,12 +26,14 @@ class TLSFAllocator:
|
||||
def _insert_block(self, start:int, size:int, prev:int|None=None):
|
||||
if prev is None: prev = self.blocks[start][2]
|
||||
self.storage[self.lv1(size)][self.lv2(size)].append(start)
|
||||
self.lv1_entries[self.lv1(size)] += 1
|
||||
self.blocks[start] = (size, start + size, prev, True)
|
||||
return self
|
||||
|
||||
def _remove_block(self, start:int, size:int, prev:int|None=None):
|
||||
if prev is None: prev = self.blocks[start][2]
|
||||
self.storage[self.lv1(size)][self.lv2(size)].remove(start)
|
||||
self.lv1_entries[self.lv1(size)] -= 1
|
||||
self.blocks[start] = (size, start + size, prev, False)
|
||||
return self
|
||||
|
||||
@@ -67,6 +70,7 @@ class TLSFAllocator:
|
||||
|
||||
# Search for the smallest block that can fit the requested size. Start with the it's bucket and go up until any block is found.
|
||||
for l1 in range(self.lv1(size), len(self.storage)):
|
||||
if self.lv1_entries[l1] == 0: continue
|
||||
for l2 in range(self.lv2(size) if l1 == size.bit_length() else 0, (1 << self.l2_cnt)):
|
||||
if len(self.storage[l1][l2]) > 0:
|
||||
nsize = self.blocks[self.storage[l1][l2][0]][0]
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import ctypes, collections, time, dataclasses, pathlib
|
||||
from tinygrad.helpers import to_mv, mv_address, getenv, round_up, DEBUG
|
||||
import ctypes, collections, time, dataclasses, pathlib, fcntl, os, signal
|
||||
from tinygrad.helpers import to_mv, mv_address, getenv, round_up, DEBUG, temp
|
||||
from tinygrad.runtime.autogen.am import am, mp_11_0, mp_13_0_0, nbio_4_3_0, mmhub_3_0_0, gc_11_0_0, osssys_6_0_0
|
||||
from tinygrad.runtime.support.allocator import TLSFAllocator
|
||||
from tinygrad.runtime.support.am.ip import AM_SOC21, AM_GMC, AM_IH, AM_PSP, AM_SMU, AM_GFX, AM_SDMA
|
||||
@@ -255,6 +255,15 @@ class AMDev:
|
||||
self.pcidev, self.devfmt = pcidev, devfmt
|
||||
self.vram, self.doorbell64, self.mmio = vram_bar, doorbell_bar, mmio_bar
|
||||
|
||||
os.umask(0) # Set umask to 0 to allow creating files with 0666 permissions
|
||||
|
||||
# Avoid O_CREAT because we don’t want to re-create/replace an existing file (triggers extra perms checks) when opening as non-owner.
|
||||
if os.path.exists(lock_name:=temp(f"am_{self.devfmt}.lock")): self.lock_fd = os.open(lock_name, os.O_RDWR)
|
||||
else: self.lock_fd = os.open(lock_name, os.O_RDWR | os.O_CREAT, 0o666)
|
||||
|
||||
try: fcntl.flock(self.lock_fd, fcntl.LOCK_EX | fcntl.LOCK_NB)
|
||||
except OSError: raise RuntimeError(f"Failed to open AM device {self.devfmt}. It's already in use.")
|
||||
|
||||
self._run_discovery()
|
||||
self._build_regs()
|
||||
|
||||
@@ -284,14 +293,17 @@ class AMDev:
|
||||
self.sdma:AM_SDMA = AM_SDMA(self)
|
||||
|
||||
if self.partial_boot and (self.reg("regCP_MEC_RS64_CNTL").read() & gc_11_0_0.CP_MEC_RS64_CNTL__MEC_HALT_MASK == 0):
|
||||
print(f"am {self.devfmt}: MEC is active. Someone might be using the GPU? Issue a full reset.")
|
||||
if DEBUG >= 2: print(f"am {self.devfmt}: MEC is active. Issue a full reset.")
|
||||
self.partial_boot = False
|
||||
|
||||
if not self.partial_boot:
|
||||
if self.psp.is_sos_alive(): self.smu.mode1_reset()
|
||||
for ip in [self.soc21, self.gmc, self.ih, self.psp, self.smu]:
|
||||
ip.init()
|
||||
if DEBUG >= 2: print(f"am {self.devfmt}: {ip.__class__.__name__} initialized")
|
||||
try: # do not interrupt the boot process
|
||||
signal.signal(signal.SIGINT, signal.SIG_IGN)
|
||||
if self.psp.is_sos_alive(): self.smu.mode1_reset()
|
||||
for ip in [self.soc21, self.gmc, self.ih, self.psp, self.smu]:
|
||||
ip.init()
|
||||
if DEBUG >= 2: print(f"am {self.devfmt}: {ip.__class__.__name__} initialized")
|
||||
finally: signal.signal(signal.SIGINT, signal.default_int_handler)
|
||||
|
||||
# Booting done
|
||||
self.is_booting = False
|
||||
|
||||
@@ -41,9 +41,9 @@ class Function:
|
||||
|
||||
import tinygrad.function as F
|
||||
|
||||
def _metaop(op, shape:tuple[sint,...], dtype:DType, device:Union[str, tuple[str, ...]], arg=None, src:tuple[UOp, ...]=()):
|
||||
if isinstance(device, str): return UOp.metaop(op, shape, dtype, device, arg, src)
|
||||
return MultiLazyBuffer([UOp.metaop(op, shape, dtype, d, arg, src) for d in device], None)
|
||||
def _metaop(op, shape:tuple[sint,...], dtype:DType, device:Union[str, tuple[str, ...]], arg=None):
|
||||
if isinstance(device, str): return UOp.metaop(op, shape, dtype, device, arg)
|
||||
return MultiLazyBuffer([UOp.metaop(op, shape, dtype, d, arg) for d in device], None)
|
||||
|
||||
def _from_np_dtype(npdtype:'np.dtype') -> DType: # type: ignore [name-defined] # noqa: F821
|
||||
import numpy as np
|
||||
@@ -179,7 +179,7 @@ class Tensor(SimpleMathTrait):
|
||||
# data might be on a different device
|
||||
if isinstance(device, str): self.lazydata:Union[UOp, MultiLazyBuffer] = data if data.device == device else data.copy_to_device(device)
|
||||
# if device is a tuple, we should have/construct a MultiLazyBuffer
|
||||
elif isinstance(data, UOp): self.lazydata = MultiLazyBuffer.from_sharded(data, device, None, None)
|
||||
elif isinstance(data, UOp): self.lazydata = Tensor(data).shard(device).lazydata
|
||||
else:
|
||||
assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}"
|
||||
self.lazydata = data
|
||||
@@ -394,33 +394,33 @@ class Tensor(SimpleMathTrait):
|
||||
if self.grad is not None and real.grad is not None: self.grad.replace(real.grad)
|
||||
return self.replace(real)
|
||||
|
||||
def shard(self, devices:tuple[str, ...], axis:Optional[int]=None, splits:Optional[tuple[int, ...]]=None) -> Tensor:
|
||||
def shard(self, devices:tuple[str, ...], axis:Optional[int]=None) -> Tensor:
|
||||
"""
|
||||
Shards the tensor across the given devices. Optionally specify which axis to shard on, and how to split it across devices.
|
||||
Shards the tensor across the given devices. Optionally specify which axis to shard on.
|
||||
|
||||
```python exec="true" source="above" session="tensor" result="python"
|
||||
t = Tensor.empty(2, 3)
|
||||
print(t.shard((t.device, t.device), axis=1, splits=(2, 1)).lazydata)
|
||||
t = Tensor.empty(2, 4)
|
||||
print(t.shard((t.device, t.device), axis=1).lazydata)
|
||||
```
|
||||
|
||||
"""
|
||||
assert isinstance(self.lazydata, UOp), "can't shard a MultiLazyBuffer"
|
||||
devices, bounds = tuple(Device.canonicalize(x) for x in devices), None
|
||||
if axis is not None:
|
||||
devices = tuple(Device.canonicalize(x) for x in devices)
|
||||
if axis is None: lbs = [self.lazydata] * len(devices)
|
||||
else:
|
||||
axis = self._resolve_dim(axis)
|
||||
if splits is None:
|
||||
if not isinstance(total:=self.shape[axis], int): raise RuntimeError(f"cannot shard symbolic shape {self.shape=}, {axis=}")
|
||||
sz = ceildiv(total, len(devices))
|
||||
splits = tuple([max(0, min(sz, total - sz*i)) for i in range(len(devices))])
|
||||
assert sum(splits) == self.shape[axis], "specified splits do not sum up to axis shape"
|
||||
bounds = tuple(itertools.pairwise(itertools.accumulate(splits, initial=0)))
|
||||
return Tensor(MultiLazyBuffer.from_sharded(self.lazydata, devices, axis, bounds), device=devices, requires_grad=self.requires_grad)
|
||||
sz = ceildiv(self.shape[axis], len(devices))
|
||||
sizes = [max(0, min(sz, self.shape[axis] - sz*i)) for i in range(len(devices))]
|
||||
lbs = [cast(UOp, t.lazydata) for t in self.split(sizes, axis)]
|
||||
sharded_lbs = [lb.copy_to_device(d) for lb,d in zip(lbs, devices)]
|
||||
# NOTE: this contiguous is making it impossible for the scheduler to do late const folding
|
||||
mlb = MultiLazyBuffer([lb.contiguous(allow_buffer_view=False) for lb in sharded_lbs], axis)
|
||||
return Tensor(mlb, device=devices, requires_grad=self.requires_grad)
|
||||
|
||||
def shard_(self, devices:tuple[str, ...], axis:Optional[int]=None, splits:Optional[tuple[int, ...]]=None):
|
||||
def shard_(self, devices:tuple[str, ...], axis:Optional[int]=None):
|
||||
"""
|
||||
Shards the tensor across the given devices in place.
|
||||
"""
|
||||
return self.replace(self.shard(devices, axis, splits))
|
||||
return self.replace(self.shard(devices, axis))
|
||||
|
||||
@staticmethod
|
||||
def from_uop(y:UOp, **kwargs) -> Tensor:
|
||||
@@ -915,23 +915,29 @@ class Tensor(SimpleMathTrait):
|
||||
print(dy.tolist()) # dz/dy
|
||||
```
|
||||
"""
|
||||
assert isinstance(self.lazydata, UOp), "multi isn't supported yet"
|
||||
target_uops: list[UOp] = [x.lazydata for x in targets if isinstance(x.lazydata, UOp)]
|
||||
assert gradient is not None or self.shape == tuple(), "when no gradient is provided, backward must be called on a scalar tensor"
|
||||
grads = compute_gradient(self.lazydata, self.lazydata.const_like(1) if gradient is None else cast(UOp, gradient.lazydata), target_uops)
|
||||
ret = []
|
||||
for x in target_uops:
|
||||
if (y:=grads.get(x)) is None: raise RuntimeError(f"{x}\n\nnot found in\n\n{self.lazydata}")
|
||||
ret.append(Tensor(y, device=x.device))
|
||||
return ret
|
||||
if gradient is None: gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
|
||||
rets = []
|
||||
for i,(uop,grad) in enumerate(zip(self.lazydata.lbs, gradient.lazydata.lbs)):
|
||||
target_uops = [x.lazydata.lbs[i] for x in targets]
|
||||
grads = compute_gradient(uop, grad, set(target_uops))
|
||||
ret = []
|
||||
for x in target_uops:
|
||||
if (y:=grads.get(x)) is None: raise RuntimeError(f"{x}\n\nnot found in\n\n{uop}")
|
||||
ret.append(y)
|
||||
rets.append(ret)
|
||||
# create returned Tensors
|
||||
if isinstance(self.lazydata, UOp): return [Tensor(u, device=t.device) for t,u in zip(targets, rets[0])]
|
||||
return [Tensor(MultiLazyBuffer(list(u), cast(MultiLazyBuffer, t.lazydata).axis, cast(MultiLazyBuffer, t.lazydata).real),
|
||||
device=t.device) for t,u in zip(targets, zip(*rets))]
|
||||
|
||||
def _deepwalk(self):
|
||||
def _walk(node, visited):
|
||||
def _deepwalk(self) -> list[Tensor]:
|
||||
def _walk(node:Tensor, visited:set[Tensor]):
|
||||
visited.add(node)
|
||||
# if tensor is not leaf, reset grad
|
||||
if (ctx := getattr(node, "_ctx", None)) is not None and len(ctx.parents) != 0: node.grad = None
|
||||
if ctx:
|
||||
for i in node._ctx.parents:
|
||||
for i in cast(Function, node._ctx).parents:
|
||||
if i not in visited: yield from _walk(i, visited)
|
||||
yield node
|
||||
return list(_walk(self, set()))
|
||||
@@ -954,18 +960,22 @@ class Tensor(SimpleMathTrait):
|
||||
# this is "implicit gradient creation"
|
||||
gradient = Tensor(1.0, dtype=self.dtype, device=self.device, requires_grad=False)
|
||||
|
||||
toposort_uop = self.lazydata.toposort
|
||||
assert self.shape == gradient.shape, f"grad shape must match tensor shape, {gradient.shape!r} != {self.shape!r}"
|
||||
self.grad = gradient
|
||||
for t0 in reversed(toposorted):
|
||||
if t0.grad is None: raise RuntimeError(f"tensor {t0} has no grad")
|
||||
token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := t0._ctx.metadata) is not None else None)
|
||||
grads = t0._ctx.backward(t0.grad.lazydata)
|
||||
ctx = cast(Function, t0._ctx)
|
||||
token = _METADATA.set(dataclasses.replace(md, backward=True) if (md := ctx.metadata) is not None else None)
|
||||
grads = ctx.backward(t0.grad.lazydata)
|
||||
_METADATA.reset(token)
|
||||
grads = [Tensor(g, device=self.device, requires_grad=False) if g is not None else None
|
||||
for g in ([grads] if len(t0._ctx.parents) == 1 else grads)]
|
||||
for t, g in zip(t0._ctx.parents, grads):
|
||||
for g in ([grads] if len(ctx.parents) == 1 else grads)]
|
||||
for t, g in zip(ctx.parents, grads):
|
||||
if g is not None and t.requires_grad:
|
||||
assert g.shape == t.shape, f"grad shape must match tensor shape, {g.shape!r} != {t.shape!r}"
|
||||
assert t.lazydata in toposort_uop or (isinstance(t.lazydata, MultiLazyBuffer) and any(x in toposort_uop for x in t.lazydata.lbs)), \
|
||||
f"grad uop must have a path from self\ngrad uop: {t.lazydata}"
|
||||
t.grad = g if t.grad is None else (t.grad + g)
|
||||
if not retain_graph: del t0._ctx
|
||||
return self
|
||||
|
||||
Reference in New Issue
Block a user