exclude GPU on tiny (#766)

This commit is contained in:
George Hotz
2023-05-05 10:07:23 -07:00
committed by GitHub
parent f2a964f447
commit 81aa3e546b
4 changed files with 18 additions and 6 deletions

View File

@@ -21,7 +21,7 @@ repos:
pass_filenames: false pass_filenames: false
- id: tests - id: tests
name: subset of (CPU) tests name: subset of (CPU) tests
entry: env CPU=1 pytest test/unit/ test/test_ops.py entry: env CPU=1 EXCLUDE_DEVICES=GPU pytest test/unit/ test/test_ops.py
language: system language: system
always_run: true always_run: true
pass_filenames: false pass_filenames: false

View File

@@ -2,11 +2,16 @@ import unittest
import numpy as np import numpy as np
from tinygrad.lazy import Device from tinygrad.lazy import Device
from tinygrad.tensor import Tensor from tinygrad.tensor import Tensor
from tinygrad.helpers import dtypes from tinygrad.helpers import getenv
def multidevice_test(fxn): def multidevice_test(fxn):
exclude_devices = getenv("EXCLUDE_DEVICES", "").split(",")
def ret(self): def ret(self):
for device in Device._buffers: for device in Device._buffers:
print(device)
if device in exclude_devices:
print(f"WARNING: {device} test is excluded")
continue
with self.subTest(device=device): with self.subTest(device=device):
try: try:
Device[device] Device[device]

View File

@@ -460,6 +460,13 @@ class Linearizer:
else: else:
break break
# if last dim <= 5 and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS # if last dim <= 16 and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS
if self.first_reduce < (self.shape_len-self.upcasted) and self.full_unupcasted_shape[-1] <= 32 and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))): if self.first_reduce < (self.shape_len-self.upcasted) and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))):
self.upcast() if self.full_unupcasted_shape[-1] <= 16:
self.upcast()
else:
for splits in [16,8,4]:
if self.full_unupcasted_shape[-1]%splits == 0:
self.shift_to(len(self.full_unupcasted_shape)-1, splits, insert_before=len(self.full_unupcasted_shape))
self.upcast()
break

View File

@@ -95,7 +95,7 @@ class ASTRunner:
if getenv("OPTLOCAL") and self.global_size is not None and self.local_size is None: self.local_size = self.optimize_local_size(rawbufs, allow_cache=(getenv("OPTLOCAL") >= 2)) if getenv("OPTLOCAL") and self.global_size is not None and self.local_size is None: self.local_size = self.optimize_local_size(rawbufs, allow_cache=(getenv("OPTLOCAL") >= 2))
if et := self.clprg(self.global_size, self.local_size, *rawbufs, wait=force_wait or DEBUG>=2): GlobalCounters.time_sum_s += et if et := self.clprg(self.global_size, self.local_size, *rawbufs, wait=force_wait or DEBUG>=2): GlobalCounters.time_sum_s += et
if DEBUG >= 2: if DEBUG >= 2:
print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(27-len(self.name))) if self.display_name is not None else self.name:26s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {self.op_estimate/1e6:7.1f}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " + print(f"{colored(f'*** {GlobalCounters.kernel_count:4d}', 'magenta' if jit else None)} {(self.display_name+' '*(28-len(self.name))) if self.display_name is not None else self.name:26s} arg {len(rawbufs):3d} sz {str(self.global_size):18s} {str(self.local_size):12s} OPs {int(self.op_estimate/1e6):6d}M/{GlobalCounters.global_ops/1e9:7.2f}G mem {GlobalCounters.mem_used/1e9:5.2f} GB " +
(str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/(et*1e9):8.2f} GFLOPS, {self.mem_estimate/(et*1e9):7.2f} GB/s)")) (str() if et is None else f"tm {et*1e6:9.2f}us/{GlobalCounters.time_sum_s*1e3:9.2f}ms ({self.op_estimate/(et*1e9):8.2f} GFLOPS, {self.mem_estimate/(et*1e9):7.2f} GB/s)"))
GlobalCounters.kernel_count += 1 GlobalCounters.kernel_count += 1
GlobalCounters.global_ops += self.op_estimate GlobalCounters.global_ops += self.op_estimate