mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
exclude GPU on tiny (#766)
This commit is contained in:
@@ -21,7 +21,7 @@ repos:
|
|||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
- id: tests
|
- id: tests
|
||||||
name: subset of (CPU) tests
|
name: subset of (CPU) tests
|
||||||
entry: env CPU=1 pytest test/unit/ test/test_ops.py
|
entry: env CPU=1 EXCLUDE_DEVICES=GPU pytest test/unit/ test/test_ops.py
|
||||||
language: system
|
language: system
|
||||||
always_run: true
|
always_run: true
|
||||||
pass_filenames: false
|
pass_filenames: false
|
||||||
|
|||||||
@@ -2,11 +2,16 @@ import unittest
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
from tinygrad.lazy import Device
|
from tinygrad.lazy import Device
|
||||||
from tinygrad.tensor import Tensor
|
from tinygrad.tensor import Tensor
|
||||||
from tinygrad.helpers import dtypes
|
from tinygrad.helpers import getenv
|
||||||
|
|
||||||
def multidevice_test(fxn):
|
def multidevice_test(fxn):
|
||||||
|
exclude_devices = getenv("EXCLUDE_DEVICES", "").split(",")
|
||||||
def ret(self):
|
def ret(self):
|
||||||
for device in Device._buffers:
|
for device in Device._buffers:
|
||||||
|
print(device)
|
||||||
|
if device in exclude_devices:
|
||||||
|
print(f"WARNING: {device} test is excluded")
|
||||||
|
continue
|
||||||
with self.subTest(device=device):
|
with self.subTest(device=device):
|
||||||
try:
|
try:
|
||||||
Device[device]
|
Device[device]
|
||||||
|
|||||||
@@ -460,6 +460,13 @@ class Linearizer:
|
|||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
|
|
||||||
# if last dim <= 5 and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS
|
# if last dim <= 16 and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS
|
||||||
if self.first_reduce < (self.shape_len-self.upcasted) and self.full_unupcasted_shape[-1] <= 32 and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))):
|
if self.first_reduce < (self.shape_len-self.upcasted) and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))):
|
||||||
self.upcast()
|
if self.full_unupcasted_shape[-1] <= 16:
|
||||||
|
self.upcast()
|
||||||
|
else:
|
||||||
|
for splits in [16,8,4]:
|
||||||
|
if self.full_unupcasted_shape[-1]%splits == 0:
|
||||||
|
self.shift_to(len(self.full_unupcasted_shape)-1, splits, insert_before=len(self.full_unupcasted_shape))
|
||||||
|
self.upcast()
|
||||||
|
break
|
||||||
|
|||||||
@@ -95,7 +95,7 @@ class ASTRunner:
|
|||||||
if getenv("OPTLOCAL") and self.global_size is not None and self.local_size is None: self.local_size = self.optimize_local_size(rawbufs, allow_cache=(getenv("OPTLOCAL") >= 2))
|
if getenv("OPTLOCAL") and self.global_size is not None and self.local_size is None: self.local_size = self.optimize_local_size(rawbufs, allow_cache=(getenv("OPTLOCAL") >= 2))
|
||||||
if et := self.clprg(self.global_size, self.local_size, *rawbufs, wait=force_wait or DEBUG>=2): GlobalCounters.time_sum_s += et
|
if et := self.clprg(self.global_size, self.local_size, *rawbufs, wait=force_wait or DEBUG>=2): GlobalCounters.time_sum_s += et
|
||||||
if DEBUG >= 2:
|
if DEBUG >= 2:
|
||||||
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(27-len(self.name))) if self.display_name is not None else self.name:26s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {self.op_estimate/1e6:7.1f}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(28-len(self.name))) if self.display_name is not None else self.name:26s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {int(self.op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
|
||||||
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/(et*1e9):8.2f} GFLOPS, {self.mem_estimate/(et*1e9):7.2f} GB/s)"))
|
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/(et*1e9):8.2f} GFLOPS, {self.mem_estimate/(et*1e9):7.2f} GB/s)"))
|
||||||
GlobalCounters.kernel_count += 1
|
GlobalCounters.kernel_count += 1
|
||||||
GlobalCounters.global_ops += self.op_estimate
|
GlobalCounters.global_ops += self.op_estimate
|
||||||
|
|||||||
Reference in New Issue
Block a user