mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
small changes from define_reg (#11327)
* small changes from define_reg * fix webgpu
This commit is contained in:
8
.github/workflows/test.yml
vendored
8
.github/workflows/test.yml
vendored
@@ -239,8 +239,8 @@ jobs:
|
||||
run: |
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=1 python3 ./extra/gemm/simple_matmul.py
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=1 python3 ./extra/gemm/simple_matmul.py
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores_padded_amd TestLinearizer.test_tensor_cores_padded_uops
|
||||
- name: Test emulated AMD MFMA tensor cores
|
||||
@@ -252,8 +252,8 @@ jobs:
|
||||
run: |
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD_RDNA4=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD_RDNA4=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=0 python3 ./extra/gemm/simple_matmul.py
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD_RDNA4=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=1 python3 ./extra/gemm/simple_matmul.py
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD_RDNA4=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=1 python3 ./extra/gemm/simple_matmul.py
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD_RDNA4=1 FORWARD_ONLY=1 PYTHON=1 N=16 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD_RDNA4=1 FORWARD_ONLY=1 PYTHON=1 N=64 HALF=1 ACC_HALF=1 ATOL=1e-3 python3 ./extra/gemm/simple_matmul.py
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD_RDNA4=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores
|
||||
PYTHONPATH=. DEBUG=2 EMULATE_AMD_RDNA4=1 FORWARD_ONLY=1 PYTHON=1 python3 test/test_linearizer.py TestLinearizer.test_tensor_cores_padded TestLinearizer.test_tensor_cores_padded_uops
|
||||
- name: Test emulated CUDA tensor cores
|
||||
|
||||
@@ -704,7 +704,7 @@ class TestLinearizer(unittest.TestCase):
|
||||
r = (x@y).relu()
|
||||
k = helper_linearizer_opt(r)[-1]
|
||||
uops = get_program(k.get_optimized_ast(), k.opts).uops
|
||||
stores = [u for u in uops if u.op is Ops.STORE and u.src[0].op is not Ops.DEFINE_REG]
|
||||
stores = [u for u in uops if u.op is Ops.STORE and u.dtype.addrspace != AddrSpace.REG]
|
||||
|
||||
# the float4 value stores directly in lds and we skip upcast
|
||||
self.assertEqual(stores[0].src[-1].dtype, dtypes.float.vec(4))
|
||||
|
||||
@@ -6,7 +6,7 @@ from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.dtype import dtypes, ImageDType, PtrDType, promo_lattice, DType
|
||||
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, graph_rewrite, GroupOp, identity_element
|
||||
from tinygrad.uop.symbolic import split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat
|
||||
from tinygrad.helpers import getenv, flatten, AMX, prod, partition
|
||||
from tinygrad.helpers import getenv, flatten, AMX, prod, partition, all_same
|
||||
from tinygrad.uop.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
@@ -109,11 +109,15 @@ def expand_index(buf:UOp, vec:UOp, mask:UOp|None=None):
|
||||
def cat_after_store(cat:UOp, data:UOp):
|
||||
# TODO: this is written in many places
|
||||
offset = 0
|
||||
ret = []
|
||||
ret: list[UOp] = []
|
||||
for s in cat.src:
|
||||
ret.append(s.store(data.gep(tuple(range(offset, offset+s.dtype.count)))))
|
||||
offset += s.dtype.count
|
||||
return UOp.sink(ret[0], *ret[1:])
|
||||
# dtype CAT
|
||||
dtypes: list[PtrDType] = [x.dtype for x in ret if isinstance(x.dtype, PtrDType)]
|
||||
assert len(dtypes) == len(ret) and all_same([(x.size, x.addrspace) for x in dtypes])
|
||||
out_dtype = dtypes[0].base.scalar().vec(sum([x.count for x in dtypes])).ptr(dtypes[0].size, dtypes[0].addrspace)
|
||||
return UOp(Ops.PTRCAT, dtype=out_dtype, src=tuple(ret))
|
||||
|
||||
def gep_on_store(gep:UOp, st:UOp):
|
||||
# NOTE: we need to invert the gep here, but it may be an expanding gep
|
||||
|
||||
@@ -155,12 +155,11 @@ spec = PatternMatcher([
|
||||
|
||||
# INDEX is used in new style load/store
|
||||
# INDEX takes a <buf, alu, gate?>
|
||||
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat())), lambda: True),
|
||||
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL)), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
|
||||
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG)), UPat())), lambda: True),
|
||||
(UPat(Ops.INDEX, src=(UPat((Ops.DEFINE_GLOBAL, Ops.DEFINE_LOCAL, Ops.DEFINE_REG)), UPat(), UPat(dtype=dtypes.bool))), lambda: True),
|
||||
|
||||
# LOAD/STORE reg
|
||||
(UPat(Ops.LOAD, src=(UPat((Ops.STORE, Ops.DEFINE_REG)),)), lambda: True),
|
||||
(UPat(Ops.STORE, src=(UPat(Ops.DEFINE_REG), UPat())), lambda: True),
|
||||
# LOAD on STORE
|
||||
(UPat(Ops.LOAD, src=(UPat(Ops.STORE),)), lambda: True),
|
||||
|
||||
# LOAD takes a <bufidx, alt?, barrier?>
|
||||
(UPat(Ops.LOAD, src=(index_pat,)), validate_index),
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any, Literal, cast
|
||||
import math, operator, struct, functools
|
||||
from collections import defaultdict
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
|
||||
from tinygrad.dtype import ConstType, dtypes, PtrDType
|
||||
from tinygrad.dtype import ConstType, dtypes, PtrDType, AddrSpace
|
||||
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, cdiv, cmod, CORRECT_DIVMOD_FOLDING
|
||||
from tinygrad.uop.transcendental import xpow
|
||||
|
||||
@@ -406,7 +406,11 @@ def reduce_mul_chain(r:UOp):
|
||||
|
||||
# this is symbolic 2.0
|
||||
REMOVE_FROM_SINK = {Ops.SINK, Ops.UNROLL, Ops.PTRCAT, Ops.CAT}
|
||||
REMOVE_FROM_BARRIER = {Ops.VECTORIZE, Ops.SINK, Ops.CAT, Ops.PTRCAT}
|
||||
sym = symbolic_flat+PatternMatcher([
|
||||
# LOAD/STORE -> NOOP
|
||||
(UPat.var('x').store(UPat.var('x').load()), lambda x: None if x.dtype.addrspace != AddrSpace.REG else x.src[0].src[0]),
|
||||
(UPat(Ops.LOAD, src=(UPat.cvar('c'))), lambda c: c),
|
||||
# self ASSIGN is just self
|
||||
(UPat(Ops.ASSIGN, src=(UPat.var('x'), UPat.var('x'))), lambda x: x),
|
||||
# VECTORIZE/CONST, VECTORIZE/GEP
|
||||
@@ -459,8 +463,10 @@ sym = symbolic_flat+PatternMatcher([
|
||||
# remove NOOPs from SINK
|
||||
(UPat(Ops.SINK, name="root"),
|
||||
lambda root: UOp(Ops.SINK, root.dtype, a, root.arg) if len(a:=tuple(x for x in root.src if x.op is not Ops.NOOP)) != len(root.src) else None),
|
||||
# remove VECTORIZE from SINK/BARRIER
|
||||
(UPat(Ops.BARRIER, src=(UPat((Ops.VECTORIZE, Ops.SINK, Ops.CAT, Ops.PTRCAT), name='sink'),)), lambda sink: UOp(Ops.BARRIER, dtypes.void, sink.src)),
|
||||
# remove VECTORIZE from SINK/BARRIER. TODO: SINK/BARRIER are really the same thing at GLOBAL/LOCAL levels
|
||||
(UPat(Ops.BARRIER, name="root"),
|
||||
lambda root: UOp(Ops.BARRIER, root.dtype, tuple(flatten(x.src if x.op in REMOVE_FROM_BARRIER else (x,) for x in root.src)), root.arg)
|
||||
if any(x.op in REMOVE_FROM_BARRIER for x in root.src) else None),
|
||||
(UPat(Ops.SINK, name="root"),
|
||||
lambda root: UOp(Ops.SINK, root.dtype, tuple(flatten(x.src if x.op in REMOVE_FROM_SINK else (x,) for x in root.src)), root.arg)
|
||||
if any(x.op in REMOVE_FROM_SINK for x in root.src) else None),
|
||||
|
||||
Reference in New Issue
Block a user