assign is contiguous (#10066)

* assign is contiguous

* disable process replay for SDXL
This commit is contained in:
George Hotz
2025-04-27 08:40:33 -04:00
committed by GitHub
parent 1253819151
commit b6d2effaf5
4 changed files with 7 additions and 6 deletions

View File

@@ -54,8 +54,9 @@ jobs:
run: JIT=1 python3.11 examples/stable_diffusion.py --seed 0 --noshow --timing | tee sd_no_fp16.txt
- name: Run Stable Diffusion v2
run: JIT=1 python3.11 examples/sdv2.py --fp16 --seed 0 --noshow --timing | tee sdv2.txt
# process replay can't capture this, the graph is too large
- name: Run SDXL
run: JIT=1 python3.11 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
run: CAPTURE_PROCESS_REPLAY=0 JIT=1 python3.11 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
- name: Run model inference benchmark
run: METAL=1 python3.11 test/external/external_model_benchmark.py
- name: Test speed vs torch
@@ -192,7 +193,7 @@ jobs:
- name: Run Stable Diffusion
run: NV=1 python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing | tee sd.txt
- name: Run SDXL
run: NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
run: CAPTURE_PROCESS_REPLAY=0 NV=1 CAPTURE_PROCESS_REPLAY=0 python3 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
- name: Run LLaMA
run: |
NV=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt
@@ -390,7 +391,7 @@ jobs:
- name: Run Stable Diffusion
run: AMD=1 python3 examples/stable_diffusion.py --fp16 --seed 0 --noshow --timing | tee sd.txt
- name: Run SDXL
run: AMD=1 python3 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
run: CAPTURE_PROCESS_REPLAY=0 AMD=1 python3 examples/sdxl.py --seed 0 --noshow --timing | tee sdxl.txt
- name: Run LLaMA 7B
run: |
AMD=1 JIT=0 python3 examples/llama.py --gen 1 --prompt "Hello." --count 10 --temperature 0 --timing | tee llama_unjitted.txt

View File

@@ -111,7 +111,7 @@ class TestRealWorld(unittest.TestCase):
loss.backward()
optimizer.step()
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 92)
helper_test("train_mnist", lambda: (Tensor.randn(BS, 1, 28, 28),), train, 0.07, 93)
@unittest.skipIf(CI and Device.DEFAULT in {"CPU", "GPU", "LLVM"}, "slow")
def test_train_cifar(self):

View File

@@ -350,7 +350,7 @@ class TestJit(unittest.TestCase):
assert len(res3) == 10, "All values should be different, rand works in jit."
assert res3 != res2, "Jit rand is diff with diff seeds"
@unittest.expectedFailure # requires contiguous folding
#@unittest.expectedFailure # requires contiguous folding
def test_jit_random_after_unrealized_random(self):
@TinyJit
def f(): return Tensor.rand()

View File

@@ -292,7 +292,7 @@ class Tensor(SimpleMathTrait):
assert self.device == x.device, f"assign device mismatch {self.device} != {x.device}"
assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}"
assert not x.requires_grad # self requires_grad is okay?
if not self.lazydata.is_realized: return self.replace(x)
if not self.lazydata.is_realized: return self.replace(x.contiguous())
self.lazydata = self.lazydata.assign(x.lazydata)
return self