mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
assign is contiguous (#10066)
* assign is contiguous * disable process replay for SDXL
This commit is contained in:
7
.github/workflows/benchmark.yml
vendored
7
.github/workflows/benchmark.yml
vendored
@@ -54,8 +54,9 @@ jobs:
|
||||
run: JIT=1 python3.11 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd_no_fp16.txt
|
||||
- name: Run Stable Diffusion v2
|
||||
run: JIT=1 python3.11 examples/sdv2.py --fp16 --seed 0 --noshow --timing | tee sdv2.txt
|
||||
# process replay can't capture this, the graph is too large
|
||||
- name: Run SDXL
|
||||
run: JIT=1 python3.11 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
|
||||
run: CAPTURE_PROCESS_REPLAY=0 JIT=1 python3.11 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
|
||||
- name: Run model inference benchmark
|
||||
run: METAL=1 python3.11 test/external/external_model_benchmark.py
|
||||
- name: Test speed vs torch
|
||||
@@ -192,7 +193,7 @@ jobs:
|
||||
- name: Run Stable Diffusion
|
||||
run: NV=1 python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing | tee sd.txt
|
||||
- name: Run SDXL
|
||||
run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
|
||||
run: CAPTURE_PROCESS_REPLAY=0 NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
|
||||
- name: Run LLaMA
|
||||
run: |
|
||||
NV=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
|
||||
@@ -390,7 +391,7 @@ jobs:
|
||||
- name: Run Stable Diffusion
|
||||
run: AMD=1 python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing | tee sd.txt
|
||||
- name: Run SDXL
|
||||
run: AMD=1 python3 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
|
||||
run: CAPTURE_PROCESS_REPLAY=0 AMD=1 python3 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
|
||||
- name: Run LLaMA 7B
|
||||
run: |
|
||||
AMD=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
|
||||
|
||||
@@ -111,7 +111,7 @@ class TestRealWorld(unittest.TestCase):
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 92)
|
||||
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 93)
|
||||
|
||||
@unittest.skipIf(CI and Device.DEFAULT in {"CPU", "GPU", "LLVM"}, "slow")
|
||||
def test_train_cifar(self):
|
||||
|
||||
@@ -350,7 +350,7 @@ class TestJit(unittest.TestCase):
|
||||
assert len(res3) == 10, "All values should be different, rand works in jit."
|
||||
assert res3 != res2, "Jit rand is diff with diff seeds"
|
||||
|
||||
@unittest.expectedFailure # requires contiguous folding
|
||||
#@unittest.expectedFailure # requires contiguous folding
|
||||
def test_jit_random_after_unrealized_random(self):
|
||||
@TinyJit
|
||||
def f(): return Tensor.rand()
|
||||
|
||||
@@ -292,7 +292,7 @@ class Tensor(SimpleMathTrait):
|
||||
assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
|
||||
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
|
||||
assert not x.requires_grad # self requires_grad is okay?
|
||||
if not self.lazydata.is_realized: return self.replace(x)
|
||||
if not self.lazydata.is_realized: return self.replace(x.contiguous())
|
||||
self.lazydata = self.lazydata.assign(x.lazydata)
|
||||
return self
|
||||
|
||||
|
||||
Reference in New Issue
Block a user