diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 97b0264ea1..4e65939fdb 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -192,7 +192,7 @@ jobs: - name: Test dtype with Python emulator run: DEBUG=1 PYTHONPATH=. PYTHON=1 python3 -m pytest -n=auto test/test_dtype.py test/test_dtype_alu.py - name: Test ops with Python emulator - run: DEBUG=2 PYTHON=1 python3 -m pytest -n=auto test/test_ops.py -k "not (test_split or test_simple_cumsum or test_cumsum or test_einsum or test_dot or test_dot_1d or test_big_gemm or test_broadcastdot or test_multidot or test_var_axis or test_std_axis or test_broadcast_full or test_broadcast_partial or test_simple_conv3d or test_dilated_conv_transpose2d or test_simple_conv_transpose3d or test_large_input_conv2d or test_maxpool2d or test_maxpool2d_simple or test_maxpool2d_bigger_stride or test_avgpool2d or test_cat or test_scaled_product_attention or test_scaled_product_attention_causal or test_slice_fancy_indexing_dim_inject_none or test_slice_fancy_indexing_list_indices or test_slice_fancy_indexing_no_dim_collapse or test_slice_fancy_indexing_tuple_indices or test_slice_fancy_indexing_list_with_tensors or test_slice_fancy_indexing_dim_collapse_int or test_interpolate_bilinear or test_interpolate_bilinear_corners_aligned)" --durations=20 + run: DEBUG=2 PYTHON=1 python3 -m pytest -n=auto test/test_ops.py -k "not (test_split or test_simple_cumsum or test_cumsum or test_einsum or test_dot or test_dot_1d or test_big_gemm or test_broadcastdot or test_multidot or test_var_axis or test_std_axis or test_broadcast_full or test_broadcast_partial or test_simple_conv3d or test_dilated_conv_transpose2d or test_simple_conv_transpose3d or test_large_input_conv2d or test_max_pool2d or test_max_pool2d_simple or test_max_pool2d_bigger_stride or test_avg_pool2d or test_cat or test_scaled_product_attention or test_scaled_product_attention_causal or test_slice_fancy_indexing_dim_inject_none or test_slice_fancy_indexing_list_indices or test_slice_fancy_indexing_no_dim_collapse or test_slice_fancy_indexing_tuple_indices or test_slice_fancy_indexing_list_with_tensors or test_slice_fancy_indexing_dim_collapse_int or test_interpolate_bilinear or test_interpolate_bilinear_corners_aligned)" --durations=20 - name: Test uops with Python emulator run: PYTHON=1 python3 -m pytest test/test_uops.py --durations=20 - name: Test symbolic with Python emulator diff --git a/docs/abstractions2.py b/docs/abstractions2.py index e2690baab8..09bdcdc466 100644 --- a/docs/abstractions2.py +++ b/docs/abstractions2.py @@ -39,7 +39,7 @@ DEVICE = "CLANG" # NOTE: you can change this! import struct from tinygrad.dtype import dtypes from tinygrad.device import Buffer, Device -from tinygrad.ops import BinaryOps, MetaOps, UOp, Ops +from tinygrad.ops import UOp, Ops from tinygrad.shape.shapetracker import ShapeTracker # allocate some buffers + load in values @@ -81,15 +81,15 @@ from tinygrad.engine.realize import run_schedule from tinygrad.engine.schedule import create_schedule # allocate some values + load in values -a = LazyBuffer.metaop(MetaOps.EMPTY, (1,), dtypes.int32, DEVICE) -b = LazyBuffer.metaop(MetaOps.EMPTY, (1,), dtypes.int32, DEVICE) +a = LazyBuffer.metaop(Ops.EMPTY, (1,), dtypes.int32, DEVICE) +b = LazyBuffer.metaop(Ops.EMPTY, (1,), dtypes.int32, DEVICE) a.buffer.allocate().copyin(memoryview(bytearray(struct.pack("I", 2)))) b.buffer.allocate().copyin(memoryview(bytearray(struct.pack("I", 3)))) del a.srcs del b.srcs # describe the computation -out = a.alu(BinaryOps.ADD, b) +out = a.alu(Ops.ADD, b) # schedule the computation as a list of kernels sched = create_schedule([out]) diff --git a/docs/tensor/movement.md b/docs/tensor/movement.md index a68770b230..acbc90819c 100644 --- a/docs/tensor/movement.md +++ b/docs/tensor/movement.md @@ -17,9 +17,9 @@ ::: tinygrad.Tensor.repeat_interleave ::: tinygrad.Tensor.split ::: tinygrad.Tensor.chunk +::: tinygrad.Tensor.meshgrid ::: tinygrad.Tensor.squeeze ::: tinygrad.Tensor.unsqueeze -::: tinygrad.Tensor.pad2d ::: tinygrad.Tensor.T ::: tinygrad.Tensor.transpose ::: tinygrad.Tensor.flatten diff --git a/examples/hlb_cifar10.py b/examples/hlb_cifar10.py index cea7953a74..78c59bdb18 100644 --- a/examples/hlb_cifar10.py +++ b/examples/hlb_cifar10.py @@ -114,7 +114,7 @@ class SpeedyResNet: def __call__(self, x, training=True): # pad to 32x32 because whitening conv creates 31x31 images that are awfully slow to compute with # TODO: remove the pad but instead let the kernel optimize itself - forward = lambda x: x.conv2d(self.whitening).pad2d((1,0,0,1)).sequential(self.net) + forward = lambda x: x.conv2d(self.whitening).pad((1,0,0,1)).sequential(self.net) return forward(x) if training else (forward(x) + forward(x[..., ::-1])) / 2. # hyper-parameters were exactly the same as the original repo diff --git a/examples/llm.c/export.py b/examples/llm.c/export.py index 46e36dddfc..c0d52f32cd 100755 --- a/examples/llm.c/export.py +++ b/examples/llm.c/export.py @@ -8,7 +8,7 @@ from tinygrad.helpers import dedup, to_function_name, flatten, getenv, GlobalCou from tinygrad.engine.schedule import create_schedule from tinygrad.engine.realize import get_kernel, run_schedule from tinygrad.engine.memory import memory_planner -from tinygrad.ops import MetaOps, Ops +from tinygrad.ops import Ops TIMING = getenv("TIMING") diff --git a/examples/serious_mnist.py b/examples/serious_mnist.py index b0c4c69ae5..752f5dab41 100644 --- a/examples/serious_mnist.py +++ b/examples/serious_mnist.py @@ -45,7 +45,7 @@ class ConvBlock: def __call__(self, input): x = input.reshape(shape=(-1, self.inp, self.w, self.h)) for cweight, cbias in zip(self.cweights, self.cbiases): - x = x.pad2d(padding=[1,1,1,1]).conv2d(cweight).add(cbias).relu() + x = x.pad(padding=[1,1,1,1]).conv2d(cweight).add(cbias).relu() x = self._bn(x) x = self._seb(x) return x diff --git a/examples/so_vits_svc.py b/examples/so_vits_svc.py index 9cab45f0ab..00c07e9f5a 100644 --- a/examples/so_vits_svc.py +++ b/examples/so_vits_svc.py @@ -361,21 +361,21 @@ class SineGen: self.dim = self.harmonic_num + 1 def _f02uv(self, f0): return (f0 > self.voiced_threshold).float() #generate uv signal def _f02sine(self, f0_values): - def padDiff(x : Tensor): return (x.pad2d((0,0,-1,1)) - x).pad2d((0,0,0,-1)) + def padDiff(x : Tensor): return (x.pad((0,0,-1,1)) - x).pad((0,0,0,-1)) def mod(x: Tensor, n: int) -> Tensor: return x - n * x.div(n).floor() # this is what the % operator does in pytorch. rad_values = mod((f0_values / self.sampling_rate) , 1) # convert to F0 in rad rand_ini = Tensor.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device) # initial phase noise #rand_ini[:, 0] = 0 - m = Tensor.ones(f0_values.shape[0]).unsqueeze(1).pad2d((0,f0_values.shape[2]-1,0,0)).cast(dtypes.bool) + m = Tensor.ones(f0_values.shape[0]).unsqueeze(1).pad((0,f0_values.shape[2]-1,0,0)).cast(dtypes.bool) m = tilde(m) rand_ini = m.where(rand_ini, 0) #rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini tmp = rad_values[:, 0, :] + rand_ini - m = Tensor.ones(tmp.shape).pad2d((0,0,0,rad_values.shape[1]-1,0)).cast(dtypes.bool) + m = Tensor.ones(tmp.shape).pad((0,0,0,rad_values.shape[1]-1,0)).cast(dtypes.bool) m = tilde(m) - tmp = tmp.unsqueeze(1).pad2d((0,0,0,rad_values.shape[1]-1,0)) + tmp = tmp.unsqueeze(1).pad((0,0,0,rad_values.shape[1]-1,0)) rad_values = m.where(rad_values, tmp) tmp_over_one = mod(rad_values.cumsum(1), 1) @@ -383,7 +383,7 @@ class SineGen: cumsum_shift = Tensor.zeros_like(rad_values) #cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0 - tmp_over_one_idx = (tmp_over_one_idx * -1.0).pad2d((0,0,1,0)) + tmp_over_one_idx = (tmp_over_one_idx * -1.0).pad((0,0,1,0)) cumsum_shift = tmp_over_one_idx sines = ((rad_values + cumsum_shift).cumsum(1) * 2 * np.pi).sin() diff --git a/examples/sovits_helpers/preprocess.py b/examples/sovits_helpers/preprocess.py index 88b4332192..17a265040a 100644 --- a/examples/sovits_helpers/preprocess.py +++ b/examples/sovits_helpers/preprocess.py @@ -137,7 +137,7 @@ class Resample: waveform = waveform.reshape(-1, shape[-1]) # pack batch num_wavs, length = waveform.shape target_length = int(math.ceil(new_freq * length / orig_freq)) - waveform = waveform.pad2d((self.width, self.width + orig_freq)) + waveform = waveform.pad((self.width, self.width + orig_freq)) resampled = waveform[:, None].conv2d(self.kernel, stride=orig_freq) resampled = resampled.transpose(1, 2).reshape(num_wavs, -1) resampled = resampled[..., :target_length] diff --git a/examples/yolov8.py b/examples/yolov8.py index 68e701b422..37396c7501 100644 --- a/examples/yolov8.py +++ b/examples/yolov8.py @@ -282,7 +282,7 @@ class SPPF: self.cv2 = Conv_Block(c_ * 4, c2, 1, 1, padding=None) # TODO: this pads with 0s, whereas torch function pads with -infinity. This results in a < 2% difference in prediction which does not make a difference visually. - self.maxpool = lambda x : x.pad2d((k // 2, k // 2, k // 2, k // 2)).max_pool2d(kernel_size=k, stride=1) + self.maxpool = lambda x : x.pad((k // 2, k // 2, k // 2, k // 2)).max_pool2d(kernel_size=k, stride=1) def __call__(self, x): x = self.cv1(x) diff --git a/extra/debug_sd_speed.py b/extra/debug_sd_speed.py new file mode 100644 index 0000000000..9d7af29cff --- /dev/null +++ b/extra/debug_sd_speed.py @@ -0,0 +1,99 @@ +# NOTE: this is written in a way that checkout back to old commit still works +# fast SD 297ms step on M1 Max, 4444e6d https://github.com/tinygrad/tinygrad/pull/2129 +# lazy rewrite, 1765849 https://github.com/tinygrad/tinygrad/pull/2878 +# SD 415ms step on M1 Max on master around 11/15/2024 + +import time +from typing import Optional +try: from tinygrad.jit import TinyJit +except ImportError: from tinygrad import TinyJit +from tinygrad.tensor import Tensor, Device +from tinygrad.helpers import GlobalCounters +from tinygrad.nn import Linear, LayerNorm +from tinygrad.nn.state import get_parameters + +class CrossAttention: + def __init__(self, query_dim:int, ctx_dim:int, n_heads:int, d_head:int): + self.to_q = Linear(query_dim, n_heads*d_head, bias=False) + self.to_k = Linear(ctx_dim, n_heads*d_head, bias=False) + self.to_v = Linear(ctx_dim, n_heads*d_head, bias=False) + self.num_heads = n_heads + self.head_size = d_head + self.to_out = [Linear(n_heads*d_head, query_dim)] + + def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor: + ctx = x if ctx is None else ctx + q,k,v = self.to_q(x), self.to_k(ctx), self.to_v(ctx) + q,k,v = [y.reshape(x.shape[0], -1, self.num_heads, self.head_size).transpose(1,2) for y in (q,k,v)] + attention = Tensor.scaled_dot_product_attention(q, k, v).transpose(1,2) + h_ = attention.reshape(x.shape[0], -1, self.num_heads * self.head_size) + return h_.sequential(self.to_out) + +class GEGLU: + def __init__(self, dim_in:int, dim_out:int): + self.proj = Linear(dim_in, dim_out * 2) + self.dim_out = dim_out + + def __call__(self, x:Tensor) -> Tensor: + x, gate = self.proj(x).chunk(2, dim=-1) + return x * gate.gelu() + +class FeedForward: + def __init__(self, dim:int, mult:int=4): + self.net = [ + GEGLU(dim, dim*mult), + lambda x: x, # needed for weights loading code to work + Linear(dim*mult, dim) + ] + + def __call__(self, x:Tensor) -> Tensor: + return x.sequential(self.net) + +class BasicTransformerBlock: + def __init__(self, dim:int, ctx_dim:int, n_heads:int, d_head:int): + self.attn1 = CrossAttention(dim, dim, n_heads, d_head) + self.ff = FeedForward(dim) + self.attn2 = CrossAttention(dim, ctx_dim, n_heads, d_head) + self.norm1 = LayerNorm(dim) + self.norm2 = LayerNorm(dim) + self.norm3 = LayerNorm(dim) + + def __call__(self, x:Tensor, ctx:Optional[Tensor]=None) -> Tensor: + x = x + self.attn1(self.norm1(x)) # 5.4 before, # 6.8 master + x = x + self.attn2(self.norm2(x), ctx=ctx) # 12 before, 12 master + x = x + self.ff(self.norm3(x)) # 23 before, # 27 master + return x + +def helper_test(gen, model): + tms = [] + for _ in range(5): + early_gen = [x.realize() if isinstance(x, Tensor) else x for x in gen()] + GlobalCounters.reset() + Device[Device.DEFAULT].synchronize() + st = time.perf_counter_ns() + model(*early_gen) + Device[Device.DEFAULT].synchronize() + tms.append(time.perf_counter_ns() - st) + print(f"{min(tms)/1e6=:.2f} ms") + +def derandomize_model(model): + for p in get_parameters(model): + p.lazydata = Tensor.empty(*p.shape, device=p.device, dtype=p.dtype).lazydata + p.realize() + +def test_transformer_block(): + # dim, d_head, x = 320, 40, (4096, 320) # 137ms 4444e6d 115ms master + # dim, d_head, x = 640, 80, (1024, 640) # 36ms 4444e6d, 31ms master + dim, d_head, x = 1280, 160, (256, 1280) # 23ms 4444e6d, 28ms master, 31ms on 176584993 + + model = [BasicTransformerBlock(dim, 768, 8, d_head) for _ in range(4)] + + derandomize_model(model) + @TinyJit + def test(t, t2): + for l in model: t = l(t, t2) + return t.realize() + helper_test(lambda: (Tensor.empty(2, *x), Tensor.empty(2, 77, 768)), test) + +if __name__ == "__main__": + test_transformer_block() diff --git a/extra/models/resnet.py b/extra/models/resnet.py index 334ca1bea3..60b7cdedcb 100644 --- a/extra/models/resnet.py +++ b/extra/models/resnet.py @@ -107,7 +107,7 @@ class ResNet: is_feature_only = self.fc is None if is_feature_only: features = [] out = self.bn1(self.conv1(x)).relu() - out = out.pad2d([1,1,1,1]).max_pool2d((3,3), 2) + out = out.pad([1,1,1,1]).max_pool2d((3,3), 2) out = out.sequential(self.layer1) if is_feature_only: features.append(out) out = out.sequential(self.layer2) diff --git a/extra/models/retinanet.py b/extra/models/retinanet.py index 415dd5de25..8e188787dc 100644 --- a/extra/models/retinanet.py +++ b/extra/models/retinanet.py @@ -170,7 +170,7 @@ class ResNetFPN: def __call__(self, x): out = self.body.bn1(self.body.conv1(x)).relu() - out = out.pad2d([1,1,1,1]).max_pool2d((3,3), 2) + out = out.pad([1,1,1,1]).max_pool2d((3,3), 2) out = out.sequential(self.body.layer1) p3 = out.sequential(self.body.layer2) p4 = p3.sequential(self.body.layer3) diff --git a/extra/onnx_ops.py b/extra/onnx_ops.py index 252d10ba08..f8067511c1 100644 --- a/extra/onnx_ops.py +++ b/extra/onnx_ops.py @@ -282,7 +282,7 @@ def MaxUnpool(xT: Tensor, xI: Tensor, outshape: Optional[Tensor]=None, kernel_sh if outshape is not None and (outshape := to_python_const(outshape)) != ret.shape: diff = [outshape[2] - ret.shape[2], outshape[3] - ret.shape[3]] pad_args = [diff[0]//2, diff[1]//2, diff[0]-diff[0]//2, diff[1]-diff[1]//2] - ret = ret.pad2d((pad_args[1], pad_args[3], pad_args[0], pad_args[2])) + ret = ret.pad((pad_args[1], pad_args[3], pad_args[0], pad_args[2])) return ret def Conv(X: Tensor, W: Tensor, B:Optional[Tensor]=None, auto_pad="NOTSET", dilations=1, group=1, kernel_shape=None, pads=None, strides=1): @@ -334,7 +334,7 @@ def Dropout(data: Tensor, ratio=0.5, training_mode=False, seed=None): def LRN(x: Tensor, size, alpha=1e-4, beta=0.75, bias=1.0): bs, c, iy, ix = x.shape - return x / x.mul(x).reshape(bs,1,c,iy*ix).pad2d((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1).reshape(bs,c,iy,ix).mul(alpha).add(bias).pow(beta) + return x / x.mul(x).reshape(bs,1,c,iy*ix).pad((0,0,(size-1)//2, size//2)).avg_pool2d((size, 1), 1).reshape(bs,c,iy,ix).mul(alpha).add(bias).pow(beta) def MeanVarianceNormalization(x: Tensor, axis=(0, 2, 3)): mean = x.mean(axis, keepdim=True) @@ -563,29 +563,13 @@ def ImageDecoder(encoded_stream: Tensor, pixel_format="RGB"): raise ValueError(f"pixel_format={pixel_format!r} is not supported.") def AffineGrid(theta: Tensor, size: Tensor, align_corners=0): - _, _, *data_sz = to_python_const(size) - size_zeros, original_grid = Tensor.zeros(data_sz), Tensor.ones(data_sz) - stackable = [original_grid] - for dim, dim_sz in enumerate(data_sz): - a = Tensor.arange(-1, 1.0001, 2/(dim_sz-1)) if align_corners == 1 else Tensor.arange(-1+1/dim_sz, 1, 2/dim_sz) - if dim == 0: stackable = [a.reshape(dim_sz, *[1]*(len(data_sz)-1)) + size_zeros, *stackable] - elif dim == 1: stackable = [a.reshape(1, dim_sz, *[1]*(len(data_sz)-2)) + size_zeros, *stackable] - else: stackable = [a.reshape(1, dim_sz) + size_zeros, *stackable] - original_grid = Tensor.stack(*stackable, dim=len(data_sz)) - if original_grid.ndim == 3: - N, dim_2d, dim_homo = theta.shape - assert dim_2d == 2 and dim_homo == 3 - H, W, dim_homo = original_grid.shape - assert dim_homo == 3 - original_grid = original_grid.reshape(H*W, dim_homo).transpose() - return theta.matmul(original_grid).permute(0,2,1).reshape(N, H, W, dim_2d) - assert original_grid.ndim == 4 - N, dim_3d, dim_homo = theta.shape - assert dim_3d == 3 and dim_homo == 4 - D, H, W, dim_homo = original_grid.shape - assert dim_homo == 4 - original_grid = original_grid.reshape(D*H*W, dim_homo).transpose() - return theta.matmul(original_grid).permute(0,2,1).reshape(N, D, H, W, dim_3d) + N, _, *spatial_dims = to_python_const(size) + def generate_grid(steps): + return Tensor.linspace(-1, 1, steps, device=theta.device) if align_corners else Tensor.linspace(-1+1/steps, 1-1/steps, steps, device=theta.device) + grids = Tensor.meshgrid(*(generate_grid(d) for d in spatial_dims)) + base_grid = Tensor.stack(*reversed(grids), Tensor.ones_like(grids[0], device=theta.device), dim=-1) + base_grid = base_grid.reshape(1, prod(spatial_dims), len(grids)+1).expand(N, -1, -1) + return (base_grid @ theta.transpose(1, 2)).reshape(N, *spatial_dims, -1) # **************** com.microsoft Ops **************** diff --git a/extra/optimization/helpers.py b/extra/optimization/helpers.py index 342c4c27b6..7de5a24d33 100644 --- a/extra/optimization/helpers.py +++ b/extra/optimization/helpers.py @@ -2,7 +2,7 @@ from typing import Tuple from tinygrad import Variable from tinygrad.codegen.kernel import Opt, OptOps -from tinygrad.ops import UOp, Ops, KernelInfo, TernaryOps, BinaryOps, UnaryOps, MetaOps +from tinygrad.ops import UOp, Ops, KernelInfo from tinygrad.dtype import dtypes, PtrDType from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View diff --git a/test/external/external_test_nv.py b/test/external/external_test_nv.py index d4d924287e..58e556b44e 100644 --- a/test/external/external_test_nv.py +++ b/test/external/external_test_nv.py @@ -9,7 +9,7 @@ from tinygrad.engine.realize import get_runner, CompiledRunner from test.external.fuzz_linearizer import get_fuzz_rawbufs from tinygrad.codegen.kernel import Kernel -from tinygrad.ops import LazyOp, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer +from tinygrad.ops import LazyOp, Ops, ReduceOps, BufferOps, MemBuffer from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View @@ -26,12 +26,12 @@ class TestNV(unittest.TestCase): TestNV.addr = struct.pack("QQ", TestNV.b.lazydata.buffer._buf.va_addr, TestNV.a.lazydata.buffer._buf.va_addr) def test_oor_kernels(self): - ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 256, 1, 512, 4, 16, 4, 16), strides=(0, 100352, 0, 196, 0, 14, 0, 1), offset=-15, mask=((0, 1), (0, 256), (0, 1), (0, 512), (0, 4), (1, 15), (0, 4), (1, 15)), contiguous=False), View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(2097152, 0, 0, 128, 2, 4096, 1088, 17), offset=0, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(dtypes.float, False)),), arg=((0, 3, 4), dtypes.float)),), arg=(dtypes.half, False)),), arg=MemBuffer(idx=0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 512, 1, 1, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501 + ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=Ops.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 256, 1, 512, 4, 16, 4, 16), strides=(0, 100352, 0, 196, 0, 14, 0, 1), offset=-15, mask=((0, 1), (0, 256), (0, 1), (0, 512), (0, 4), (1, 15), (0, 4), (1, 15)), contiguous=False), View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(2097152, 0, 0, 128, 2, 4096, 1088, 17), offset=0, mask=None, contiguous=False))))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(256, 1, 512, 7, 7, 512, 3, 3), strides=(25088, 0, 49, 7, 1, 0, 0, 0), offset=0, mask=None, contiguous=False),))))), arg=None),), arg=(dtypes.float, False)),), arg=((0, 3, 4), dtypes.float)),), arg=(dtypes.half, False)),), arg=MemBuffer(idx=0, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 512, 1, 1, 512, 3, 3), strides=(0, 0, 4608, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)))) # noqa: E501 opts = [Opt(op=OptOps.TC, axis=6, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=4), Opt(op=OptOps.UPCAST, axis=3, amt=0), Opt(op=OptOps.LOCAL, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=3), Opt(op=OptOps.UPCAST, axis=1, amt=2)] # noqa: E501 helper_test_lin(Kernel(ast), opts=opts, failed_platforms=["NV"]) def test_error_on_huge_dims(self): - ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BinaryOps.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 683), strides=(0, 0, 0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 683), strides=(0, 0, 683, 1), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=dtypes.float),), arg=(3,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501 + ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=ReduceOps.SUM, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=Ops.MUL, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 683), strides=(0, 0, 0, 1), offset=0, mask=None, contiguous=False),)))), LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=2, dtype=dtypes.half, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 683), strides=(0, 0, 683, 1), offset=0, mask=None, contiguous=True),))))), arg=None),), arg=dtypes.float),), arg=(3,)),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(1, 1, 1024, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=True),)))) # noqa: E501 opts = [Opt(op=OptOps.GROUP, axis=0, amt=0), Opt(op=OptOps.PADTO, axis=1, amt=32), Opt(op=OptOps.UNROLL, axis=0, amt=4), Opt(op=OptOps.LOCAL, axis=0, amt=2), Opt(op=OptOps.LOCAL, axis=0, amt=2)] # noqa: E501 with self.assertRaises(RuntimeError) as cm: lin = Kernel(ast) @@ -43,7 +43,7 @@ class TestNV(unittest.TestCase): def test_buf4_usage(self): TestNV.along = Tensor([105615], device="NV").realize() - ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=UnaryOps.SIN, src=(LazyOp(op=UnaryOps.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.ulong, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))),), arg=dtypes.float),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))) # noqa: E501 + ast = LazyOp(op=BufferOps.STORE, src=(LazyOp(op=Ops.SIN, src=(LazyOp(op=Ops.CAST, src=(LazyOp(op=BufferOps.LOAD, src=(), arg=MemBuffer(idx=1, dtype=dtypes.ulong, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))),), arg=dtypes.float),), arg=None),), arg=MemBuffer(idx=0, dtype=dtypes.float, st=ShapeTracker(views=(View(shape=(3,), strides=(1,), offset=0, mask=None, contiguous=True),)))) # noqa: E501 temp_runner = get_runner(TestNV.d0.dname, (ast,)) temp_runner([TestNV.b.lazydata.buffer, TestNV.along.lazydata.buffer], var_vals={}) val = TestNV.b.lazydata.buffer.as_buffer().cast("f")[0] diff --git a/test/external/external_test_valid_remove.py b/test/external/external_test_valid_remove.py index 24f94c640a..1467139daf 100644 --- a/test/external/external_test_valid_remove.py +++ b/test/external/external_test_valid_remove.py @@ -2,7 +2,7 @@ import unittest from tinygrad import Device -from tinygrad.ops import UOp, Ops, BinaryOps +from tinygrad.ops import UOp, Ops from tinygrad.engine.search import Opt, OptOps from tinygrad.dtype import dtypes from tinygrad.shape.shapetracker import ShapeTracker @@ -20,7 +20,7 @@ class TestOpenpilotValidhack(unittest.TestCase): UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.MAX, dtypes.float, arg=None, src=( x5:=UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 8, 9, 10)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 8, 9, 10)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( diff --git a/test/external/fuzz_linearizer.py b/test/external/fuzz_linearizer.py index 296ec2c60e..49199c11c5 100644 --- a/test/external/fuzz_linearizer.py +++ b/test/external/fuzz_linearizer.py @@ -25,7 +25,7 @@ from tinygrad.codegen.kernel import Opt, OptOps from tinygrad.engine.search import get_kernel_actions, bufs_from_lin from tinygrad.engine.realize import CompiledRunner from tinygrad.helpers import getenv, from_mv, prod, colored, Context, DEBUG, Timing -from tinygrad.ops import UnaryOps, UOp, Ops +from tinygrad.ops import UOp, Ops from tinygrad.device import is_dtype_supported def on_linearizer_will_run(): pass @@ -252,7 +252,7 @@ def fuzz_linearizer(lin: Kernel, rtol=1e-2, atol=1e-2, opts_list=None): def _is_simple(lin: Kernel) -> bool: if len(lin.ast.src) > 1: return False ast:UOp = lin.ast.src[0] - if ast.src[0].op is UnaryOps.CAST and ast.src[0].src[0].op is Ops.LOAD: return True + if ast.src[0].op is Ops.CAST and ast.src[0].src[0].op is Ops.LOAD: return True return False if __name__ == "__main__": diff --git a/test/external/fuzz_schedule.py b/test/external/fuzz_schedule.py index b9ab2fed86..059edb37b8 100644 --- a/test/external/fuzz_schedule.py +++ b/test/external/fuzz_schedule.py @@ -6,7 +6,7 @@ from tinygrad.engine.realize import capturing, lower_schedule_item from tinygrad.helpers import DEBUG, MULTIOUTPUT, colored, getenv from tinygrad.engine.lazy import LazyBuffer from tinygrad.engine.schedule import LBScheduleItem, _graph_schedule, ScheduleItem -from tinygrad.ops import MetaOps +from tinygrad.ops import Ops from tinygrad.tensor import Tensor, _to_np_dtype ctx_vars = { MULTIOUTPUT: (0, 1) } @@ -33,7 +33,7 @@ def fuzz_schedule(outs:List[LazyBuffer]): for lsi in ts: for out in lsi.outputs: # freeze assign state before exec - if out.op is MetaOps.ASSIGN: + if out.op is Ops.ASSIGN: prerealized[out] = out.buffer.as_buffer() assign_targets[out.srcs[1]] = out for x in lsi.inputs: @@ -50,9 +50,9 @@ def fuzz_schedule(outs:List[LazyBuffer]): rawbufs: Dict[LazyBuffer, Buffer] = {} for lsi in ts: for out in lsi.outputs: - base = rawbufs[lsi.inputs[0]].base if out.op is MetaOps.BUFFER_VIEW else None + base = rawbufs[lsi.inputs[0]].base if out.op is Ops.BUFFER_VIEW else None rawbufs[out] = Buffer(out.buffer.device, out.buffer.size, out.buffer.dtype, base=base) - if out.op is MetaOps.ASSIGN: rawbufs[out].ensure_allocated().copyin(prerealized[out]) + if out.op is Ops.ASSIGN: rawbufs[out].ensure_allocated().copyin(prerealized[out]) for x in lsi.inputs: if x not in rawbufs: # override the assign_target after ASSIGN diff --git a/test/external/fuzz_symbolic.py b/test/external/fuzz_symbolic.py index bee91cfa57..42798bd181 100644 --- a/test/external/fuzz_symbolic.py +++ b/test/external/fuzz_symbolic.py @@ -42,9 +42,9 @@ def gt(expr, rng=None): return expr > rng, rng # NOTE: you have to replace these for this test to pass -from tinygrad.ops import python_alu, BinaryOps -python_alu[BinaryOps.MOD] = lambda x,y: x%y -python_alu[BinaryOps.IDIV] = lambda x,y: x//y +from tinygrad.ops import python_alu, Ops +python_alu[Ops.MOD] = lambda x,y: x%y +python_alu[Ops.IDIV] = lambda x,y: x//y if __name__ == "__main__": ops = [add_v, div, mul, add_num, mod] diff --git a/test/test_compile_failures.py b/test/test_compile_failures.py index 23d3414397..5ab87c0b62 100644 --- a/test/test_compile_failures.py +++ b/test/test_compile_failures.py @@ -11,7 +11,6 @@ class TestCompileFailures(unittest.TestCase): def test_interpolate_atari(self): self.compile(Tensor.empty(210, 160, dtype='uint8').interpolate((64, 64))) - @unittest.skip("FIXME: broken on METAL") def test_add_max_uchar(self): self.compile((Tensor.empty(1024, dtype='uint8') + Tensor.empty(1024, dtype='uint8')).max()) diff --git a/test/test_const_folding.py b/test/test_const_folding.py index 6d4b681148..3763e864f6 100644 --- a/test/test_const_folding.py +++ b/test/test_const_folding.py @@ -154,7 +154,7 @@ class TestReduceOpsConstFolding(unittest.TestCase): _check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).sum()) np.testing.assert_equal(Tensor.ones(4).pad(((1, 1),)).sum().numpy(), 4) - # NOTE: cannot just count the non-padded area because some UnaryOps f do not have f(0) = 0. + # NOTE: cannot just count the non-padded area because some Ops f do not have f(0) = 0. _check_ast_count(1, Tensor.ones(4).pad(((1, 1),)).exp().sum()) np.testing.assert_allclose(Tensor.ones(4).pad(((1, 1),)).exp().sum().numpy(), 4 * math.e + 2) @@ -251,7 +251,6 @@ class TestTautologicalCompare(unittest.TestCase): np.testing.assert_equal((Tensor(True) < Tensor(False)).numpy(), False) np.testing.assert_equal((Tensor(True) < Tensor(True)).numpy(), False) - @unittest.skip("not implemented yet") def test_a_eq_a(self): # self eq is always true for int or bool a = Tensor([1, 2, 3]) @@ -261,7 +260,6 @@ class TestTautologicalCompare(unittest.TestCase): a = Tensor([math.nan, 1.0, 2.0]) np.testing.assert_equal((a == a).numpy(), [False, True, True]) - @unittest.skip("not implemented yet") def test_a_ne_a(self): # self not eq is always false for int or bool a = Tensor([1, 2, 3]) diff --git a/test/test_fusion_op.py b/test/test_fusion_op.py index 2322ab6204..b979ef2b2f 100644 --- a/test/test_fusion_op.py +++ b/test/test_fusion_op.py @@ -46,5 +46,24 @@ class TestFusionOp(unittest.TestCase): with self.assertRaises(AssertionError): self.assertEqual(sched1[-1].ast, sched3[-1].ast) self.assertLess(time.perf_counter()-st, 2.0) + def test_recursive_pad(self): + st = time.perf_counter() + val = 1.0 + a = Tensor(val).realize() + for _ in range(24): a = Tensor.stack(a, a)[0] + r = a.item() + self.assertEqual(r, val) + self.assertLess(time.perf_counter()-st, 2.0) + + def test_recursive_reshape(self): + st = time.perf_counter() + a = Tensor.empty(32, 32).realize() + b = Tensor.empty(16, 2).realize() + r = a.sum(1) + for _ in range(24): r = r.reshape(16, 2) + b + sched = r.schedule() + self.assertEqual(len(sched), 1) + self.assertLess(time.perf_counter()-st, 2.0) + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/test/test_lazybuffer.py b/test/test_lazybuffer.py index 168720e2d2..6ea598c2ec 100644 --- a/test/test_lazybuffer.py +++ b/test/test_lazybuffer.py @@ -2,8 +2,9 @@ import numpy as np import unittest from tinygrad import Tensor, Device, dtypes +from tinygrad.engine.realize import run_schedule from tinygrad.ops import Ops -from tinygrad.engine.lazy import LazyBuffer, MetaOps +from tinygrad.engine.lazy import LazyBuffer from tinygrad.engine.schedule import create_schedule class TestLazyBuffer(unittest.TestCase): @@ -69,6 +70,25 @@ class TestLazyBuffer(unittest.TestCase): assert lb.const_like(1).base.arg == 1.0 assert type(lb.const_like(1).base.arg) is float + def test_forced_realized_alu(self): + a = Tensor.randn(2, 2).realize() + b = Tensor.randn(2, 2).realize() + add = a + b + add.lazydata.forced_realize = True + out = add+2 + sched = create_schedule([out.lazydata]) + self.assertEqual(len(sched), 2) + run_schedule(sched) + np.testing.assert_allclose(out.numpy(), a.numpy()+b.numpy()+2) + + def test_forced_realized_metaop(self): + empty = Tensor.empty(1) + empty.lazydata.forced_realize = True + sched = create_schedule([empty.lazydata]) + self.assertEqual(len(sched), 1) + self.assertIs(sched[0].ast.op, Ops.EMPTY) + run_schedule(sched) + class TestReduceOp(unittest.TestCase): def test_no_split_reduce_kernel(self): a = Tensor.rand(4, 4).realize() @@ -95,24 +115,24 @@ class TestReduceOp(unittest.TestCase): class TestView(unittest.TestCase): def test_all_masked_out(self): - # start with non CONST MetaOps + # start with non CONST Ops a = Tensor.rand(10, 10) - assert a.lazydata.base.op is not MetaOps.CONST + assert a.lazydata.base.op is not Ops.CONST # all masked out, degrades to const 0 b = a.pad(((0, 10), None))[10:] assert b.shape == (10, 10) - assert b.lazydata.base.op is MetaOps.CONST and b.lazydata.base.arg == 0 + assert b.lazydata.base.op is Ops.CONST and b.lazydata.base.arg == 0 # mask out dim = 1 works too b = a.pad((None, (0, 10)))[:, 10:] assert b.shape == (10, 10) - assert b.lazydata.base.op is MetaOps.CONST and b.lazydata.base.arg == 0 + assert b.lazydata.base.op is Ops.CONST and b.lazydata.base.arg == 0 # partial masked out does not degrade into CONST b = a.pad(((0, 5), None))[5:] assert b.shape == (10, 10) - assert b.lazydata.base.op is not MetaOps.CONST + assert b.lazydata.base.op is not Ops.CONST if __name__ == "__main__": unittest.main() diff --git a/test/test_linearizer.py b/test/test_linearizer.py index 957f529c1b..d4bf0b3a55 100644 --- a/test/test_linearizer.py +++ b/test/test_linearizer.py @@ -6,7 +6,7 @@ from dataclasses import replace from test.helpers import ast_const from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError, Kernel from tinygrad.codegen.lowerer import get_grouped_dims -from tinygrad.ops import UOp, Ops, BinaryOps, TernaryOps, UnaryOps, GroupOp +from tinygrad.ops import UOp, Ops, GroupOp from tinygrad.device import Device, Buffer from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View @@ -109,10 +109,10 @@ class TestLinearizer(unittest.TestCase): st_x = x.lazydata.st g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] first_x = UOp(Ops.LOAD, dtypes.float, (g1, st_x.reshape((1, 32)).expand((32, 32)).to_uop())) - first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (1,))) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (1,))) second_x = UOp(Ops.LOAD, dtypes.float, (g1, st_x.reshape((32, 1)).to_uop())) diff = second_x + first_reduce*ast_const(dtypes.float, -1, (32, 1)) - second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (0,))) + second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (0,))) store = UOp(Ops.STORE, dtypes.void, (g0, ShapeTracker.from_shape((1, 1)).to_uop(), second_reduce)) sink = UOp(Ops.SINK, src=(store,)) opts = [ @@ -145,10 +145,10 @@ class TestLinearizer(unittest.TestCase): st_x = x.lazydata.st g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] first_x = UOp(Ops.LOAD, dtypes.float, (g1, st_x.reshape((27, 1, 32, 5)).expand((27, 32, 32, 5)).to_uop())) - first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,))) second_x = UOp(Ops.LOAD, dtypes.float, (g1, st_x.reshape((27, 32, 1, 5)).to_uop())) diff = second_x + first_reduce*ast_const(dtypes.float, -1, (27, 32, 1, 5)) - second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,))) + second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,))) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce)) sink = UOp(Ops.SINK, src=(store,)) opts = [ @@ -207,13 +207,13 @@ class TestLinearizer(unittest.TestCase): x2 = Tensor.randn(27, 32, 5, dtype=dtypes.float).realize() g0, g1, g2, g3 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(4)] first_x = UOp(Ops.LOAD, dtypes.float, (g1, x0.lazydata.st.reshape((27, 1, 1, 32, 5)).expand((27, 32, 32, 32, 5)).to_uop())) - first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,))) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,))) second_x = UOp(Ops.LOAD, dtypes.float, (g2, x1.lazydata.st.reshape((27, 1, 32, 1, 5)).expand((27, 32, 32, 1, 5)).to_uop())) diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 32, 32, 1, 5))) - second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (2,))) + second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (2,))) third_x = UOp(Ops.LOAD, dtypes.float, (g3, x2.lazydata.st.reshape((27, 32, 1, 1, 5)).to_uop())) mul = (third_x*second_reduce) - third_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (mul,), (BinaryOps.ADD, (1,))) + third_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (mul,), (Ops.ADD, (1,))) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 1, 5)).to_uop(), third_reduce)) sink = UOp(Ops.SINK, src=(store,)) wanna_output = (x2.numpy()*(x1.numpy()-x0.numpy().sum(axis=1, keepdims=True)).sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,1,5) @@ -234,11 +234,11 @@ class TestLinearizer(unittest.TestCase): st = x.lazydata.st g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] first_x = UOp(Ops.LOAD, dtypes.float, (g1, st.reshape((8, 1, 32, 8, 1, 16)).expand((8, 32, 32, 8, 16, 16)).to_uop())) - first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2, 5))) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2, 5))) second_x = UOp(Ops.LOAD, dtypes.float, (g1, st.reshape((8, 32, 1, 8, 16, 1)).to_uop())) neg_first_reduce = first_reduce * ast_const(dtypes.float, -1, (8, 32, 1, 8, 16, 1)) squares = (second_x+neg_first_reduce) - squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (1, 4))) + squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (1, 4))) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((8, 1, 1, 8, 1, 1)).to_uop(), squares_sum,)) sink = UOp(Ops.SINK, src=(store,)) wanna_output = (x.numpy()-x.numpy().sum(axis=(1,3), keepdims=True)).sum(axis=(1,3)).reshape((8,1,1,8,1,1)) @@ -285,10 +285,10 @@ class TestLinearizer(unittest.TestCase): x = Tensor.randn(27, 15, 5, dtype=dtypes.float).softmax(1).realize() g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop())) - first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,))) second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop())) diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 15, 1, 5))) - second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,))) + second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,))) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce)) sink = UOp(Ops.SINK, src=(store,)) opts = [ @@ -317,11 +317,11 @@ class TestLinearizer(unittest.TestCase): g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32)).to_uop())) first_x_p = UOp(Ops.LOAD, dtypes.float, (g2, x_p.lazydata.st.reshape((4, 1, 32)).expand((4, 32, 32)).to_uop())) - first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) - first_reduce_p = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x_p.alu(UnaryOps.EXP2),), (BinaryOps.ADD, (2,))) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,))) + first_reduce_p = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x_p.alu(Ops.EXP2),), (Ops.ADD, (2,))) second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 32, 1)).to_uop())) diff = (second_x+(first_reduce + first_reduce_p)*ast_const(dtypes.float, -1, (4, 32, 1))) - second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,))) + second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,))) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((4, 1, 1)).to_uop(), second_reduce)) sink = UOp(Ops.SINK, src=(store,)) opts = [ @@ -352,10 +352,10 @@ class TestLinearizer(unittest.TestCase): x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize() g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] first_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop())) - first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,))) second_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop())) diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 15, 1, 5))) - second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,))) + second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,))) store0 = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce)) second_out = second_reduce * ast_const(dtypes.float, 1/15, (27, 1, 1, 5)) store1 = UOp(Ops.STORE, src=(g1, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_out)) @@ -375,10 +375,10 @@ class TestLinearizer(unittest.TestCase): x = Tensor.randn(27, 15, 5, dtype=dtypes.float).realize() g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] first_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 1, 15, 5)).expand((27, 15, 15, 5)).to_uop())) - first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,))) second_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((27, 15, 1, 5)).to_uop())) diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 15, 1, 5))) - second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,))) + second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,))) store0 = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce)) store1 = UOp(Ops.STORE, src=(g1, ShapeTracker(views=(View(shape=(27,15,1,5), strides=(5,0,1,1), offset=0, mask=None, contiguous=False),)).to_uop(), first_reduce)) # noqa: E501 wanna_output0 = (x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(27,1,1,5) @@ -399,10 +399,10 @@ class TestLinearizer(unittest.TestCase): x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize() g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5)).to_uop())) - first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,))) second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 3, 1, 5)).to_uop())) diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 3, 1, 5))) - second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,))) + second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,))) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce)) sink = UOp(Ops.SINK, src=(store,)) opts = [[Opt(OptOps.UNROLL, 0, 3), Opt(OptOps.UNROLL, 0, 3)]] @@ -415,10 +415,10 @@ class TestLinearizer(unittest.TestCase): x = Tensor.randn(27, 3, 5, dtype=dtypes.float).realize() g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 3, 5)).expand((27, 3, 3, 5)).to_uop())) - first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,))) second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 3, 1, 5)).to_uop())) diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 3, 1, 5))) - second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,))) + second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,))) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce)) sink = UOp(Ops.SINK, src=(store,)) opts = [[Opt(OptOps.UPCAST, 0, 3)]] @@ -434,10 +434,10 @@ class TestLinearizer(unittest.TestCase): x = Tensor.randn(27, 12, 5, dtype=dtypes.float).realize() g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 1, 12, 5)).expand((27, 12, 12, 5)).to_uop())) - first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,))) second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((27, 12, 1, 5)).to_uop())) diff = (second_x+first_reduce*ast_const(dtypes.float, -1, (27, 12, 1, 5))) - second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (BinaryOps.ADD, (1,))) + second_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (diff,), (Ops.ADD, (1,))) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((27, 1, 1, 5)).to_uop(), second_reduce)) sink = UOp(Ops.SINK, src=(store,)) opts = [[Opt(OptOps.GROUPTOP, 0, 3), Opt(OptOps.GROUPTOP, 1, 3)]] @@ -450,13 +450,13 @@ class TestLinearizer(unittest.TestCase): x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize() g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)).to_uop())) - first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,))) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,))) neg_mean = first_reduce * ast_const(dtypes.float, -1/35, (15, 25, 35, 1)) second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 25, 35, 1)).to_uop())) squares = (second_x+neg_mean)*(second_x+neg_mean) - squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (2,))) + squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (2,))) variance = squares_sum * ast_const(dtypes.float, 1/35, (15, 25, 1, 1)) - std = variance.alu(UnaryOps.SQRT) + std = variance.alu(Ops.SQRT) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((15, 25, 1, 1)).to_uop(), std)) sink = UOp(Ops.SINK, src=(store,)) wanna_output = x.numpy().std(axis=2, ddof=0).reshape((15,25,1,1)) @@ -468,13 +468,13 @@ class TestLinearizer(unittest.TestCase): x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize() g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 1, 25, 35)).expand((15, 25, 25, 35)).to_uop())) - first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (2,))) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (2,))) neg_mean = first_reduce * ast_const(dtypes.float, -0.04, (15, 25, 1, 35)) second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((15, 25, 1, 35)).to_uop())) squares = (second_x+neg_mean)*(second_x+neg_mean) - squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (1,))) + squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (1,))) variance = squares_sum * ast_const(dtypes.float, 0.04, (15, 1, 1, 35)) - std = variance.alu(UnaryOps.SQRT) + std = variance.alu(Ops.SQRT) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((15, 1, 1, 35)).to_uop(), std)) sink = UOp(Ops.SINK, src=(store,)) wanna_output = x.numpy().std(axis=1, ddof=0).reshape((15,1,1,35)) @@ -488,13 +488,13 @@ class TestLinearizer(unittest.TestCase): x = Tensor.randn(15, 25, 35, dtype=dtypes.float).realize() g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] first_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((15, 25, 1, 35)).expand((15, 25, 35, 35)).to_uop())) - first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,))) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,))) neg_mean = first_reduce * ast_const(dtypes.float, -1/35, (15, 25, 35, 1)) second_x = UOp(Ops.LOAD, dtypes.float, (g2, x.lazydata.st.reshape((15, 25, 35, 1)).to_uop())) squares = (second_x+neg_mean)*(second_x+neg_mean) - squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (2,))) + squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (2,))) variance = squares_sum * ast_const(dtypes.float, 1/35, (15, 25, 1, 1)) - std = variance.alu(UnaryOps.SQRT) + std = variance.alu(Ops.SQRT) store_mean = UOp(Ops.STORE, src=(g1, ShapeTracker.from_shape((15, 25, 1, 1)).to_uop(), neg_mean)) store_std = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((15, 25, 1, 1)).to_uop(), std)) sink = UOp(Ops.SINK, src=(store_std, store_mean)) @@ -511,13 +511,13 @@ class TestLinearizer(unittest.TestCase): g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] # push reduce (3, 27, 32) -> (3, 27, 1) -> (3, 27, 32) expand to LOAD first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((3, 27, 1, 32)).expand((3, 27, 32, 32)).to_uop())) - first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.ADD, (3,))) + first_reduce = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.ADD, (3,))) neg_mean = first_reduce * ast_const(dtypes.float, -0.03125, (3, 27, 32, 1)) # store = UOp(UOps.STORE, src=(g0, ShapeTracker.from_shape((3, 27, 32, 1)).to_uop(), mean)) # verify_lazyop(store) second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((3, 27, 32, 1)).to_uop())) squares = (second_x+neg_mean)*(second_x+neg_mean) - squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (BinaryOps.ADD, (2,))) + squares_sum = UOp(Ops.REDUCE_AXIS, dtypes.float, (squares,), (Ops.ADD, (2,))) variance = squares_sum * ast_const(dtypes.float, 0.03125, (3, 27, 1, 1)) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((3, 27, 1, 1)).to_uop(), variance)) sink = UOp(Ops.SINK, src=(store,)) @@ -532,13 +532,13 @@ class TestLinearizer(unittest.TestCase): x = Tensor.rand(4, 32).realize() g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] first_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 1, 32,)).expand((4, 32, 32)).to_uop())) - max_x = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (BinaryOps.MAX, (2,))) + max_x = UOp(Ops.REDUCE_AXIS, dtypes.float, (first_x,), (Ops.MAX, (2,))) second_x = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((4, 32, 1,)).to_uop())) centered_x = second_x+max_x*ast_const(dtypes.float, -1, (4, 32, 1)) - exp_x = centered_x.alu(UnaryOps.EXP2) - sum_exp_x = UOp(Ops.REDUCE_AXIS, dtypes.float, (exp_x,), (BinaryOps.ADD, (1,))) - # y = exp_x * sum_exp_x.alu(UnaryOps.RECIP) # kernels cannot do a return to full shape - recip_sum_exp_x = sum_exp_x.alu(UnaryOps.RECIP) + exp_x = centered_x.alu(Ops.EXP2) + sum_exp_x = UOp(Ops.REDUCE_AXIS, dtypes.float, (exp_x,), (Ops.ADD, (1,))) + # y = exp_x * sum_exp_x.alu(Ops.RECIP) # kernels cannot do a return to full shape + recip_sum_exp_x = sum_exp_x.alu(Ops.RECIP) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((4,1,1)).to_uop(), recip_sum_exp_x)) sink = UOp(Ops.SINK, src=(store,)) expected = 1/np.exp2(x.numpy() - x.numpy().max(axis=-1, keepdims=True)).sum(axis=-1, keepdims=True).reshape(4,1,1) @@ -556,7 +556,7 @@ class TestLinearizer(unittest.TestCase): View(shape=(16384, 16384), strides=(1, 32768), offset=0, mask=None, contiguous=False))) arange_input_st = arange_input_st.reshape((1, 16384, 1, 16384)).expand((4, 16384, 256, 16384)) arange_axis = (3,) - arange = UOp(Ops.REDUCE_AXIS, dtypes.int, (ast_const(dtypes.int, 1, st=arange_input_st),), (BinaryOps.ADD, arange_axis)) + arange = UOp(Ops.REDUCE_AXIS, dtypes.int, (ast_const(dtypes.int, 1, st=arange_input_st),), (Ops.ADD, arange_axis)) output_shape = tuple(1 if i in arange_axis else s for i,s in enumerate(arange_input_st.shape)) out = arange+ast_const(dtypes.int, -1, output_shape) store = UOp(Ops.STORE, src=(UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0), ShapeTracker.from_shape(output_shape).to_uop(), out)) @@ -573,7 +573,7 @@ class TestLinearizer(unittest.TestCase): # TODO: do this arange broadcast in the scheduler arange_input_st = arange_input_st.reshape((1, 16384, 1, 16384)).expand((4, 16384, 256, 16384)) arange_axis = (3,) - arange = UOp(Ops.REDUCE_AXIS, dtypes.int, (ast_const(dtypes.int, 1, st=arange_input_st),), (BinaryOps.ADD, arange_axis)) + arange = UOp(Ops.REDUCE_AXIS, dtypes.int, (ast_const(dtypes.int, 1, st=arange_input_st),), (Ops.ADD, arange_axis)) arange_out_shape = tuple(1 if i in arange_axis else s for i,s in enumerate(arange_input_st.shape)) arange = arange+ast_const(dtypes.int, -1, arange_out_shape) # p2: the indexing @@ -581,10 +581,10 @@ class TestLinearizer(unittest.TestCase): data1 = (g1, ShapeTracker.from_shape(dataset.shape).reshape((1, 16384, 256, 1)).expand(arange_out_shape).to_uop()) idxs = Tensor([0,3,5,6]).realize() data2 = (g2, ShapeTracker.from_shape((4,)+(1,)*(len(arange_out_shape)-1)).expand(arange_out_shape).to_uop()) - arange_eq = arange.alu(BinaryOps.CMPNE, UOp(Ops.LOAD, dtypes.int, data2)).alu(BinaryOps.CMPNE, ast_const(dtypes.bool, True, arange_out_shape)) + arange_eq = arange.alu(Ops.CMPNE, UOp(Ops.LOAD, dtypes.int, data2)).alu(Ops.CMPNE, ast_const(dtypes.bool, True, arange_out_shape)) reduce_input = UOp(Ops.LOAD, dataset.dtype, data1)*UOp(Ops.CAST, dataset.dtype.scalar(), src=(arange_eq,)) out_axis = (1,) - out = UOp(Ops.REDUCE_AXIS, reduce_input.dtype, (reduce_input,), (BinaryOps.ADD, out_axis)) + out = UOp(Ops.REDUCE_AXIS, reduce_input.dtype, (reduce_input,), (Ops.ADD, out_axis)) output_shape = tuple(1 if i in out_axis else s for i,s in enumerate(arange_out_shape)) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape(output_shape).to_uop(), out)) sink = UOp(Ops.SINK, src=(store,)) @@ -605,7 +605,7 @@ class TestLinearizer(unittest.TestCase): ast_const(dtypes.int, st=ShapeTracker(views=(View(shape=(1, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),)), val=10), UOp(Ops.MUL, dtypes.int, arg=None, src=( ast_const(dtypes.int, -1, (1, 20, 1)), - UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.MAX, (0,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.MAX, (0,)), src=( UOp(Ops.MUL, dtypes.int, arg=None, src=( UOp(Ops.CAST, dtypes.int, arg=None, src=( UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( @@ -618,7 +618,7 @@ class TestLinearizer(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa E501 ast_const(dtypes.bool, True, st=ShapeTracker(views=(View(shape=(10, 20, 1), strides=(0, 0, 0), offset=0, mask=None, contiguous=False),))),)),)), # noqa E501 UOp(Ops.ADD, dtypes.int, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (2,)), src=( ast_const(dtypes.int, -1, st=ShapeTracker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)), contiguous=False), View(shape=(10, 20, 10), strides=(1, 0, 20), offset=0, mask=None, contiguous=False)))),)), # noqa E501 ast_const(dtypes.int, 10, (10, 20, 1)))),)),)),)),)), ast_const(dtypes.int, -1, (1, 20, 1)),)),)),)) @@ -637,7 +637,7 @@ class TestLinearizer(unittest.TestCase): ast_const(dtypes.int, 200, (1, 1)), UOp(Ops.MUL, dtypes.int, arg=None, src=( ast_const(dtypes.int, -1, (1, 1)), - UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.MAX, (0,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.MAX, (0,)), src=( UOp(Ops.MUL, dtypes.int, arg=None, src=( UOp(Ops.CAST, dtypes.int, arg=None, src=( UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( @@ -650,7 +650,7 @@ class TestLinearizer(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(200, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), # noqa: E501 ast_const(dtypes.bool, True, (200, 1)),)),)), UOp(Ops.ADD, dtypes.int, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=( ast_const(dtypes.int, -1, st=ShapeTracker(views=(View(shape=(201, 399), strides=(0, 0), offset=0, mask=((0, 201), (199, 399)), contiguous=False), View(shape=(200, 200), strides=(1, 400), offset=0, mask=None, contiguous=False)))),)), # noqa: E501 ast_const(dtypes.int, 200, (200, 1)),)),)),)),)),)), ast_const(dtypes.int, -1, (1, 1)),)),)),)) @@ -672,16 +672,16 @@ class TestLinearizer(unittest.TestCase): g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] x_ld0 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((1, N, N)).expand((N,N,N)).to_uop())) x_ld1 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).to_uop())) - r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.ADD, (1,))) - r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),),(BinaryOps.ADD, (0,))) + r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.ADD, (1,))) + r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),),(Ops.ADD, (0,))) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((1,1,N)).to_uop(), r1)) sink = UOp(Ops.SINK, src=(store,)) helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().sum(axis=0, keepdims=True)).sum(axis=0).reshape(1,1,N)], opts=opts) x_ld0 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).expand((N,N,N)).to_uop())) x_ld1 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, N, 1)).to_uop())) - r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.ADD, (2,))) - r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (BinaryOps.ADD, (1,))) + r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.ADD, (2,))) + r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (Ops.ADD, (1,))) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((N,1,1)).to_uop(), r1)) sink = UOp(Ops.SINK, src=(store,)) helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().sum(axis=1, keepdims=True)).sum(axis=1).reshape(N,1,1)], opts=opts) @@ -699,16 +699,16 @@ class TestLinearizer(unittest.TestCase): g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(2)] x_ld0 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((1, N, N)).expand((N,N,N)).to_uop())) x_ld1 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).to_uop())) - r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.MAX, (1,))) - r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),), (BinaryOps.MAX, (0,))) + r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.MAX, (1,))) + r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, 1, N)),), (Ops.MAX, (0,))) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((1,1,N)).to_uop(), r1)) sink = UOp(Ops.SINK, src=(store,)) helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().max(axis=0, keepdims=True)).max(axis=0).reshape(1,1,N)], opts=opts) x_ld0 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, 1, N)).expand((N,N,N)).to_uop())) x_ld1 = UOp(Ops.LOAD, dtypes.float, (g1, x.lazydata.st.reshape((N, N, 1)).to_uop())) - r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (BinaryOps.MAX, (2,))) - r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (BinaryOps.MAX, (1,))) + r0 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld0,), (Ops.MAX, (2,))) + r1 = UOp(Ops.REDUCE_AXIS, dtypes.float, (x_ld1+r0*ast_const(dtypes.float, -1, (N, N, 1)),), (Ops.MAX, (1,))) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((N,1,1)).to_uop(), r1)) sink = UOp(Ops.SINK, src=(store,)) helper_linearizer_ast(sink, [x], wanna_output=[(x.numpy()-x.numpy().max(axis=1, keepdims=True)).max(axis=1).reshape(N,1,1)], opts=opts) @@ -735,7 +735,7 @@ class TestLinearizer(unittest.TestCase): UOp(Ops.WHERE, dtypes.float, arg=None, src=( UOp(Ops.CMPLT, dtypes.bool, arg=None, src=( ast_const(dtypes.float, 0.5*N, (N, 1, 1)), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), @@ -743,7 +743,7 @@ class TestLinearizer(unittest.TestCase): UOp(Ops.WHERE, dtypes.float, arg=None, src=( UOp(Ops.CMPLT, dtypes.bool, arg=None, src=( ast_const(dtypes.float, 0.75*N, (N, N, 1)), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=( UOp(Ops.LOAD, dtypes.float, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), ld0.to_uop(),)),)),)), @@ -768,7 +768,7 @@ class TestLinearizer(unittest.TestCase): UOp(Ops.WHERE, dtypes.float, arg=None, src=( UOp(Ops.CMPLT, dtypes.bool, arg=None, src=( ast_const(dtypes.float, 0.5*N, (1, 1, N)), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -776,7 +776,7 @@ class TestLinearizer(unittest.TestCase): UOp(Ops.WHERE, dtypes.float, arg=None, src=( UOp(Ops.CMPLT, dtypes.bool, arg=None, src=( ast_const(dtypes.float, 0.75*N, (N, 1, N)), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=( UOp(Ops.LOAD, dtypes.float, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), ld0.to_uop(),)),)),)), @@ -804,7 +804,7 @@ class TestLinearizer(unittest.TestCase): UOp(Ops.WHERE, dtypes.float, arg=None, src=( UOp(Ops.CMPLT, dtypes.bool, arg=None, src=( ast_const(dtypes.float, 0.5*N, (1, 1, 1, 1)), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 1)), src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), @@ -812,7 +812,7 @@ class TestLinearizer(unittest.TestCase): UOp(Ops.WHERE, dtypes.float, arg=None, src=( UOp(Ops.CMPLT, dtypes.bool, arg=None, src=( ast_const(dtypes.float, 0.75*N, (N, N, 1, 1)), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2, 3)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2, 3)), src=( UOp(Ops.LOAD, dtypes.float, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(N, N, N, N), strides=(0, 0, N, 1), offset=0, mask=None, contiguous=False),))),)),)),)), # noqa: E501 @@ -831,7 +831,7 @@ class TestLinearizer(unittest.TestCase): def test_end_local(self): g0, g1 = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=i) for i in range(2)] load = UOp(Ops.LOAD, dtypes.int, (g1, ShapeTracker.from_shape((32,)).to_uop())) - reduce = UOp(Ops.REDUCE_AXIS, dtypes.int, (load,), (BinaryOps.ADD, (0,))) + reduce = UOp(Ops.REDUCE_AXIS, dtypes.int, (load,), (Ops.ADD, (0,))) store = UOp(Ops.STORE, src=(g0, ShapeTracker.from_shape((1,)).to_uop(), reduce)) sink = UOp(Ops.SINK, src=(store,)) load_t = Tensor.full(load.st_arg.shape, 1).contiguous().realize() @@ -1219,20 +1219,20 @@ class TestLinearizer(unittest.TestCase): assert len(sched) == 1 lin = Kernel(sched[0].ast) - assert sum(u.op in {UnaryOps.RECIP, BinaryOps.FDIV} for u in lin.linearize().uops) == max_ops, msg + assert sum(u.op in {Ops.RECIP, Ops.FDIV} for u in lin.linearize().uops) == max_ops, msg a = Tensor.empty((4,4)) b = Tensor.empty((4,4)) d = Tensor.empty((4,4)) c = (a*b)/b - helper(c, "found UnaryOps.RECIP in (a*b)/b operation") + helper(c, "found Ops.RECIP in (a*b)/b operation") c = a/a - helper(c, "found UnaryOps.RECIP in (a/a) operation") + helper(c, "found Ops.RECIP in (a/a) operation") c = (a/b)/d - helper(c, "found multiple UnaryOps.RECIP in (a/b)/d operation", 1) + helper(c, "found multiple Ops.RECIP in (a/b)/d operation", 1) def test_sum_collapse(self): t = Tensor([2]).reshape(1, 1).expand(256, 256).sum() @@ -1260,7 +1260,7 @@ class TestLinearizer(unittest.TestCase): lin = Kernel(sched_copy[-1].ast) lin.hand_coded_optimizations() lin.linearize() - assert not any(u.op == TernaryOps.WHERE for u in lin.uops), "found where where where should be folded" + assert not any(u.op == Ops.WHERE for u in lin.uops), "found where where where should be folded" def test_phi_simplification(self): def helper(t, max_ops=0): @@ -1272,7 +1272,7 @@ class TestLinearizer(unittest.TestCase): assert len(set([u.op for u in uops if u.op in {Ops.RANGE, Ops.SPECIAL}])) == 1, "has either specials or ranges, not both" assert len([u for u in uops if u.op is Ops.ASSIGN]) == 0, "ASSIGN should have been simplified" # TODO: once uops track min/max this will be fixed - #assert len([u for u in uops if u.op is BinaryOps.MAX]) <= max_ops, "no unnecessary MAX ops" + #assert len([u for u in uops if u.op is Ops.MAX]) <= max_ops, "no unnecessary MAX ops" helper(Tensor.arange(5.5, (3.5*300), 3.5), max_ops=2) helper(Tensor.arange(-1, -100, -5), max_ops=2) @@ -1602,7 +1602,7 @@ class TestFloat4(unittest.TestCase): UOp(Ops.STORE, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0), UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 3, 32000, 1), strides=(0, 32000, 1, 0), offset=0, mask=None, contiguous=True),))), # noqa: E501 - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=( UOp(Ops.CAST, dtypes.float, src=( UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, src=( @@ -1632,7 +1632,7 @@ class TestFloat4(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0), UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 1, 128, 512, 512, 1, 1, 1), strides=(0, 0, 262144, 512, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),))), # noqa: E501 UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1), @@ -1662,7 +1662,7 @@ class TestFloat4(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0), UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 256, 1, 64, 1, 114, 1, 114), strides=(0, 831744, 0, 12996, 0, 114, 0, 1), offset=0, mask=None, contiguous=True),))), # noqa: E501 UOp(Ops.CAST, dtypes.half, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (4, 6)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (4, 6)), src=( UOp(Ops.CAST, dtypes.float, src=( UOp(Ops.LOAD, dtypes.half, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1), @@ -1949,7 +1949,7 @@ class TestKernelOpts(unittest.TestCase): UOp(Ops.STORE, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, arg=ShapeTracker(views=(View(shape=(1, 256), strides=(0, 1), offset=0, mask=None, contiguous=True),))), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.CAST, dtypes.float, src=( UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( @@ -2138,7 +2138,7 @@ class TestKernelOpts(unittest.TestCase): g0, g1, g2 = [UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=i) for i in range(3)] ld0 = UOp(Ops.LOAD, dtypes.float, (g1, ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=False),)).to_uop())) # noqa: E501 ld1 = UOp(Ops.LOAD, dtypes.float, (g2, ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)).to_uop())) # noqa: E501 - store = UOp(Ops.STORE, src=(g0, ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)).to_uop(), UOp(Ops.REDUCE_AXIS, dtypes.float, (ld0*ld1,), (BinaryOps.ADD, (0, 2, 4, 6)),))) # noqa: E501 + store = UOp(Ops.STORE, src=(g0, ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)).to_uop(), UOp(Ops.REDUCE_AXIS, dtypes.float, (ld0*ld1,), (Ops.ADD, (0, 2, 4, 6)),))) # noqa: E501 sink = UOp(Ops.SINK, src=(store,)) data1 = Tensor.randn(2, 1, 4, 1, 3, 4, 2, 6, 1, 3).realize() data2 = Tensor.randn(2, 1, 4, 1, 3, 4, 2, 6, 1, 3).realize() diff --git a/test/test_linearizer_dumb.py b/test/test_linearizer_dumb.py index b0830f7004..d9e7a0e058 100644 --- a/test/test_linearizer_dumb.py +++ b/test/test_linearizer_dumb.py @@ -5,7 +5,7 @@ import unittest from test.helpers import ast_const from tinygrad import Device, dtypes -from tinygrad.ops import UOp, Ops, BinaryOps +from tinygrad.ops import UOp, Ops from tinygrad.helpers import getenv from tinygrad.shape.shapetracker import ShapeTracker, View from tinygrad.engine.search import Opt, OptOps @@ -21,7 +21,7 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.MAX, dtypes.half, arg=None, src=( UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.CAST, dtypes.half, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( @@ -64,7 +64,7 @@ class TestLinearizerDumb(unittest.TestCase): ast_const(dtypes.bool, True, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1000, 1), strides=(0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)),)), UOp(Ops.ADD, dtypes.int, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=( ast_const(dtypes.int, -1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1001, 1999), strides=(0, 0), offset=0, mask=((0, 1001), (999, 1999)), contiguous=False), View(shape=(1000, 1000), strides=(1, 2000), offset=0, mask=None, contiguous=False))), src=()),)),)), ast_const(dtypes.int, 1000, st_src=( @@ -75,7 +75,7 @@ class TestLinearizerDumb(unittest.TestCase): for opt in opts: k.apply_opt(opt) prg = k.to_program() print(prg.src) - assert prg.uops is not None and not any(uop.op is BinaryOps.MAX for uop in prg.uops), "leftover MAX" + assert prg.uops is not None and not any(uop.op is Ops.MAX for uop in prg.uops), "leftover MAX" @unittest.skipUnless(Device[Device.DEFAULT].renderer.has_local, "need local") def test_expander_new_srcs(self): @@ -83,7 +83,7 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) @@ -105,14 +105,14 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4096, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.CAST, dtypes.half, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.CAST, dtypes.half, arg=None, src=( UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( UOp(Ops.ADD, dtypes.int, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (2,)), src=( ast_const(dtypes.int, 1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32001, 63999), strides=(0, 0), offset=0, mask=((0, 32001), (31999, 63999)), contiguous=False), View(shape=(4096, 32000, 32000), strides=(0, 1, 64000), offset=0, mask=None, contiguous=False))), src=()),)),)), ast_const(dtypes.int, -1, st_src=( @@ -136,7 +136,7 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( @@ -168,7 +168,7 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1), strides=(0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 1)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 1)), src=( UOp(Ops.WHERE, dtypes.float, arg=None, src=( UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( @@ -200,7 +200,7 @@ class TestLinearizerDumb(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5, 13, 1, 1, 1, 1, 1, 4, 3, 3), strides=(2340, 468, 36, 0, 0, 0, 0, 0, 9, 3, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (6,)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), diff --git a/test/test_linearizer_failures.py b/test/test_linearizer_failures.py index 26dd7b23af..6dc6e4a97b 100644 --- a/test/test_linearizer_failures.py +++ b/test/test_linearizer_failures.py @@ -3,7 +3,7 @@ import unittest, random import numpy as np from tinygrad.codegen.kernel import Kernel, KernelOptError from tinygrad.device import is_dtype_supported -from tinygrad.ops import UOp, Ops, BinaryOps +from tinygrad.ops import UOp, Ops from tinygrad.engine.search import Opt, OptOps from tinygrad import Device, dtypes, Tensor from tinygrad.helpers import CI @@ -47,7 +47,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 16, 1), strides=(16, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 16, 16), strides=(16, 1, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), @@ -64,7 +64,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 2, 37, 9, 1, 1), strides=(666, 333, 9, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.MAX, (4, 5)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.MAX, (4, 5)), src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 2, 111, 27), strides=(6160, 3080, 28, 1), offset=0, mask=((0, 32), (0, 2), (0, 110), (0, 27)), contiguous=False), View(shape=(32, 2, 37, 9, 2, 2), strides=(5994, 2997, 81, 3, 27, 1), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) @@ -76,7 +76,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 8, 16, 1), strides=(128, 16, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(32, 8, 16, 16), strides=(2048, 256, 16, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)) @@ -89,7 +89,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 1, 1, 1), strides=(0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 6)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 4, 6)), src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( x5:=UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( @@ -111,7 +111,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.ADD, dtypes.int, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=( ast_const(dtypes.int, -1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(11, 19), strides=(0, 0), offset=0, mask=((0, 11), (9, 19)), contiguous=False), View(shape=(10, 10), strides=(1, 20), offset=0, mask=None, contiguous=False))), src=()),)),)), ast_const(dtypes.int, 10, st_src=( @@ -125,7 +125,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 32, 1, 34, 1, 34), strides=(36992, 1156, 0, 34, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2, 4)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2, 4)), src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 32, 6, 8, 4, 6, 8, 4), strides=(2048, 64, 6291456, 8, 0, 1048576, 1, 0), offset=0, mask=((0, 512), (0, 32), (0, 6), (0, 8), (0, 1), (0, 6), (0, 8), (0, 1)), contiguous=False), View(shape=(512, 32, 6, 35, 6, 35), strides=(1179648, 36864, 6144, 192, 32, 1), offset=0, mask=((0, 512), (0, 32), (0, 6), (0, 32), (0, 6), (0, 32)), contiguous=False), View(shape=(512, 32, 238, 238), strides=(1411200, 44100, 210, 1), offset=0, mask=((0, 512), (0, 32), (0, 210), (0, 210)), contiguous=False), View(shape=(512, 32, 7, 34, 7, 34), strides=(1812608, 56644, 8092, 238, 34, 1), offset=0, mask=None, contiguous=True))), src=()),)),)),)),)) @@ -142,7 +142,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.RECIP, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( x9:=UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( @@ -166,7 +166,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 3, 1, 1, 1, 1, 5, 15, 5, 3, 4), strides=(0, 0, 0, 4500, 0, 0, 0, 0, 900, 60, 12, 4, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -183,7 +183,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1024, 1), strides=(0, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.ADD, dtypes.half, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(BinaryOps.ADD, (3,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(Ops.ADD, (3,)), src=( UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), @@ -202,7 +202,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 64, 1, 1), strides=(0, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.RECIP, dtypes.float, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 3)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 3)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( @@ -277,7 +277,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 6)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 4, 6)), src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( x5:=UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( @@ -299,7 +299,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 8)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 4, 8)), src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( x5:=UOp(Ops.ADD, dtypes.float, arg=None, src=( x6:=UOp(Ops.MUL, dtypes.float, arg=None, src=( @@ -313,7 +313,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 3, 4, 2, 6, 1, 3), strides=(0, 0, 0, 0, 0, 0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)),)), x6,)), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 8)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 4, 8)), src=( x5,)),)),)),)),)) opts = [Opt(op=OptOps.PADTO, axis=0, amt=32), Opt(op=OptOps.GROUP, axis=0, amt=4)] helper_test_lin(Kernel(ast), opts, failed_platforms=[]) @@ -325,7 +325,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(384, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.ADD, dtypes.half, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(BinaryOps.ADD, (3,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(Ops.ADD, (3,)), src=( UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), @@ -344,7 +344,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 1, 1, 1, 4, 1, 6, 1, 3), strides=(0, 0, 0, 0, 0, 18, 0, 3, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 4, 6)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 4, 6)), src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( x5:=UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( @@ -370,7 +370,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5,)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -405,7 +405,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 13, 1), strides=(0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 13, 1024), strides=(0, 1024, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)), @@ -420,7 +420,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 40, 1, 28, 28, 1, 1), strides=(31360, 0, 784, 0, 28, 1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -442,7 +442,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 384, 1), strides=(384, 0, 1, 0), offset=0, mask=None, contiguous=True),)), src=()),)), UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=2, src=()), @@ -462,7 +462,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 1, 4, 1, 9, 7, 3, 3), strides=(2268, 0, 567, 0, 63, 9, 3, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -511,7 +511,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 2, 3)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 2, 3)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( @@ -624,7 +624,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.ADD, dtypes.int, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=( ast_const(dtypes.int, 1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1025, 2047), strides=(0, 0), offset=0, mask=((0, 1025), (1023, 2047)), contiguous=False), View(shape=(1024, 1024), strides=(1, 2048), offset=0, mask=None, contiguous=False))), src=()),)),)), ast_const(dtypes.int, -1, st_src=( @@ -639,7 +639,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.ADD, dtypes.int, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=( ast_const(dtypes.int, 1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(129, 255), strides=(0, 0), offset=0, mask=((0, 129), (127, 255)), contiguous=False), View(shape=(128, 128), strides=(1, 256), offset=0, mask=None, contiguous=False))), src=()),)),)), ast_const(dtypes.int, -1, st_src=( @@ -678,7 +678,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 1), strides=(0, 13, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(BinaryOps.MAX, (3,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.half, arg=(Ops.MAX, (3,)), src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 13), strides=(0, 169, 13, 1), offset=0, mask=None, contiguous=True),)), src=()),)),)),)),)) @@ -731,7 +731,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(128, 1, 64, 56, 56, 1, 1, 1), strides=(200704, 0, 3136, 56, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.CAST, dtypes.half, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( @@ -749,7 +749,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 12, 31, 31, 1, 1, 1), strides=(11532, 0, 961, 31, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.CAST, dtypes.half, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( @@ -767,7 +767,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 16, 13, 1), strides=(0, 13, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=( UOp(Ops.EXP2, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( @@ -791,7 +791,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 256, 14, 14, 1, 1, 1), strides=(50176, 0, 196, 14, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.CAST, dtypes.half, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( @@ -809,7 +809,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1,), strides=(0,), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0,)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( x5:=UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -851,7 +851,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 1, 6, 10, 3, 1, 1, 1), strides=(180, 0, 30, 3, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.MAX, dtypes.float, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6, 7)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (6, 7)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -875,7 +875,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(5, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.CAST, dtypes.uchar, arg=None, src=( UOp(Ops.ADD, dtypes.uint, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.uint, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.uint, arg=(Ops.ADD, (1,)), src=( UOp(Ops.CAST, dtypes.uint, arg=None, src=( ast_const(dtypes.uchar, 1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(6, 9), strides=(0, 0), offset=0, mask=((0, 6), (4, 9)), contiguous=False), View(shape=(5, 5), strides=(1, 10), offset=0, mask=None, contiguous=False))), src=()),)),)),)), @@ -895,7 +895,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1, 32, 24, 24, 1, 1, 1), strides=(18432, 0, 576, 24, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.MAX, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6, 7)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (6, 7)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.uchar, arg=None, src=( @@ -920,7 +920,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 32, 1, 1, 1, 5, 5, 256), strides=(0, 0, 6400, 0, 0, 0, 1280, 256, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (0, 3, 4)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (0, 3, 4)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.uchar, arg=None, src=( @@ -943,7 +943,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10000, 1, 32, 24, 24, 1, 1, 1), strides=(18432, 0, 576, 24, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.MAX, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (6, 7)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (6, 7)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.uchar, arg=None, src=( @@ -969,7 +969,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.ADD, dtypes.int, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=( ast_const(dtypes.int, 1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60001, 119999), strides=(0, 0), offset=0, mask=((0, 60001), (59999, 119999)), contiguous=False), View(shape=(60000, 60000), strides=(1, 120000), offset=0, mask=None, contiguous=False))), src=()),)),)), ast_const(dtypes.int, -1, st_src=( @@ -987,7 +987,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 128, 28, 28, 1, 1, 1), strides=(100352, 0, 784, 28, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.CAST, dtypes.half, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( @@ -1007,7 +1007,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) @@ -1021,7 +1021,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) @@ -1035,7 +1035,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(25, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(26, 49), strides=(0, -1), offset=48, mask=((0, 26), (24, 49)), contiguous=False), View(shape=(25, 25), strides=(1, 50), offset=0, mask=None, contiguous=False))), src=()),)),)),)),)) @@ -1052,7 +1052,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 1, 1, 1), strides=(3, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2, 3)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2, 3)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -1065,7 +1065,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 2, 3, 1), strides=(0, 0, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(Ops.ADD, dtypes.int, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (4,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (4,)), src=( ast_const(dtypes.int, 1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(3, 3), strides=(0, 0), offset=0, mask=((0, 3), (1, 3)), contiguous=False), View(shape=(2, 3, 2, 3, 3), strides=(0, 0, 1, 0, 4), offset=0, mask=((0, 2), (0, 3), (0, 2), (0, 3), (0, 2)), contiguous=False))), src=()),)),)), x19:=ast_const(dtypes.int, -1, st_src=( @@ -1078,7 +1078,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=3, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(2, 3, 2, 3, 1), strides=(3, 1, 0, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(Ops.ADD, dtypes.int, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (4,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (4,)), src=( ast_const(dtypes.int, 1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(4, 5), strides=(0, 0), offset=0, mask=((0, 4), (2, 5)), contiguous=False), View(shape=(2, 3, 2, 3, 3), strides=(0, 0, 0, 1, 6), offset=0, mask=None, contiguous=False))), src=()),)),)), x19,)),)), @@ -1093,7 +1093,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(512, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.MUL, dtypes.float, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (1,)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.bool, arg=None, src=( @@ -1127,7 +1127,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60000, 1), strides=(1, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.ADD, dtypes.int, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (1,)), src=( ast_const(dtypes.int, 1, st_src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(60001, 119999), strides=(0, 0), offset=0, mask=((0, 60001), (59999, 119999)), contiguous=False), View(shape=(60000, 60000), strides=(1, 120000), offset=0, mask=None, contiguous=False))), src=()),)),)), ast_const(dtypes.int, -1, st_src=( @@ -1142,7 +1142,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 64, 1, 1, 256, 1, 1, 256), strides=(0, 0, 65536, 0, 0, 256, 0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3, 4)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3, 4)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( @@ -1160,7 +1160,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(10, 6, 1), strides=(6, 1, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -1178,7 +1178,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 20, 1, 20), strides=(0, 0, 20, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.CMPNE, dtypes.bool, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.bool, arg=(BinaryOps.ADD, (3,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.bool, arg=(Ops.ADD, (3,)), src=( UOp(Ops.MUL, dtypes.bool, arg=None, src=( UOp(Ops.LOAD, dtypes.bool, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.bool.ptr(), arg=1, src=()), @@ -1220,7 +1220,7 @@ class TestLinearizerFailures(unittest.TestCase): x9,)), UOp(Ops.ADD, dtypes.half, arg=None, src=( UOp(Ops.CAST, dtypes.half, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (2,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (2,)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( @@ -1249,7 +1249,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.CAST, dtypes.half, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( @@ -1267,7 +1267,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 1, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.uchar, arg=(BinaryOps.ADD, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.uchar, arg=(Ops.ADD, (1,)), src=( UOp(Ops.MUL, dtypes.uchar, arg=None, src=( UOp(Ops.LOAD, dtypes.uchar, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.uchar.ptr(), arg=1, src=()), @@ -1279,7 +1279,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), arg=2, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1024, 50000, 1), strides=(1, 0, 0), offset=0, mask=None, contiguous=False),)), src=()),)), UOp(Ops.ADD, dtypes.int, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(BinaryOps.ADD, (2,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.int, arg=(Ops.ADD, (2,)), src=( UOp(Ops.WHERE, dtypes.int, arg=None, src=( UOp(Ops.VALID, dtypes.bool, arg=None, src=( UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(50001, 99999), strides=(0, 0), offset=0, mask=((0, 50001), (49999, 99999)), contiguous=False), View(shape=(1024, 50000, 50000), strides=(0, 1, 100000), offset=0, mask=None, contiguous=False))), src=()),)), @@ -1306,7 +1306,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(256, 1, 64, 56, 56, 1, 1, 1), strides=(200704, 0, 3136, 56, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.CAST, dtypes.half, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( @@ -1326,7 +1326,7 @@ class TestLinearizerFailures(unittest.TestCase): UOp(Ops.DEFINE_GLOBAL, dtypes.half.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(W, 1, 64, 56, 56, 1, 1, 1), strides=(200704, 0, 3136, 56, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), UOp(Ops.CAST, dtypes.half, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (5, 6, 7)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (5, 6, 7)), src=( UOp(Ops.CAST, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.half, arg=None, src=( UOp(Ops.LOAD, dtypes.half, arg=None, src=( diff --git a/test/test_linearizer_overflows.py b/test/test_linearizer_overflows.py index 5ff6e0ca74..2e7265f652 100644 --- a/test/test_linearizer_overflows.py +++ b/test/test_linearizer_overflows.py @@ -8,7 +8,7 @@ from tinygrad.engine.search import Opt, OptOps from tinygrad.engine.search import time_linearizer, bufs_from_lin # stuff needed to unpack a kernel -from tinygrad.ops import UOp, Ops, BinaryOps +from tinygrad.ops import UOp, Ops from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View @@ -33,7 +33,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -68,7 +68,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.STORE, None, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(512, 1, 64, 32, 32, 1, 1, 1), strides=(65536, 0, 1024, 32, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -85,7 +85,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.STORE, None, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(16, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -102,7 +102,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.STORE, None, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(4, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -119,7 +119,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.STORE, None, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(2, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -136,7 +136,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.STORE, None, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -153,7 +153,7 @@ class TestLinearizerOverflow(unittest.TestCase): UOp(Ops.STORE, None, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, None, arg=ShapeTracker(views=(View(shape=(3, 1, 128, 128, 128, 1, 1, 1), strides=(2097152, 0, 16384, 128, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)), src=()), - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (7, 6, 5)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (7, 6, 5)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), @@ -175,7 +175,7 @@ class TestLinearizerOverflowAlt(unittest.TestCase): in_st_2 = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)).to_uop() ot_st = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)).to_uop() prod = UOp(Ops.LOAD, dtypes.float, (g1, in_st_1)) * UOp(Ops.LOAD, dtypes.float, (g2, in_st_2)) - store = UOp(Ops.STORE, src=(g0, ot_st, UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (BinaryOps.ADD, (7, 6, 5))))) + store = UOp(Ops.STORE, src=(g0, ot_st, UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (Ops.ADD, (7, 6, 5))))) ast = UOp(Ops.SINK, src=(store,)) opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.LOCAL, axis=2, amt=2), Opt(op=OptOps.UPCAST, axis=0, amt=2)] _test_overflow(ast, opts) @@ -187,7 +187,7 @@ class TestLinearizerOverflowAlt(unittest.TestCase): in_st_2 = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 3, 7, 7), strides=(0, 0, 147, 0, 0, 49, 7, 1), offset=0, mask=None, contiguous=False),)).to_uop() ot_st = ShapeTracker(views=(View(shape=(BS, 1, 64, 112, 112, 1, 1, 1), strides=(802816, 0, 12544, 112, 1, 0, 0, 0), offset=0, mask=None, contiguous=True),)).to_uop() prod = UOp(Ops.LOAD, dtypes.float, (g1, in_st_1)) * UOp(Ops.LOAD, dtypes.float, (g2, in_st_2)) - store = UOp(Ops.STORE, src=(g0, ot_st, UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (BinaryOps.ADD, (7, 6, 5))))) + store = UOp(Ops.STORE, src=(g0, ot_st, UOp(Ops.REDUCE_AXIS, dtypes.float, (prod,), (Ops.ADD, (7, 6, 5))))) ast = UOp(Ops.SINK, src=(store,)) opts = [Opt(op=OptOps.LOCAL, axis=3, amt=16), Opt(op=OptOps.UPCAST, axis=1, amt=4), Opt(op=OptOps.LOCAL, axis=2, amt=16), Opt(op=OptOps.UPCAST, axis=4, amt=4), Opt(op=OptOps.UPCAST, axis=1, amt=2), Opt(op=OptOps.UPCAST, axis=5, amt=2)] _test_overflow(ast, opts) diff --git a/test/test_multitensor.py b/test/test_multitensor.py index 5b1704c364..af67972a79 100644 --- a/test/test_multitensor.py +++ b/test/test_multitensor.py @@ -1,7 +1,7 @@ import unittest, functools, random from typing import List from tinygrad import Tensor, Device, nn, GlobalCounters, TinyJit, dtypes -from tinygrad.ops import MetaOps, BinaryOps, Ops +from tinygrad.ops import Ops from tinygrad.helpers import CI, getenv, prod, Context from tinygrad.nn.state import get_parameters, get_state_dict from tinygrad.engine.schedule import create_schedule @@ -51,6 +51,15 @@ class TestMultiTensor(unittest.TestCase): assert lb.shape == (128,) (X + X).realize() + def test_tensor_from_multi(self): + X = Tensor([1, 2], dtype=dtypes.int).shard_(devices_2, 0) + Y = Tensor(X.lazydata) + self.assertEqual(Y.device, Device.DEFAULT) + np.testing.assert_equal(X.numpy(), Y.numpy()) + + with self.assertRaises(AssertionError): + _ = Tensor(X.lazydata, dtype=dtypes.float) + def test_sharded_arange(self): sharded_arange = Tensor.arange(1000).shard(devices_2, 0) sharded_arange.realize() @@ -481,7 +490,7 @@ class TestMultiTensor(unittest.TestCase): for p in get_parameters(bn): p.shard_(devices_4).realize() out = bn(t) - scheds = [sched for sched in create_schedule(out.lazydata.lbs) if sched.outputs[0].device in devices_4 and sched.ast.op is not MetaOps.COPY] + scheds = [sched for sched in create_schedule(out.lazydata.lbs) if sched.outputs[0].device in devices_4 and sched.ast.op is not Ops.COPY] assert set(out.device for sched in scheds for out in sched.outputs) == set(devices_4), "should have ast on each shard device" asts = [sched.ast for sched in scheds] assert len(asts) @@ -640,21 +649,21 @@ class TestMultiTensor(unittest.TestCase): for si in t.schedule(): ast = si.ast.src[0] assert ast.op is Ops.STORE - assert ast.src[2].op is BinaryOps.ADD + assert ast.src[2].op is Ops.ADD assert ast.src[2].src[0].op is Ops.LOAD assert ast.src[2].src[1].src[1].op is Ops.CONST and ast.src[2].src[1].src[1].arg == 1 t = 2 * t for si in t.schedule(): ast = si.ast.src[0] assert ast.op is Ops.STORE - assert ast.src[2].op is BinaryOps.MUL + assert ast.src[2].op is Ops.MUL assert ast.src[2].src[0].src[1].op is Ops.CONST and ast.src[2].src[0].src[1].arg == 2 assert ast.src[2].src[1].op is Ops.LOAD t = t + t.full_like(3) for si in t.schedule(): ast = si.ast.src[0] assert ast.op is Ops.STORE - assert ast.src[2].op is BinaryOps.ADD + assert ast.src[2].op is Ops.ADD assert ast.src[2].src[0].op is Ops.LOAD assert ast.src[2].src[1].src[1].op is Ops.CONST and ast.src[2].src[1].src[1].arg == 3 diff --git a/test/test_ops.py b/test/test_ops.py index ba2fa20bae..b781d30286 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -216,6 +216,29 @@ class TestOps(unittest.TestCase): for i in range(len(tor)): helper_test_op([], lambda: tor[i], lambda: ten[i], forward_only=True) + def test_meshgrid(self): + x, xt = torch.tensor([0.,1.,2.], requires_grad=True), Tensor([0.,1.,2.], requires_grad=True) + y, yt = torch.tensor([3.,4.,5.,6.], requires_grad=True), Tensor([3.,4.,5.,6.], requires_grad=True) + z, zt = torch.tensor([7.,8.,9.], requires_grad=True), Tensor([7.,8.,9.], requires_grad=True) + for indexing in ("ij", "xy"): + tor = torch.meshgrid(x, indexing=indexing) + ten = xt.meshgrid(indexing=indexing) + self.assertEqual(len(tor), len(ten)) + for tor_i, ten_i in zip(tor, ten): + helper_test_op([], lambda: tor_i, lambda: ten_i) + tor = torch.meshgrid(x, y, indexing=indexing) + ten = xt.meshgrid(yt, indexing=indexing) + self.assertEqual(len(tor), len(ten)) + for tor_i, ten_i in zip(tor, ten): + helper_test_op([], lambda: tor_i, lambda: ten_i) + tor = torch.meshgrid(x, torch.tensor(10., requires_grad=True), y, z, indexing=indexing) + ten = xt.meshgrid(Tensor(10., requires_grad=True), yt, zt, indexing=indexing) + self.assertEqual(len(tor), len(ten)) + for tor_i, ten_i in zip(tor, ten): + helper_test_op([], lambda: tor_i, lambda: ten_i) + + self.helper_test_exception([], lambda: torch.meshgrid(x, indexing="bad"), lambda: xt.meshgrid(indexing="bad"), expected=RuntimeError) + def test_arange(self): helper_test_op([], lambda: torch.arange(10, dtype=torch.int32), lambda: Tensor.arange(10), forward_only=True) helper_test_op([], lambda: torch.arange(36, dtype=torch.int32), lambda: Tensor.arange(36), forward_only=True) @@ -879,15 +902,15 @@ class TestOps(unittest.TestCase): helper_test_op([(45,65), (65)], lambda x,y: x.matmul(y), Tensor.dot) helper_test_op([(8,45,65), (65)], lambda x,y: x.matmul(y), Tensor.dot) helper_test_op([(65), (8,65,45)], lambda x,y: x.matmul(y), Tensor.dot) - self.helper_test_exception([(4), (1,2)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError)) - self.helper_test_exception([(2,1), (4)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError)) - self.helper_test_exception([(1), (4)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError)) + self.helper_test_exception([(4), (1,2)], lambda x, y: x.matmul(y), Tensor.dot, expected=RuntimeError) + self.helper_test_exception([(2,1), (4)], lambda x, y: x.matmul(y), Tensor.dot, expected=RuntimeError) + self.helper_test_exception([(1), (4)], lambda x, y: x.matmul(y), Tensor.dot, expected=RuntimeError) def test_dot(self): helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-5) helper_test_op([(8,45,65), (8,65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-5) - self.helper_test_exception([(2, 4), (1, 3)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError)) - self.helper_test_exception([(2, 1), (4, 3)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError)) - with self.assertRaises(AssertionError): + self.helper_test_exception([(2, 4), (1, 3)], lambda x, y: x.matmul(y), Tensor.dot, expected=RuntimeError) + self.helper_test_exception([(2, 1), (4, 3)], lambda x, y: x.matmul(y), Tensor.dot, expected=RuntimeError) + with self.assertRaises(RuntimeError): a = Tensor(3.14) a.matmul(a) def test_mulacc_with_zero_strides(self): @@ -954,7 +977,7 @@ class TestOps(unittest.TestCase): helper_test_op([(0), (0)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-7) def test_broadcastdot(self): helper_test_op([(10,45,65), (65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4) - with self.assertRaises(AssertionError): + with self.assertRaises(RuntimeError): a = Tensor(3.14) b = Tensor.ones(3,3) a @ b @@ -988,6 +1011,12 @@ class TestOps(unittest.TestCase): self.helper_test_exception([()], lambda x: x.sum(1), lambda x: x.sum(1), expected=IndexError) self.helper_test_exception([()], lambda x: x.sum((1,)), lambda x: x.sum((1,)), expected=IndexError) + def test_sum_acc_dtype(self): + helper_test_op([(45,3)], lambda x: x.sum(), lambda x: x.sum(acc_dtype=dtypes.float32)) + if is_dtype_supported(dtypes.float64): helper_test_op([(45,3)], lambda x: x.sum(dtype=torch.float64), lambda x: x.sum(acc_dtype=dtypes.float64)) + + with self.assertRaises(AttributeError): Tensor([1.0, 2.0]).sum(acc_dtype="") + def test_sum_with_zeros_shape(self): helper_test_op([(4, 0)], lambda x: x.sum(axis=(0,))) helper_test_op([(4, 0)], lambda x: x.sum(axis=(1,))) @@ -1003,6 +1032,9 @@ class TestOps(unittest.TestCase): helper_test_op([()], lambda x: x.prod(0)) helper_test_op([()], lambda x: x.prod(-1)) + def test_prod_acc_dtype(self): + with self.assertRaises(AttributeError): Tensor([1.0, 2.0]).prod(acc_dtype="") + def test_min(self): helper_test_op([(3,3)], lambda x: x.min()) helper_test_op([(45,3)], lambda x: x.min()) @@ -1339,27 +1371,36 @@ class TestOps(unittest.TestCase): helper_test_op([(4,4)], lambda x: x[:, 1:2][0:1]) helper_test_op([(4,4)], lambda x: x[:, 1:2][:, 0:1]) - def test_pad2d(self): - helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad2d(padding=(1,2,3,4))) - helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (-1,2,-3,4)), lambda x: x.pad2d(padding=(-1,2,-3,4))) - helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=5), lambda x: x.pad2d(padding=(1,2,3,4),value=5)) - helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (-1,2,-3,4), value=5), lambda x: x.pad2d(padding=(-1,2,-3,4),value=5)) - def test_pad(self): + helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)), lambda x: x.pad(padding=(1,2,3,4))) + helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (-1,2,-3,4)), lambda x: x.pad(padding=(-1,2,-3,4))) + helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=5), lambda x: x.pad(padding=(1,2,3,4),value=5)) + helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (-1,2,-3,4), value=5), lambda x: x.pad(padding=(-1,2,-3,4),value=5)) + helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=math.inf), lambda x: x.pad(padding=(1,2,3,4),value=math.inf)) + helper_test_op([(3,3,3,3)], lambda x: torch.nn.functional.pad(x, (-1,2,-3,4), value=-math.inf), + lambda x: x.pad(padding=(-1,2,-3,4),value=-math.inf)) helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4)),lambda x: x.pad(((3,4),(1,2)))) + helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (-1,2,-3,4)), lambda x: x.pad(((-3,4), (-1,2)))) helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=5), lambda x: x.pad(((3,4), (1,2)), value=5)) - helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=math.inf), lambda x: x.pad(((3,4), (1,2)), value=math.inf)) - helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (1,2,3,4), value=-math.inf), lambda x: x.pad(((3,4), (1,2)), value=-math.inf)) helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (0,0,3,4), value=1), lambda x: x.pad(((3,4), None), value=1)) helper_test_op([(3,3)], lambda x: torch.nn.functional.pad(x, (0,0,0,0), value=1), lambda x: x.pad((None, None), value=1)) + # raise error for uneven pads + self.helper_test_exception([(3,3)], lambda x: torch.nn.functional.pad(x, (2,0,2)), lambda x: x.pad((2,0,2)), + expected=(RuntimeError, ValueError)) + # raise error for too many or too little pads + self.helper_test_exception([(3,3)], lambda x: torch.nn.functional.pad(x, (0,0,0,0,1,0,3,0)), lambda x: x.pad((0,0,0,0,1,0,3,0)), + expected=(RuntimeError, ValueError)) + x = Tensor.ones(3,3) + with self.assertRaises(ValueError): x.pad((None,(0,1),(3,0))) + with self.assertRaises(ValueError): x.pad(((0,1),)) def test_pad_reshape(self): helper_test_op([(1, 2)], lambda x: torch.nn.functional.pad(x, (0, 1, 1, 0)).reshape((3, 2)), - lambda x: x.pad2d((0, 1, 1, 0)).reshape((3, 2)), forward_only=True) + lambda x: x.pad((0, 1, 1, 0)).reshape((3, 2)), forward_only=True) helper_test_op([(1, 2)], lambda x: torch.nn.functional.pad(x, (0, 2, 1, 1)).reshape((4, 3)), - lambda x: x.pad2d((0, 2, 1, 1)).reshape((4, 3)), forward_only=True) + lambda x: x.pad((0, 2, 1, 1)).reshape((4, 3)), forward_only=True) helper_test_op([(1, 1, 1, 2)], lambda x: torch.nn.functional.pad(x, (0, 4, 2, 2, 1, 2, 0, 2)).reshape((4, 3, 6, 5)), lambda x: x.pad(((0, 2), (1, 2), (2, 2), (0, 4))).reshape((4, 3, 6, 5)), forward_only=True) @@ -1832,7 +1873,7 @@ class TestOps(unittest.TestCase): def test_padding_add(self): helper_test_op([(64,64), (60,60)], lambda x,w: x+torch.nn.functional.pad(w, (2,2,2,2)), - lambda x,w: x+w.pad2d((2,2,2,2))) + lambda x,w: x+w.pad((2,2,2,2))) def test_dilated_conv2d(self): bs = 4 @@ -1844,34 +1885,40 @@ class TestOps(unittest.TestCase): lambda x,w: torch.nn.functional.conv2d(x,w,dilation=dilation).relu(), lambda x,w: Tensor.conv2d(x,w,dilation=dilation).relu()) - def test_maxpool2d_simple(self): + def test_max_pool2d_simple(self): ksz = (2,2) helper_test_op([(1,1,2,3)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz), lambda x: Tensor.max_pool2d(x, kernel_size=ksz)) - def test_maxpool2d(self): + def test_max_pool2d(self): for ksz in [(2,2), (3,3), 2, 3, (3,2), (5,5), (5,1)]: with self.subTest(kernel_size=ksz): helper_test_op([(32,2,110,28)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz), lambda x: Tensor.max_pool2d(x, kernel_size=ksz)) - def test_maxpool2d_padding(self): + def test_max_pool2d_padding(self): for ksz in [(2,2), (3,3), 2, 3, (3,2)]: with self.subTest(kernel_size=ksz): helper_test_op([(32,2,110,28)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=ksz, padding=1), lambda x: Tensor.max_pool2d(x, kernel_size=ksz, padding=1)) - def test_maxpool2d_bigger_stride(self): + def test_max_pool2d_padding_int(self): + ksz = (2,2) + helper_test_op([(32,2,110,28)], + lambda x: torch.nn.functional.max_pool2d(x.int(), kernel_size=ksz, padding=1), + lambda x: Tensor.max_pool2d(x.int(), kernel_size=ksz, padding=1), forward_only=True) + + def test_max_pool2d_bigger_stride(self): for stride in [(2,3), (3,2), 2, 3]: with self.subTest(stride=stride): helper_test_op([(32,2,110,28)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(2,2), stride=stride), lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), stride=stride)) - def test_maxpool2d_bigger_stride_dilation(self): + def test_max_pool2d_bigger_stride_dilation(self): for stride, dilation in zip([(2,3), (3,2), 2, 3, 4], [(3,2), (2,3), 2, 3, 6]): with self.subTest(stride=stride): helper_test_op([(32,2,110,28)], @@ -1879,25 +1926,25 @@ class TestOps(unittest.TestCase): lambda x: Tensor.max_pool2d(x, kernel_size=(2,2), stride=stride, dilation=dilation)) @unittest.skipIf( Device.DEFAULT in {"CUDA", "NV"}, "CUDA fails on this") - def test_maxpool2d_unit_stride(self): + def test_max_pool2d_unit_stride(self): helper_test_op([(8, 2, 17, 14)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), stride=1), lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), stride=1)) - def test_maxpool2d_smaller_stride(self): + def test_max_pool2d_smaller_stride(self): for stride in [(2,3), (3,2), 2, 3]: with self.subTest(stride=stride): helper_test_op([(8, 2, 17, 14)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), stride=stride), lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), stride=stride)) - def test_maxpool2d_dilation(self): + def test_max_pool2d_dilation(self): for dilation in [(2, 3), (3, 2), 2, 3]: helper_test_op([(8, 2, 17, 14)], lambda x: torch.nn.functional.max_pool2d(x, kernel_size=(5,5), dilation=dilation), lambda x: Tensor.max_pool2d(x, kernel_size=(5,5), dilation=dilation)) - def test_avgpool2d(self): + def test_avg_pool2d(self): shape = (32,2,111,28) for ksz in [(2,2), (3,3), (3,2), (5,5), (5,1)]: with self.subTest(kernel_size=ksz): @@ -1907,12 +1954,12 @@ class TestOps(unittest.TestCase): # TODO fix edge case @unittest.expectedFailure - def test_avgpool2d_failure(self): + def test_avg_pool2d_failure(self): helper_test_op([(1,1,8,8)], lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(1,2), padding=(0,1), stride=(5,1)), lambda x: Tensor.avg_pool2d(x, kernel_size=(1,2), padding=(0,1), stride=(5,1)), rtol=1e-5) - def test_avgpool2d_padding(self): + def test_avg_pool2d_padding(self): shape = (32,2,111,28) for ksz in [(2,2), (3,3), 2, 3, (3,2)]: with self.subTest(kernel_size=ksz): @@ -1920,7 +1967,7 @@ class TestOps(unittest.TestCase): lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz, padding=1), lambda x: Tensor.avg_pool2d(x, kernel_size=ksz, padding=1), rtol=1e-5) - def test_avgpool2d_padding_not_counted(self): + def test_avg_pool2d_padding_not_counted(self): shape = (32,2,111,28) for ksz in [(2,2), (3,3), 2, 3, (3,2)]: with self.subTest(kernel_size=ksz): @@ -1928,7 +1975,7 @@ class TestOps(unittest.TestCase): lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=ksz, padding=1, count_include_pad=False), lambda x: Tensor.avg_pool2d(x, kernel_size=ksz, padding=1, count_include_pad=False), rtol=1e-5) - def test_global_avgpool2d(self): + def test_global_avg_pool2d(self): helper_test_op([(32,2,111,28)], lambda x: torch.nn.functional.avg_pool2d(x, kernel_size=(111,28)), lambda x: Tensor.avg_pool2d(x, kernel_size=(111,28)), rtol=1e-5) @@ -2032,6 +2079,8 @@ class TestOps(unittest.TestCase): helper_test_op([(3, 3)], lambda x: x.repeat_interleave(6)) helper_test_op([(3, 3)], lambda x: x.repeat_interleave(2, 1)) helper_test_op([(3, 3)], lambda x: x.repeat_interleave(2, 0)) + helper_test_op([(3, 3)], lambda x: x.repeat_interleave(2, -1)) + helper_test_op([(3, 3)], lambda x: x.repeat_interleave(2, -2)) def test_simple_repeat(self): repeats = [3, 3, 4] diff --git a/test/test_renderer_failures.py b/test/test_renderer_failures.py index d997cb3f37..5fa5a4b73b 100644 --- a/test/test_renderer_failures.py +++ b/test/test_renderer_failures.py @@ -8,7 +8,7 @@ from tinygrad.dtype import dtypes from tinygrad.engine.realize import CompiledRunner from tinygrad.helpers import dedup, flatten, prod from tinygrad.renderer.cstyle import CStyleLanguage -from tinygrad.ops import BinaryOps, UOp, Ops +from tinygrad.ops import UOp, Ops from tinygrad.renderer import Program from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.engine.lazy import LazyBuffer @@ -34,7 +34,7 @@ class TestCStyleFailures(unittest.TestCase): b = UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), 1) idx = UOp.const(dtypes.int, 0) ld = UOp(Ops.LOAD, dtypes.int, (b.index(idx),)) - alu = ld.alu(BinaryOps.MAX, UOp.const(dtypes.int, dtypes.min(dtypes.int)+1)) + alu = ld.alu(Ops.MAX, UOp.const(dtypes.int, dtypes.min(dtypes.int)+1)) store = UOp.store(a.index(idx), alu) sink = UOp(Ops.SINK, dtypes.void, (store,)) uops = linearize_uop(full_graph_rewrite(sink, Device[Device.DEFAULT].renderer)) diff --git a/test/test_schedule.py b/test/test_schedule.py index 7b8d719904..2925e25d1a 100644 --- a/test/test_schedule.py +++ b/test/test_schedule.py @@ -13,13 +13,12 @@ from tinygrad.device import is_dtype_supported from tinygrad.dtype import DType from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View -from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, Ops, graph_rewrite, track_rewrites +from tinygrad.ops import UOp, Ops, graph_rewrite, track_rewrites from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod, Context from tinygrad.codegen.kernel import Kernel, verify_ast from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, view_right, view_left -from tinygrad.engine.realize import CompiledRunner, run_schedule +from tinygrad.engine.realize import CompiledRunner, get_runner, run_schedule from tinygrad.engine.lazy import LazyBuffer, view_supported_devices -from test.helpers import ast_const, timeit from extra.models.llama import precompute_freqs_cis class KernelCountException(Exception): pass @@ -41,9 +40,7 @@ def check_schedule(t:Union[Tensor, List[Tensor], LazyBuffer], allowed:int, to_pr # test the (sink) ops linearize for s in sched: if s.ast.op is not Ops.SINK: continue - l = Kernel(s.ast) - l.hand_coded_optimizations() - l.to_program() + get_runner(s.bufs[0].device, s.ast) return sched def _realize_weights(m): @@ -311,7 +308,6 @@ class TestSchedule(unittest.TestCase): img = Tensor.empty(64,64) x = (img.sum(0) + img.sum(1)) out = x.relu() - del x # is 3 without this check_schedule(out, 2) #@unittest.skip("failing in old lazy") @@ -335,6 +331,7 @@ class TestSchedule(unittest.TestCase): d = (a+b).reshape(16,1) check_schedule(d, 0, [c]) + @unittest.skipUnless(is_dtype_supported(dtypes.half), "need half") def test_multi_permute_should_collapse(self): a = Tensor.empty(4,4,4,4) b = Tensor.empty(16) @@ -1045,7 +1042,7 @@ class TestSchedule(unittest.TestCase): b = r.sum(0) * 4 c = r.sum(1) * 2 schedule = check_schedule([b, c], 3) - self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD) + self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD) # multireduce spec def test_multireduce_simple_chase(self): @@ -1056,7 +1053,7 @@ class TestSchedule(unittest.TestCase): c = r.sum(1) + 12 np_r = (a.numpy() + (a.numpy().sum(0) + 6)).sum(0) * 2 # schedule = check_schedule([b,c], 3) - # self.assertIs(schedule[0].ast[0].src[0].arg, BinaryOps.MUL) + # self.assertIs(schedule[0].ast[0].src[0].arg, Ops.MUL) schedule = check_schedule([b,c], 4) run_schedule(schedule) np.testing.assert_allclose(b.numpy(), np_r.sum(0) + 8, atol=1e-4, rtol=1e-4) @@ -1069,7 +1066,7 @@ class TestSchedule(unittest.TestCase): d = r.T * 4 e = r * d schedule = check_schedule([d, e], 3) - self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD) + self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD) # multireduce spec def test_multireduce_push_permute_chase(self): @@ -1080,7 +1077,7 @@ class TestSchedule(unittest.TestCase): d = r.T * 4 e = r * (d + a).sum(2) schedule = check_schedule([d, e], 3) # make sure it doesn't fuse - self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD) + self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD) run_schedule(schedule) np.testing.assert_allclose(d.numpy(), (a.numpy().sum(2) + b.numpy()).T * 4, atol=1e-4, rtol=1e-4) np.testing.assert_allclose(e.numpy(), (a.numpy().sum(2) + b.numpy()) * (d.numpy() + a.numpy()).sum(2), atol=1e-4, rtol=1e-4) @@ -1092,7 +1089,7 @@ class TestSchedule(unittest.TestCase): r = a.sum(1) + c d = r[:4] * b schedule = check_schedule(d, 2) - self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD) + self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD) # multireduce spec def test_multireduce_push_shrink_chase(self): @@ -1105,7 +1102,7 @@ class TestSchedule(unittest.TestCase): out = r[:4] * b + d.sum(1)[:4] # schedule = check_schedule(out, 2) schedule = check_schedule(out, 3) - self.assertIs(schedule[0].ast.src[0].src[2].op, BinaryOps.ADD) + self.assertIs(schedule[0].ast.src[0].src[2].op, Ops.ADD) run_schedule(schedule) np.testing.assert_allclose(out.numpy(), (a.numpy().sum(1) + c.numpy())[:4] * b.numpy() + d.numpy().sum(1)[:4], atol=1e-4, rtol=1e-4) @@ -1290,16 +1287,16 @@ class TestSchedule(unittest.TestCase): @unittest.skipIf(Device.DEFAULT not in view_supported_devices, "subbuffer not supported") def test_bitcast_subbufer(self): x = cast(LazyBuffer, Tensor.empty(1, dtype=dtypes.float32).realize().lazydata) - a = x.alu(UnaryOps.EXP2).cast(dtypes.int32, True, allow_buffer_view=True) + a = x.alu(Ops.EXP2).cast(dtypes.int32, True, allow_buffer_view=True) b = x.cast(dtypes.int32, True, allow_buffer_view=True) - b = a.alu(BinaryOps.ADD, b) + b = a.alu(Ops.ADD, b) check_schedule(b, 2) # this should fuse when it makes sense def test_bitcast_disable_subbufer(self): x = cast(LazyBuffer, Tensor.empty(1, dtype=dtypes.float32).realize().lazydata) - a = x.alu(UnaryOps.EXP2).cast(dtypes.int32, True, allow_buffer_view=False) + a = x.alu(Ops.EXP2).cast(dtypes.int32, True, allow_buffer_view=False) b = x.cast(dtypes.int32, True, allow_buffer_view=False) - b = a.alu(BinaryOps.ADD, b) + b = a.alu(Ops.ADD, b) check_schedule(b, 1) def test_reduceop_reshape_dont_push(self): @@ -1533,7 +1530,7 @@ class TestIndexing(unittest.TestCase): def test_arange_view_op(self): a = Tensor.arange(12).reshape(4, 3).shrink(((1, 2), (1, 3))).contiguous() assert isinstance(a.lazydata, LazyBuffer) - self.assertIs(a.lazydata.base.op, MetaOps.BUFFER_VIEW) + self.assertIs(a.lazydata.base.op, Ops.BUFFER_VIEW) self.check_schedule(a, 1) np.testing.assert_equal(a.numpy(), [[4, 5]]) @@ -1541,7 +1538,7 @@ class TestIndexing(unittest.TestCase): def test_arange_shrink_copy(self): a = Tensor.arange(12).reshape(4, 3).shrink(((1, 2), (1, 3))).to("CLANG") assert isinstance(a.lazydata, LazyBuffer) - self.assertIs(a.lazydata.base.op, MetaOps.COPY) + self.assertIs(a.lazydata.base.op, Ops.COPY) self.check_schedule(a, 1) np.testing.assert_equal(a.numpy(), [[4, 5]]) @@ -1549,8 +1546,8 @@ class TestIndexing(unittest.TestCase): def test_arange_expand_copy(self): a = Tensor.arange(4).reshape(2, 2, 1).expand(2, 2, 2).to("CLANG") assert isinstance(a.lazydata, LazyBuffer) - self.assertIs(a.lazydata.base.op, MetaOps.COPY) - self.assertIs(a.lazydata.base.srcs[0].base.op, BinaryOps.ADD) + self.assertIs(a.lazydata.base.op, Ops.COPY) + self.assertIs(a.lazydata.base.srcs[0].base.op, Ops.ADD) self.check_schedule(a, 1) np.testing.assert_equal(a.numpy(), [[[0, 0], [1, 1]], [[2, 2], [3, 3]]]) @@ -1635,16 +1632,6 @@ class TestIndexing(unittest.TestCase): self.assertEqual(new_uop.st, ShapeTracker.from_shape((4,)).reshape((4, 1))) self.assertEqual(swizzle_cnt(new_uop), 0) - def test_strongly_connected_DAG(self): - val = 1.0 - a = Tensor(val).realize() - def f(a): - for _ in range(24): a = Tensor.stack(a, a)[0] - return a.item() - r, et = timeit(f, a) - self.assertEqual(r, val) - self.assertLess(et, 1600) - def test_no_rewrite_elementwise(self): bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)] ld1 = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) @@ -1656,9 +1643,9 @@ class TestIndexing(unittest.TestCase): def test_simple_store_reshape(self): bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)] ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) - r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1))) + r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0, 1))) r = UOp(Ops.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(())) - r = r + ast_const(dtypes.int, 2, ()) + r = r + 2 sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),)) rsink = graph_rewrite(sink, view_right) # this AST first needs to swizzle, but it doesn't have implicit movementops @@ -1668,50 +1655,12 @@ class TestIndexing(unittest.TestCase): def test_no_reshape_reduceop(self): bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)] ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) - r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1))) + r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0, 1))) sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape((1, 1)).to_uop(), r)),)) rsink = graph_rewrite(sink, view_right) verify_ast(sink) self.assertEqual(sink.key, rsink.key) - def test_reshape_many(self): - bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)] - ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) - r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1))) - r = UOp(Ops.VIEW, dtypes.int, (r,), ShapeTracker.from_shape(())) - for _ in range(24): r = r + ast_const(dtypes.int, 2, ()) - sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),)) - rsink, et = timeit(graph_rewrite, sink, view_right) - # this AST first needs to swizzle, but it doesn't have implicit movementops - with self.assertRaisesRegex(AssertionError, "swizzle"): verify_ast(sink) - verify_ast(rsink) - self.assertLessEqual(et, 1e3) - - @unittest.skip("test is flaky") - def test_complexity(self): - SZ = 30 if getenv("BIG") else 10 - sizes = [10*(i+1) for i in range(SZ)] - tms: List[float] = [] - for sz in sizes: - bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)] - ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((32, 32)).to_uop())) - r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0, 1))) - for _ in range(sz): r = r + ast_const(dtypes.int, 2, ()) - sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), r)),)) - rsink, et = timeit(graph_rewrite, sink, view_right) - with self.assertRaisesRegex(AssertionError, "implicit reshape"): verify_ast(sink) - verify_ast(rsink) - tms.append(et) - if getenv("GRAPH_TIMING"): - import plotly.express as px - fig = px.line(x=sizes, y=tms, title="graph_rewrite time as ast grows") - fig.update_layout(paper_bgcolor="black", plot_bgcolor="black", font={"color":"white"}, - yaxis={"gridcolor":"rgba(255, 255, 255, 0.3)"}, xaxis={"gridcolor":"rgba(255, 255, 255, 0.3)"}) - fig.show() - change = tms[-1] / tms[0] - assert change <= SZ, f"bad complexity, time increased by {change:4.2f}x while input only grew {SZ}x" - - @track_rewrites(named=True) def swizzle_rewrite(u:UOp) -> UOp: return graph_rewrite(graph_rewrite(u, view_left), view_right) @@ -1749,10 +1698,9 @@ class TestSwizzle(unittest.TestCase): # LazyBuffer to pre-rewrite AST bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(2)] ld = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop())) - r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (BinaryOps.ADD, (0,))) + r = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld,), (Ops.ADD, (0,))) swizzle_r = UOp(Ops.VIEW, dtypes.int, (r,), unwrap(r.st).reshape(())) - const = ast_const(dtypes.int, 1, ()) - alu = swizzle_r+const + alu = swizzle_r+1 sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu,),),)) # graph rewrite sink = swizzle_rewrite(sink) @@ -1772,11 +1720,11 @@ class TestSwizzle(unittest.TestCase): # LazyBuffer to pre-rewrite AST bufs = [UOp(Ops.DEFINE_GLOBAL, dtypes.int.ptr(), (), i) for i in range(3)] ld1 = UOp(Ops.LOAD, dtypes.int, (bufs[1], ShapeTracker.from_shape((4,)).to_uop())) - r1 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld1,), (BinaryOps.ADD, (0,))) + r1 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld1,), (Ops.ADD, (0,))) ld2 = UOp(Ops.LOAD, dtypes.int, (bufs[2], ShapeTracker.from_shape((4,)).to_uop())) - r2 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld2,), (BinaryOps.ADD, (0,))) + r2 = UOp(Ops.REDUCE_AXIS, dtypes.int, (ld2,), (Ops.ADD, (0,))) alu = UOp(Ops.VIEW, r1.dtype, (r1,), ShapeTracker.from_shape(()))+UOp(Ops.VIEW, r2.dtype, (r2,), ShapeTracker.from_shape(())) - sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu+ast_const(dtypes.int, 2, ()),),),)) # noqa: E501 + sink = UOp(Ops.SINK, dtypes.void, (UOp(Ops.STORE, dtypes.void, (bufs[0], ShapeTracker.from_shape(()).to_uop(), alu+2,),),)) # noqa: E501 # graph rewrite sink = swizzle_rewrite(sink) # verify output @@ -1788,7 +1736,7 @@ class TestSwizzle(unittest.TestCase): def test_swizzle_rewrite_alt(self): swizzle = UOp(Ops.VIEW, dtypes.float, arg=ShapeTracker(views=(View(shape=(2, 3, 3, 65, 3, 65), strides=(103788, 34596, 3, 558, 1, 9), offset=0, mask=((0, 2), (0, 3), (0, 3), (0, 62), (0, 3), (0, 62)), contiguous=False), View(shape=(2, 3, 256, 256), strides=(114075, 38025, 195, 1), offset=0, mask=((0, 2), (0, 3), (0, 195), (0, 195)), contiguous=False), View(shape=(1, 2, 1, 3, 4, 64, 4, 64), strides=(0, 196608, 0, 65536, 16384, 256, 64, 1), offset=0, mask=None, contiguous=True))), src=( # noqa: E501 - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.ADD, (3,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.ADD, (3,)), src=( UOp(Ops.LOAD, dtypes.float, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=1, src=()), UOp(Ops.VIEW, dtypes.void, arg=(ld_st:=ShapeTracker(views=(View(shape=(2, 1, 3, 16, 62, 62, 3, 3), strides=(0, 0, 9, 27, 0, 0, 3, 1), offset=0, mask=None, contiguous=False),))), src=()),)),)),)) # noqa: E501 diff --git a/test/test_search.py b/test/test_search.py index b91d003e54..a9e6d4e6f0 100644 --- a/test/test_search.py +++ b/test/test_search.py @@ -3,7 +3,7 @@ import unittest from test.helpers import ast_const from tinygrad.codegen.kernel import Opt, OptOps from tinygrad.codegen.kernel import Kernel -from tinygrad.ops import UOp, Ops, BinaryOps +from tinygrad.ops import UOp, Ops from tinygrad.engine.schedule import create_schedule from tinygrad.engine.search import time_linearizer, bufs_from_lin, actions, beam_search from tinygrad.device import Device, Buffer @@ -107,7 +107,7 @@ class TestBEAM(unittest.TestCase): UOp(Ops.STORE, dtypes.void, arg=None, src=( UOp(Ops.DEFINE_GLOBAL, dtypes.float.ptr(), arg=0, src=()), UOp(Ops.VIEW, dtypes.void, arg=ShapeTracker(views=(View(shape=(1, 1, 256), strides=(0, 0, 1), offset=0, mask=None, contiguous=True),)), src=()), # noqa: E501 - UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(BinaryOps.MAX, (1,)), src=( + UOp(Ops.REDUCE_AXIS, dtypes.float, arg=(Ops.MAX, (1,)), src=( UOp(Ops.MUL, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( UOp(Ops.ADD, dtypes.float, arg=None, src=( diff --git a/test/test_tensor_variable.py b/test/test_tensor_variable.py index 4d3bc83765..eb7ddeb386 100644 --- a/test/test_tensor_variable.py +++ b/test/test_tensor_variable.py @@ -55,10 +55,10 @@ class TestTensorVariable(unittest.TestCase): ret = t.var().item() assert ret == 0 - def test_symbolic_pad2d(self): + def test_symbolic_pad(self): vv = Variable("a", 1, 10).bind(2) t = Tensor.ones(2, 2).contiguous() - t = t.pad2d([vv, vv, vv, vv]).mean() + t = t.pad([vv, vv, vv, vv]).mean() ones = 4 zeros = 6+6+4+4+6+6 self.assertAlmostEqual(t.item(), ones/(ones+zeros)) diff --git a/test/test_uop_graph.py b/test/test_uop_graph.py index 502c0df635..85396d7b1a 100644 --- a/test/test_uop_graph.py +++ b/test/test_uop_graph.py @@ -2,8 +2,7 @@ from typing import List import unittest, time from tinygrad import dtypes, Device from tinygrad.helpers import DEBUG -from tinygrad.ops import BinaryOps, Ops, UOp, KernelInfo -from tinygrad.ops import UPat, PatternMatcher +from tinygrad.ops import Ops, UOp, KernelInfo, UPat, PatternMatcher from tinygrad.renderer import Renderer from tinygrad.codegen.lowerer import rewrite_shapetracker_with_index from tinygrad.codegen.uopgraph import full_graph_rewrite, graph_rewrite, expander, sym @@ -541,7 +540,7 @@ class TestExpander(unittest.TestCase): @unittest.skip("no longer supported") def test_reduce_known_axis(self): e1 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),)) - sink = UOp(Ops.REDUCE, dtypes.int, (3*e1,e1), BinaryOps.ADD) + sink = UOp(Ops.REDUCE, dtypes.int, (3*e1,e1), Ops.ADD) sink = expander_rewrite(sink) assert sink.op is Ops.CONST self.assertEqual(sink.arg, 3*(0+1+2+3)) @@ -549,7 +548,7 @@ class TestExpander(unittest.TestCase): @unittest.skip("no longer supported") def test_reduce_const(self): e1 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),)) - sink = UOp(Ops.REDUCE, dtypes.int, (UOp.const(dtypes.int, 3), e1), BinaryOps.ADD) + sink = UOp(Ops.REDUCE, dtypes.int, (UOp.const(dtypes.int, 3), e1), Ops.ADD) sink = expander_rewrite(sink) assert sink.op is Ops.CONST self.assertEqual(sink.arg, 3*4) @@ -590,7 +589,7 @@ class TestExpander(unittest.TestCase): def test_reduce_different_axis(self): e1 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((1,4),)) e2 = UOp(Ops.EXPAND, dtypes.int, tuple(UOp.const(dtypes.int, x) for x in range(4)), ((2,4),)) - sink = UOp(Ops.REDUCE, dtypes.int, (e1,e2), BinaryOps.ADD) + sink = UOp(Ops.REDUCE, dtypes.int, (e1,e2), Ops.ADD) sink = expander_rewrite(sink) print(sink) diff --git a/test/test_uops.py b/test/test_uops.py index 13dbdb7d2b..abb9eae35e 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -6,7 +6,7 @@ from tinygrad.tensor import Tensor, _to_np_dtype from tinygrad.helpers import CI, DEBUG, getenv, Context from tinygrad.dtype import dtypes, DType from tinygrad.device import Buffer, Device -from tinygrad.ops import Ops, UOp, UPat, UnaryOps, BinaryOps, TernaryOps, KernelInfo, exec_alu, spec # noqa F401 +from tinygrad.ops import Ops, UOp, UPat, KernelInfo, exec_alu, spec # noqa F401 from tinygrad.renderer import Program from tinygrad.engine.schedule import create_schedule, to_si from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel @@ -29,7 +29,7 @@ def uop(uops:List[UOp], uop:Ops, dtype:Optional[DType], src:Tuple[UOp, ...], arg def _test_single_value(vals, op, dts): uops = [] - output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1] + output_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else dts[-1] buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0) buf_loads = [uop(uops, Ops.DEFINE_GLOBAL, dtype.ptr(), (), i+1) for i,dtype in enumerate(dts)] loads = (uop(uops, Ops.LOAD, dtype, [buf_loads[i].index(uop(uops, Ops.CONST, dtypes.int32, (), 0))]) for i, dtype in enumerate(dts)) @@ -45,7 +45,7 @@ def _test_single_value(vals, op, dts): def _test_single_value_const(vals, op, dts): uops = [] - output_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else dts[-1] + output_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else dts[-1] buf_store = uop(uops, Ops.DEFINE_GLOBAL, output_dtype.ptr(), (), 0) loads = (uop(uops, Ops.CONST, dtype, [], a) for a,dtype in zip(vals, dts)) alu = uop(uops, op, output_dtype, loads) @@ -103,49 +103,49 @@ class TestUOps(unittest.TestCase): class TestFloatUOps(TestUOps): @unittest.skipIf(Device.DEFAULT == "CLANG", 'not supported as uop') - def test_exp2(self): self._test_uop_fxn(UnaryOps.EXP2, lambda a: np.exp2(a)) + def test_exp2(self): self._test_uop_fxn(Ops.EXP2, lambda a: np.exp2(a)) @unittest.skipIf(Device.DEFAULT == "CLANG", 'not supported as uop') - def test_log2(self): self._test_uop_fxn(UnaryOps.LOG2, lambda a: math.log2(a) if a > 0 else float('-inf' if a==0 else 'nan')) + def test_log2(self): self._test_uop_fxn(Ops.LOG2, lambda a: math.log2(a) if a > 0 else float('-inf' if a==0 else 'nan')) @unittest.skipIf(Device.DEFAULT == "CLANG", 'not supported as uop') - def test_sin(self): self._test_uop_fxn(UnaryOps.SIN, lambda a: math.sin(a)) - def test_recip(self): self._test_uop_fxn(UnaryOps.RECIP, lambda a: 1/a if a != 0 else float('inf')) - def test_sqrt(self): self._test_uop_fxn(UnaryOps.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan')) + def test_sin(self): self._test_uop_fxn(Ops.SIN, lambda a: math.sin(a)) + def test_recip(self): self._test_uop_fxn(Ops.RECIP, lambda a: 1/a if a != 0 else float('inf')) + def test_sqrt(self): self._test_uop_fxn(Ops.SQRT, lambda a: math.sqrt(a) if a >= 0 else float('nan')) - def test_add(self): self._test_bop_fxn(BinaryOps.ADD, lambda a,b: a+b) - def test_mul(self): self._test_bop_fxn(BinaryOps.MUL, lambda a,b: a*b) - def test_max(self): self._test_bop_fxn(BinaryOps.MAX, lambda a,b: max(a,b)) - def test_cmplt(self): self._test_bop_fxn(BinaryOps.CMPLT, lambda a,b: a>int(b), (dtypes.int32, dtypes.int32), no_b_neg=True) + def test_shr_int32(self): self._test_bop_fxn(Ops.SHR, lambda a,b: int(a)>>int(b), (dtypes.int32, dtypes.int32), no_b_neg=True) @unittest.skipUnless(getenv("PTX"), "only ptx uses bitshifts") - def test_shl_int32(self): self._test_bop_fxn(BinaryOps.SHL, lambda a,b: int(a)< ALU/CONST rewrite is now instant """ - matcher = PatternMatcher([(UPat(UOps.ALU, name="x", src=(UPat(UOps.CONST), UPat(UOps.ALU))), lambda x: x)]) - c4 = UOp(UOps.ALU, dtypes.float, (c1,c3), BinaryOps.ADD) - c5 = UOp(UOps.ALU, dtypes.float, (c3,c1), BinaryOps.ADD) + matcher = PatternMatcher([(UPat(GroupOp.ALU, name="x", src=(UPat(Ops.CONST), UPat(GroupOp.ALU))), lambda x: x)]) + c4 = UOp(Ops.ADD, dtypes.float, (c1,c3)) + c5 = UOp(Ops.ADD, dtypes.float, (c3,c1)) self.assertEqual(matcher.rewrite(c3), None) self.assertEqual(matcher.rewrite(c4), c4) self.assertEqual(matcher.rewrite(c5), None) diff --git a/test/unit/test_shapetracker_math.py b/test/unit/test_shapetracker_math.py index 96af3d8288..5bfcd24685 100644 --- a/test/unit/test_shapetracker_math.py +++ b/test/unit/test_shapetracker_math.py @@ -109,7 +109,6 @@ class TestShapeTrackerAddVariable(unittest.TestCase): vm2 = View(shape=(var_i, var_j, 3), strides=(var_j*3, 3, 1), offset=0, mask=None, contiguous=True) ShapeTracker((vm1,)) + ShapeTracker((vm2,)) - @unittest.skip("two vars not supported") def test_merge_symbolic_views_2(self): var_i = Variable('i', 1, 10) var_j = Variable('j', 1, 10) diff --git a/test/unit/test_uop_symbolic.py b/test/unit/test_uop_symbolic.py index 6b88f15216..fba3f2f9d6 100644 --- a/test/unit/test_uop_symbolic.py +++ b/test/unit/test_uop_symbolic.py @@ -455,6 +455,13 @@ class TestSymbolic(unittest.TestCase): self.helper_test_variable((a+b+c*2).lt(1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)")) self.helper_test_variable((a+b*2+c*4).lt(1).ne(True), 0, 1, ("((((a+b)+c)<1)!=True)", "(((c+(a+b))<1)!=True)")) + def test_where_removal(self): + cond = Variable("a", 0, 3).lt(2) + u1, u0 = cond.ufix(1), cond.ufix(0) + self.helper_test_variable(cond, 0, 1, "(a<2)") + self.helper_test_variable(cond.where(u1, u0), 0, 1, "(a<2)") + self.helper_test_variable(cond.where(u1, u0).where(u1, u0), 0, 1, "(a<2)") + class TestSymbolicNumeric(unittest.TestCase): def helper_test_numeric(self, f): MIN, MAX = 0, 10 diff --git a/tinygrad/codegen/kernel.py b/tinygrad/codegen/kernel.py index 547d352979..6565dc310a 100644 --- a/tinygrad/codegen/kernel.py +++ b/tinygrad/codegen/kernel.py @@ -5,7 +5,7 @@ from collections import defaultdict from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict, Callable, Sequence from enum import Enum, auto -from tinygrad.ops import GroupOp, BinaryOps, KernelInfo, UOp, Ops, PatternMatcher, can_pad, print_uops, type_verify, resolve, Variable, sint, \ +from tinygrad.ops import GroupOp, KernelInfo, UOp, Ops, PatternMatcher, can_pad, print_uops, type_verify, resolve, Variable, sint, \ graph_rewrite, track_rewrites from tinygrad.device import Device from tinygrad.renderer import Renderer, TensorCore, Program @@ -276,7 +276,7 @@ class Kernel: if has_cast and not (reduceop.src[0].op is Ops.CAST and reduceop.src[0].dtype == tc.dtype_out): return None mul_op = reduceop.src[0].src[0] if has_cast else reduceop.src[0] - if mul_op.op is not BinaryOps.MUL: return None + if mul_op.op is not Ops.MUL: return None def buf_index(src:UOp) -> Optional[int]: # TODO: apply tc even if the sources are not from LOAD @@ -303,7 +303,7 @@ class Kernel: return TensorCoreOptions(axes=(s0, s1, s2), axes_exist=(True, True), axis_pads=axis_pads) def _apply_tc_opt(self, use_tensor_cores:int, axis:int, opt_level:int) -> bool: - if use_tensor_cores and self.reduceop is not None and self.reduceop.arg[0] is BinaryOps.ADD: + if use_tensor_cores and self.reduceop is not None and self.reduceop.arg[0] is Ops.ADD: for tc in self.opts.tensor_cores: tensor_core_opts = [self._create_tc_opts(reduceop, tc, axis, opt_level) for reduceop in self.reduceops] # can only fuse reduces with the same tc options @@ -338,7 +338,7 @@ class Kernel: 2: apply tensor core shape but don't use UOp.WMMA extra_opts -- additional Opt's to apply after the tensor core instead of the hand-coded additional Opt's (default None) tc_opt -- controls which kinds of kernels may be eligible for tensor cores application (default 2 during BEAM, 0 otherwise) - 0: applies to only kernels with a single reduce axis and direct UOps.LOAD into BinaryOps.MUL + 0: applies to only kernels with a single reduce axis and direct UOps.LOAD into Ops.MUL 1: allows kernels with multiple reduce axes and also multiplication of UOps.CAST'd buffers 2: allows kernels with M, N, K axes that are not multiples of the tensor core dimensions by applying padding those axes as needed """ @@ -441,7 +441,7 @@ class Kernel: check(not self.vars, "does not work with symbolic shape") check(axis < self.first_upcast, "cannot pad upcasted") # ok to pad SUM if all parent ALU ops have f(0) = 0 - if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is BinaryOps.ADD and can_pad(r), f"cannot pad {r}") + if (r:=self.reduceop) is not None and self.first_reduce <= axis: check(r.arg[0] is Ops.ADD and can_pad(r), f"cannot pad {r}") padded = False for i,st in enumerate(self.sts): if (s:=st.shape[axis]) == 1: continue # reduced @@ -470,8 +470,8 @@ class Kernel: # should use matvec - TODO: adjust/tune based on the wide vs tall/large vs small mat MV_BLOCKSIZE, MV_THREADS_PER_ROW, MV_ROWS_PER_THREAD = getenv("MV_BLOCKSIZE", 4), getenv("MV_THREADS_PER_ROW", 8), getenv("MV_ROWS_PER_THREAD", 4) if self.opts.has_local and getenv("MV",1) != 0 and (MV_BLOCKSIZE > 1 or MV_THREADS_PER_ROW > 1 or MV_ROWS_PER_THREAD > 1) and \ - self.reduceop is not None and self.reduceop.arg[0] is BinaryOps.ADD and len(self.full_shape) >= 2 and self.opts.has_shared and \ - (mulop:=self.reduceop.src[0]).op is BinaryOps.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD: + self.reduceop is not None and self.reduceop.arg[0] is Ops.ADD and len(self.full_shape) >= 2 and self.opts.has_shared and \ + (mulop:=self.reduceop.src[0]).op is Ops.MUL and mulop.src[0].op is Ops.LOAD and mulop.src[1].op is Ops.LOAD: st0, st1 = self.sts[self.bufs.index(mulop.src[0])], self.sts[self.bufs.index(mulop.src[1])] strides0, strides1 = st0.real_strides(), st1.real_strides() def has_expanded_axis(shape, strides): return any(resolve(s > 1) and not resolve(st != 0) for s,st in zip(shape,strides)) @@ -663,7 +663,7 @@ class Kernel: # for TC=2, we can't do the shapetracker fixup srcs = [fixup_ast(rsrc.src[0]), fixup_ast(rsrc.src[1])] # MUL/SUM instead of WMMA - ret = UOp(Ops.REDUCE_AXIS, tc.dtype_out, (srcs[0].alu(BinaryOps.MUL, srcs[1]).cast(tc.dtype_out),), (alu_op, wmma_arg[-1])) + ret = UOp(Ops.REDUCE_AXIS, tc.dtype_out, (srcs[0].alu(Ops.MUL, srcs[1]).cast(tc.dtype_out),), (alu_op, wmma_arg[-1])) else: # real WMMA, use CONTRACT/EXPAND to get the vectorization right wmma_upcast_axes = wmma_arg[-2] diff --git a/tinygrad/codegen/transcendental.py b/tinygrad/codegen/transcendental.py index 42cb2abcf5..4f9aeff51a 100644 --- a/tinygrad/codegen/transcendental.py +++ b/tinygrad/codegen/transcendental.py @@ -169,7 +169,7 @@ def sin_poly_large(d:UOp, q:UOp) -> UOp: def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp: """ - Implements a 1.0 ULP approximation for UnaryOps.SIN. + Implements a 1.0 ULP approximation for Ops.SIN. - fast=True assumes x <= switch_over. - switch_over is the threshold for switching to payne_hanek_reduction. """ @@ -192,7 +192,7 @@ def xsin(d:UOp, fast:bool=False, switch_over:float=30.0) -> UOp: def xexp2(d:UOp) -> UOp: """ - Implements a 1.0 ULP approximation for UnaryOps.EXP2 + Implements a 1.0 ULP approximation for Ops.EXP2 - Paper: https://arxiv.org/pdf/2001.09258 """ assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES @@ -218,7 +218,7 @@ def xexp2(d:UOp) -> UOp: def xlog2(d:UOp) -> UOp: """ - Implements a 1.0 ULP approximation for UnaryOps.LOG2 + Implements a 1.0 ULP approximation for Ops.LOG2 Paper: https://arxiv.org/pdf/2001.09258 5.5 """ assert d.dtype in TRANSCENDENTAL_SUPPORTED_DTYPES diff --git a/tinygrad/codegen/uopgraph.py b/tinygrad/codegen/uopgraph.py index 4113a0169c..610c36e5ad 100644 --- a/tinygrad/codegen/uopgraph.py +++ b/tinygrad/codegen/uopgraph.py @@ -3,7 +3,7 @@ from typing import Optional, Tuple, Dict, List, TYPE_CHECKING, Any, DefaultDict, import functools, itertools, operator from collections import defaultdict from tinygrad.dtype import dtypes, ImageDType, PtrDType -from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple +from tinygrad.ops import UOp, Ops, UPat, PatternMatcher, symbolic_flat, symbolic_simple from tinygrad.ops import graph_rewrite, split_uop, uop_given_valid, parse_valid, is_increasing, simplify_valid, GroupOp from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, partition, all_same from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, TRANSCENDENTAL_SUPPORTED_DTYPES @@ -23,7 +23,7 @@ def fold_expanded(ex, buf): for i,s in enumerate(new_srcs): idx = s.src[0].src[1] if s.dtype.count != 1 or (is_image and idx.dtype.count == 2): continue - if idx.op is BinaryOps.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg + if idx.op is Ops.ADD and idx.src[1].op is Ops.CONST: root_src, arg = idx.src[0], idx.src[1].arg elif idx.op is Ops.CONST: root_src, arg = "CONST", idx.arg else: root_src, arg = idx, 0 # add gates for gated @@ -92,12 +92,12 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> Optional[UOp]: # can drop valid if idx is out of bound when valid is False drop_stmt = [] - for stmt in split_uop(valid, BinaryOps.AND): + for stmt in split_uop(valid, Ops.AND): X, is_upper_bound, c = parse_valid(stmt) # for X0 + X1 + ... >= 1, check if it's out of bound when Xi = 0 for all i - if not is_upper_bound and c == 1 and all(u.op in GroupOp.Irreducible and u.vmin == 0 for u in split_uop(X, BinaryOps.ADD)): - testidx = functools.reduce(lambda nowidx,u: nowidx.substitute({u:u.const_like(0)}), split_uop(X, BinaryOps.ADD), idx) + if not is_upper_bound and c == 1 and all(u.op in GroupOp.Irreducible and u.vmin == 0 for u in split_uop(X, Ops.ADD)): + testidx = functools.reduce(lambda nowidx,u: nowidx.substitute({u:u.const_like(0)}), split_uop(X, Ops.ADD), idx) testidx = testidx.simplify() if testidx.gep(0).vmax < 0 or testidx.gep(1).vmax < 0: drop_stmt.append(stmt) @@ -114,7 +114,7 @@ def simplify_valid_load(buf:UOp, start_idx:UOp, valid:UOp) -> Optional[UOp]: break if not drop_stmt and idx is start_idx: return None - new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in split_uop(valid, BinaryOps.AND) if s not in drop_stmt]) else None + new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in split_uop(valid, Ops.AND) if s not in drop_stmt]) else None return buf.index(idx, new_valid) # ***** optional patterns ***** @@ -123,23 +123,23 @@ powers_of_two = {2**i:i for i in range(64)} @functools.lru_cache(None) def get_late_rewrite_patterns(ops, force_transcendental=False): pat: List[Tuple[UPat, Callable]] = [(UPat(op, dtype=TRANSCENDENTAL_SUPPORTED_DTYPES, src=(UPat.var("d"),)), f) for op,f in \ - ((UnaryOps.EXP2, xexp2), (UnaryOps.LOG2, xlog2), (UnaryOps.SIN, xsin)) if op not in ops or force_transcendental] + ((Ops.EXP2, xexp2), (Ops.LOG2, xlog2), (Ops.SIN, xsin)) if op not in ops or force_transcendental] # rewrite MOD to AND (which should always be supported, but not for generic in tests) - if BinaryOps.AND in ops: + if Ops.AND in ops: pat += [(UPat(Ops.MOD, src=(UPat.var('base'), UPat.cvar("const"))), lambda base,const: base & (const.arg-1) if const.arg in powers_of_two else None)] # rewrite MUL/IDIV to SHL+SHR - if BinaryOps.SHL in ops and BinaryOps.SHR in ops: + if Ops.SHL in ops and Ops.SHR in ops: pat += [ (UPat(Ops.MUL, dtype=dtypes.ints, src=[UPat.cvar("const"), UPat.var("mul")]), lambda mul, const: mul << powers_of_two[const.arg] if const.arg in powers_of_two else None), # (x * (2**y)) -> shl(x,y) (UPat(Ops.IDIV, src=(UPat.var("div"), UPat.cvar("const"))), lambda div, const: div >> powers_of_two[const.arg] if const.arg in powers_of_two else None)] # (x // (2**y)) -> shr(x,y) - if UnaryOps.NEG in ops: - pat += [(UPat.var('x')*-1, lambda x: x.alu(UnaryOps.NEG))] - if BinaryOps.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(UnaryOps.NEG), lambda x,y: x.alu(BinaryOps.SUB, y))] - if TernaryOps.MULACC in ops: - pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(TernaryOps.MULACC, b, c))] + if Ops.NEG in ops: + pat += [(UPat.var('x')*-1, lambda x: x.alu(Ops.NEG))] + if Ops.SUB in ops: pat += [(UPat.var('x')+UPat.var('y').alu(Ops.NEG), lambda x,y: x.alu(Ops.SUB, y))] + if Ops.MULACC in ops: + pat += [(UPat.var('a')*UPat.var('b')+UPat.var('c'), lambda a,b,c: a.alu(Ops.MULACC, b, c))] return PatternMatcher(pat) # ***** threefry ***** @@ -225,7 +225,7 @@ def reduce_collapse(acc:UOp, ret:UOp, alu:UOp): if len(reduce_unparented) == 0: return None new_acc = acc.replace(src=acc.src[0:1]+tuple(reduce_parented)) ret = new_acc.assign(new_acc.alu(alu.op, ret)) - if alu.op is BinaryOps.ADD: + if alu.op is Ops.ADD: for r in reduce_unparented: ret = ret * (r.src[1]-r.src[0]).cast(ret.dtype.scalar()).broadcast(ret.dtype.count) return ret diff --git a/tinygrad/engine/fuse.py b/tinygrad/engine/fuse.py deleted file mode 100644 index d140b77ebb..0000000000 --- a/tinygrad/engine/fuse.py +++ /dev/null @@ -1,112 +0,0 @@ -from collections import defaultdict, deque -from typing import Set, Tuple, List, Dict, DefaultDict -from tinygrad.device import Buffer -from tinygrad.ops import UOp, Ops -from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, dedup, merge_dicts -from tinygrad.shape.shapetracker import ShapeTracker -from tinygrad.engine.lazy import LazyBuffer - -def _recursive_group(tr:LazyBuffer, st:ShapeTracker, r:LazyBuffer, children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], - realizes:Dict[LazyBuffer, None], reduce_for_op:Dict[LazyBuffer, UOp], group:Dict[LazyBuffer, None], - cache:Dict[Tuple[LazyBuffer, ShapeTracker], None]) -> None: - """recursively search the LazyBuffer for groupable children, realize the LazyBuffer if a child can't group""" - if (tr, st) in cache: return - cache.setdefault((tr, st)) - if tr in realizes and tr is not r: - # can only fuse contiguous - # max one reduceop per kernel - if not st.contiguous or st.size != r.st.size or tr in reduce_for_op: group.setdefault(r) - return group.setdefault(tr) - for tr_next in children[tr]: - # max one reduceop per kernel - if tr_next.op is Ops.REDUCE_AXIS: return group.setdefault(r) - # can only fuse contiguous - if len(st_childs:=dedup(s.st for s in tr_next.srcs if s.base == tr)) > 1: return group.setdefault(r) - _recursive_group(tr_next, st+st_childs[0], r, children, realizes, reduce_for_op, group, cache) - -def _get_isolated_children(r:LazyBuffer, reduce_for_op:Dict[LazyBuffer, UOp], children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], - realizes:Dict[LazyBuffer, None], group:Dict[LazyBuffer, None]) -> Dict[LazyBuffer, None]: - rc_parents, cache = deque(group), set() - while rc_parents: - if (p:=rc_parents.pop()) in cache: continue - cache.add(p) - # max one reduceop per kernel - if p.op is Ops.REDUCE_AXIS: return {} - rc_parents.extend(x.base for x in p.srcs if x.base.realized is None and x.base is not r) - # search descendants of the reduceop that can cleanly group - descendants: Dict[LazyBuffer, None] = {} - for tr in group: _recursive_group(tr, tr.st, tr, children, realizes, reduce_for_op, descendants, cache={}) - return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants]) - -def get_realizes(children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbufs:Dict[LazyBuffer, None], double_reduces:Dict[LazyBuffer, None], - ubuf_realizes:Dict[UOp, UOp], assigns:Set[UOp], buf_uops:Dict[Buffer, UOp]) -> List[List[UOp]]: - """search the graph for all the LazyBuffers that need to realize""" - # get all the realizes from big graph - realizes: Dict[LazyBuffer, None] = {} - for r in allbufs: - if (ubuf:=buf_uops[r.buffer]) in ubuf_realizes: realizes[r] = None - # find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child) - reduce_for_op: Dict[LazyBuffer, UOp] = {} - reduce_of_const: List[UOp] = [] - for r in allbufs: - if r in realizes or r.op is not Ops.REDUCE_AXIS: continue - group: Dict[LazyBuffer, None] = {} - _recursive_group(r, r.st, r, children, realizes, reduce_for_op, group, cache={}) - # max one reduceop per kernel - can_chase = all(tr not in reduce_for_op for tr in group) - # TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs - forced_realize = r in group - if not forced_realize and len(group) > 1: - group = _get_isolated_children(r, reduce_for_op, children, realizes, group) - # can only fuse assign if no other assign_target is used in the kernel - if not forced_realize and any(x.op is Ops.ASSIGN for x in group): - parents = deque((r, *group)) - while parents and not forced_realize: - if (p:=parents.pop().base).is_realized() or p in realizes: - if p.is_realized() and buf_uops[(b:=p.buffer)] in assigns and not any(x.buffer is b for x in group): forced_realize, can_chase = True, False - continue - parents.extend(p.srcs) - if forced_realize or not group: - tr = r - if can_chase: - # can chase this down to contiguous children - st = tr.st - while len(children[tr]) == 1: - tr_next = next(iter(children[tr])) - st_childs = dedup(s.st for s in tr_next.srcs if s.base is tr) - if len(st_childs) > 1: break - if st.size != st_childs[0].size: break - st = st + st_childs[0] - if not st.contiguous or tr_next.op is Ops.REDUCE_AXIS: break - tr = tr_next - # don't cast to higher size before store (tr cannot be realized if forced_realize) - if tr.op is Ops.CAST and tr.dtype.base.itemsize > tr.srcs[0].dtype.base.itemsize: - tr = tr.srcs[0].base - group = {tr: None} - realizes[tr] = None - rbuf = buf_uops[r.buffer] - reduce_for_op.update((tr, rbuf) for tr in group) - if FUSE_ARANGE and r.arg[0] is Ops.ADD and r.srcs[0].base.op is Ops.CONST: reduce_of_const.append(rbuf) - - # fuse double reduces with no other child - if FUSE_CONV_BW: - for reduceop in double_reduces: - top_reduce = reduceop.base.srcs[0].base - if len(children[top_reduce]) == 1: - del realizes[top_reduce] - if (ubuf:=buf_uops[top_reduce.buffer]) in ubuf_realizes: del ubuf_realizes[ubuf] - - for rbuf in reduce_of_const: - group = {tr:None for tr,rop in reduce_for_op.items() if rop is rbuf} - if any(tr.forced_realize for tr in group): continue - kernel_children = {c for tr in group for c in children[tr] if c.op not in {Ops.COPY, Ops.BUFFER_VIEW}} - if len(kernel_children) == 0: continue - for tr in group: - del realizes[tr] - if (ubuf:=buf_uops[tr.buffer]) in ubuf_realizes: del ubuf_realizes[ubuf] - - output_groups: DefaultDict[UOp, List[UOp]] = defaultdict(list) - for buf in realizes: - output_groups[reduce_for_op.get(buf, ubuf:=buf_uops[buf.buffer])].append(ubuf) - ubuf_realizes[ubuf] = ubuf - return list(output_groups.values()) diff --git a/tinygrad/engine/lazy.py b/tinygrad/engine/lazy.py index 21fa3f543d..9bf5b8672f 100644 --- a/tinygrad/engine/lazy.py +++ b/tinygrad/engine/lazy.py @@ -1,8 +1,8 @@ from __future__ import annotations -from typing import Union, Optional, Any, Tuple, List, get_args +from typing import Optional, Any, Tuple, List, get_args from tinygrad.dtype import dtypes, DType, ConstType, to_dtype, ImageDType from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP, LAZYCACHE -from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, exec_alu, python_alu +from tinygrad.ops import exec_alu, python_alu from tinygrad.ops import identity_element, MathTrait, resolve, UOp, sint, GroupOp, Ops from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.device import Buffer @@ -11,9 +11,9 @@ from weakref import ref, ReferenceType, WeakValueDictionary lazycache: WeakValueDictionary[Any, LazyBuffer] = WeakValueDictionary() def create_lazybuffer(device:str, st:ShapeTracker, dtype:DType, op:Optional[Ops]=None, arg:Any=None, srcs:Tuple[LazyBuffer, ...]=(), base:Optional[LazyBuffer]=None, enable_cache=bool(LAZYCACHE)): - if st.size == 0: op, arg, srcs, base = MetaOps.CONST, 0, (), None + if st.size == 0: op, arg, srcs, base = Ops.CONST, 0, (), None dtype = to_dtype(dtype) - if op is MetaOps.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, UOp) else arg, True + if op is Ops.CONST: arg, enable_cache = dtypes.as_const(arg, dtype) if not isinstance(arg, UOp) else arg, True cache_key = (device, st, dtype, op, arg, tuple(ref(x) for x in srcs)) if base is None else (st, ref(base)) if enable_cache and (rret := lazycache.get(cache_key, None)) is not None: return rret @@ -32,13 +32,13 @@ class LazyBuffer(MathTrait): if base is None: # properties on base self.op, self.arg, self.srcs = op, arg, srcs # this is a UOp, except the src is LazyBuffers and not UOps - assert self.op is not MetaOps.ASSIGN or srcs[0].base.realized is not None, "assign target must be realized" + assert self.op is not Ops.ASSIGN or srcs[0].base.realized is not None, "assign target must be realized" - if self.op is MetaOps.BUFFER_VIEW: + if self.op is Ops.BUFFER_VIEW: # some LazyBuffers can be processed with only a view, no AST required self.buffer: Buffer = srcs[0].base.buffer.view(st.size, self.dtype, srcs[0].st.views[0].offset * srcs[0].dtype.itemsize) else: - self.buffer = srcs[0].base.buffer if self.op is MetaOps.ASSIGN else Buffer(device, self.size, self.dtype) + self.buffer = srcs[0].base.buffer if self.op is Ops.ASSIGN else Buffer(device, self.size, self.dtype) self.buffer.ref(1) self.contiguous_child: Optional[Tuple[ReferenceType[LazyBuffer], ShapeTracker]] = None self.forced_realize = False @@ -74,14 +74,15 @@ class LazyBuffer(MathTrait): def const_like(self, b): return self.const_with_shape(b, self.shape) def const_with_shape(self, val:ConstType, shape:Tuple[sint,...]) -> LazyBuffer: assert isinstance(val, get_args(ConstType)), f"{val=} has {type(val)=}, not a ConstType" - return LazyBuffer.metaop(MetaOps.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(shape)).expand(shape) + return LazyBuffer.metaop(Ops.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(shape)).expand(shape) + @property def is_realized(self) -> bool: return self.base.realized is not None def assign(self, x:LazyBuffer) -> LazyBuffer: assert x.size == self.size, f"assign target must have same size {self.size=} != {x.size=}" - assert self.is_realized(), f"assign target must be realized {self}" - return LazyBuffer.metaop(MetaOps.ASSIGN, self.shape, self.dtype, self.device, arg=() if self.st.contiguous else (self.st,), + assert self.is_realized, f"assign target must be realized {self}" + return LazyBuffer.metaop(Ops.ASSIGN, self.shape, self.dtype, self.device, arg=() if self.st.contiguous else (self.st,), src=(self.base, x), enable_cache=True) def can_view(self): @@ -90,7 +91,7 @@ class LazyBuffer(MathTrait): def contiguous(self, allow_buffer_view=True): if not self.st.contiguous or self.size != self.base.size or self.is_unrealized_const(): - ret = self.alu(MetaOps.BUFFER_VIEW) if allow_buffer_view and self.can_view() else self.alu(MetaOps.CONTIGUOUS) + ret = self.alu(Ops.BUFFER_VIEW) if allow_buffer_view and self.can_view() else self.alu(Ops.CONTIGUOUS) if (sti := self.st.invert(self.base.shape)) is not None: self.base.contiguous_child = ref(ret), sti return ret self.base.forced_realize = True @@ -101,7 +102,7 @@ class LazyBuffer(MathTrait): if self.dtype == dtype: return self if self.device.startswith("DISK") and not bitcast: raise RuntimeError("attempted to cast disk buffer (bitcast only)") if self.is_unrealized_unmasked_const() and not bitcast: - return create_lazybuffer(self.device, self.st, dtype, MetaOps.CONST, dtypes.as_const(self.base.arg, dtype)) + return create_lazybuffer(self.device, self.st, dtype, Ops.CONST, dtypes.as_const(self.base.arg, dtype)) new_shape = self.shape if bitcast and self.dtype.itemsize != dtype.itemsize: if not self.device.startswith("DISK"): raise RuntimeError("shape changing bitcast only supported on DISK right now") @@ -112,27 +113,27 @@ class LazyBuffer(MathTrait): elif getenv("CAST_BEFORE_VIEW", 1) and dtype.itemsize <= self.dtype.itemsize and self is not self.base: # TODO: applying this makes gpt2 slower return self.base.cast(dtype, bitcast)._view(self.st) - cast_op: Union[Ops, Ops] = (MetaOps.BUFFER_VIEW if self.can_view() and allow_buffer_view else UnaryOps.BITCAST) if bitcast else UnaryOps.CAST + cast_op: Ops = (Ops.BUFFER_VIEW if self.can_view() and allow_buffer_view else Ops.BITCAST) if bitcast else Ops.CAST return create_lazybuffer(self.device, ShapeTracker.from_shape(new_shape), dtype, cast_op, dtype, (self,)) - def is_unrealized_const(self): return self.base.realized is None and self.base.op is MetaOps.CONST and not isinstance(self.base.arg, UOp) + def is_unrealized_const(self): return self.base.realized is None and self.base.op is Ops.CONST and not isinstance(self.base.arg, UOp) def is_unrealized_unmasked_const(self): return self.is_unrealized_const() and all(v.mask is None for v in self.st.views) def _copy(self, device:str) -> LazyBuffer: assert self.st.contiguous and self.size == self.base.size, f"can only copy contig {self} {self.base}" - return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, MetaOps.COPY, self.buffer.nbytes, (self,), enable_cache=False) + return create_lazybuffer(device, ShapeTracker.from_shape(self.shape), self.dtype, Ops.COPY, self.buffer.nbytes, (self,), enable_cache=False) def copy_to_device(self, device:str, force:bool=False, clone:bool=False) -> LazyBuffer: # no COPY if self.device == device and not clone: return self # double COPY = one COPY - if not force and self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op is MetaOps.COPY: + if not force and self.st.contiguous and self.size == self.base.size and not self.base.realized and self.base.op is Ops.COPY: return self.base.srcs[0].copy_to_device(device).reshape(self.st.shape) # const doesn't have to be copied (issues with disk tensor) if self.is_unrealized_const(): - return LazyBuffer.metaop(MetaOps.CONST, tuple(), self.dtype, device, arg=self.base.arg)._view(self.st) + return LazyBuffer.metaop(Ops.CONST, tuple(), self.dtype, device, arg=self.base.arg)._view(self.st) # if it's a shrink, do the shrink before the copy with CONTIGUOUS if prod(self.st.shape) < prod(self.base.st.shape): return self.contiguous()._copy(device) @@ -149,25 +150,25 @@ class LazyBuffer(MathTrait): srcs.append(root._view(s.base.contiguous_child[1])) else: srcs.append(s) - if not all_same(dts:=[x.dtype.base for x in (srcs[1:] if op is TernaryOps.WHERE else srcs)]): + if not all_same(dts:=[x.dtype.base for x in (srcs[1:] if op is Ops.WHERE else srcs)]): raise AssertionError(f"all dtypes must match {dts} on {op}") assert all_same([x.shape for x in srcs]), f"all shapes must be the same {[x.shape for x in srcs]}" - if op is TernaryOps.WHERE: assert srcs[0].dtype == dtypes.bool, "TernaryOps.WHERE must have the first arg be bool" + if op is Ops.WHERE: assert srcs[0].dtype == dtypes.bool, "Ops.WHERE must have the first arg be bool" - out_dtype = dtypes.bool if op in (BinaryOps.CMPLT, BinaryOps.CMPNE) else srcs[-1].dtype + out_dtype = dtypes.bool if op in (Ops.CMPLT, Ops.CMPNE) else srcs[-1].dtype # const folding if op in python_alu and all(s.is_unrealized_unmasked_const() for s in srcs): return self.cast(out_dtype).const_like(exec_alu(op, out_dtype, [s.base.arg for s in srcs])) if op in GroupOp.Binary: x, y = self, in_srcs[0] - if op is BinaryOps.ADD: + if op is Ops.ADD: if y.is_unrealized_unmasked_const() and y.base.arg == 0: return x if x.is_unrealized_unmasked_const() and x.base.arg == 0: return y - if op is BinaryOps.MUL: + if op is Ops.MUL: if x.is_unrealized_unmasked_const() and (val := x.base.arg) in (1, 0): return y if val == 1 else y.const_like(0) if y.is_unrealized_unmasked_const() and (val := y.base.arg) in (1, 0): return x if val == 1 else x.const_like(0) - if op is BinaryOps.IDIV and y.is_unrealized_unmasked_const() and y.base.arg == 1: return x + if op is Ops.IDIV and y.is_unrealized_unmasked_const() and y.base.arg == 1: return x return create_lazybuffer(self.device, ShapeTracker.from_shape(self.shape), out_dtype, op, None, tuple(srcs)) diff --git a/tinygrad/engine/schedule.py b/tinygrad/engine/schedule.py index 3b0c749821..23caf7743d 100644 --- a/tinygrad/engine/schedule.py +++ b/tinygrad/engine/schedule.py @@ -3,12 +3,12 @@ from collections import defaultdict, deque from dataclasses import dataclass, field from typing import Set, Tuple, List, Dict, Optional, DefaultDict, cast from tinygrad.ops import GroupOp, UOp, Ops, PatternMatcher, UPat, Variable, can_pad, graph_rewrite, resolve, track_rewrites, sint -from tinygrad.helpers import DEBUG, Context, Metadata, all_int, all_same, colored, diskcache_put, prod, dedup, getenv, unwrap +from tinygrad.helpers import Context, Metadata, all_int, all_same, colored, diskcache_put, merge_dicts, prod, dedup, getenv, unwrap +from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG from tinygrad.dtype import ImageDType, dtypes from tinygrad.shape.shapetracker import ShapeTracker from tinygrad.shape.view import View, strides_for_shape from tinygrad.engine.lazy import LazyBuffer -from tinygrad.engine.fuse import get_realizes from tinygrad.device import Buffer # creation can recurse a lot @@ -37,16 +37,20 @@ class ScheduleItem: # **** small wrapper for LazyBuffer -> UOp +def UPatSrc(*args, **kwargs): return UPat(Ops.VIEW, src=(UPat.var("b"), UPat(*args, **{**kwargs, "name":"to_store"})), name="base") +@functools.lru_cache(None) +def is_scheduled(u:UOp): return u.op is Ops.VIEW and len(u.src) == 2 + @dataclass(frozen=True) class ScheduleContext: - buf_uops: Dict[Buffer, UOp] = field(default_factory=dict) # this maps Buffers to BUFFER uops + buf_uops: Dict[UOp, Buffer] = field(default_factory=dict) # this maps BUFFER uops to Buffers ubuf_metadata: Dict[UOp, Metadata] = field(default_factory=dict) # this maps BUFFER uops to Metadata var_vals: Dict[Variable, int] = field(default_factory=dict) # this maps a BIND's DEFINE_VAR to its value - assigns: Set[UOp] = field(default_factory=set) # this holds all the UOps.BUFFERs we ASSIGN to in this schedule + assigns: Set[UOp] = field(default_factory=set) # this holds all the BUFFER uops we ASSIGN to in this schedule lazybufs: Dict[Buffer, LazyBuffer] = field(default_factory=dict) # this is a lookup for the LazyBuffers we need to mark as realized -def to_uop(buf:LazyBuffer, ctx:ScheduleContext, children:DefaultDict[LazyBuffer, Dict[LazyBuffer, None]], allbufs:Dict[LazyBuffer, None], - double_reduces:Dict[LazyBuffer, None], cache:Dict[LazyBuffer, UOp]) -> UOp: +def to_uop(buf:LazyBuffer, ctx:ScheduleContext, children:DefaultDict[UOp, Dict[UOp, None]], allbufs:Dict[UOp, UOp], + double_reduces:Dict[UOp, None], cache:Dict[LazyBuffer, UOp]) -> UOp: if (r:=cache.get(buf)) is not None: return r if buf is not buf.base: cache[buf] = ret = to_uop(buf.base, ctx, children, allbufs, double_reduces, cache).view(buf.st) @@ -57,34 +61,33 @@ def to_uop(buf:LazyBuffer, ctx:ScheduleContext, children:DefaultDict[LazyBuffer, if DEBUG >= 2: print(f"forcing image {buf.dtype} with shape {buf.shape} to {buf.dtype.base}") # hack the underlying buffer too buf.dtype = buf.buffer.dtype = buf.dtype.base - assert not buf.is_realized(), "can't fixup allocated buffer" + assert not buf.is_realized, "can't fixup allocated buffer" buf.buffer.options = None - dtype = buf.dtype.base if isinstance(buf.dtype, ImageDType) else buf.dtype + dtype = buf.dtype if buf.op in GroupOp.Meta else buf.dtype.base # consts are always fused and generated if buf.op is Ops.CONST: if isinstance(val:=buf.arg, UOp): ctx.var_vals.update([val.unbind()]) - return UOp(Ops.VALID, dtypes.bool, (buf.st.to_uop(),)).where(v:=UOp.const(dtype, buf.arg), v.const_like(0)) - # everything else has BUFFER - ubuf = ctx.buf_uops.setdefault(b:=buf.buffer, UOp(Ops.BUFFER, b.dtype.ptr(), (), (len(ctx.buf_uops), (b.device, b.size, b.dtype)))) - # if the buffer is already realized we just load it - if buf.is_realized(): return UOp(Ops.PRELOAD, dtype, (ubuf, buf.st.to_uop())) - # everything else needs sources - src = tuple(to_uop(x, ctx, children, allbufs, double_reduces, cache) for x in buf.srcs) - if buf.op in {Ops.REDUCE_AXIS, Ops.CONTIGUOUS}: ret = UOp(buf.op, dtype, src, buf.arg) + return UOp(Ops.VALID, dtypes.bool, (buf.st.to_uop(),)).where(UOp.const(dtype, val), 0) + # everything else is a VIEW of BUFFER (with an optional op) + if buf.is_realized: + ctx.buf_uops[ubuf:=UOp.new_buffer((b:=buf.buffer).device, b.size, b.dtype, num=len(ctx.buf_uops))] = buf.buffer + op = None elif buf.op is Ops.ASSIGN: - ctx.assigns.add(ubuf) - ret = UOp(Ops.ASSIGN, dtype, (ubuf, src[1]), buf.arg) - elif buf.op in GroupOp.Meta: ret = UOp(buf.op, buf.dtype, (ubuf, *src), buf.arg) - else: ret = UOp(cast(Ops, buf.op), dtype, src) - if buf.forced_realize: ret = UOp(Ops.CONTIGUOUS, dtype, (ret,)) - cache[buf] = ret = UOp(Ops.LOAD, dtype, (ubuf, buf.st.to_uop(), UOp.store(ubuf, ShapeTracker.from_shape(buf.shape).to_uop(), ret))) - if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata - ctx.lazybufs[b] = buf - # things for fuse.py - allbufs[buf] = None - if buf.op is Ops.REDUCE_AXIS and buf.srcs[0].base.op is buf.op and buf.srcs[0] is not buf.srcs[0].base: double_reduces[buf] = None - for x in buf.srcs: - if x.base.realized is None: children[x.base][buf] = None + target, new_val = [to_uop(x, ctx, children, allbufs, double_reduces, cache) for x in buf.srcs] + ctx.assigns.add(ubuf:=target.buf_uop) + op = UOp(Ops.ASSIGN, dtype, (ubuf, new_val), buf.arg) + else: + ctx.buf_uops[ubuf:=UOp.new_buffer((b:=buf.buffer).device, b.size, b.dtype, num=len(ctx.buf_uops))] = buf.buffer + op = UOp(cast(Ops, buf.op), dtype, tuple(to_uop(x, ctx, children, allbufs, double_reduces, cache) for x in buf.srcs), + None if buf.op in {Ops.CAST, Ops.BITCAST} else buf.arg) + cache[buf] = ret = UOp(Ops.VIEW, dtype.base, (ubuf,) if op is None else (ubuf, op.contiguous() if buf.forced_realize else op), buf.st) + if op is not None: + if buf.metadata is not None: ctx.ubuf_metadata[ubuf] = buf.metadata + ctx.lazybufs[buf.buffer] = buf + for x in op.src: + if is_scheduled(x.base): children[x.base.buf_uop][ubuf] = None + allbufs[ubuf] = ret + if buf.op is Ops.REDUCE_AXIS and buf.srcs[0].base.op is buf.op and buf.srcs[0] is not buf.srcs[0].base: double_reduces[ubuf] = None return ret # **** AST graph rewrite @@ -188,13 +191,14 @@ def _append_preload(ctx:ScheduleItemContext, x:UOp, b:UOp) -> UOp: to_si = PatternMatcher([ (UPat(Ops.VIEW, name="x"), _append_st_vars), (UPat(Ops.PRELOAD, src=(UPat.var("b"), UPat()), name="x"), _append_preload), - (UPat(Ops.SINK, src=(UPat.store(UPat(), UPat(), UPat(GroupOp.Meta, name="x")),)), lambda ctx,x: x), + (UPat(Ops.SINK, src=(UPat.store(UPat.var("b"), UPat(), UPat(GroupOp.Meta, name="x")),)), lambda ctx,b,x: x.replace(src=(b, *x.src))), ]) # ** fusion lazy = PatternMatcher([ - (UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), UPat.var("v"))), lambda ctx,b,v: v), + (UPatSrc(), lambda ctx,to_store,**kwargs: to_store), + (UPat(Ops.BUFFER, name="b").view(name="view"), lambda ctx,b,view: UOp(Ops.PRELOAD, view.dtype, (b, view.st.to_uop()))), (UPat(Ops.CONTIGUOUS, src=(UPat.var("x"),)), lambda ctx,x: x), ]) @@ -226,37 +230,136 @@ if getenv("RUN_PROCESS_REPLAY"): def save_process_replay(): for x,ret in PROCESS_REPLAY_CAPTURE: diskcache_put("schedule_process_replay", str(x[0].key), (x, {}, ret)) +# **** Schedule grouping + +def uval(u:UOp) -> UOp: + assert is_scheduled(u), f"must be a scheduled op {u}" + return to_store.src[0] if (to_store:=u.src[1]).is_contiguous_base else to_store + +def recursive_group(tr:UOp, st:ShapeTracker, r:UOp, children:DefaultDict[UOp, Dict[UOp, None]], allbufs:Dict[UOp, UOp], realizes:Dict[UOp, UOp], + reduce_for_op:Dict[UOp, UOp], group:Dict[UOp, None], cache:Dict[Tuple[UOp, ShapeTracker], None]) -> None: + """recursively search the uop for groupable children, realize the UOp if a child can't group""" + if (tr, st) in cache: return + cache.setdefault((tr, st)) + rsize = unwrap(allbufs[r].st).size + if tr in realizes and tr is not r: + # can only fuse contiguous + # max one reduceop per kernel + if not st.contiguous or st.size != rsize or tr in reduce_for_op: group.setdefault(r) + return group.setdefault(tr) + for tr_next in children[tr]: + # max one reduceop per kernel + if (tr_next_uop:=uval(allbufs[tr_next]).base).op is Ops.REDUCE_AXIS: return group.setdefault(r) + # can only fuse contiguous + if len(st_childs:=dedup(unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop == tr)) > 1: return group.setdefault(r) + recursive_group(tr_next, st+st_childs[0], r, children, allbufs, realizes, reduce_for_op, group, cache) + +def get_isolated_children(r:UOp, reduce_for_op:Dict[UOp, UOp], children:DefaultDict[UOp, Dict[UOp, None]], allbufs:Dict[UOp, UOp], + realizes:Dict[UOp, UOp], group:Dict[UOp, None]) -> Dict[UOp, None]: + rc_parents, cache = deque(group), set() + while rc_parents: + if (p:=uval(allbufs[rc_parents.pop()])) in cache: continue + cache.add(p) + # max one reduceop per kernel + if p.op is Ops.REDUCE_AXIS: return {} + rc_parents.extend(x.base.buf_uop for x in p.src if is_scheduled(x.base) and x.base.buf_uop is not r) + # search descendants of the reduceop that can cleanly group + descendants: Dict[UOp, None] = {} + for tr in group: recursive_group(tr, unwrap(allbufs[tr].st), tr, children, allbufs, realizes, reduce_for_op, descendants, cache={}) + return merge_dicts([group, {} if any(tr in group for tr in descendants) else descendants]) + +def group_realizes(children:DefaultDict[UOp, Dict[UOp, None]], allbufs:Dict[UOp, UOp], double_reduces:Dict[UOp, None], + realizes:Dict[UOp, UOp], assigns:Set[UOp]) -> List[List[UOp]]: + """search the big graph for all the reduceops that need to realize, sometimes group/fuse the reduceop""" + # find all reduces, and pair them to a elementwise op. if they can't be cleanly paired, force realize the reduce (or a contig child) + reduce_for_op: Dict[UOp, UOp] = {} + reduce_of_const: List[UOp] = [] + for r, r_uop in allbufs.items(): + if r in realizes or (r_uop:=uval(r_uop)).op is not Ops.REDUCE_AXIS: continue + group: Dict[UOp, None] = {} + recursive_group(r, unwrap(r_uop.st), r, children, allbufs, realizes, reduce_for_op, group, cache={}) + # max one reduceop per kernel + can_chase = all(tr not in reduce_for_op for tr in group) + # TODO: forced_realize exists because the scheduler is incapable of checking for self-contained DAGs + forced_realize = r in group + if not forced_realize and len(group) > 1: + group = get_isolated_children(r, reduce_for_op, children, allbufs, realizes, group) + # can only fuse assign if no other assign_target is used in the kernel + if not forced_realize and any(x in assigns for x in group): + parents = deque((r, *group)) + while parents and not forced_realize: + if (p_uop:=allbufs.get(p:=parents.pop())) is None: continue + if (p_uop:=uval(p_uop)).op is Ops.ASSIGN and p not in group: forced_realize, can_chase = True, False + if p in realizes: continue + parents.extend([x.base.src[0] for x in p_uop.src if x.base.op is Ops.VIEW and len(x.base.src) != 0]) + if forced_realize or not group: + tr = r + if can_chase: + # can chase this down to contiguous children + st = unwrap(r_uop.st) + while len(children[tr]) == 1: + tr_next_uop = uval(allbufs[(tr_next:=next(iter(children[tr])))]) + st_childs = dedup([unwrap(x.st) for x in tr_next_uop.src if is_scheduled(x.base) and x.base.buf_uop is tr]) + if len(st_childs) > 1: break + if st.size != st_childs[0].size: break + st = st + st_childs[0] + if not st.contiguous or tr_next_uop.op is Ops.REDUCE_AXIS: break + tr = tr_next + # don't cast to higher size before store (tr cannot be realized if forced_realize) + if (tr_uop:=uval(allbufs[tr])).op is Ops.CAST and tr_uop.dtype.base.itemsize > tr_uop.src[0].dtype.base.itemsize: + tr = tr_uop.src[0].base.buf_uop + group = {tr: None} + realizes[tr] = tr + reduce_for_op.update((tr, r) for tr in group) + if FUSE_ARANGE and r_uop.arg[0] is Ops.ADD and r_uop.src[0].base.op is Ops.WHERE: reduce_of_const.append(r) + + # fuse double reduces with no other child + if FUSE_CONV_BW: + for reduceop in double_reduces: + top_reduce = uval(allbufs[reduceop]).src[0].base.buf_uop + if len(children[top_reduce]) == 1: del realizes[top_reduce] + + for rbuf in reduce_of_const: + group = {tr:None for tr,rop in reduce_for_op.items() if rop is rbuf} + if any(allbufs[tr].src[1].is_contiguous_base for tr in group): continue + kernel_children = {c for tr in group for c in children[tr] if uval(allbufs[c]).op not in {Ops.COPY, Ops.BUFFER_VIEW}} + if len(kernel_children) == 0: continue + for tr in group: del realizes[tr] + + output_groups: DefaultDict[UOp, List[UOp]] = defaultdict(list) + for ubuf in realizes: output_groups[reduce_for_op.get(ubuf, ubuf)].append(ubuf) + return list(output_groups.values()) + # **** Schedule creation and BFS toposort -def realize(ctx:Dict[UOp, UOp], b:UOp, load:UOp, store:UOp) -> UOp: - ctx[b] = store - return UOp(Ops.LOAD, load.dtype, (b, load.st_arg.to_uop())) +def realize(ctx:Dict[UOp, UOp], b:UOp, to_store:UOp, base:UOp) -> UOp: + ctx[b] = UOp.store(b, ShapeTracker.from_shape((st:=unwrap(base.st)).shape).to_uop(), to_store) + return UOp(Ops.LOAD, base.dtype, (b, st.to_uop())) -def realize_view(ctx:Dict[UOp, UOp], base:UOp, view:UOp, **kwargs) -> Optional[UOp]: +def realize_view(ctx:Dict[UOp, UOp], base:UOp, view:UOp, to_store:UOp, b:UOp) -> Optional[UOp]: base_shape = unwrap(base.st).shape st = unwrap(view.st) # fold simple pads if len(st.views) == 1 and (m:=st.views[-1].mask) is not None and all_int(base_shape) and resolve(prod(base_shape) >= prod([y-x for x,y in m])): - return None if can_pad(base) else realize(ctx, **kwargs).view(st) + return None if can_pad(base) else realize(ctx, b, to_store, base).view(st) # early realize before expand - if resolve(prod(base_shape) < prod(st.shape)): return realize(ctx, **kwargs).view(st) + if resolve(prod(base_shape) < prod(st.shape)): return realize(ctx, b, to_store, base).view(st) # otherwise safety check pads - return None if (all(v.mask is None for v in st.views) or can_pad(base)) else realize(ctx, **kwargs).view(st) + return None if (all(v.mask is None for v in st.views) or can_pad(base)) else realize(ctx, b, to_store, base).view(st) -def UPatLoadStore(to_store=UPat()): return UPat.load(b:=UPat.var("b"), UPat(), UPat.store(b, UPat(), to_store, name="store"), name="load") do_realize = PatternMatcher([ # always realize meta ops - (UPatLoadStore(UPat((Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta))), realize), + (UPatSrc((Ops.ASSIGN, Ops.CONTIGUOUS, *GroupOp.Meta)), realize), # don't realize image to image casts - (UPatLoadStore(UPat(Ops.CAST, src=(UPat(Ops.LOAD, name="x"),), dtype=dtypes.float)).view(name="v"), lambda ctx,x,v,**kwargs: r.src[2].view(v.st) + (UPatSrc(Ops.CAST, src=(UPat(Ops.LOAD, name="x"),), dtype=dtypes.float).view(name="v"), lambda ctx,x,v,**kwargs: r.src[2].view(v.st) if (r:=ctx.get(b:=x.buf_uop)) is not None and r.op is Ops.STORE and isinstance(b.dtype, ImageDType) and r.src[2].op not in GroupOp.Meta else None), # realize before expand or unsafe pad ops - (UPatLoadStore(UPat.var("base")).view(name="view"), realize_view), + (UPatSrc().view(name="view"), realize_view), # realize before COPY or BUFFER_VIEW - (UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.var("u"), UPat.any(UPatLoadStore(), UPatLoadStore().view(name="view"))), name="root"), - lambda ctx,root,u,view=None,**kwargs: root.replace(src=(u, realize(ctx,**kwargs) if view is None else realize(ctx,**kwargs).view(view.st))),), + (UPat((Ops.COPY, Ops.BUFFER_VIEW), src=(UPat.any(UPatSrc(), UPatSrc().view(name="view")),), name="root"), + lambda ctx,root,view=None,**kwargs: root.replace(src=(realize(ctx,**kwargs) if view is None else realize(ctx,**kwargs).view(view.st),)),), ]) -break_sched = PatternMatcher([(UPatLoadStore(), lambda ctx,b,store,load: realize(ctx, b, load, store) if b in ctx else None),]) +break_sched = PatternMatcher([(UPatSrc(), lambda ctx,b,to_store,base: realize(ctx, b, to_store, base) if b in ctx else None),]) @track_rewrites(named=True) def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem], Dict[Variable, int]]: @@ -266,31 +369,30 @@ def create_schedule_with_vars(outs:List[LazyBuffer]) -> Tuple[List[ScheduleItem] ctx = ScheduleContext() cache: Dict[LazyBuffer, UOp] = {} # **** TODO: delete these next 3 after big graph - children: DefaultDict[LazyBuffer, Dict[LazyBuffer, None]] = defaultdict(dict) - allbufs: Dict[LazyBuffer, None] = {} - double_reduces: Dict[LazyBuffer, None] = {} + children: DefaultDict[UOp, Dict[UOp, None]] = defaultdict(dict) + allbufs: Dict[UOp, UOp] = {} + double_reduces: Dict[UOp, None] = {} big_graph = UOp.sink(*(to_uop(x, ctx, children, allbufs, double_reduces, cache) for x in outs)) # get realizes realizes: Dict[UOp, UOp] = {} graph_rewrite(big_graph, do_realize, realizes) - store_groups = get_realizes(children, allbufs, double_reduces, realizes, ctx.assigns, ctx.buf_uops) + store_groups = group_realizes(children, allbufs, double_reduces, realizes, ctx.assigns) # split realizes into small graphs graph_rewrite(big_graph, break_sched, realizes) sinks = [UOp.sink(*(realizes[u] for u in stores)) for stores in store_groups] # preschedule all realizes - bufs = list(ctx.buf_uops) prescheduled: List[ScheduleItem] = [] for sink in sinks: - metadata = tuple({mx for x in sink.sparents if x.op in GroupOp.Buffer and len(x.src) > 2 and (mx:=ctx.ubuf_metadata.get(x.buf_uop))}) + metadata = tuple({mx for x in sink.sparents if (x.op is Ops.STORE or is_scheduled(x)) and (mx:=ctx.ubuf_metadata.get(x.buf_uop))}) ast, ast_ctx = full_ast_rewrite(sink, ctx.var_vals, ctx.assigns) - prescheduled.append(ScheduleItem(ast, tuple(b for u in ast_ctx.bufs if (b:=bufs[u.arg[0]]).size != 0), metadata, tuple(ast_ctx.assign_preloads))) + prescheduled.append(ScheduleItem(ast, tuple(b for u in ast_ctx.bufs if (b:=ctx.buf_uops[u]).size != 0), metadata, tuple(ast_ctx.assign_preloads))) # do BFS schedule_targets = {out:si for si in prescheduled for out in si.outputs} graph: DefaultDict[ScheduleItem, List[ScheduleItem]] = defaultdict(list) in_degree: DefaultDict[ScheduleItem, int] = defaultdict(int) for si in prescheduled: # realize outputs before a parent is assigned to - parents_assigns = dedup(xsi for x in si.assign_preloads if (xsi:=schedule_targets.get(bufs[x.arg[0]])) and xsi is not si) + parents_assigns = dedup(xsi for x in si.assign_preloads if (xsi:=schedule_targets.get(ctx.buf_uops[x])) and xsi is not si) for assign in parents_assigns: graph[si].append(assign) in_degree[assign] += 1 diff --git a/tinygrad/multi.py b/tinygrad/multi.py index 975932a3ca..18870aa0d5 100644 --- a/tinygrad/multi.py +++ b/tinygrad/multi.py @@ -84,6 +84,7 @@ class MultiLazyBuffer(MathTrait): return functools.reduce(operator.add, llbs) # passthroughs + @property def is_realized(self) -> bool: return all(lb.base.realized is not None for lb in self.real_lbs) def cast(self, dtype:DType, bitcast:bool=False, allow_buffer_view=True): return MultiLazyBuffer([x.cast(dtype, bitcast, allow_buffer_view) for x in self.lbs], self.axis, self.real) diff --git a/tinygrad/ops.py b/tinygrad/ops.py index ee17fc1bd6..73c4d06adc 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -29,14 +29,14 @@ class SimpleMathTrait: dtype: Optional[DType] = getattr(self, 'dtype', None) assert dtype is not None, "MathTraits __neg__ requires a dtype" return self.logical_not() if dtype.scalar() == dtypes.bool else self*(-1) - def add(self, x, reverse=False): return self._binop(BinaryOps.ADD, x, reverse) - def mul(self, x, reverse=False): return self._binop(BinaryOps.MUL, x, reverse) - def bitwise_and(self, x, reverse=False): return self._binop(BinaryOps.AND, x, reverse) - def bitwise_or(self, x, reverse=False): return self._binop(BinaryOps.OR, x, reverse) - def xor(self, x, reverse=False): return self._binop(BinaryOps.XOR, x, reverse) - def idiv(self, x, reverse=False): return self._binop(BinaryOps.IDIV, x, reverse) - def sub(self, x, reverse=False): return self.ufix(x).alu(BinaryOps.ADD, -self) if reverse else self.alu(BinaryOps.ADD, self.ufix(-x)) - def div(self, x, reverse=False): return (self.ufix(x)*self.alu(UnaryOps.RECIP)) if reverse else (self*self.ufix(x).alu(UnaryOps.RECIP)) + def add(self, x, reverse=False): return self._binop(Ops.ADD, x, reverse) + def mul(self, x, reverse=False): return self._binop(Ops.MUL, x, reverse) + def bitwise_and(self, x, reverse=False): return self._binop(Ops.AND, x, reverse) + def bitwise_or(self, x, reverse=False): return self._binop(Ops.OR, x, reverse) + def xor(self, x, reverse=False): return self._binop(Ops.XOR, x, reverse) + def idiv(self, x, reverse=False): return self._binop(Ops.IDIV, x, reverse) + def sub(self, x, reverse=False): return self.ufix(x).alu(Ops.ADD, -self) if reverse else self.alu(Ops.ADD, self.ufix(-x)) + def div(self, x, reverse=False): return (self.ufix(x)*self.alu(Ops.RECIP)) if reverse else (self*self.ufix(x).alu(Ops.RECIP)) def __neg__(self): return self.neg() @@ -58,9 +58,9 @@ class SimpleMathTrait: def __ror__(self, x): return self.bitwise_or(x, True) def __rxor__(self, x): return self.xor(x, True) - def lt(self, x): return self.alu(BinaryOps.CMPLT, self.ufix(x)) - def gt(self, x): return self.ufix(x).alu(BinaryOps.CMPLT, self) - def ne(self, x): return self.alu(BinaryOps.CMPNE, self.ufix(x)) + def lt(self, x): return self.alu(Ops.CMPLT, self.ufix(x)) + def gt(self, x): return self.ufix(x).alu(Ops.CMPLT, self) + def ne(self, x): return self.alu(Ops.CMPNE, self.ufix(x)) def ge(self, x): return self.lt(x).logical_not() def le(self, x): return self.gt(x).logical_not() def eq(self, x): return self.ne(x).logical_not() @@ -74,26 +74,26 @@ class SimpleMathTrait: class MathTrait(SimpleMathTrait): # pylint: disable=abstract-method # TODO: move to Tensor when new backward is done - def lshift(self, x, reverse=False): return self._binop(BinaryOps.SHL, x, reverse) - def rshift(self, x, reverse=False): return self._binop(BinaryOps.SHR, x, reverse) + def lshift(self, x, reverse=False): return self._binop(Ops.SHL, x, reverse) + def rshift(self, x, reverse=False): return self._binop(Ops.SHR, x, reverse) def __lshift__(self, x): return self.lshift(x) def __rshift__(self, x): return self.rshift(x) def __rlshift__(self, x): return self.lshift(x, True) def __rrshift__(self, x): return self.rshift(x, True) # not in Tensor - def __mod__(self, x): return self.alu(BinaryOps.MOD, self.ufix(x)) - def __rmod__(self, x): return self.ufix(x).alu(BinaryOps.MOD, self) + def __mod__(self, x): return self.alu(Ops.MOD, self.ufix(x)) + def __rmod__(self, x): return self.ufix(x).alu(Ops.MOD, self) - def maximum(self, x): return self.alu(BinaryOps.MAX, self.ufix(x)) + def maximum(self, x): return self.alu(Ops.MAX, self.ufix(x)) def minimum(self, x): return -(-self).maximum(-x) - def where(self, x, y): return self.alu(TernaryOps.WHERE, x, y) - def threefry(self, seed): return self.alu(BinaryOps.THREEFRY, seed) - def reciprocal(self): return self.alu(UnaryOps.RECIP) - def sqrt(self): return self.alu(UnaryOps.SQRT) - def sin(self): return self.alu(UnaryOps.SIN) - def log2(self): return self.alu(UnaryOps.LOG2) - def exp2(self): return self.alu(UnaryOps.EXP2) + def where(self, x, y): return self.alu(Ops.WHERE, x, x.ufix(y)) + def threefry(self, seed): return self.alu(Ops.THREEFRY, seed) + def reciprocal(self): return self.alu(Ops.RECIP) + def sqrt(self): return self.alu(Ops.SQRT) + def sin(self): return self.alu(Ops.SIN) + def log2(self): return self.alu(Ops.LOG2) + def exp2(self): return self.alu(Ops.EXP2) # the order of these Ops controls the order of the toposort class Ops(FastEnum): @@ -182,11 +182,8 @@ class GroupOp: # do not preserve f(0) = 0 UnsafePad = {Ops.RECIP, Ops.LOG2, Ops.EXP2, Ops.IDIV} -# TODO: remove this? -UnaryOps = BinaryOps = MetaOps = TernaryOps = Ops - # https://en.wikipedia.org/wiki/Identity_element -def identity_element(op:Ops, dt:DType): return dtypes.as_const({BinaryOps.ADD:0, BinaryOps.MUL:1, BinaryOps.MAX:dtypes.min(dt)}[op], dt) +def identity_element(op:Ops, dt:DType): return dtypes.as_const({Ops.ADD:0, Ops.MUL:1, Ops.MAX:dtypes.min(dt)}[op], dt) def can_pad(u:UOp) -> bool: return not any(x.op in GroupOp.UnsafePad for x in u.sparents) @@ -325,7 +322,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def store(self, *src:UOp, **kwargs): return UOp(Ops.STORE, dtypes.void, (self,)+src, **kwargs) def alu(self, arg, *src:UOp): out_dtype = (self, *src)[-1].dtype - if arg in {BinaryOps.CMPLT, BinaryOps.CMPNE} and out_dtype is not None: + if arg in {Ops.CMPLT, Ops.CMPNE} and out_dtype is not None: out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool return UOp(arg, out_dtype, (self,)+src) @staticmethod @@ -339,19 +336,27 @@ class UOp(MathTrait, metaclass=UOpMetaClass): UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=(idx, False)) def r(self, op, axis): return UOp(Ops.REDUCE_AXIS, self.dtype, (self,), (op, axis)) def assign(self, x:UOp): return UOp(Ops.ASSIGN, self.dtype, (self,x)) + def contiguous(self): return UOp(Ops.CONTIGUOUS, self.dtype, (self,)) + @property + def is_contiguous_base(self): return self.op is Ops.CONTIGUOUS and not (self.src[0].base.op is Ops.VIEW and len(self.src[0].base.src) == 2) # *** uop movement ops *** @property - def base(self): return self.src[0] if self.op is Ops.VIEW and len(self.src) != 0 else self - def view(self, st:ShapeTracker): return self if self.st is None or self.st == st else UOp(Ops.VIEW, self.dtype, (self,), st) + def base(self): return self.src[0] if self.op is Ops.VIEW and len(self.src) == 1 else self + def view(self, st:ShapeTracker): + assert self.op is not Ops.STORE, "VIEW of STORE is invalid, STORE is always base" + return self if self.st is None or self.st == st else UOp(Ops.VIEW, self.dtype, (self,), st) def reshape(self, arg:Tuple[sint, ...]): return self.view(unwrap(self.st).reshape(arg)) # *** uop Buffer stuff *** + @staticmethod + def new_buffer(device:str, size:int, dtype:DType, num=-1): return UOp(Ops.BUFFER, dtype.ptr(), (), (num, (device, size, dtype))) + @property def buf_uop(self) -> UOp: - assert self.op in {*GroupOp.Buffer, Ops.ASSIGN} and self.src[0].op is Ops.BUFFER, f"buf_uop called on {self.op}" + assert self.op in {*GroupOp.Buffer, Ops.ASSIGN, Ops.VIEW} and self.src[0].op is Ops.BUFFER, f"buf_uop called on {self.op}" return self.src[0] # *** uop Variable stuff *** @@ -388,15 +393,15 @@ class UOp(MathTrait, metaclass=UOpMetaClass): """largest known int that divides self""" if self.op is Ops.CONST: return self.arg if self.op is Ops.VCONST: return math.gcd(*self.arg) - if self.op is BinaryOps.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor()) - if self.op is BinaryOps.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1 + if self.op is Ops.ADD: return math.gcd(self.src[0].const_factor(), self.src[1].const_factor()) + if self.op is Ops.MUL: return self.src[0].arg if self.src[0].op is Ops.CONST else self.src[1].arg if self.src[1].op is Ops.CONST else 1 return 1 def divides(self, v) -> Optional[UOp]: if v==1: return self if self.op is Ops.CONST: return self.const_like(self.arg//v) if self.arg%v == 0 else None if self.op is Ops.VCONST: return self.const_like(tuple(x//v for x in self.arg)) if all(x%v == 0 for x in self.arg) else None - if self.op is BinaryOps.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None - if self.op is BinaryOps.MUL: + if self.op is Ops.ADD: return d0+d1 if (d0:=self.src[0].divides(v)) is not None and (d1:=self.src[1].divides(v)) is not None else None + if self.op is Ops.MUL: if (d0:=self.src[0].divides(v)) is not None: return d0 * self.src[1] if (d1:=self.src[1].divides(v)) is not None: return self.src[0] * d1 return None # generic None if we aren't sure @@ -408,18 +413,18 @@ class UOp(MathTrait, metaclass=UOpMetaClass): def _min_max(self) -> Tuple[ConstType, ConstType]: if self.op in GroupOp.Binary and not dtypes.is_float(self.dtype): (s0_vmin, s0_vmax), (s1_vmin, s1_vmax) = self.src[0]._min_max, self.src[1]._min_max - if self.op is BinaryOps.ADD: return s0_vmin+s1_vmin, s0_vmax+s1_vmax - if self.op is BinaryOps.MUL: return min(vals:=(s0_vmin*s1_vmin, s0_vmin*s1_vmax, s0_vmax*s1_vmin, s0_vmax*s1_vmax)), max(vals) - if self.op is BinaryOps.MOD and s1_vmin > 0: return 0, s1_vmax-1 - if self.op is BinaryOps.IDIV and s1_vmin == s1_vmax: # min/max are equal in a CONST + if self.op is Ops.ADD: return s0_vmin+s1_vmin, s0_vmax+s1_vmax + if self.op is Ops.MUL: return min(vals:=(s0_vmin*s1_vmin, s0_vmin*s1_vmax, s0_vmax*s1_vmin, s0_vmax*s1_vmax)), max(vals) + if self.op is Ops.MOD and s1_vmin > 0: return 0, s1_vmax-1 + if self.op is Ops.IDIV and s1_vmin == s1_vmax: # min/max are equal in a CONST if s1_vmin > 0: return s0_vmin//s1_vmin, s0_vmax//s1_vmin if s1_vmin < 0 and s0_vmin >= 0: return -(s0_vmax//-s1_vmin), -(s0_vmin//-s1_vmin) - if self.op is BinaryOps.MAX: return max(s0_vmin, s1_vmin), max(s0_vmax, s1_vmax) - if self.op is BinaryOps.CMPLT: return (s0_vmax 0 else -math.inf if x == 0 else math.nan, UnaryOps.EXP2: hook_overflow(math.inf, lambda x: 2**x), - UnaryOps.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, UnaryOps.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x), - UnaryOps.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, - UnaryOps.NEG: operator.neg, BinaryOps.ADD: operator.add, BinaryOps.SUB: operator.sub, BinaryOps.MUL: operator.mul, - BinaryOps.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], BinaryOps.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf, - BinaryOps.MAX: max, BinaryOps.CMPNE: operator.ne, BinaryOps.CMPLT: operator.lt, BinaryOps.XOR: operator.xor, - BinaryOps.OR: operator.or_, BinaryOps.AND: operator.and_, BinaryOps.SHR: operator.rshift, BinaryOps.SHL: operator.lshift, - TernaryOps.MULACC: lambda x,y,z: (x*y)+z, TernaryOps.WHERE: lambda x,y,z: y if x else z} + Ops.LOG2: lambda x: math.log2(x) if x > 0 else -math.inf if x == 0 else math.nan, Ops.EXP2: hook_overflow(math.inf, lambda x: 2**x), + Ops.SQRT: lambda x: math.sqrt(x) if x >= 0 else math.nan, Ops.RECIP: lambda x: 1/x if x != 0 else math.copysign(math.inf, x), + Ops.SIN: lambda x: math.sin(x) if not math.isinf(x) else math.nan, + Ops.NEG: operator.neg, Ops.ADD: operator.add, Ops.SUB: operator.sub, Ops.MUL: operator.mul, + Ops.MOD: lambda x,y: abs(int(x))%abs(int(y))*(1,-1)[x<0], Ops.IDIV: lambda x,y: abs(x)//abs(y)*(1,-1)[x*y<0] if y != 0 else x*math.inf, + Ops.MAX: max, Ops.CMPNE: operator.ne, Ops.CMPLT: operator.lt, Ops.XOR: operator.xor, + Ops.OR: operator.or_, Ops.AND: operator.and_, Ops.SHR: operator.rshift, Ops.SHL: operator.lshift, + Ops.MULACC: lambda x,y,z: (x*y)+z, Ops.WHERE: lambda x,y,z: y if x else z} def exec_alu(op:Ops, dtype:DType, operands, truncate_output=True): if dtype.count > 1: @@ -511,7 +516,7 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]: elif u.op is Ops.STORE: mem += u.src[1].dtype.itemsize * mults elif u.op in GroupOp.ALU and u not in dont_count: - flops += (mults * (2 if u.op is TernaryOps.MULACC else 1)) * u.dtype.count + flops += (mults * (2 if u.op is Ops.MULACC else 1)) * u.dtype.count elif u.op is Ops.WMMA and u not in dont_count: flops += 2 * prod(u.arg[1]) // u.arg[5] * mults return flops, mem @@ -760,8 +765,8 @@ spec = PatternMatcher([ (UPat(Ops.SPECIAL, src=()), lambda: True), # TODO: confirm the args of both of these are shapetrackers - (UPat(Ops.VIEW, src=()), lambda: True), - (UPat(Ops.VIEW, src=(UPat(),)), lambda: True), + (UPat(Ops.VIEW, dtypes.void, src=()), lambda: True), + (UPat(Ops.VIEW, src=(UPat.var("src"),), name="x"), lambda x,src: src.op is not Ops.STORE and x.dtype == src.dtype), (UPat(Ops.VALID, dtypes.bool, (UPat(Ops.VIEW),)), lambda: True), (UPat(Ops.CONST, name="x"), lambda x: x.dtype == x.dtype.scalar() and (type(x.arg) is type(dtypes.as_const(x.arg, x.dtype)))), @@ -853,7 +858,7 @@ def mod_folding(x:UOp, c:int) -> Optional[UOp]: if 0 < c and 0 <= x.vmin and (quotient:=x.vmin//c) == x.vmax//c: return x-quotient*c remainder, something_changed = [], False - for u in split_uop(x, BinaryOps.ADD): + for u in split_uop(x, Ops.ADD): if (factor:=u.const_factor())%c != factor: divides = u.divides(factor)*(factor%c) assert divides is not None @@ -873,7 +878,7 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]: if 0 <= x.vmin and x.vmax < c: return x.const_like(0) quotient, remainder, rem_const, something_changed, gcd, divisor = [], [], 0, False, c, 1 - for u in split_uop(x, BinaryOps.ADD): + for u in split_uop(x, Ops.ADD): if u.op is Ops.CONST: # add all const together first if rem_const != 0: something_changed = True @@ -907,7 +912,7 @@ def div_folding(x:UOp, c:int) -> Optional[UOp]: return quo if rem is None else cast(UOp, div_folding(rem, div))//(c//div)+quo def lt_folding(x:UOp, c:int) -> Optional[UOp]: - p, np = partition(split_uop(x, BinaryOps.ADD), lambda u: u.const_factor() == 1) + p, np = partition(split_uop(x, Ops.ADD), lambda u: u.const_factor() == 1) if np and (d:=math.gcd(*[u.const_factor() for u in np], c)) > 1 and 0 <= sum(u.vmin for u in p) and sum(u.vmax for u in p) < d: return cast(UOp, functools.reduce(operator.add, np).divides(d)).lt(c//d) return None @@ -915,7 +920,7 @@ def lt_folding(x:UOp, c:int) -> Optional[UOp]: def fold_unrolled_divs(divs:UOp): # div pattern in unrolled arange # example: (x//4+(x+1)//4+(x+2)//4+(x+3)//4 -> x - add_chain, denominator, seen_const, ans = list(split_uop(divs, BinaryOps.ADD)), None, [], None + add_chain, denominator, seen_const, ans = list(split_uop(divs, Ops.ADD)), None, [], None for u in add_chain: if not (u.op is Ops.IDIV and u.src[1].op is Ops.CONST): return None if denominator is None: denominator = u.src[1].arg @@ -937,9 +942,9 @@ def canonicalize_simplex(X:UOp) -> Optional[UOp]: # (X := a0*x0 + a1*x1 + ...) > 0 is equivalent to x0 + x1 + ... > 0 if xi >= 0 and ai > 0 for ints. # returns x0 + x1 + ... in such case, or None if not changed, ret = False, [] - for u in split_uop(X, BinaryOps.ADD): + for u in split_uop(X, Ops.ADD): # assumed the const is the last src of MUL - if u.op is BinaryOps.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0: + if u.op is Ops.MUL and u.src[1].op is Ops.CONST and u.src[1].arg > 0: changed = True u = u.src[0] if not (u.op in GroupOp.Irreducible and u.vmin >= 0): return None @@ -949,8 +954,8 @@ def canonicalize_simplex(X:UOp) -> Optional[UOp]: def is_increasing(f:UOp) -> bool: # is f a monotonically increasing function regards its input if f.op in GroupOp.Irreducible: return True - if f.op is BinaryOps.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1]) - if f.op in (BinaryOps.MUL, BinaryOps.IDIV) and f.src[1].op is Ops.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0]) + if f.op is Ops.ADD: return is_increasing(f.src[0]) and is_increasing(f.src[1]) + if f.op in (Ops.MUL, Ops.IDIV) and f.src[1].op is Ops.CONST and f.src[1].arg >= 0: return is_increasing(f.src[0]) return False # False if not sure def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]: @@ -958,10 +963,10 @@ def parse_valid(valid:UOp) -> Tuple[UOp, bool, int]: # if it's X >= c, returns X, False, c # (X < c).ne(True) -> X >= c - if valid.op is BinaryOps.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \ - (s0:=valid.src[0]).op is BinaryOps.CMPLT and s0.src[1].op is Ops.CONST: return s0.src[0], False, s0.src[1].arg + if valid.op is Ops.CMPNE and valid.src[1].op is Ops.CONST and valid.src[1].arg == 1 and \ + (s0:=valid.src[0]).op is Ops.CMPLT and s0.src[1].op is Ops.CONST: return s0.src[0], False, s0.src[1].arg # X < c -> X <= c-1 - if valid.op is BinaryOps.CMPLT and valid.src[1].op is Ops.CONST: return valid.src[0], True, valid.src[1].arg-1 + if valid.op is Ops.CMPLT and valid.src[1].op is Ops.CONST: return valid.src[0], True, valid.src[1].arg-1 raise ValueError(f"not able to parse {valid=}") def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]: @@ -969,7 +974,7 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]: # first, parse valid into {expr: (lower_bound, upper_bound)} bounds:DefaultDict[UOp, List[Optional[ConstType]]] = defaultdict(lambda: [None, None]) - for stmt in split_uop(valid, BinaryOps.AND): + for stmt in split_uop(valid, Ops.AND): try: expr, is_upper, c = parse_valid(stmt) except ValueError: return uop # give up if we cannot parse the valid bounds[expr][int(is_upper)] = c @@ -981,9 +986,9 @@ def uop_given_valid(valid:UOp, uop:UOp) -> Optional[UOp]: # every candidate is a set of contrained UOp based on valid, and if every item in a set simplifies the uop into a same output, we rewrite uop candidates = [] - if expr.op is Ops.ADD and v[0] == 1 and all(u.op in GroupOp.Irreducible for u in split_uop(expr, BinaryOps.ADD)): + if expr.op is Ops.ADD and v[0] == 1 and all(u.op in GroupOp.Irreducible for u in split_uop(expr, Ops.ADD)): # if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output - candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, BinaryOps.ADD)]) + candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in split_uop(expr, Ops.ADD)]) # try checking the whole clause if expr in uop.sparents: candidates.append([(expr, UOp.variable("fake", expr.vmin if v[0] is None else v[0], expr.vmax if v[1] is None else v[1], expr.dtype))]) @@ -1033,6 +1038,7 @@ symbolic_simple = PatternMatcher([ ((UPat.var("x") & UPat.var("x")), lambda x: x), ((UPat.var("x") | UPat.var("x")), lambda x: x), (UPat.var("x", dtype=dtypes.bool).logical_not().logical_not(), lambda x: x), + (UPat.var("x", dtype=dtypes.bool).where(UPat.const(dtypes.bool, True), UPat.const(dtypes.bool, False)), lambda x: x), # ** zero folding ** (UPat.var("x") < UPat.var("x"), lambda x: UOp.const(dtypes.bool.vec(x.dtype.count), False)), # x < x -> False (UPat.var("x", dtype=dtypes.ints) != UPat.var("x", dtype=dtypes.ints), @@ -1124,8 +1130,8 @@ symbolic_flat = symbolic+PatternMatcher([ _substitute = PatternMatcher([(UPat(tuple(Ops), name="x"), lambda ctx,x: ctx.get(x,None))]) # for debug -syms = { BinaryOps.ADD: "+", BinaryOps.SUB: "-", BinaryOps.IDIV: "//", BinaryOps.MOD: "%", BinaryOps.SHL: "<<", BinaryOps.SHR: ">>", - BinaryOps.MUL: "*", BinaryOps.CMPLT: "<", BinaryOps.CMPNE: "!=", BinaryOps.AND: "&", BinaryOps.OR: "|", BinaryOps.XOR: "^"} +syms = { Ops.ADD: "+", Ops.SUB: "-", Ops.IDIV: "//", Ops.MOD: "%", Ops.SHL: "<<", Ops.SHR: ">>", + Ops.MUL: "*", Ops.CMPLT: "<", Ops.CMPNE: "!=", Ops.AND: "&", Ops.OR: "|", Ops.XOR: "^"} renderer = PatternMatcher([ (UPat((Ops.DEFINE_VAR, Ops.SPECIAL), name="x"), lambda x: UOp(Ops.NOOP, arg=x.arg[0])), (UPat(Ops.RANGE, name="x"), lambda x: UOp(Ops.NOOP, arg=f"ridx{x.arg[0]}")), diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 435a307d9c..790c59a861 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -2,7 +2,7 @@ from __future__ import annotations from typing import Dict, List, Optional, Tuple, Union, DefaultDict, Literal, Callable, cast import os, math from collections import defaultdict, Counter -from tinygrad.ops import GroupOp, UnaryOps, BinaryOps, TernaryOps, Ops, UOp, PatternMatcher, UPat, cast_float_to_bf16 +from tinygrad.ops import GroupOp, Ops, UOp, PatternMatcher, UPat, cast_float_to_bf16 from tinygrad.helpers import strip_parens, getenv, prod, dedup, AMX from tinygrad.dtype import ImageDType, dtypes, DType, PtrDType from tinygrad.renderer import Renderer, TensorCore @@ -19,36 +19,36 @@ base_rewrite = PatternMatcher([ (UPat(Ops.VECTORIZE, name="x"), lambda ctx,x: f"{ctx.float4.replace('float4', ctx.render_dtype(x.dtype))}" + \ (f"{{{','.join([ctx[y] for y in x.src])}}}" if ctx.device == "CLANG" else f"({','.join([ctx[y] for y in x.src])})")), - (UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_dtype(x.dtype)})({ctx[x.src[0]]})"), + (UPat(Ops.CAST, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx[x.src[0]])})"), (UPat(Ops.BITCAST, name="x"), lambda ctx,x: f"(*(({ctx.buffer_prefix}{ctx.render_dtype(x.dtype)}*)&{ctx[x.src[0]]}))"), (UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"{ctx.smem_align}{ctx.smem_prefix}{ctx.render_dtype(x.dtype.base)} {ctx[x]}[{x.arg[1]}];"), (UPat(Ops.BARRIER), lambda ctx: ctx.barrier), (UPat(Ops.NOOP, name="x"), lambda ctx,x: ctx[x.src[0]]), (UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"{ctx.code_for_workitem[x.arg[0][0]](x.arg[0][-1])}; /* {x.arg[1]} */"), # const - (UPat(Ops.CONST, arg=math.inf, name="x"), lambda ctx, x: f"(({ctx.render_dtype(x.dtype)}){ctx.infinity})"), - (UPat(Ops.CONST, arg=-math.inf, name="x"), lambda ctx, x: f"(({ctx.render_dtype(x.dtype)})-{ctx.infinity})"), - (UPat(Ops.CONST, dtype=dtypes.floats, name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){ctx.nan})" if math.isnan(x.arg) else None), + (UPat(Ops.CONST, arg=math.inf, name="x"), lambda ctx, x: f"({ctx.render_cast(x.dtype, ctx.infinity)})"), + (UPat(Ops.CONST, arg=-math.inf, name="x"), lambda ctx, x: f"({ctx.render_cast(x.dtype, f'-{ctx.infinity}')})"), + (UPat(Ops.CONST, dtype=dtypes.floats, name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, ctx.nan)})" if math.isnan(x.arg) else None), (UPat(Ops.CONST, dtype=dtypes.float, name="x"), lambda ctx,x: f"{x.arg}f"), (UPat(Ops.CONST, dtype=dtypes.int64, name="x"), lambda ctx,x: f"{x.arg}ll"), (UPat(Ops.CONST, dtype=dtypes.uint64, name="x"), lambda ctx,x: f"{x.arg}ull"), (UPat(Ops.CONST, dtype=dtypes.uint32, name="x"), lambda ctx,x: f"{x.arg}u"), (UPat(Ops.CONST, dtype=dtypes.bool, name="x"), lambda ctx,x: "1" if x.arg else "0"), # consts are rendered to larger type and casted - (UPat(Ops.CONST, (dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}f)"), - (UPat(Ops.CONST, (dtypes.uint8, dtypes.uint16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg}u)"), - (UPat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"(({ctx.render_dtype(x.dtype)}){x.arg})"), + (UPat(Ops.CONST, (dtypes.bfloat16, dtypes.half), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, f'{x.arg}f')})"), + (UPat(Ops.CONST, (dtypes.uint8, dtypes.uint16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, f'{x.arg}u')})"), + (UPat(Ops.CONST, (dtypes.int8, dtypes.int16), name="x"), lambda ctx,x: f"({ctx.render_cast(x.dtype, x.arg)})"), # default const render (UPat(Ops.CONST, name="x"), lambda ctx,x: str(x.arg)), # new load/store (UPat(Ops.INDEX, src=(UPat.var("buf"), UPat.var('idx'))), - lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == BinaryOps.ADD else ctx[idx]})"), + lambda ctx,buf,idx: f"({ctx[buf]}+{strip_parens(ctx[idx]) if idx.arg == Ops.ADD else ctx[idx]})"), (UPat(Ops.LOAD, src=(UPat.var('bidx'), UPat.var("var"), UPat.var("gate"))), lambda ctx,bidx,var,gate: f"({ctx[gate]}?*{ctx[bidx]}:{ctx[var]})"), (UPat(Ops.LOAD, src=(UPat.var('bidx'),), allow_any_len=True), lambda ctx,bidx: f"*{ctx[bidx]}"), (UPat(Ops.STORE, src=(UPat.var('bidx'), UPat.var("var")), allow_any_len=True), lambda ctx,bidx,var: f"*{ctx[bidx]} = {ctx[var]};"), # alu/gep (UPat(GroupOp.ALU, name="x"), lambda ctx,x: ctx.code_for_op[x.op]( - *([strip_parens(ctx[v]) if v.op == x.op and x.op in {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.XOR} else ctx[v] for v in x.src]), x.dtype)), + *([strip_parens(ctx[v]) if v.op == x.op and x.op in {Ops.ADD, Ops.MUL, Ops.XOR} else ctx[v] for v in x.src]), x.dtype)), (UPat(Ops.GEP, name="x"), lambda ctx,x: ctx[x.src[0]] + \ (f"[{x.arg[0]}]" if x.src[0].dtype.count > (8 if ctx.device in {"CUDA", "NV"} else 4) or ctx.device == 'CLANG' else f".{'xyzwabcd'[x.arg[0]]}")), ]) @@ -82,13 +82,13 @@ class CStyleLanguage(Renderer): infinity: str = "INFINITY" nan: str = "NAN" code_for_op: Dict = { - UnaryOps.SQRT: lambda x,dtype: f"sqrt({x})", UnaryOps.RECIP: lambda x,dtype: f"(1/{x})", UnaryOps.NEG: lambda x,dtype: f"-{x}", - UnaryOps.EXP2: lambda x,dtype: f"exp2({x})", UnaryOps.LOG2: lambda x,dtype: f"log2({x})", UnaryOps.SIN: lambda x,dtype: f"sin({x})", - BinaryOps.AND: lambda a,b,dtype: f"({a}&{b})", BinaryOps.XOR: lambda a,b,dtype: f"({a}^{b})", BinaryOps.OR: lambda a,b,dtype: f"({a}|{b})", - BinaryOps.ADD: lambda a,b,dtype: f"({a}+{b})", BinaryOps.SUB: lambda a,b,dtype: f"({a}-{b})", BinaryOps.MUL: lambda a,b,dtype: f"({a}*{b})", - BinaryOps.MOD: lambda a,b,dtype: f"({a}%{b})", BinaryOps.IDIV: lambda a,b,dtype: f"({a}/{b})", BinaryOps.CMPNE: lambda a,b,dtype: f"({a}!={b})", - BinaryOps.SHR: lambda a,b,dtype: f"({a}>>{b})", BinaryOps.SHL: lambda a,b,dtype: f"({a}<<{b})", BinaryOps.CMPLT: lambda a,b,dtype: f"({a}<{b})", - TernaryOps.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})" } + Ops.SQRT: lambda x,dtype: f"sqrt({x})", Ops.RECIP: lambda x,dtype: f"(1/{x})", Ops.NEG: lambda x,dtype: f"-{x}", + Ops.EXP2: lambda x,dtype: f"exp2({x})", Ops.LOG2: lambda x,dtype: f"log2({x})", Ops.SIN: lambda x,dtype: f"sin({x})", + Ops.AND: lambda a,b,dtype: f"({a}&{b})", Ops.XOR: lambda a,b,dtype: f"({a}^{b})", Ops.OR: lambda a,b,dtype: f"({a}|{b})", + Ops.ADD: lambda a,b,dtype: f"({a}+{b})", Ops.SUB: lambda a,b,dtype: f"({a}-{b})", Ops.MUL: lambda a,b,dtype: f"({a}*{b})", + Ops.MOD: lambda a,b,dtype: f"({a}%{b})", Ops.IDIV: lambda a,b,dtype: f"({a}/{b})", Ops.CMPNE: lambda a,b,dtype: f"({a}!={b})", + Ops.SHR: lambda a,b,dtype: f"({a}>>{b})", Ops.SHL: lambda a,b,dtype: f"({a}<<{b})", Ops.CMPLT: lambda a,b,dtype: f"({a}<{b})", + Ops.WHERE: lambda a,b,c,dtype: f"({a}?{b}:{c})" } string_rewrite = base_rewrite extra_matcher = extra_pm @@ -103,6 +103,7 @@ class CStyleLanguage(Renderer): [") {\n" + tmp] + ['\n'.join(kernel), "\n}"]) return prg if prefix is None else "\n".join(prefix)+f"\n{prg}" + def render_cast(self, dt:DType, val: str) -> str: return f"({self.render_dtype(dt)})({val})" def render_dtype(self, dt:DType, mutable=True) -> str: if isinstance(dt, ImageDType): return f"{'write_only' if mutable else 'read_only'} image2d_t" @@ -173,8 +174,8 @@ class ClangRenderer(CStyleLanguage): # language options buffer_suffix = " restrict" type_map = {dtypes.bool:"_Bool", dtypes.half:"__fp16"} - code_for_op = {**({k:v for k,v in CStyleLanguage.code_for_op.items() if k not in [UnaryOps.EXP2, UnaryOps.SIN, UnaryOps.LOG2]}), - UnaryOps.SQRT: lambda x,dtype: f"__builtin_sqrt({x})" if dtype == dtypes.float64 else f"__builtin_sqrtf({x})"} + code_for_op = {**({k:v for k,v in CStyleLanguage.code_for_op.items() if k not in [Ops.EXP2, Ops.SIN, Ops.LOG2]}), + Ops.SQRT: lambda x,dtype: f"__builtin_sqrt({x})" if dtype == dtypes.float64 else f"__builtin_sqrtf({x})"} if AMX: tensor_cores = [TensorCore(dims=(sz,sz,1), threads=[], reduce_axes=[], upcast_axes=([(1,sz)],[(0,sz)],[(1,sz),(0,sz)]), dtype_in=dt, dtype_out=dt) @@ -264,12 +265,12 @@ class MetalRenderer(CStyleLanguage): type_map = {dtypes.bfloat16: "bfloat"} # precise::sin - code_for_op = {**CStyleLanguage.code_for_op, UnaryOps.SIN: lambda x,dtype: f"precise::sin({x})"} + code_for_op = {**CStyleLanguage.code_for_op, Ops.SIN: lambda x,dtype: f"precise::sin({x})"} # upcast to float32 all the ops that don't support bfloat16 extra_matcher = PatternMatcher([ # NOTE: this is copied from PTX - (UPat((UnaryOps.SQRT, UnaryOps.EXP2, UnaryOps.LOG2, UnaryOps.SIN), dtype=dtypes.bfloat16, name="x"), + (UPat((Ops.SQRT, Ops.EXP2, Ops.LOG2, Ops.SIN), dtype=dtypes.bfloat16, name="x"), lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16))), ]) + extra_pm @@ -279,10 +280,12 @@ class MetalRenderer(CStyleLanguage): def render_kernel(self, function_name, kernel, bufs, uops, prefix=None): prefix, wmma_args = ["#include ","using namespace metal;"], set([uop.arg for uop in uops if uop.op is Ops.WMMA]) - for arg in wmma_args: prefix.append(f"""{arg[3].name}2 __{arg[0]}({arg[2].name}2 m, {arg[2].name}2 n, {arg[3].name}2 o) {{ - simdgroup_{arg[3].name}8x8 a,b,c; a.thread_elements()[0] = m.x; a.thread_elements()[1] = m.y; b.thread_elements()[0] = n.x; - b.thread_elements()[1] = n.y; c.thread_elements()[0] = o.x; c.thread_elements()[1] = o.y; simdgroup_multiply_accumulate(c, a, b, c); - return {arg[3].name}2(c.thread_elements()[0], c.thread_elements()[1]);\n}}""") + for arg in wmma_args: prefix.append( + f"""{(dtype_out:=self.render_dtype(arg[3].vec(2)))} __{arg[0]}({(dtype_in:=self.render_dtype(arg[2].vec(2)))} a, {dtype_in} b, {dtype_out} c){{ + simdgroup_{self.render_dtype(arg[2])}8x8 mat_a, mat_b; simdgroup_{self.render_dtype(arg[3])}8x8 mat_c; + mat_a.thread_elements()[0] = a[0]; mat_b.thread_elements()[0] = b[0]; mat_c.thread_elements()[0] = c[0]; + mat_a.thread_elements()[1] = a[1]; mat_b.thread_elements()[1] = b[1]; mat_c.thread_elements()[1] = c[1]; + simdgroup_multiply_accumulate(mat_c, mat_a, mat_b, mat_c);\n return {dtype_out}(mat_c.thread_elements()[0], mat_c.thread_elements()[1]);\n}}""") return super().render_kernel(function_name, kernel, bufs, uops, prefix) _nms = "xyzwabcdefghijkl" @@ -309,11 +312,11 @@ class CUDARenderer(CStyleLanguage): code_for_workitem = {"g": lambda x: f"blockIdx.{chr(120+int(x))}", "l": lambda x: f"threadIdx.{chr(120+int(x))}", "i": lambda x: f"(blockIdx.{chr(120+int(x))}*blockDim.{chr(120+int(x))}+threadIdx.{chr(120+int(x))})"} code_for_op = { **CStyleLanguage.code_for_op, - UnaryOps.SIN: lambda x,dtype: f"hsin({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sin({x})", - UnaryOps.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})", - UnaryOps.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})", - UnaryOps.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})", - UnaryOps.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"(1/{x})" } + Ops.SIN: lambda x,dtype: f"hsin({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sin({x})", + Ops.LOG2: lambda x,dtype: f"hlog2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"log2({x})", + Ops.EXP2: lambda x,dtype: f"hexp2({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"exp2({x})", + Ops.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})", + Ops.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"(1/{x})" } type_map = {dtypes.bfloat16: "nv_bfloat16"} def render_vector_prefix(self, dt:DType) -> str: @@ -372,10 +375,10 @@ class AMDRenderer(CStyleLanguage): code_for_workitem = {"g": lambda x: f"__ockl_get_group_id({x})", "l": lambda x: f"__ockl_get_local_id({x})", "i": lambda x: f"(__ockl_get_group_id({x})*__ockl_get_local_size({x})+__ockl_get_local_id({x}))"} code_for_op = { **CStyleLanguage.code_for_op, - UnaryOps.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})", - UnaryOps.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})", - UnaryOps.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})", - UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})" } + Ops.SIN: lambda x,dtype: f"__ocml_sin_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})", + Ops.LOG2: lambda x,dtype: f"__ocml_log2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})", + Ops.EXP2: lambda x,dtype: f"__ocml_exp2_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})", + Ops.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32)}({x})" } smem_prefix = "__attribute__((shared))" barrier = '__builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup");' + '__builtin_amdgcn_s_barrier();' + \ '__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup");' @@ -430,9 +433,9 @@ class DSPRenderer(ClangRenderer): buffer_suffix = " restrict __attribute__((align_value(128)))" kernel_prefix = "__attribute__((noinline)) " type_map = { **ClangRenderer.type_map, dtypes.uint64: "unsigned long long", dtypes.int64: "long long" } - code_for_op = {**ClangRenderer.code_for_op, UnaryOps.SIN: lambda x,dtype: f"__builtin_sin({x})", - UnaryOps.LOG2: lambda x,dtype: f"__builtin_log2l({x})" if dtype == dtypes.float64 else f"__builtin_log2f({x})", - UnaryOps.EXP2: lambda x,dtype: f"__builtin_exp2l({x})" if dtype == dtypes.float64 else f"__builtin_exp2f({x})"} + code_for_op = {**ClangRenderer.code_for_op, Ops.SIN: lambda x,dtype: f"__builtin_sin({x})", + Ops.LOG2: lambda x,dtype: f"__builtin_log2l({x})" if dtype == dtypes.float64 else f"__builtin_log2f({x})", + Ops.EXP2: lambda x,dtype: f"__builtin_exp2l({x})" if dtype == dtypes.float64 else f"__builtin_exp2f({x})"} def render_kernel(self, function_name:str, kernel:List[str], bufs:List[Tuple[str,Tuple[DType,bool]]], uops:List[UOp], prefix=None) -> str: ret = super().render_kernel(function_name, kernel, bufs, uops, prefix) diff --git a/tinygrad/renderer/ptx.py b/tinygrad/renderer/ptx.py index 4912c024de..46e77e74e0 100644 --- a/tinygrad/renderer/ptx.py +++ b/tinygrad/renderer/ptx.py @@ -1,11 +1,11 @@ from typing import DefaultDict, Dict, List, Union, Optional, cast, Callable, Tuple import struct from collections import defaultdict -from tinygrad.ops import BinaryOps, UnaryOps, TernaryOps, Ops, UOp, PatternMatcher, UPat, GroupOp -from tinygrad.dtype import dtypes, DType, PtrDType, ConstType +from tinygrad.ops import Ops, UOp, PatternMatcher, UPat, GroupOp +from tinygrad.dtype import dtypes, DType, PtrDType from tinygrad.renderer import Renderer from tinygrad.renderer.cstyle import CUDARenderer -from tinygrad.helpers import prod +from tinygrad.helpers import prod, flatten def render_val(x, dtype): if dtypes.is_float(dtype): @@ -15,24 +15,24 @@ def render_val(x, dtype): return str(int(x)) + ("U" if dtypes.is_unsigned(dtype) else "") asm_for_op: Dict[Ops, Callable] = { - UnaryOps.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};", - UnaryOps.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", UnaryOps.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};", - UnaryOps.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", UnaryOps.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};", - BinaryOps.SHR: lambda d,a,b,dt,name: f"shr.{name} {d}, {a}, {b};", BinaryOps.SHL: lambda d,a,b,dt,name: f"shl.b{name[1:]} {d}, {a}, {b};", - BinaryOps.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};", - BinaryOps.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};", - BinaryOps.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if name == "pred" else f"xor.b{name[1:]} {d}, {a}, {b};", - BinaryOps.AND: lambda d,a,b,dt, name: f"and.pred {d}, {a}, {b};" if name == "pred" else f"and.b{name[1:]} {d}, {a}, {b};", - BinaryOps.OR: lambda d,a,b,dt, name: f"or.pred {d}, {a}, {b};" if name == "pred" else f"or.b{name[1:]} {d}, {a}, {b};", - BinaryOps.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};", - BinaryOps.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", BinaryOps.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};", - BinaryOps.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};", BinaryOps.CMPNE: lambda d,a,b,dt,name: f"setp.ne.{name} {d}, {a}, {b};", - TernaryOps.MULACC: lambda d,a,b,c,dt,name: f"{'fma.rn' if dtypes.is_float(dt) else 'mad.lo'}.{name} {d}, {a}, {b}, {c};", - TernaryOps.WHERE: lambda d,a,b,c,dt,name: + Ops.RECIP: lambda d,a,dt,name: f"rcp{'.approx' if dtypes.is_float(dt) else ''}.{name} {d}, {a};", + Ops.EXP2: lambda d,a,dt,name: f"ex2.approx.{name} {d}, {a};", Ops.LOG2: lambda d,a,dt,name: f"lg2.approx.{name} {d}, {a};", + Ops.SIN: lambda d,a,dt,name: f"sin.approx.{name} {d}, {a};", Ops.SQRT: lambda d,a,dt,name: f"sqrt.approx.{name} {d}, {a};", + Ops.SHR: lambda d,a,b,dt,name: f"shr.{name} {d}, {a}, {b};", Ops.SHL: lambda d,a,b,dt,name: f"shl.b{name[1:]} {d}, {a}, {b};", + Ops.ADD: lambda d,a,b,dt,name: f"{'or' if name == 'pred' else 'add'}.{name} {d}, {a}, {b};", + Ops.MUL: lambda d,a,b,dt,name: ('and' if dt == dtypes.bool else 'mul') + f"{'.lo' if dtypes.is_int(dt) else ''}.{name} {d}, {a}, {b};", + Ops.XOR: lambda d,a,b,dt,name: f"xor.pred {d}, {a}, {b};" if name == "pred" else f"xor.b{name[1:]} {d}, {a}, {b};", + Ops.AND: lambda d,a,b,dt, name: f"and.pred {d}, {a}, {b};" if name == "pred" else f"and.b{name[1:]} {d}, {a}, {b};", + Ops.OR: lambda d,a,b,dt, name: f"or.pred {d}, {a}, {b};" if name == "pred" else f"or.b{name[1:]} {d}, {a}, {b};", + Ops.IDIV: lambda d,a,b,dt,name: f"div.{name} {d}, {a}, {b};", + Ops.MAX: lambda d,a,b,dt,name: f"max.{name} {d}, {a}, {b};", Ops.MOD: lambda d,a,b,dt,name: f"rem.{name} {d}, {a}, {b};", + Ops.CMPLT: lambda d,a,b,dt,name: f"setp.lt.{name} {d}, {a}, {b};", Ops.CMPNE: lambda d,a,b,dt,name: f"setp.ne.{name} {d}, {a}, {b};", + Ops.MULACC: lambda d,a,b,c,dt,name: f"{'fma.rn' if dtypes.is_float(dt) else 'mad.lo'}.{name} {d}, {a}, {b}, {c};", + Ops.WHERE: lambda d,a,b,c,dt,name: f"@{a} mov.{name} {d}, {b};\n@!{a} mov.{name} {d}, {c};" if name == "pred" else f"selp.{'b16' if name == 'f16' else name} {d}, {b}, {c}, {a};" } -supports_half: List[Ops] = [UnaryOps.EXP2, BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPLT, TernaryOps.WHERE] +supports_half: List[Ops] = [Ops.EXP2, Ops.ADD, Ops.MUL, Ops.MAX, Ops.CMPLT, Ops.WHERE] doesnt_support_half: Tuple[Ops, ...] = tuple(op for op in asm_for_op.keys() if op not in supports_half) ptx_matcher = PatternMatcher([ # bool CMPNE is XOR, bool CMPLT is XOR+AND (universal makes this slow, this is for renderer only) @@ -54,6 +54,84 @@ ptx_matcher = PatternMatcher([ (UPat.var("x") >> UPat.var("y"), lambda x,y: UOp(Ops.SHR, x.dtype, (x,y.cast(dtypes.uint))) if y.dtype != dtypes.uint else None), ]) +def mem_type(x: UOp): return 'shared' if x.src[0].op is Ops.DEFINE_LOCAL or any(_x.op is Ops.DEFINE_LOCAL for _x in x.src[0].parents) else 'global' + +def render_store(ctx: "PTXRenderer", x: UOp, bidx: UOp, var: UOp, pred: Optional[UOp]=None): + gate = f"@{ctx.r[pred]} " if pred is not None and pred.op is not Ops.IF else "" + return [f"{gate}st.{mem_type(bidx)}.v{var.dtype.count}.{ctx.mem_types[var.dtype.scalar()]} [{ctx.r[bidx]}+0], {{{', '.join(ctx.r[var])}}};"] \ + if var.dtype.count > 1 else [f"{gate}st.{mem_type(bidx)}.{ctx.mem_types[var.dtype]} [{ctx.r[bidx]}+0], {ctx.r[var]};"] + +def render_wmma(ctx: "PTXRenderer", x: UOp): + assert ctx.wmma_r, "registry values for wmma must be populated" + _, (N, M, K), dtype_in, _, _, _, upcast_axes, _ = x.arg + n_operands = tuple(prod(sz for _, sz in upc)*dtype_in.itemsize//4 for upc in upcast_axes[:2]) + dt_map = { dtypes.half: "f16" } + _i = 0 + for vv in x.src[:2]: + for i in range(0, len(ctx.r[vv]), 2): + yield f"mov.b32 {ctx.wmma_r[_i]}, {{{', '.join(ctx.r[vv][i:i+2])}}};" + _i += 1 + yield f'mma.sync.aligned.m{M}n{N}k{K}.row.col.f32.{dt_map[dtype_in]}.{dt_map[dtype_in]}.f32{" "*12}' +\ + f'{{{", ".join(ctx.r[x])}}}, {{{", ".join(ctx.wmma_r[:n_operands[0]])}}}, {{{", ".join(ctx.wmma_r[-n_operands[1]:])}}}, ' + \ + f'{{{", ".join(ctx.r[x.src[2]])}}};' + +def modifier(a: DType, b: DType): return '.rzi' if dtypes.is_int(a) and dtypes.is_float(b) else '.rn' if dtypes.is_float(a) and \ + (a.itemsize < b.itemsize or dtypes.is_int(b) or b == dtypes.bool) else '' + +string_rewrite = PatternMatcher([ + (UPat(Ops.CONST, name="x", dtype=dtypes.bool), lambda ctx, x: f"setp.ne.s16 {ctx.r[x]}, {render_val(x.arg, x.dtype)}, 0;"), + (UPat(Ops.CONST, name="x"), lambda ctx, x: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(x.arg, x.dtype)};"), + (UPat(Ops.STORE, name="x", src=(UPat.var('bidx'), UPat.var("var"), UPat.var("pred")), allow_any_len=True), render_store), + (UPat(Ops.SPECIAL, name="x"), lambda ctx,x: f"mov.u32 %{x.arg[0]}, %{'ctaid' if x.arg[0][0] == 'g' else 'tid'}.{chr(120+int(x.arg[0][-1]))};"), + (UPat(Ops.DEFINE_GLOBAL, name="x"), lambda ctx, x: f"ld.param.{ctx.types[dtypes.ulong]} {ctx.r[x]}, [data{x.arg}+0];"), + (UPat((Ops.CMPLT, Ops.CMPNE), name="x"), + lambda ctx, x: ctx.code_for_op[x.op](ctx.r[x], *[ctx.r[v] for v in x.src], x.src[0].dtype, ctx.types[x.src[0].dtype])), + (UPat(GroupOp.ALU, name="x"), lambda ctx, x: ctx.code_for_op[x.op](ctx.r[x], *[ctx.r[v] for v in x.src], x.dtype, ctx.types[x.dtype])), + (UPat(Ops.BITCAST, name="x", src=(UPat.var("a")), allow_any_len=True), lambda ctx, x, a: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {ctx.r[a]};"), + (UPat(Ops.CAST, name="x", src=(UPat(dtype=dtypes.bool, name="a"))), + lambda ctx, x, a: f"selp.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(1, x.dtype)}, {render_val(0, x.dtype)}, {ctx.r[a]};"), + (UPat(Ops.CAST, name="x", dtype=dtypes.bool), + lambda ctx, x: f"setp.ne.b{ctx.types[x.src[0].dtype][1:]} {ctx.r[x]}, {ctx.r[x.src[0]]}, {render_val(0, x.src[0].dtype)};"), + (UPat(Ops.CAST, name="x", src=(UPat.var("a"))), + lambda ctx, x, a: f"cvt{modifier(x.dtype, a.dtype)}.{ctx.types[x.dtype]}.{ctx.types[x.src[0].dtype]} {ctx.r[x]}, {ctx.r[x.src[0]]};"), + (UPat(Ops.LOAD, name="x", src=(UPat.var('loc'), UPat(name='alt'), UPat(name="gate", op=GroupOp.ALU))), lambda ctx, x, loc, alt, gate: flatten([ + [f"mov.{ctx.mem_types[x.dtype.scalar()]} {v}, {render_val(0, x.dtype.scalar())};" for v in ctx.r[x]], + [f"@{ctx.r[gate]} ld.{mem_type(x)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];"] + ]) if alt.dtype.count > 1 else [ + f"@{ctx.r[gate]} ld.{mem_type(x)}.{ctx.mem_types[x.dtype.scalar()]} {ctx.r[x]}, [{ctx.r[loc]}+0];", + f"@!{ctx.r[gate]} mov.b{ctx.types[x.dtype.scalar()][1:]} {ctx.r[x]}, {ctx.r[alt]};"]), + (UPat(Ops.LOAD, name="x", src=(UPat.var('loc'),), allow_any_len=True), + lambda ctx, x, loc: f" ld.{mem_type(x)}.v{x.dtype.count}.{ctx.mem_types[x.dtype.scalar()]} {{{', '.join(ctx.r[x])}}}, [{ctx.r[loc]}+0];" \ + if x.dtype.count > 1 else f"ld.{mem_type(x)}.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{ctx.r[loc]}+0];"), + (UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.VECTORIZE, dtype=dtypes.bool),), allow_any_len=True), + lambda ctx, x, pred: flatten([ + [f"setp.ne.s16 {ctx.r[pred][i]}, {render_val(pred.src[0].arg, x.dtype.scalar())}, 0;", + f"mov.b{ctx.types[x.dtype.scalar()][1:]} {uu}, {ctx.r[pred][i]};"] for i, uu in enumerate(ctx.r[x])])), + (UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.VECTORIZE, dtype=dtypes.half),), allow_any_len=True), + lambda ctx, x, pred: flatten([[f"mov.b{ctx.types[x.dtype.scalar()][1:]} {ctx.r[pred][i]}, {render_val(pred.src[0].arg, x.dtype.scalar())};", + f"mov.b{ctx.types[x.dtype.scalar()][1:]} {uu}, {ctx.r[pred][i]};"] for i, uu in enumerate(ctx.r[x])])), + (UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.VECTORIZE),), allow_any_len=True), lambda ctx, x, pred: [ + f"mov.b{ctx.types[x.dtype.scalar()][1:]} {uu}, {render_val(pred.src[0].arg, x.dtype.scalar())};" for i, uu in enumerate(ctx.r[x])]), + (UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.CONST, dtype=dtypes.bool), ), allow_any_len=True), lambda ctx, x, pred: [ + f"setp.ne.s16 {ctx.r[pred]}, {render_val(pred.arg, pred.dtype)}, 0;", f"mov.pred {ctx.r[x]}, {ctx.r[pred]};"]), + (UPat(Ops.DEFINE_ACC, name="x", src=(UPat(name="pred", op=Ops.CONST), ), allow_any_len=True), + lambda ctx, x, pred: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x]}, {render_val(pred.arg, x.dtype)};"), + (UPat(Ops.RANGE, name="x"), lambda ctx, x: [f"mov.u32 {ctx.r[x]}, {ctx.r[x.src[0]]};", "LOOP_" + f"{ctx.r[x][1:]}:"]), + (UPat(Ops.ASSIGN, name="x", dtype=dtypes.bool), lambda ctx, x: [f"mov.pred {ctx.r[x.src[0]]}, {ctx.r[x.src[1]]};"]), + (UPat(Ops.ASSIGN, name="x"), lambda ctx, x: f"mov.b{ctx.types[x.dtype][1:]} {ctx.r[x.src[0]]}, {ctx.r[x.src[1]]};"), + (UPat(Ops.ENDRANGE, name="x", src=(UPat.var("src0"),)), lambda ctx, x, src0: [ + ctx.code_for_op[Ops.ADD](ctx.r[src0], ctx.r[src0], "1", dtypes.int, ctx.types[dtypes.int]), + ctx.code_for_op[Ops.CMPLT](ctx.r[x], ctx.r[x.src[0]], ctx.r[src0.src[1]], dtypes.int, ctx.types[dtypes.int]), + f"@{ctx.r[x]} bra LOOP_{ctx.r[src0][1:]};"]), + (UPat(Ops.DEFINE_LOCAL, name="x"), + lambda ctx, x: [f".shared .align 4 .b8 {x.arg[0]}[{x.arg[1]*x.dtype.itemsize}];", f"mov.u64 {ctx.r[x]}, {x.arg[0]}[0];"]), + (UPat(Ops.IF, name="x"), lambda ctx, x: f"@!{ctx.r[x.src[0]]} bra IF_{ctx.r[x.src[0]][1:]}_{ctx.uops.index(x)};"), + (UPat(Ops.ENDIF, name="x"), lambda ctx, x: f"IF_{ctx.r[x.src[0].src[0]][1:]}_{ctx.uops.index(x.src[0])}:"), + (UPat(Ops.WMMA, name="x"), lambda ctx, x: list(render_wmma(ctx, x))), + (UPat(Ops.BARRIER, name="x"), lambda ctx, x: ctx.barrier), + (UPat(Ops.DEFINE_VAR, name="x"), lambda ctx, x: f"ld.param.{ctx.mem_types[x.dtype]} {ctx.r[x]}, [{x.arg[0]}+0];"), +]) + class PTXRenderer(Renderer): device = "CUDA" suffix = "PTX" @@ -80,34 +158,6 @@ class PTXRenderer(Renderer): mem_types: Dict[DType, str] = types.copy() mem_types.update({dtypes.int8: "s8", dtypes.uint8: "u8", dtypes.bool: "u8", dtypes.float16: "b16"}) - const_requires_mov: List[DType] = [dtypes.half, dtypes.bool] - - def render_const(self, x:ConstType, dtype:DType, mov=None) -> Union[List[str], str]: - val = render_val(x, dtype) - if dtype == dtypes.bool: return [f"setp.ne.s16 {mov}, {val}, 0;"] - return [f"mov.b{self.types[dtype][1:]} {mov}, {val};"] if mov else val - - def render_local(self, dest, name, size, dtype) -> List[str]: - return [f".shared .align 4 .b8 {name}[{size*dtype.itemsize}];", f"mov.u64 {dest}, {name}[0];"] - - def render_loop(self, idx, start, label, acc=None) -> List[str]: return [f"mov.u32 {idx}, {start};", f"{label}:"] - - def render_bra(self, b1, pred=None, invert=False) -> List[str]: - return [f"@{'!' if invert else ''}{pred} bra {b1};"] if pred else [f"bra {b1};"] - - def render_load(self, loc, dest, dtype, gate=None, alt=None, ss="", offset=0) -> List[str]: - assert dtype != dtypes.bool - if gate: return [f"@{gate} ld{ss}.{self.mem_types[dtype]} {dest}, [{loc}+{offset}];", f"@!{gate} mov.b{self.types[dtype][1:]} {dest}, {alt};"] - return [f"ld{ss}.{self.mem_types[dtype]} {dest}, [{loc}+{offset}];"] - - def render_cast(self, d:str, a:str, dtype:DType, atype:DType, bitcast=False, pred=False) -> List[str]: - if bitcast: return [f"mov.b{self.types[dtype][1:]} {d}, {a};"] - if atype == dtypes.bool: return [f"selp.b{self.types[dtype][1:]} {d}, {render_val(1, dtype)}, {render_val(0, dtype)}, {a};"] - if dtype == dtypes.bool: return [f"setp.ne.b{self.types[atype][1:]} {d}, {a}, {self.render_const(0, atype)};"] - rnd = ('.rzi' if dtypes.is_int(dtype) and dtypes.is_float(atype) else - '.rn' if dtypes.is_float(dtype) and (dtype.itemsize < atype.itemsize or dtypes.is_int(atype) or atype == dtypes.bool) else '') - return [f"cvt{rnd}.{self.types[dtype]}.{self.types[atype]} {d}, {a};"] - def render_kernel(self, kernel, function_name, bufs, regs) -> str: kernel = [f".reg .{reg.split('_')[-2]} %{reg}<{cnt}>;" for reg,cnt in regs] + kernel + ["ret;"] def fmt(line): return line if line[0]=="$" else "\t" + line.replace(" ", "\t" if len(line.split(" ")[0]) > 7 else "\t\t", 1) @@ -120,116 +170,56 @@ class PTXRenderer(Renderer): kernel:List[str] = [] bufs = [] - def kk(*s: str): kernel.append("\n".join(s)) - c: DefaultDict[str, int] = defaultdict(int) r: Dict[UOp, Union[List[str], str]] = {} + self.r = r + self.uops = uops + def ssa(prefix:str, u:Optional[UOp]=None, dtype:Optional[str]=None) -> str: nonlocal c, r prefix += f"_{dtype if dtype is not None else self.types[cast(UOp, u).dtype]}_" c[prefix] += 1 - if u is not None: r[u] = f"%{prefix}{c[prefix]-1}" return f"%{prefix}{c[prefix]-1}" - def const(x:ConstType, dtype:DType, mov=False): - if mov or dtype in self.const_requires_mov: - kk(*self.render_const(x, dtype, mov=(out:=ssa('const', dtype=self.types[dtype])))) - return out - return self.render_const(x, dtype) - - def _cast(a, dtype:DType, atype:DType, bitcast=False, u=None, pred=False): - if atype == dtype or isinstance(atype, PtrDType): - if u is not None: r[u] = a - return a - kk(*self.render_cast((ret:=ssa('cast', u, self.types[dtype])), a, dtype, atype, bitcast)) - return ret - for u in uops: - uop,dtype,src,args = u.op,u.dtype,u.src,u.arg - if uop is Ops.IF: - pred_reg = _cast(r[src[0]], dtypes.bool, src[0].dtype, u=u, pred=True) - kk(*self.render_bra(f"IF_{r[src[0]][1:]}_{uops.index(u)}", pred_reg, invert=True)) - elif uop is Ops.BARRIER and self.barrier: kk(self.barrier) - elif uop is Ops.ENDRANGE: - kk(self.code_for_op[BinaryOps.ADD](r[src[0]], r[src[0]], "1", dtypes.int, self.types[dtypes.int]), - self.code_for_op[BinaryOps.CMPLT](pred:=ssa("pred", dtype="pred"), r[src[0]], r[src[0].src[1]], dtypes.int, self.types[dtypes.int])) - kk(*self.render_bra(f"LOOP_{r[src[0]][1:]}", pred)) - elif uop is Ops.ENDIF: - kk(f"IF_{r[src[0].src[0]][1:]}_{uops.index(src[0])}:") - elif uop is Ops.STORE: - assert src[0].dtype == dtypes.int64, "store isn't int64" - mem_type = '.shared' if src[0].op is Ops.DEFINE_LOCAL or any(x.op is Ops.DEFINE_LOCAL for x in src[0].parents) else '.global' - gate = f"@{r[src[2]]} " if len(src)>2 and src[2].op is not Ops.IF else "" - if src[1].dtype.count > 1: - kk(gate + f"st{mem_type}.v{src[1].dtype.count}.{self.mem_types[src[1].dtype.scalar()]} [{r[src[0]]}+0], {{{', '.join(r[src[1]])}}};") - else: - kk(gate + f"st{mem_type}.{self.mem_types[src[1].dtype]} [{r[src[0]]}+0], {r[src[1]]};") - else: - if uop is Ops.RANGE: kk(*self.render_loop(loop:=ssa('ridx', u), r[src[0]], "LOOP_"+loop[1:])) - elif uop in GroupOp.ALU: - src_dtype = src[0].dtype if uop in {BinaryOps.CMPLT, BinaryOps.CMPNE} else dtype - kk(self.code_for_op[uop](ssa("alu", u), *[r[x] for x in src], src_dtype, self.types[src_dtype])) - elif uop is Ops.DEFINE_ACC: - if dtype.count > 1: - r[u] = [ssa('acc', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)] - for uu in r[u]: kk(f"mov.b{self.types[dtype.scalar()][1:]} {uu}, {const(src[0].src[0].arg, dtype.scalar())};") - else: kk(f"mov.{f'b{self.types[dtype][1:]}' if dtype != dtypes.bool else 'pred'} {ssa('acc', u)}, {const(src[0].arg, dtype)};") - elif uop is Ops.SPECIAL: - assert args[0][0] != "i", "idx not supported" - kk(f"mov.u32 %{args[0]}, %{'ctaid' if args[0][0] == 'g' else 'tid'}.{chr(120+int(args[0][-1]))};") - r[u] = "%" + args[0] - kernel = [f".reg .u32 %{args[0]};"] + kernel - elif uop is Ops.DEFINE_VAR: - bufs.append((args[0], dtype)) - r[u] = f"%{args[0]}" - kk(*self.render_load(args[0], ssa('dat', u, self.types[dtype]), dtype, ss=".param")) - elif uop is Ops.CONST: r[u] = const(args, dtype, mov=True) - elif uop is Ops.GEP: - assert len(u.arg) == 1 - r[u] = r[src[0]][u.arg[0]] - elif uop is Ops.LOAD: - assert src[0].dtype == dtypes.int64, "load isn't int64" - mem_type = '.shared' if src[0].op is Ops.DEFINE_LOCAL or any(x.op is Ops.DEFINE_LOCAL for x in src[0].parents) else '.global' - has_gate = len(src) > 2 and src[2].op in GroupOp.ALU - if dtype.count > 1: - r[u] = [ssa('val', dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)] - if has_gate: - for v in r[u]: kk(f"mov.{self.mem_types[dtype.scalar()]} {v}, {render_val(0, dtype.scalar())};") - kk((f"@{r[src[2]]}" if has_gate else "") - + f" ld{mem_type}.v{dtype.count}.{self.mem_types[dtype.scalar()]} {{{', '.join(r[u])}}}, [{r[src[0]]}+0];") - else: - kk(*self.render_load(r[src[0]], ssa('val', u), dtype, gate=r[src[2]] if has_gate else None, - alt=r[src[1]] if has_gate else None, ss=mem_type, offset=0)) - elif uop is Ops.ASSIGN: - if dtype.count > 1: - for x0, x1 in zip(r[src[0]], r[src[1]]): kk(f"mov.b{self.types[dtype.scalar()][1:]} {x0}, {x1};") - else: kk(f"mov.{f'b{self.types[dtype][1:]}' if dtype != dtypes.bool else 'pred'} {r[src[0]]}, {r[src[1]]};") - r[u] = r[src[0]] - # NOTE: casting to str is fine because you can't vectorize a vectorize - elif uop is Ops.VECTORIZE: r[u] = [cast(str,r[x]) for x in src] - elif uop in {Ops.CAST, Ops.BITCAST}: - _cast(r[src[0]], dtype, src[0].dtype, bitcast=uop is Ops.BITCAST, u=u) - elif uop is Ops.DEFINE_LOCAL: - # TODO: we should sum these, and fetch 0xC000 from somewhere - assert args[1]*dtype.itemsize <= 0xC000, "too large local" - kk(*self.render_local(ssa('local', u, self.types[dtypes.ulong]), args[0], args[1], dtype)) - elif uop is Ops.DEFINE_GLOBAL: - bufs.append((nm:=f"data{args}", dtype)) - r[u] = f"%{nm}" - dt = dtypes.ulong if dtype.__class__ == PtrDType else dtype - kk(*self.render_load(nm, ssa('dat', u, self.types[dt]), dt, ss=".param")) - elif uop is Ops.WMMA: - _, (N, M, K), dtype_in, _, _, _, upcast_axes, _ = args - wmma, n_operands = [], tuple(prod(sz for _, sz in upc)*dtype_in.itemsize//4 for upc in upcast_axes[:2]) - dt_map = { dtypes.half: "f16" } - for vv in src[:2]: - for i in range(0, len(r[vv]), 2): - wmma.append(ssa("wmma", dtype="b32")) - kk(f'mov.b32 {wmma[-1]}, {{{", ".join(r[vv][i:i+2])}}};') - r[u] = [ssa("wmma", dtype=self.types[dtype.scalar()]) for _ in range(dtype.count)] - kk(f'mma.sync.aligned.m{M}n{N}k{K}.row.col.f32.{dt_map[dtype_in]}.{dt_map[dtype_in]}.f32\ - {{{", ".join(r[u])}}}, {{{", ".join(wmma[:n_operands[0]])}}}, {{{", ".join(wmma[-n_operands[1]:])}}}, {{{", ".join(r[src[2]])}}};') - else: raise NotImplementedError(f"no code for {uop}") + if u.op is Ops.VECTORIZE: + r[u] = [cast(str,r[x]) for x in u.src] + continue + if u.op is Ops.GEP: + assert len(u.arg) == 1 + r[u] = r[u.src[0]][u.arg[0]] + continue + if u.op in {Ops.CAST, Ops.BITCAST}: + if u.src[0].dtype == u.dtype or isinstance(u.src[0].dtype, PtrDType): + r[u] = r[u.src[0]] + continue + r[u] = ssa('cast', u, self.types[u.dtype]) + elif u.op is Ops.ENDRANGE: r[u] = ssa("pred", u, dtype="pred") + elif u.op is Ops.RANGE: r[u] = ssa("ridx", u) + elif u.op in GroupOp.ALU: r[u] = ssa("alu", u) + elif u.op is Ops.DEFINE_ACC: + if u.dtype.scalar() in [dtypes.half, dtypes.bool]: + r[u.src[0]] = [ssa("const", u.src[0].src[0]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa("const", u.src[0]) + r[u] = [ssa('acc', u, dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa("acc", u) + elif u.op is Ops.SPECIAL: r[u] = "%" + u.arg[0] + elif u.op is Ops.DEFINE_VAR: + bufs.append((u.arg[0], u.dtype)) + r[u] = ssa("dat", u, self.types[u.dtype]) + elif u.op is Ops.CONST: r[u] = ssa("const", u, dtype=self.types[u.dtype]) + elif u.op is Ops.LOAD: + assert u.src[0].dtype == dtypes.int64, "load isn't int64" + r[u] = [ssa('val', dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] if u.dtype.count > 1 else ssa('val', u) + elif u.op is Ops.DEFINE_LOCAL: r[u] = ssa('local', u, self.types[dtypes.ulong]) + elif u.op is Ops.DEFINE_GLOBAL: + bufs.append((f"data{u.arg}", u.dtype)) + r[u] = ssa('dat', u, self.types[dtypes.ulong if u.dtype.__class__ == PtrDType else u.dtype]) + elif u.op is Ops.WMMA: + self.wmma_r = [ssa("wmma", dtype="b32") for vv in u.src[:2] for i in range(0, len(r[vv]), 2)] + r[u] = [ssa("wmma", dtype=self.types[u.dtype.scalar()]) for _ in range(u.dtype.count)] + if (l:=cast(Union[str, List[str]], string_rewrite.rewrite(u, ctx=self))) is None: + raise RuntimeError(f"failed to render {u.op} with {u.dtype} srcs {[x.dtype for x in u.u.src]}") + kernel.extend([l] if isinstance(l, str) else l) + if u.op is Ops.ASSIGN: r[u] = r[u.src[0]] + elif u.op is Ops.SPECIAL: kernel = [f".reg .u32 %{u.arg[0]};"] + kernel return self.render_kernel(kernel, name, bufs, c.items()) - diff --git a/tinygrad/runtime/ops_clang.py b/tinygrad/runtime/ops_clang.py index eab8fdfe87..d7ce3dc376 100644 --- a/tinygrad/runtime/ops_clang.py +++ b/tinygrad/runtime/ops_clang.py @@ -1,12 +1,13 @@ from typing import Optional, List import ctypes, subprocess, pathlib, tempfile from tinygrad.device import Compiled, Compiler, MallocAllocator -from tinygrad.helpers import cpu_time_execution, DEBUG, cpu_objdump +from tinygrad.helpers import cpu_time_execution, cpu_objdump from tinygrad.renderer.cstyle import ClangRenderer class ClangCompiler(Compiler): - def __init__(self, cachekey="compile_clang", args:Optional[List[str]]=None): + def __init__(self, cachekey="compile_clang", args:Optional[List[str]]=None, objdump_tool='objdump'): self.args = ['-march=native'] if args is None else args + self.objdump_tool = objdump_tool super().__init__(cachekey) def compile(self, src:str) -> bytes: @@ -16,9 +17,10 @@ class ClangCompiler(Compiler): '-', '-o', str(output_file.name)], input=src.encode('utf-8')) return pathlib.Path(output_file.name).read_bytes() + def disassemble(self, lib:bytes): return cpu_objdump(lib, self.objdump_tool) + class ClangProgram: def __init__(self, name:str, lib:bytes): - if DEBUG >= 6: cpu_objdump(lib) self.name, self.lib = name, lib # write to disk so we can load it with tempfile.NamedTemporaryFile(delete=True) as cached_file_path: diff --git a/tinygrad/runtime/ops_dsp.py b/tinygrad/runtime/ops_dsp.py index 0865b4b870..4156f93422 100644 --- a/tinygrad/runtime/ops_dsp.py +++ b/tinygrad/runtime/ops_dsp.py @@ -3,7 +3,7 @@ from typing import Tuple, Any import ctypes, os, mmap, tempfile, pathlib, array, functools, threading, contextlib, sys assert sys.platform != 'win32' from tinygrad.device import BufferOptions, Compiled, Allocator -from tinygrad.helpers import from_mv, getenv, DEBUG, round_up, mv_address, to_mv, cpu_objdump +from tinygrad.helpers import from_mv, getenv, round_up, mv_address, to_mv from tinygrad.runtime.ops_clang import ClangCompiler from tinygrad.renderer.cstyle import DSPRenderer from tinygrad.runtime.autogen import libc, qcom_dsp @@ -23,7 +23,6 @@ def rpc_prep_args(ins=None, outs=None, in_fds=None): class DSPProgram: def __init__(self, device:DSPDevice, name:str, lib:bytes): self.device, self.lib = device, lib - if DEBUG >= 6: cpu_objdump(lib, objdump_tool='llvm-objdump') def __call__(self, *bufs, vals:Tuple[int, ...]=(), wait=False): if len(bufs) >= 16: raise RuntimeError(f"Too many buffers to execute: {len(bufs)}") @@ -72,7 +71,8 @@ class DSPDevice(Compiled): self.link_ld.flush() compiler_args = ["--target=hexagon", "-mcpu=hexagonv65", "-fuse-ld=lld", "-nostdlib", "-mhvx=v65", "-mhvx-length=128b", f"-T{self.link_ld.name}"] - super().__init__(device, DSPAllocator(self), DSPRenderer(), ClangCompiler("compile_dsp", args=compiler_args), functools.partial(DSPProgram, self)) + super().__init__(device, DSPAllocator(self), DSPRenderer(), + ClangCompiler("compile_dsp", args=compiler_args, objdump_tool='llvm-objdump'), functools.partial(DSPProgram, self)) fastrpc_shell = memoryview(bytearray(pathlib.Path('/dsp/cdsp/fastrpc_shell_3').read_bytes())) self.shell_buf = self.allocator.alloc(round_up(fastrpc_shell.nbytes, 0x1000), BufferOptions(nolru=True)) diff --git a/tinygrad/runtime/ops_llvm.py b/tinygrad/runtime/ops_llvm.py index 0078c74fec..d34e6fa5a0 100644 --- a/tinygrad/runtime/ops_llvm.py +++ b/tinygrad/runtime/ops_llvm.py @@ -2,7 +2,7 @@ from __future__ import annotations import ctypes, functools from typing import Tuple from tinygrad.device import Compiled, Compiler, MallocAllocator -from tinygrad.helpers import DEBUG, cpu_time_execution, cpu_objdump, getenv +from tinygrad.helpers import cpu_time_execution, getenv, cpu_objdump from tinygrad.renderer.llvmir import LLVMRenderer import llvmlite.binding as llvm @@ -16,16 +16,17 @@ class LLVMCompiler(Compiler): builder.opt_level = 3; builder.size_level = 0; builder.loop_vectorize = True; builder.slp_vectorize = True # noqa: E702 builder.populate(self.optimizer) super().__init__("compile_llvm_opt" if opt else "compile_llvm") + def compile(self, src:str) -> bytes: mod = llvm.parse_assembly(src) mod.verify() self.optimizer.run(mod) - if DEBUG >= 5: print(self.device.target_machine.emit_assembly(mod)) return self.device.target_machine.emit_object(mod) + def disassemble(self, lib:bytes): cpu_objdump(lib) + class LLVMProgram: def __init__(self, device:LLVMDevice, name:str, lib:bytes): - if DEBUG >= 6: cpu_objdump(lib) self.name, self.lib = name, lib device.engine.add_object_file(llvm.object_file.ObjectFileRef.from_data(lib)) self.fxn = device.engine.get_function_address(name) diff --git a/tinygrad/runtime/ops_python.py b/tinygrad/runtime/ops_python.py index 452618f813..32a38672cd 100644 --- a/tinygrad/runtime/ops_python.py +++ b/tinygrad/runtime/ops_python.py @@ -7,7 +7,7 @@ import pickle, base64, itertools, time, struct from tinygrad.dtype import DType, dtypes, ImageDType, PtrDType, truncate from tinygrad.helpers import all_same, getenv, flatten from tinygrad.device import Compiled, Compiler, Allocator -from tinygrad.ops import BinaryOps, TernaryOps, exec_alu, Ops, UOp, GroupOp +from tinygrad.ops import exec_alu, Ops, UOp, GroupOp from tinygrad.renderer import Renderer from tinygrad.renderer.cstyle import CUDARenderer, MetalRenderer, AMDRenderer, IntelRenderer, ClangRenderer @@ -175,7 +175,7 @@ class PythonProgram: else: raise NotImplementedError(f"unimplemented tensor core {arg}") elif uop in GroupOp.ALU: assert all_same([len(x) for x in inp]), f"{[len(x) for x in inp]} doesn't match on {uop}" - assert all_same([dtype] + dtp) or uop in {BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE}, f"dtype mismatch on {uop}" + assert all_same([dtype] + dtp) or uop in {Ops.CMPNE, Ops.CMPLT, Ops.WHERE}, f"dtype mismatch on {uop}" ul[i] = [exec_alu(uop, dtype, p) for p in zip(*inp)] assert i in ul, (uop, dtype, idp, arg) i += 1 diff --git a/tinygrad/shape/shapetracker.py b/tinygrad/shape/shapetracker.py index a0d93a7bc9..5204b27f1f 100644 --- a/tinygrad/shape/shapetracker.py +++ b/tinygrad/shape/shapetracker.py @@ -5,7 +5,7 @@ from typing import Tuple, List, Optional, Dict, Set from tinygrad.helpers import merge_dicts, getenv from tinygrad.shape.view import View, strides_for_shape from tinygrad.dtype import dtypes -from tinygrad.ops import UOp, Ops, BinaryOps, graph_rewrite, split_uop, symbolic_flat, Variable, sint, uop_given_valid, simplify_valid +from tinygrad.ops import UOp, Ops, graph_rewrite, split_uop, symbolic_flat, Variable, sint, uop_given_valid, simplify_valid @dataclass(frozen=True, order=True) class ShapeTracker: @@ -77,7 +77,7 @@ class ShapeTracker: # TODO: always apply these in to_indexed_uops? if (newvalid:=simplify_valid(valid)) is not None: valid = newvalid if (newidx:=uop_given_valid(valid, idx)) is not None: idx = graph_rewrite(newidx, symbolic_flat) - for c in split_uop(idx, BinaryOps.ADD): + for c in split_uop(idx, Ops.ADD): if c.op is Ops.RANGE: ret[c.arg[0]] = 1 if c.op is Ops.MUL and c.src[0].op is Ops.RANGE and c.src[1].op is Ops.CONST: ret[c.src[0].arg[0]] = c.src[1].arg if c.op is Ops.MUL and c.src[1].op is Ops.RANGE and c.src[0].op is Ops.CONST: ret[c.src[1].arg[0]] = c.src[0].arg diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index a305cd3fc4..c54e54c22b 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -9,7 +9,7 @@ from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, leas from tinygrad.helpers import argfix, make_tuple, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv, fetch, polyN from tinygrad.multi import MultiLazyBuffer -from tinygrad.ops import MetaOps, smax, smin, resolve, UOp, Ops, BinaryOps, sint, Variable, SimpleMathTrait +from tinygrad.ops import smax, smin, resolve, UOp, Ops, sint, Variable, SimpleMathTrait from tinygrad.device import Device, Buffer, BufferOptions from tinygrad.engine.lazy import LazyBuffer from tinygrad.engine.realize import run_schedule @@ -51,7 +51,7 @@ def _to_np_dtype(dtype:DType) -> Optional[type]: return np.dtype(dtype.fmt).type if dtype.fmt is not None else None def _fromnp(x: 'np.ndarray') -> LazyBuffer: # type: ignore [name-defined] # noqa: F821 - ret = LazyBuffer.metaop(MetaOps.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY") + ret = LazyBuffer.metaop(Ops.EMPTY, x.shape, _from_np_dtype(x.dtype), "NPY") # fake realize ret.buffer.allocate(x) del ret.srcs @@ -64,9 +64,9 @@ def get_shape(x) -> Tuple[int, ...]: return (len(subs),) + (subs[0] if subs else ()) def _frompy(x:Union[List, Tuple, bytes], dtype:DType) -> LazyBuffer: - if isinstance(x, bytes): ret, data = LazyBuffer.metaop(MetaOps.EMPTY, (len(x)//dtype.itemsize,), dtype, "PYTHON"), x + if isinstance(x, bytes): ret, data = LazyBuffer.metaop(Ops.EMPTY, (len(x)//dtype.itemsize,), dtype, "PYTHON"), x else: - ret = LazyBuffer.metaop(MetaOps.EMPTY, get_shape(x), dtype, "PYTHON") + ret = LazyBuffer.metaop(Ops.EMPTY, get_shape(x), dtype, "PYTHON") assert dtype.fmt is not None, f"{dtype=} has None fmt" truncate_function = truncate[dtype] data = struct.pack(f"@{ret.size}{dtype.fmt}", *[truncate_function(xi) for xi in fully_flatten(x)]) @@ -115,7 +115,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method training: ClassVar[bool] = False no_grad: ClassVar[bool] = False - def __init__(self, data:Union[None, ConstType, List, Tuple, LazyBuffer, 'np.ndarray', bytes, MultiLazyBuffer, UOp, pathlib.Path], # type: ignore [name-defined] # noqa: F821 + def __init__(self, data:Union[None, ConstType, UOp, bytes, List, Tuple, LazyBuffer, MultiLazyBuffer, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821 device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None): if dtype is not None: dtype = to_dtype(dtype) assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}" @@ -133,11 +133,12 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method self._ctx: Optional[Function] = None # create a LazyBuffer from the different types of inputs - if isinstance(data, LazyBuffer): assert dtype is None or dtype == data.dtype, "dtype doesn't match, and casting isn't supported" - elif isinstance(data, get_args(ConstType)): data = _metaop(MetaOps.CONST, tuple(), dtype or dtypes.from_py(data), device, data) + if isinstance(data, (LazyBuffer, MultiLazyBuffer)): assert dtype is None or dtype==data.dtype, "dtype doesn't match, and casting isn't supported" + elif data is None: data = _metaop(Ops.EMPTY, (0,), dtype or dtypes.default_float, device) + elif isinstance(data, get_args(ConstType)): data = _metaop(Ops.CONST, tuple(), dtype or dtypes.from_py(data), device, data) elif isinstance(data, UOp): assert data.op is Ops.BIND and data.src[0].op is Ops.DEFINE_VAR and data.src[1].op is Ops.CONST, f"can't create tensor from UOp {data}" - data = _metaop(MetaOps.CONST, tuple(), dtype or data.dtype, device, data) + data = _metaop(Ops.CONST, tuple(), dtype or data.dtype, device, data) elif isinstance(data, bytes): data = _frompy(data, dtypes.uint8 if dtype is None else dtype) elif isinstance(data, (list, tuple)): if dtype is None: @@ -145,30 +146,25 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float if dtype == dtypes.bfloat16: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtypes.bfloat16).lazydata else: data = _frompy(data, dtype) - elif data is None: data = _metaop(MetaOps.EMPTY, (0,), dtype or dtypes.default_float, device) elif str(type(data)) == "": import numpy as np assert isinstance(data, np.ndarray), f"expected np.ndarray, got {data}" - if data.shape == (): data = _metaop(MetaOps.CONST, tuple(), dtype or _from_np_dtype(data.dtype), device, data.item()) + if data.shape == (): data = _metaop(Ops.CONST, tuple(), dtype or _from_np_dtype(data.dtype), device, data.item()) else: data = _fromnp(data.astype(npdtype) if dtype is not None and (npdtype:=_to_np_dtype(dtype)) is not None else data) # type: ignore [name-defined] elif isinstance(data, pathlib.Path): dtype = dtype or dtypes.uint8 - data = _metaop(MetaOps.EMPTY, (data.stat().st_size // dtype.itemsize,), dtype, f"DISK:{data.resolve()}") + data = _metaop(Ops.EMPTY, (data.stat().st_size // dtype.itemsize,), dtype, f"DISK:{data.resolve()}") # by this point, it has to be a LazyBuffer - if not isinstance(data, (LazyBuffer, MultiLazyBuffer)): - raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}") + if not isinstance(data, (LazyBuffer, MultiLazyBuffer)): raise RuntimeError(f"can't create Tensor from {data!r} with type {type(data)}") - # data is a LazyBuffer, but it might be on the wrong device - if isinstance(device, tuple): - # if device is a tuple, we should have/construct a MultiLazyBuffer - if isinstance(data, MultiLazyBuffer): - assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}" - self.lazydata: Union[LazyBuffer, MultiLazyBuffer] = data - else: - self.lazydata = MultiLazyBuffer.from_sharded(data, device, None, None) + # data might be on a different device + if isinstance(device, str): self.lazydata:Union[LazyBuffer, MultiLazyBuffer] = data if data.device == device else data.copy_to_device(device) + # if device is a tuple, we should have/construct a MultiLazyBuffer + elif isinstance(data, LazyBuffer): self.lazydata = MultiLazyBuffer.from_sharded(data, device, None, None) else: - self.lazydata = data if data.device == device else data.copy_to_device(device) + assert data.device == device, f"MultiLazyBuffer device mismatch, {data.device} != {device}" + self.lazydata = data class train(ContextDecorator): def __init__(self, mode:bool = True): self.mode = mode @@ -248,7 +244,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method assert self.dtype == x.dtype, f"assign dtype mismatch {self.dtype} != {x.dtype}" assert not isinstance(self.lazydata, MultiLazyBuffer) or self.lazydata.axis == x.lazydata.axis, "axis must match on MultiLazyBuffer" assert not x.requires_grad # self requires_grad is okay? - if not self.lazydata.is_realized(): return self.replace(x) + if not self.lazydata.is_realized: return self.replace(x) self.lazydata = self.lazydata.assign(x.lazydata) return self @@ -384,9 +380,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method def from_uop(y:UOp, **kwargs) -> Tensor: if y.op is Ops.BIND: return Tensor(y, **kwargs, requires_grad=False) # this is the only UOp allowed in Tensor if y.op is Ops.CONST: return Tensor(y.arg, **kwargs, requires_grad=False) - if y.op is BinaryOps.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1]) - if y.op is BinaryOps.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1]) - if y.op is BinaryOps.MAX: return Tensor.from_uop(y.src[0]).maximum(Tensor.from_uop(y.src[1])) + if y.op is Ops.MUL: return Tensor.from_uop(y.src[0]) * Tensor.from_uop(y.src[1]) + if y.op is Ops.ADD: return Tensor.from_uop(y.src[0]) + Tensor.from_uop(y.src[1]) + if y.op is Ops.MAX: return Tensor.from_uop(y.src[0]).maximum(Tensor.from_uop(y.src[1])) raise RuntimeError(f"unhandled UOp {y}") # ***** creation entrypoint ***** @@ -412,7 +408,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method print(t.shape) ``` """ - return Tensor._metaop(MetaOps.EMPTY, argfix(*shape), **kwargs) + return Tensor._metaop(Ops.EMPTY, argfix(*shape), **kwargs) @staticmethod def from_blob(ptr:int, shape:Tuple[int, ...], **kwargs) -> Tensor: @@ -424,7 +420,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method Additionally, all other keyword arguments are passed to the constructor of the tensor. """ - r = Tensor._metaop(MetaOps.EMPTY, shape, **kwargs) + r = Tensor._metaop(Ops.EMPTY, shape, **kwargs) r.lazydata.buffer.allocate(external_ptr=ptr) del r.lazydata.srcs # fake realize return r @@ -1001,30 +997,49 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method print(t.shrink((((0, 2), (0, 2)))).numpy()) ``` """ - if all(x is None or x == (0,s) for x,s in zip(arg, self.shape)): return self - return F.Shrink.apply(self, arg=tuple(x if x is not None else (0,s) for x,s in zip(arg, self.shape))) + if (shrink_arg:=[x if x is not None else (0,s) for x,s in zip(arg, self.shape)]) == [(0,s) for s in self.shape]: return self + return F.Shrink.apply(self, arg=tuple(shrink_arg)) - def pad(self, arg:Tuple[Optional[Tuple[sint, sint]], ...], value:float=0.0) -> Tensor: + def pad(self, padding:Union[Sequence[sint], Sequence[Optional[Tuple[sint, sint]]]], value:float=0.0) -> Tensor: """ - Returns a tensor that pads the each axis based on input arg. - `arg` must have the same length as `self.ndim`. - For each axis, it can be `None`, which means no pad, or a tuple `(pad_before, pad_after)`. - If `value` is specified, the tensor is padded with `value` instead of `0.0`. + Returns a tensor with padding applied based on the input `padding`. + `padding` supports two padding structures: + + 1. Flat padding: (padding_left, padding_right, padding_top, padding_bottom, ...) + - This structure matches PyTorch's pad. + - `padding` length must be even. + + 2. Group padding: (..., (padding_top, padding_bottom), (padding_left, padding_right)) + - This structure matches pad for jax, numpy, tensorflow and others. + - For each axis, padding can be `None`, meaning no padding, or a tuple `(start, end)`. + - `padding` must have the same length as `self.ndim`. + + Padding values can be negative, resulting in dimension shrinks that work similarly to Python negative slices. ```python exec="true" source="above" session="tensor" result="python" - t = Tensor.arange(6).reshape(2, 3) + t = Tensor.arange(9).reshape(1, 1, 3, 3) print(t.numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" - print(t.pad(((None, (1, 2)))).numpy()) + print(t.pad((1, 2, 0, -1)).numpy()) ``` ```python exec="true" source="above" session="tensor" result="python" - print(t.pad(((None, (1, 2))), -2).numpy()) + print(t.pad(((None, None, (0, -1), (1, 2)))).numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + print(t.pad((1, 2, 0, -1), value=-float('inf')).numpy()) ``` """ - if all(x is None or x == (0,0) for x in arg): return self - ret = F.Pad.apply(self, arg=(narg:=tuple(x if x is not None else (0,0) for x in arg))) - return ret if 0 == value else ret + F.Pad.apply(Tensor.ones_like(self), arg=narg).where(0, value) + if (flat:=all(isinstance(p, (int,UOp)) for p in padding)) and len(padding)%2 != 0: raise ValueError("Flat padding must have even number of pads") + # turn flat padding into group padding + pX = ((0,0),)*(self.ndim - len(padding)//2) + tuple(zip(padding[-2::-2], padding[::-2])) if flat else padding + if len(pX) != self.ndim: raise ValueError(f"padding length is improper, {padding=} {self.ndim=}") + X, pX = self, cast(Tuple[Tuple[sint, sint]], tuple((0,0) if p is None else p for p in pX)) + def _constant(x,px,v): return F.Pad.apply(x, arg=px) if v == 0 else F.Pad.apply(x, arg=px) + F.Pad.apply(Tensor.ones_like(x), arg=px).where(0, v) + # early return for symbolic with positive pads (no need to max) + if all(resolve(p >= 0) for p in flatten(pX)): return _constant(X, pX, value) + pads, shrinks = tuple((smax(pB,0), smax(pA,0)) for pB,pA in pX), tuple((-smin(pB,0),smin(pA+s,s)) for (pB,pA),s in zip(pX, self.shape)) + return _constant(X.shrink(shrinks), pads, value) # ***** movement high level ops ***** @@ -1230,13 +1245,11 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method ``` """ dim = self._resolve_dim(dim) - assert all(len(y.shape) == len(self.shape) and all(y.shape[i] == s for i,s in enumerate(self.shape) if i != dim) for y in args) - catargs = [self, *args] - cat_dims = [s.shape[dim] for s in catargs] - cat_dim_cumsum = [0, *itertools.accumulate(cat_dims)] - slc:List[List[Optional[Tuple[sint, sint]]]] = [[None for _ in self.shape] for _ in catargs] - for d,k,s in zip(cat_dims, cat_dim_cumsum[:-1], slc): s[dim] = (k, cat_dim_cumsum[-1] - k - d) - return functools.reduce(Tensor.add, [arg.pad(tuple(s)) for arg,s in zip(catargs, slc)]) + for arg in args: assert arg.ndim==self.ndim and all(ti==ai for i,(ti,ai) in enumerate(zip(self.shape, arg.shape)) if i!=dim) + tensors = [self, *args] + dim_cumsum = list(itertools.accumulate([t.shape[dim] for t in tensors], initial=0)) + for i,t in enumerate(tensors): tensors[i] = t.pad([(dim_cumsum[i], dim_cumsum[-1]-dim_cumsum[i+1]) if j==dim else None for j in range(t.ndim)]) + return functools.reduce(Tensor.add, tensors) def stack(self:Tensor, *args:Tensor, dim:int=0) -> Tensor: """ @@ -1251,7 +1264,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method ``` """ # checks for shapes and number of dimensions delegated to cat - return self.unsqueeze(dim).cat(*[t.unsqueeze(dim) for t in args], dim=dim) + return Tensor.cat(*[t.unsqueeze(dim) for t in [self, *args]], dim=dim) def repeat_interleave(self, repeats:int, dim:Optional[int]=None) -> Tensor: """ @@ -1262,7 +1275,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method print(t.repeat_interleave(2).numpy()) ``` """ - x, dim = (self.flatten(), 0) if dim is None else (self, dim) + x, dim = (self.flatten(), 0) if dim is None else (self, self._resolve_dim(dim)) shp = x.shape return x.reshape(*shp[:dim+1], 1, *shp[dim+1:]).expand(*shp[:dim+1], repeats, *shp[dim+1:]).reshape(*shp[:dim], shp[dim]*repeats, *shp[dim+1:]) @@ -1280,16 +1293,16 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method ``` """ repeats = argfix(repeats, *args) - base_shape = (1,) * (len(repeats) - self.ndim) + self.shape - new_shape = [x for b in base_shape for x in [1, b]] - expand_shape = [x for rs in zip(repeats, base_shape) for x in rs] + base_shape = _pad_left(self.shape, repeats)[0] + unsqueezed_shape = flatten([[1, s] for s in base_shape]) + expanded_shape = flatten([[r, s] for r,s in zip(repeats, base_shape)]) final_shape = [r*s for r,s in zip(repeats, base_shape)] - return self.reshape(new_shape).expand(expand_shape).reshape(final_shape) + return self.reshape(unsqueezed_shape).expand(expanded_shape).reshape(final_shape) - def _resolve_dim(self, dim:int, *, outer:bool=False) -> int: - if not -max(1, self.ndim+outer) <= dim < max(1, self.ndim+outer): - raise IndexError(f"{dim=} out of range {[-max(1, self.ndim+outer), max(1, self.ndim+outer)-1]}") - return dim + self.ndim+outer if dim < 0 else dim + def _resolve_dim(self, dim:int, *, extra:bool=False) -> int: + total = self.ndim + int(extra) + if not -max(1, total) <= dim <= max(1, total)-1: raise IndexError(f"{dim=} out of range {[-max(1, total), max(1, total)-1]}") + return dim + total if dim < 0 else dim def split(self, sizes:Union[int, List[int]], dim:int=0) -> Tuple[Tensor, ...]: """ @@ -1340,6 +1353,33 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method dim = self._resolve_dim(dim) return list(self.split(ceildiv(self.shape[dim], chunks) if self.shape[dim] else [0]*chunks, dim=dim)) + def meshgrid(self:Tensor, *args:Tensor, indexing:Union[Literal["ij"], Literal["xy"]]="ij") -> Tuple[Tensor, ...]: + """ + Generates coordinate matrices from coordinate vectors. + Input tensors can be scalars or 1D tensors. + + `indexing` determines how the output grids are aligned. + `ij` indexing follows matrix-style indexing and `xy` indexing follows Cartesian-style indexing. + + ```python exec="true" source="above" session="tensor" result="python" + x, y = Tensor([1, 2, 3]), Tensor([4, 5, 6]) + grid_x, grid_y = x.meshgrid(y) + print(grid_x.numpy()) + print(grid_y.numpy()) + ``` + ```python exec="true" source="above" session="tensor" result="python" + grid_x, grid_y = x.meshgrid(y, indexing="xy") + print(grid_x.numpy()) + print(grid_y.numpy()) + ``` + """ + if indexing not in ("ij", "xy"): raise RuntimeError(f'indexing must be in ("ij", "xy"), got {indexing}') + if len(tensors:=(self, *args)) == 1: return tensors + basis = tuple(range(len(tensors))) if indexing == "ij" else (1, 0) + tuple(range(2, len(tensors))) + tensors = tuple(t.reshape((-1,) + (1,)*(len(args) - i)) for i,t in zip(basis, tensors)) + output_shape = _broadcast_shape(*(t.shape for t in tensors)) + return tuple(t._broadcast_to(output_shape) for t in tensors) + def squeeze(self, dim:Optional[int]=None) -> Tensor: """ Returns a tensor with specified dimensions of input of size 1 removed. @@ -1372,27 +1412,9 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method print(t.unsqueeze(1).numpy()) ``` """ - dim = self._resolve_dim(dim, outer=True) + dim = self._resolve_dim(dim, extra=True) return self.reshape(self.shape[:dim] + (1,) + self.shape[dim:]) - def pad2d(self, padding:Sequence[int], value:float=0.0) -> Tensor: - """ - Returns a tensor that pads the last two axes specified by `padding` (padding_left, padding_right, padding_top, padding_bottom). - If `value` is specified, the tensor is padded with `value` instead of `0.0`. - - ```python exec="true" source="above" session="tensor" result="python" - t = Tensor.arange(9).reshape(1, 1, 3, 3) - print(t.numpy()) - ``` - ```python exec="true" source="above" session="tensor" result="python" - print(t.pad2d((1, 1, 2, 0), value=-float("inf")).numpy()) - ``` - """ - pads = tuple((smax(p0, 0), smax(p1, 0)) for p0, p1 in zip(padding[::2], padding[1::2]))[::-1] - padded = self.pad((None,) * (self.ndim - len(padding) // 2) + tuple(pads), value=value) - shrink = tuple((-smin(p0, 0), smin(p1 + s, s)) for p0, p1, s in zip(padding[::2], padding[1::2], padded.shape[::-1]))[::-1] - return padded.shrink((None,) * (self.ndim - len(padding) // 2) + shrink) - @property def T(self) -> Tensor: """`.T` is an alias for `.transpose()`.""" @@ -1471,10 +1493,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method # ***** reduce ops ***** def _reduce(self, fxn:Type[Function], axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False) -> Tensor: - if self.ndim == 0: - if axis is not None and any(a not in [-1, 0] for a in fully_flatten([axis])): raise IndexError(f"{axis=} out of range of [-1, 0]") - axis = () axis = tuple(self._resolve_dim(x) for x in (range(self.ndim) if axis is None else make_tuple(axis, 1))) + if self.ndim == 0: axis = () ret = fxn.apply(self, axis=axis) return ret if keepdim else ret.reshape(tuple(s for i,s in enumerate(self.shape) if i not in axis)) @@ -1502,7 +1522,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method print(t.sum(axis=1).numpy()) ``` """ - ret = self.cast(acc_dtype or sum_acc_dtype(self.dtype))._reduce(F.Sum, axis, keepdim) + ret = self.cast(sum_acc_dtype(self.dtype) if acc_dtype is None else acc_dtype)._reduce(F.Sum, axis, keepdim) return ret.cast(self.dtype) if acc_dtype is None and self.dtype in (dtypes.float16, dtypes.bfloat16) else ret def prod(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False, acc_dtype:Optional[DTypeLike]=None): @@ -1529,7 +1549,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method print(t.prod(axis=1).numpy()) ``` """ - return self.cast(acc_dtype or self.dtype)._reduce(F.Prod, axis, keepdim) + return self.cast(acc_dtype if acc_dtype is not None else self.dtype)._reduce(F.Prod, axis, keepdim) def max(self, axis:Optional[Union[int, Sequence[int]]]=None, keepdim=False): """ @@ -1845,8 +1865,8 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method if axis is None: return self.flatten().argmax(0) axis = self._resolve_dim(axis) m = self == self.max(axis=axis, keepdim=True) - idx = m * Tensor.arange(self.shape[axis]-1,-1,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1)) - return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)-1).cast(dtypes.int32) + idx = m * Tensor.arange(self.shape[axis],0,-1, requires_grad=False, device=self.device).reshape(self.shape[axis], *[1]*(self.ndim-axis-1)) + return (self.shape[axis]-idx.max(axis=axis, keepdim=keepdim)).cast(dtypes.int32) def argmin(self, axis=None, keepdim=False): """ @@ -1912,7 +1932,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method return functools.reduce(lambda x, dims: x.flatten(dims[0], dims[1] - 1) if dims[0] Tensor: + def einsum(formula:str, *operands:Tensor|Sequence[Tensor], acc_dtype:Optional[DTypeLike]=None) -> Tensor: """ Sums the product of the elements of the input tensors according to a formula based on the Einstein summation convention. @@ -1924,19 +1944,19 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method print(Tensor.einsum("ij,ij->", x, y).numpy()) ``` """ - def parse_formula(formula: str, *operands: Tensor): - if "." in formula: + def parse_formula(formula:str, *operands:Tensor): + if "..." in (formula := formula.replace(" ", "")): ell_chars, ell_longest = "".join(set(string.ascii_letters) - set(formula)), 0 for i, inp in enumerate(filter(lambda x: "..." in x, inputs := formula.split("->")[0].split(","))): - if (ell_count := max(operands[i].ndim, 1) - (len(inp) - 3)) > ell_longest: ell_longest = ell_count - inputs[i] = inp.replace("...", "" if ell_count == 0 else ell_chars[-ell_count:]) - inputs_str, out_ellipse = ",".join(inputs), "" if ell_longest == 0 else ell_chars[-ell_longest:] - return (inputs_str, formula.split("->")[1].replace("...", out_ellipse)) if "->" in formula else (inputs_str, \ - out_ellipse + ''.join(sorted(c for c in inputs_str if inputs_str.count(c) == 1 and c.isalpha() and c not in out_ellipse))) + if (ell_count := max(operands[i].ndim, 1) - (len(inp) - len("..."))) > ell_longest: ell_longest = ell_count + inputs[i] = inp.replace("...", ell_chars[-ell_count:]) + inputs_str, out_ellipse = ",".join(inputs), ell_chars[-ell_longest:] + return (inputs_str, formula.split("->")[1].replace("...", out_ellipse)) if "->" in formula else \ + (inputs_str, out_ellipse + ''.join(sorted(c for c in inputs_str if inputs_str.count(c) == 1 and c.isalpha() and c not in out_ellipse))) return formula.split("->") if "->" in formula else (formula, ''.join(c for c in sorted(formula) if formula.count(c) == 1 and c.isalpha())) - xs:Tuple[Tensor, ...] = argfix(*raw_xs) - inputs_str, output = parse_formula(formula.replace(" ", ""), *xs) + xs:Tuple[Tensor, ...] = argfix(*operands) + inputs_str, output = parse_formula(formula, *xs) inputs = inputs_str.split(",") assert len(xs) == len(inputs), f"number of inputs doesn't match number of operands in formula, expected {len(inputs)}, got {len(xs)}" @@ -1949,13 +1969,12 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method # permute to the sorted letter order, then reshape/expand to create dimensions for the missing letters xs_.append(x.permute(order).reshape([val if letter in letters else 1 for letter,val in letter_val]).expand([val for _,val in letter_val])) - # determine the inverse permutation to revert back to original order - rhs_letter_order = argsort(list(output)) - rhs_order = argsort(rhs_letter_order) + # ordinal encode the output alphabet + rhs_order = argsort(argsort(list(output))) # sum over all axes that's not in the output, then permute to the output order return functools.reduce(lambda a,b:a*b, xs_) \ - .sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in output],acc_dtype=acc_dtype).permute(rhs_order) + .sum(axis=[axis for axis,(letter,_) in enumerate(letter_val) if letter not in output], acc_dtype=acc_dtype).permute(rhs_order) # ***** processing ops ***** @@ -1963,25 +1982,24 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method assert len(self.shape) >= len(k_), f"can't pool {self.shape} with {k_}" s_, d_ = make_tuple(stride, len(k_)), make_tuple(dilation, len(k_)) assert len(k_) == len(s_) == len(d_), f"stride/dilation mismatch kernel:{k_} stride:{s_} dilation:{d_}" - noop_, i_ = [None] * len(self.shape[:-len(k_)]), self.shape[-len(k_):] - assert all(resolve(d*(k-1)+1 <= i) for k,d,i in zip(k_, d_, i_)), "kernel size cannot be greater than actual input size" - o_ = [ceildiv(i - d * (k-1), s) for i,d,k,s in zip(i_, d_, k_, s_)] - if any(resolve(k > s) for k,s in zip(k_, s_)) or any(d != 1 for d in d_): + noop, i_ = [None] * (self.ndim-len(k_)), self.shape[-len(k_):] + assert all(resolve(d*(k-1)+1 <= i) for k,d,i in zip(k_,d_,i_)), "kernel size cannot be greater than actual input size" + o_ = [ceildiv(i-d*(k-1), s) for i,d,k,s in zip(i_,d_,k_,s_)] + if any(resolve(k > s) for k,s in zip(k_,s_)) or any(d != 1 for d in d_): # repeats such that we don't need padding - xup = self.repeat([1]*len(noop_) + [ceildiv(k*(i+d), i) for k,i,d in zip(k_, i_, d_)]) + x = self.repeat([1]*len(noop) + [ceildiv(k*(i+d), i) for k,i,d in zip(k_,i_,d_)]) # handle dilation - xup = xup.shrink(tuple(noop_ + [(0,k*(i+d)) for k,i,d in zip(k_, i_, d_)])).reshape(noop_ + flatten((k,i+d) for k,i,d in zip(k_, i_, d_))) + x = x.shrink(tuple(noop + [(0,k*(i+d)) for k,i,d in zip(k_,i_,d_)])).reshape(noop + flatten((k,i+d) for k,i,d in zip(k_,i_,d_))) # handle stride - xup = xup.shrink( - tuple(noop_ + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_, o_, s_)))).reshape(noop_ + flatten((k,o,s) for k,o,s in zip(k_, o_, s_))) - xup = xup.shrink(tuple(noop_ + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_, o_)))).reshape(noop_ + flatten((k,o) for k,o in zip(k_, o_))) + x = x.shrink(tuple(noop + flatten(((0,k), (0,o*s)) for k,o,s in zip(k_,o_,s_)))).reshape(noop + flatten((k,o,s) for k,o,s in zip(k_,o_,s_))) + x = x.shrink(tuple(noop + flatten(((0,k), (0,o), (0,1)) for k,o in zip(k_,o_)))).reshape(noop + flatten((k,o) for k,o in zip(k_,o_))) # permute to move reduce to the end - return xup.permute(*range(len(noop_)), *[len(noop_)+i*2+1 for i in range(len(i_))], *[len(noop_)+i*2 for i in range(len(i_))]) + return x.permute(*range(len(noop)), *[len(noop)+i*2+1 for i in range(len(i_))], *[len(noop)+i*2 for i in range(len(i_))]) # TODO: once the shapetracker can optimize well, remove this alternative implementation - xup = self.pad(tuple(noop_ + [(0, max(0,o*s-i)) for i,o,s in zip(i_, o_, s_)])).shrink(tuple(noop_ + [(0,o*s) for o,s in zip(o_, s_)])) - xup = xup.reshape(noop_ + flatten(((o,s) for o,s in zip(o_, s_)))) - xup = xup.shrink(tuple(noop_ + flatten(((0,o), (0,k)) for o,k in zip(o_, k_)))) - return xup.permute(*range(len(noop_)), *[len(noop_)+i*2 for i in range(len(i_))], *[len(noop_)+i*2+1 for i in range(len(i_))]) + x = self.pad(tuple(noop + [(0, max(0,o*s-i)) for i,o,s in zip(i_,o_,s_)])).shrink(tuple(noop + [(0,o*s) for o,s in zip(o_,s_)])) + x = x.reshape(noop + flatten(((o,s) for o,s in zip(o_,s_)))) + x = x.shrink(tuple(noop + flatten(((0,o), (0,k)) for o,k in zip(o_,k_)))) + return x.permute(*range(len(noop)), *[len(noop)+i*2 for i in range(len(i_))], *[len(noop)+i*2+1 for i in range(len(i_))]) def _padding2d(self, padding:Union[int, Sequence[int]], dims:int) -> Sequence[int]: return [padding]*2*dims if isinstance(padding, int) else (padding if len(padding) == 2*dims else [p for p in padding for _ in range(2)][::-1]) @@ -2004,7 +2022,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method ``` """ padding_, axis = self._padding2d(padding, len(k_ := make_tuple(kernel_size, 2))), tuple(range(-len(k_), 0)) - def pool(x:Tensor) -> Tensor: return x.pad2d(padding_)._pool(k_, stride if stride is not None else k_, dilation) + def pool(x:Tensor) -> Tensor: return x.pad(padding_)._pool(k_, stride if stride is not None else k_, dilation) return pool(self).mean(axis=axis) if count_include_pad else pool(self).sum(axis=axis) / pool(self.ones_like()).sum(axis=axis) def max_pool2d(self, kernel_size=(2,2), stride=None, dilation=1, padding=0): @@ -2024,7 +2042,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method ``` """ padding_ = self._padding2d(padding, len(k_ := make_tuple(kernel_size, 2))) - return self.pad2d(padding_, value=float('-inf'))._pool(k_, stride if stride is not None else k_, dilation).max(axis=tuple(range(-len(k_), 0))) + return self.pad(padding_, value=dtypes.min(self.dtype))._pool(k_, stride if stride is not None else k_, dilation).max(tuple(range(-len(k_), 0))) def conv2d(self, weight:Tensor, bias:Optional[Tensor]=None, groups=1, stride=1, dilation=1, padding:int|Tuple[int, ...]=0, acc_dtype:Optional[DTypeLike]=None) -> Tensor: @@ -2048,7 +2066,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method padding_ = self._padding2d(padding, len(HW)) # conv2d is a pooling op (with padding) - x = self.pad2d(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W) + x = self.pad(padding_)._pool(HW, stride, dilation) # (bs, groups*cin, oy, ox, H, W) rcout, oyx = cout//groups, x.shape[2:-len(HW)] if not all(x == 3 for x in HW) or stride != 1 or dilation != 1 or not WINO: # normal conv @@ -2066,7 +2084,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method # todo: stride == dilation # use padding to round up to 4x4 output tiles # (bs, cin_, tyx, HWI) - d = self.pad2d(sum([[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])], []))._pool(HWI, HWO) # noqa: E501 + d = self.pad(sum([[padding_[i*2], padding_[i*2+1] + (-(dim + sum(padding_[i * 2:(i + 1) * 2]) - 2) % 4)] for i, dim in enumerate(self.shape[-len(HW):])], []))._pool(HWI, HWO) # noqa: E501 # move HW to the front: # (HWI, bs, cin_, tyx) d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW))) tyx = d.shape[-len(HWI):] # dim of tiling @@ -2116,11 +2134,19 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method return x.conv2d(w.flatten(end_dim=1), groups=groups, bias=bias, dilation=dilation, padding=padding) def dot(self, w:Tensor, acc_dtype:Optional[DTypeLike]=None) -> Tensor: + """ Performs dot product between two tensors. + If `w` is 1-D, it's a sum product over the last axis of `self` and `w`. + If `w` is N-D with N>=2, it's a sum product over the last axis of `self` and the second-to-last axis of `w`. You can pass in the optional `acc_dtype` keyword argument to control the data type of the accumulation. + ```python exec="true" source="above" session="tensor" result="python" + a = Tensor([1, 2, 3]) + b = Tensor([1, 1, 0]) + print(a.dot(b).numpy()) + ``` ```python exec="true" source="above" session="tensor" result="python" a = Tensor([[1, 2], [3, 4]]) b = Tensor([[5, 6], [7, 8]]) @@ -2128,11 +2154,11 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method ``` """ if IMAGE: return self.image_dot(w, acc_dtype) - n1, n2 = len(self.shape), len(w.shape) - assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D" - if (L:=self.shape[-1]) != (R:=w.shape[-min(n2, 2)]): raise AssertionError(f"shapes {self.shape} and {w.shape} cannot be multiplied ({L} != {R})") - x = self.reshape(*self.shape[0:-1], *[1]*min(n1-1, n2-1, 1), self.shape[-1]) - w = w.reshape(*w.shape[0:-2], *[1]*min(n1-1, n2-1, 1), *w.shape[-min(n2, 2):]).transpose(-1, -min(n2, 2)) + x, dx, dw = self, self.ndim, w.ndim + if not (dx > 0 and dw > 0): raise RuntimeError(f"both tensors need to be at least 1D, got {dx}D and {dw}D") + if x.shape[-1] != w.shape[axis_w:=-min(w.ndim,2)]: raise RuntimeError(f"cannot dot {x.shape} and {w.shape}") + x = x.reshape(*x.shape[0:-1], *[1]*min(dx-1, dw-1, 1), x.shape[-1]) + w = w.reshape(*w.shape[0:-2], *[1]*min(dx-1, dw-1, 1), *w.shape[axis_w:]).transpose(-1, axis_w) return (x*w).sum(-1, acc_dtype=acc_dtype).cast(least_upper_dtype(x.dtype, w.dtype) if acc_dtype is None else acc_dtype) def matmul(self, x:Tensor, reverse=False, acc_dtype:Optional[DTypeLike]=None) -> Tensor: @@ -2153,7 +2179,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method def _cumsum(self, axis:int=0, _first_zero=False) -> Tensor: assert self.shape[axis] != 0 pl_sz = self.shape[axis] - int(not _first_zero) - return self.transpose(axis,-1).pad2d((pl_sz,-int(_first_zero)))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1) + return self.transpose(axis,-1).pad((pl_sz,-int(_first_zero)))._pool((self.shape[axis],)).sum(-1).transpose(axis,-1) def cumsum(self, axis:int=0) -> Tensor: """ Computes the cumulative sum of the tensor along the specified axis. @@ -2174,7 +2200,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method # for now this is a two stage cumsum SPLIT = 256 if not isinstance(s:=self.shape[axis], int) or s <= SPLIT*2: return self._cumsum(axis) - ret = self.transpose(axis,-1).pad2d((round_up(s, SPLIT)-s, 0)).unflatten(-1, (-1, SPLIT))._cumsum(-1) + ret = self.transpose(axis,-1).pad((round_up(s, SPLIT)-s, 0)).unflatten(-1, (-1, SPLIT))._cumsum(-1) base_add = ret[..., -1]._cumsum(-1, _first_zero=True) base_add = base_add.unsqueeze(-1).expand(*base_add.shape, ret.shape[-1]) def fix(x:Tensor): return x.flatten(start_dim=-2)[..., -s:].transpose(axis,-1) @@ -3557,11 +3583,12 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method # *** image Tensor function replacements *** - def image_dot(self, w:Tensor, acc_dtype=None) -> Tensor: + def image_dot(self, w:Tensor, acc_dtype:Optional[DTypeLike]=None) -> Tensor: # NOTE: we use a 1x1 conv2d to do the matmul. mxk @ kxn = (1,k,m,1).conv2d(n,k,1,1) - n1, n2 = len(self.shape), len(w.shape) - assert n1 != 0 and n2 != 0, f"both arguments to matmul need to be at least 1D, but they are {n1}D and {n2}D" - assert self.shape[-1] == w.shape[-min(n2, 2)], f"Input Tensor shapes {self.shape} and {w.shape} cannot be multiplied ({self.shape[-1]} != {w.shape[-min(n2, 2)]})" # noqa: E501 + x, dx, dw = self, self.ndim, w.ndim + if not (dx > 0 and dw > 0): raise RuntimeError(f"both tensors need to be at least 1D, got {dx}D and {dw}D") + if x.shape[-1] != w.shape[-min(w.ndim, 2)]: raise RuntimeError(f"cannot image_dot {x.shape} and {w.shape}") + bs, groups, cin, cout = prod(self.shape[0:-2]), prod(w.shape[0:-2]), w.shape[-2], w.shape[-1] out_shape_t = self.shape[0:-2] + (cout,-1) if len(self.shape) > 1 else (cout, ) @@ -3614,7 +3641,7 @@ class Tensor(SimpleMathTrait): # pylint: disable=abstract-method else: w = w.reshape(cout//4, H, rcin_hi, W, rcin_lo, 4).permute(0,1,2,3,5,4) # prepare input - x = x.permute(0,3,4,5,1,2).pad2d(self._padding2d(padding, 2))._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W) + x = x.permute(0,3,4,5,1,2).pad(self._padding2d(padding, 2))._pool((H, W), stride, dilation) # -> (bs, groups, rcin_hi, rcin_lo, oy, ox, H, W) x = x.permute(0,4,5,1,2,3,6,7).reshape(bs, (oy := x.shape[4]), (ox := x.shape[5]), *cout_expand[0:2], 1, 1, rcin_hi, rcin_lo, H, W) # prepare weights diff --git a/tinygrad/viz/serve.py b/tinygrad/viz/serve.py index cf619d83f7..54303bf887 100755 --- a/tinygrad/viz/serve.py +++ b/tinygrad/viz/serve.py @@ -11,7 +11,7 @@ from tinygrad.codegen.kernel import Kernel uops_colors = {Ops.LOAD: "#ffc0c0", Ops.PRELOAD: "#ffc0c0", Ops.STORE: "#87CEEB", Ops.CONST: "#e0e0e0", Ops.VCONST: "#e0e0e0", Ops.DEFINE_GLOBAL: "#ffe0b0", Ops.DEFINE_LOCAL: "#ffe0d0", Ops.DEFINE_ACC: "#f0ffe0", Ops.REDUCE_AXIS: "#FF6B6B", Ops.RANGE: "#c8a0e0", Ops.ASSIGN: "#e0ffc0", Ops.BARRIER: "#ff8080", Ops.IF: "#c8b0c0", Ops.SPECIAL: "#c0c0ff", - Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", **{x:"#ffffc0" for x in GroupOp.ALU}} + Ops.INDEX: "#e8ffa0", Ops.WMMA: "#efefc0", Ops.VIEW: "#C8F9D4", **{x:"#ffffc0" for x in GroupOp.ALU}, Ops.BUFFER: "#B0BDFF",} # ** API spec