mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-12 08:28:29 -05:00
@@ -3,33 +3,35 @@ from tinygrad.uop.ops import UOp, Ops, PatternMatcher, UPat
|
||||
from tinygrad.renderer.cstyle import CStyleLanguage, base_rewrite, extra_pm
|
||||
from tinygrad.helpers import strip_parens
|
||||
|
||||
def _mask(dt:DType): return 0xFF if dt.itemsize == 1 else 0xFFFF
|
||||
|
||||
def sign_extend(val:UOp, sext_am:int):
|
||||
return (UOp.where((val >> (sext_am - 1)) > 0, UOp.const(dtypes.uint32, 0xffffffff) << sext_am, UOp.const(dtypes.uint32, 0)) \
|
||||
| val.bitcast(dtypes.uint32)).bitcast(dtypes.int)
|
||||
|
||||
# store for char: buf[idx/4] <- (var << (idx%4)*8))
|
||||
def packed_store(bidx:UOp, var:UOp):
|
||||
shift_am = (bidx.src[1].cast(dtypes.uint32)%UOp.const(dtypes.uint32, 4//var.dtype.itemsize))*UOp.const(dtypes.uint32, 8*var.dtype.itemsize)
|
||||
new_v = (var & (0xFF if var.dtype.itemsize == 1 else 0xFFFF)).cast(dtypes.uint32) << shift_am
|
||||
mask = (((0xFF if var.dtype.itemsize == 1 else 0xFFFF) << shift_am) ^ 0xFFFFFFFF).cast(dtypes.uint32)
|
||||
div_idx = bidx.src[1]//(4//var.dtype.itemsize)
|
||||
elems, mask = 4//var.dtype.itemsize, _mask(var.dtype)
|
||||
shift_am, div_idx = (bidx.src[1].cast(dtypes.uint32) % elems) * (8*var.dtype.itemsize), bidx.src[1] // elems
|
||||
new_v, wmask = (var & mask).cast(dtypes.uint32) << shift_am, ((mask << shift_am) ^ 0xFFFFFFFF).cast(dtypes.uint32)
|
||||
# preserve valid condition (bidx.src[2]) if it exists for gated stores
|
||||
idx_src = (bidx.src[0], div_idx) if len(bidx.src) == 2 else (bidx.src[0], div_idx, bidx.src[2])
|
||||
buf = UOp.load(UOp(Ops.INDEX, bidx.dtype, idx_src), dtype=dtypes.uint32)
|
||||
return UOp.store(UOp(Ops.INDEX, bidx.dtype, idx_src), ((buf & mask) | new_v.cast(dtypes.uint32)))
|
||||
return UOp.store(UOp(Ops.INDEX, bidx.dtype, idx_src), (buf & wmask) | new_v)
|
||||
|
||||
# load for char: sign_extend(buf[idx/4] >> ((idx%4)*8))
|
||||
def packed_load(root:UOp, bidx:UOp, dtype:DType, var:UOp|None=None):
|
||||
div_idx = bidx.src[1]//(4//dtype.itemsize)
|
||||
shift_am = (bidx.src[1].cast(dtypes.uint32)%UOp.const(dtypes.uint32, 4//dtype.itemsize))*UOp.const(dtypes.uint32, 8*dtype.itemsize)
|
||||
if var is not None: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx, bidx.src[2])), var, dtype=dtypes.uint32, arg=root.arg)
|
||||
else: load = UOp.load(UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx)), *root.src[1:], dtype=dtypes.uint32, arg=root.arg)
|
||||
val = (load.cast(dtypes.uint32) >> shift_am) & (0xFF if dtype.itemsize == 1 else 0xFFFF)
|
||||
elems, mask = 4//dtype.itemsize, _mask(dtype)
|
||||
shift_am, div_idx = (bidx.src[1].cast(dtypes.uint32) % elems) * (8*dtype.itemsize), bidx.src[1] // elems
|
||||
idx = UOp(Ops.INDEX, bidx.dtype, (bidx.src[0], div_idx, bidx.src[2]) if var is not None else (bidx.src[0], div_idx))
|
||||
load = UOp.load(idx, *([var] if var is not None else root.src[1:]), dtype=dtypes.uint32, arg=root.arg)
|
||||
val = (load.cast(dtypes.uint32) >> shift_am) & mask
|
||||
return sign_extend(val, 8*dtype.itemsize).cast(dtype) if dtype in [dtypes.char, dtypes.short] else val.cast(dtype)
|
||||
|
||||
def is_packed(dt:DType, odt:DType|None = None) -> bool:
|
||||
if odt is None: odt = dt
|
||||
return dt.itemsize < 4 and dt.base != dtypes.half and (not isinstance(odt, PtrDType) or odt.addrspace != AddrSpace.REG)
|
||||
def _packed_size(dt:PtrDType): return dt.size // (4//dt.itemsize) if is_packed(dt) else dt.size
|
||||
|
||||
wgsl_matcher = PatternMatcher([
|
||||
(UPat((Ops.CMPLT, Ops.XOR), src=(UPat(name="a", dtype=dtypes.bool), UPat.var("b")), name="c"),
|
||||
@@ -61,10 +63,8 @@ class WGSLRenderer(CStyleLanguage):
|
||||
(UPat.cvar("x", dtype=dtypes.bool), lambda x: "true" if x.arg else "false"),
|
||||
(UPat(Ops.CONST, dtype=(dtypes.uchar, dtypes.ushort, dtypes.uint32), name="x"),
|
||||
lambda x: f"bitcast<u32>({x.arg})" if x.arg < 0 else f"{x.arg&0xFFFFFFFF}u"),
|
||||
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x:
|
||||
f"var<workgroup> {ctx[x]}: array<{ctx.buf_map(x.dtype.base)},{x.dtype.size//(4//x.dtype.itemsize) if is_packed(x.dtype) else x.dtype.size}>;"),
|
||||
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x:
|
||||
f"var {ctx[x]}: array<{ctx.buf_map(x.dtype)},{x.dtype.size//(4//x.dtype.itemsize) if is_packed(x.dtype) else x.dtype.size}>;"),
|
||||
(UPat(Ops.DEFINE_LOCAL, name="x"), lambda ctx,x: f"var<workgroup> {ctx[x]}: array<{ctx.buf_map(x.dtype.base)},{_packed_size(x.dtype)}>;"),
|
||||
(UPat(Ops.DEFINE_REG, name="x"), lambda ctx,x: f"var {ctx[x]}: array<{ctx.buf_map(x.dtype)},{_packed_size(x.dtype)}>;"),
|
||||
(UPat(Ops.BITCAST, dtype=dtypes.half, name="x", src=(UPat(dtype=(dtypes.short, dtypes.ushort, dtypes.uint32),),)),
|
||||
lambda ctx,x: f"bitcast<vec2<f16>>({ctx[x.src[0]]})[0]"),
|
||||
(UPat(Ops.BITCAST, dtype=(dtypes.char, dtypes.uchar), name="x"), lambda ctx,x: f"bitcast<{ctx.type_map[x.dtype]}>({ctx[x.src[0]]}&0xFF)"),
|
||||
|
||||
Reference in New Issue
Block a user