diff --git a/test/test_uops.py b/test/test_uops.py index 0ce664eaa7..9dedc99198 100644 --- a/test/test_uops.py +++ b/test/test_uops.py @@ -4,7 +4,7 @@ import numpy as np from tinygrad.helpers import dtypes, getenv, DType from tinygrad.tensor import Device from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ASTRunner, Compiled -from tinygrad.codegen.linearizer import UOps, ConstOp, MemOp, UOp +from tinygrad.codegen.linearizer import UOps, MemOp, UOp from tinygrad.shape.symbolic import Variable def _uops_to_prg(uops): @@ -32,7 +32,7 @@ def _test_single_value(vals, op, dtype): def _test_single_value_const(vals, op, dtype): uops = [] uop(uops, UOps.DEFINE_GLOBAL, None, (), ('data0', dtype)) - loads = (uop(uops, UOps.LOAD, dtype, [], ConstOp(a, Variable.ands([]))) for a in vals) + loads = (uop(uops, UOps.CONST, dtype, [], a) for a in vals) alu = uop(uops, UOps.ALU, dtype, loads, op) uop(uops, UOps.STORE, None, (alu, ), MemOp('data0', Variable.num(0), False, dtype, Variable.ands([]))) buf = Device[Device.DEFAULT].buffer(1, dtype) diff --git a/tinygrad/codegen/assembly.py b/tinygrad/codegen/assembly.py index 3dd446b527..645ceb81bc 100644 --- a/tinygrad/codegen/assembly.py +++ b/tinygrad/codegen/assembly.py @@ -1,5 +1,5 @@ from typing import Tuple, List, NamedTuple, Any, Dict, Optional, Union, DefaultDict, cast -from tinygrad.codegen.linearizer import UOps, ConstOp, MemOp, UOp +from tinygrad.codegen.linearizer import UOps, MemOp, UOp from tinygrad.ops import BinaryOps, UnaryOps from tinygrad.helpers import DType, dtypes, DEBUG from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode @@ -158,32 +158,24 @@ def uops_to_asmstyle(lang, function_name:str, uops:List[UOp]): elif uop == UOps.DEFINE_ACC: reg = lang.newreg(u, dtype=dtype) lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], args)) + elif uop == UOps.SPECIAL: + lang.tor[u] = lang.tor[args] + elif uop == UOps.CONST: + lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(u, dtype=dtype), [], args)) elif uop == UOps.LOAD: - if isinstance(args, ConstOp): - if args.valid.min == 0 and args.valid.max == 1: - reg = lang.newreg(u, dtype=dtype) - lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], args.invalid_value)) + idx, treg, off = lang.addr_w_offset(args) + reg = lang.newreg(u, dtype=dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar))) + if args.valid.min == 0: + lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], 0)) + if args.valid.max == 1: pred = args.valid.render(lang.render_ops, lang) lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False))) - lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], args.value)) - lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}")) - skipload_branch += 1 - else: - lang.ins.append(AssemblyInstruction(UOps.LOAD, lang.newreg(u, dtype=dtype), [], args.value if args.valid.min == 1 else args.invalid_value)) - else: - idx, treg, off = lang.addr_w_offset(args) - reg = lang.newreg(u, dtype=dtype, scalar=(idx.scalar and (not isinstance(treg, Register) or treg.scalar))) - if args.valid.min == 0: - lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [], 0)) - if args.valid.max == 1: - pred = args.valid.render(lang.render_ops, lang) - lang.ins.append(AssemblyInstruction(UOps.COND_BRANCH, None, [pred], (f"$skipload_{skipload_branch}", False))) - if args.valid.max == 1: - # NOTE: you can't compute the index in here, because it assumes it's all available later - lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None))) - if args.valid.min == 0 and args.valid.max == 1: - lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}")) - skipload_branch += 1 + if args.valid.max == 1: + # NOTE: you can't compute the index in here, because it assumes it's all available later + lang.ins.append(AssemblyInstruction(UOps.LOAD, reg, [idx] + ([treg] if treg is not None else []), (off, 'global' if not args.local else 'shared', args.memory_dtype if args.memory_dtype != dtypes.float else None))) + if args.valid.min == 0 and args.valid.max == 1: + lang.ins.append(AssemblyInstruction(UOps.LABEL, None, [], f"$skipload_{skipload_branch}")) + skipload_branch += 1 elif uop == UOps.STORE: if args is None: lang.ins.append(AssemblyInstruction(UOps.ALU, lang.tor[vin[0]], [lang.tor[vin[1]]], UnaryOps.NOOP)) diff --git a/tinygrad/codegen/linearizer.py b/tinygrad/codegen/linearizer.py index 8edfc8e5e3..816ab3a30d 100644 --- a/tinygrad/codegen/linearizer.py +++ b/tinygrad/codegen/linearizer.py @@ -1,15 +1,15 @@ from __future__ import annotations from typing import List, Tuple, Any, Optional, cast, DefaultDict, NamedTuple, Dict, Iterator, Union, Sequence, Final -import itertools, math +import itertools, math, functools from collections import defaultdict from enum import Enum, auto -from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, partition, prod +from tinygrad.helpers import colored, ImageDType, DEBUG, dtypes, DType, partition, prod, PtrDType from tinygrad.ops import LazyOp, UnaryOps from tinygrad.ops import ReduceOps, BinaryOps, TernaryOps from tinygrad.runtime.lib import RawConst from tinygrad.shape.shapetracker import ShapeTracker -from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, sym_rename +from tinygrad.shape.symbolic import Variable, NumNode, Node, SumNode, MulNode, DivNode, ModNode, LtNode, AndNode, sym_rename from tinygrad.codegen.optimizer import OptimizedKernel from tinygrad.codegen.kernel import LocalBuffer VariableOrNum = Union[Variable, NumNode, Node] @@ -18,7 +18,7 @@ VariableOrNum = Union[Variable, NumNode, Node] class UOps(Enum): LOOP = auto(); ENDLOOP = auto() # loops can be global, local, or other # noqa: E702 DEFINE_GLOBAL = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # this defines buffers # noqa: E702 - LOAD = auto(); STORE = auto(); BARRIER = auto() # noqa: E702 + LOAD = auto(); STORE = auto(); CONST = auto(); BARRIER = auto() # noqa: E702 ALU = auto(); WMMA = auto(); CAST = auto(); GEP = auto() # noqa: E702 # TODO: add CONST. use ALU WHERE for gated load # *** assembly only UOps *** @@ -58,13 +58,6 @@ class MemOp(NamedTuple): valid: Node invalid_value: Union[float, int] = 0.0 -class ConstOp(NamedTuple): - value: Union[float, int] - - # shared - valid: Node - invalid_value: Union[float, int] = 0.0 - class UOp(NamedTuple): uop: UOps dtype: Optional[DType] @@ -96,6 +89,19 @@ class Linearizer(OptimizedKernel): assert self.bufs[i].realized.__class__ is not RawConst # constants shouldn't be loaded with memops return self.arg_bufs[self.bufs[i].realized] + def uop_alu_idx(self, a:UOp, b, ops, ctx:Linearizer, op, dtype=dtypes.int32): + render_b:UOp = cast(UOp, (NumNode(b) if not isinstance(b, Node) else b).render(ops, ctx)) + return self.uop(UOps.ALU, dtype, (a, render_b), op, cachable=True) + + render_ops: Any = { Variable: lambda self, ops, ctx: ctx.uop(UOps.SPECIAL, dtypes.int32, tuple(), self), + NumNode: lambda self, ops, ctx: ctx.uop(UOps.CONST, dtypes.int32, tuple(), self.b), + MulNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MUL), + DivNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.DIV), + ModNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.MOD), + LtNode: lambda self, ops, ctx: ctx.uop_alu_idx(self.a.render(ops, ctx), self.b, ops, ctx, BinaryOps.CMPLT, dtype=dtypes.bool), + SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.ADD), self.nodes[1:], self.nodes[0].render(ops,ctx)), + AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.uop_alu_idx(a, b, ops, ctx, BinaryOps.MUL, dtype=dtypes.bool), self.nodes[1:], self.nodes[0].render(ops,ctx)) } + def global_load(self, i:int, idxs:Sequence[VariableOrNum], acc=None) -> List[UOp]: const = self.bufs[i].realized._buf if isinstance(self.bufs[i].realized, RawConst) else acc @@ -126,9 +132,14 @@ class Linearizer(OptimizedKernel): if acc is not None: assert valid.min == 1 self.load_cache[key] = self.uop(UOps.DEFINE_ACC, localtype, [], this_const) + elif this_const is not None: + self.load_cache[key] = self.uop(UOps.CONST, localtype, [], this_const) + if valid.min == 0 and valid.max == 1: + valid_rendered = valid.render(self.render_ops, self) + alt = self.uop(UOps.CONST, localtype, [], invalid_value) + self.load_cache[key] = self.uop(UOps.ALU, localtype, [valid_rendered, self.load_cache[key], alt], TernaryOps.WHERE) else: - self.load_cache[key] = self.uop(UOps.LOAD, localtype, [], MemOp(self.get_buffer_name(i), idx, self.bufs[i].__class__ is LocalBuffer, self.bufs[i].dtype, valid, invalid_value)) if this_const is None else \ - self.uop(UOps.LOAD, localtype, [], ConstOp(this_const, valid)) + self.load_cache[key] = self.uop(UOps.LOAD, localtype, [], MemOp(self.get_buffer_name(i), idx, self.bufs[i].__class__ is LocalBuffer, self.bufs[i].dtype, valid, invalid_value)) ret.append(self.uop(UOps.GEP, dtypes.float32, [self.load_cache[key]], expanded_nodes[dim].index(_idx[dim])) if localtype != dtypes.float else self.load_cache[key]) return ret @@ -174,19 +185,19 @@ class Linearizer(OptimizedKernel): # add global buffers for buf,name in self.arg_bufs.items(): - self.uop(UOps.DEFINE_GLOBAL, None, [], (name, buf.dtype)) + self.uop(UOps.DEFINE_GLOBAL, PtrDType(buf.dtype), [], (name, buf.dtype)) # add variables from symbolic shapes for var in sorted(set(v for buf in self.ast.buffers for v in buf.st.var_vals), key=lambda k: k.key): - self.uop(UOps.DEFINE_GLOBAL, None, [], (var.expr, dtypes._arg_int32)) + self.uop(UOps.DEFINE_GLOBAL, dtypes.int32, [], (var.expr, dtypes._arg_int32)) # define local buffers for lb in self.local_alias.values(): - self.uop(UOps.DEFINE_LOCAL, None, [], (lb.name, self.sts[self.bufs.index(lb)].size())) + self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), [], (lb.name, self.sts[self.bufs.index(lb)].size())) # add a local buffer for multistage reduce. # TODO: use local alias if self.group_for_reduce: # TODO: the strides of this can be controlled self.sts.append(ShapeTracker(tuple([1] * self.first_reduce + self.group_for_reduce + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)]))) self.bufs.append(LocalBuffer("temp", self.sts[-1].size())) - self.uop(UOps.DEFINE_LOCAL, None, [], ("temp", self.sts[-1].size())) + self.uop(UOps.DEFINE_LOCAL, PtrDType(dtypes.float32), [], ("temp", self.sts[-1].size())) # print if DEBUG >= 3: self.printbufs() @@ -350,7 +361,7 @@ class Linearizer(OptimizedKernel): return self - def uop(self, uop:UOps, dtype:Optional[DType], vin:List[UOp], arg:Any=None, cachable=False) -> UOp: + def uop(self, uop:UOps, dtype:Optional[DType], vin:Union[Tuple[UOp, ...], List[UOp]], arg:Any=None, cachable=False) -> UOp: key = (uop, dtype, tuple(vin), arg) if cachable and key in self.saved_exprs: return self.saved_exprs[key] self.uops.append(UOp(uop, dtype, tuple(vin), arg, len(self.uops))) diff --git a/tinygrad/renderer/assembly_arm64.py b/tinygrad/renderer/assembly_arm64.py index 9ecfd5a926..43a0a92acf 100644 --- a/tinygrad/renderer/assembly_arm64.py +++ b/tinygrad/renderer/assembly_arm64.py @@ -90,9 +90,14 @@ def specialize_to_arm64(fn_nm, asm): ins.append(f"loop_{arg}:") elif uop == UOps.CAST: if arg == BinaryOps.CMPLT: - mov_imm(0.0, 's0') - mov_imm(1.0, 's1') - ins.append(f"fcsel {rtor[out.nm]}, s1, s0, lt") + if rtor[out.nm][0] == 's': + mov_imm(0.0, 's0') + mov_imm(1.0, 's1') + ins.append(f"fcsel {rtor[out.nm]}, s1, s0, lt") + if rtor[out.nm][0] == 'x': + mov_imm(0, 'x14') + mov_imm(1, 'x15') + ins.append(f"csel {rtor[out.nm]}, x15, x14, lt") else: ins.append(f"sxtw {rtor[out.nm]}, w{rtor[vin[0].nm][1:]}") elif uop == UOps.ALU: @@ -100,7 +105,7 @@ def specialize_to_arm64(fn_nm, asm): if arg == BinaryOps.MUL and out.dtype == dtypes.bool: ins.append(f"ands {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}") elif arg == TernaryOps.WHERE: - ins.append(f"fcmp {rtor[vin[0].nm]}, #0.0") + ins.append(f"fcmp {rtor[vin[0].nm]}, #0.0" if rtor[vin[0].nm][0] == 's' else f"cmp {rtor[vin[0].nm]}, #0") ins.append(f"{alu[arg]} {rtor[out.nm]}, {rtor[vin[1].nm]}, {rtor[vin[2].nm]}, ne") elif arg in [UnaryOps.LOG2, UnaryOps.SIN, UnaryOps.EXP2, UnaryOps.SQRT]: #NOTE: Not a real instruction, use to emulate a ext call in unicorn @@ -124,8 +129,9 @@ def specialize_to_arm64(fn_nm, asm): elif arg == BinaryOps.CMPLT: ins.append(f"{alu[arg]} {','.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}" if not dtypes.is_float(vin[0][1]) else f"fcmp {rtor[vin[0].nm]}, {rtor[vin[1].nm]}") elif arg == BinaryOps.MOD: - ins.append(f"udiv x14, {rtor[vin[0].nm]}, x15") - ins.append(f"msub {rtor[out.nm]}, x14, x15, {rtor[vin[0].nm]}") + rhs = 'x15' if vin[1].__class__ is int else rtor[vin[1].nm] + ins.append(f"udiv x14, {rtor[vin[0].nm]}, {rhs}") + ins.append(f"msub {rtor[out.nm]}, x14, {rhs}, {rtor[vin[0].nm]}") else: ins.append(f"{'f' if dtypes.is_float(vin[0][1]) else 's' if arg == BinaryOps.DIV else ''}{alu[arg]} {', '.join('x15' if v.__class__ is int else rtor[v.nm] for v in [out] + vin)}") elif uop == UOps.LOAD: diff --git a/tinygrad/renderer/assembly_ptx.py b/tinygrad/renderer/assembly_ptx.py index b620175684..cb78d8d961 100644 --- a/tinygrad/renderer/assembly_ptx.py +++ b/tinygrad/renderer/assembly_ptx.py @@ -15,7 +15,10 @@ def render_cast(ins, inp, out): if inp.dtype == dtypes.bool and (dtypes.is_float(out.dtype) or dtypes.is_int(out.dtype)): ins.append(f"selp.{dtype_to_nvtype[out.dtype]} {out}, {'0f3F800000, 0f00000000' if dtypes.is_float(out.dtype) else '1, 0'}, {inp};") elif out.dtype == dtypes.bool: - ins.append(f"setp.ne.{dtype_to_nvtype[inp.dtype]} {out}, {'0f00000000' if dtypes.is_float(inp.dtype) else '0'}, {inp};") + if inp.dtype == dtypes.bool: + ins.append(f"mov.pred {out}, {inp};") + else: + ins.append(f"setp.ne.{dtype_to_nvtype[inp.dtype]} {out}, {'0f00000000' if dtypes.is_float(inp.dtype) else '0'}, {inp};") else: round_mod = ".rzi" if dtypes.is_int(out.dtype) and dtypes.is_float(inp.dtype) else '.rz' if dtypes.is_float(out.dtype) and (dtypes.is_int(inp.dtype) or dtypes.is_float(inp.dtype) and inp.dtype.itemsize > out.dtype.itemsize) else '' ins.append(f"cvt{round_mod}.{dtype_to_nvtype[out.dtype]}.{dtype_to_nvtype[inp.dtype]} {out}, {inp};") @@ -53,8 +56,11 @@ def specialize_to_ptx(lang, function_name): else: otype = vin[0].dtype if arg in [BinaryOps.CMPLT] else out.dtype if arg == TernaryOps.WHERE: - reg = lang.newreg((vin[0], 'bool'), dtypes.bool) - ins.append(f"setp.ne.{dtype_to_nvtype[vin[0].dtype]} {reg}, {'0f00000000' if dtypes.is_float(vin[0].dtype) else '0'}, {vin[0]};") + if vin[0].dtype == dtypes.bool: + reg = vin[0] + else: + reg = lang.newreg((vin[0], 'bool'), dtypes.bool) + ins.append(f"setp.ne.{dtype_to_nvtype[vin[0].dtype]} {reg}, {'0f00000000' if dtypes.is_float(vin[0].dtype) else '0'}, {vin[0]};") vin = vin[1:] + [reg] ins.append(f"{alu[arg]}{'.lo' if arg == BinaryOps.MUL and out.dtype != dtypes.float32 else ''}{'.rn' if arg == BinaryOps.DIV and out.dtype == dtypes.float32 else ''}.{dtype_to_nvtype[otype]} {out}, {', '.join(str(x) for x in vin)};") elif uop == UOps.LOAD: diff --git a/tinygrad/renderer/cstyle.py b/tinygrad/renderer/cstyle.py index 2213bf662f..051158877f 100644 --- a/tinygrad/renderer/cstyle.py +++ b/tinygrad/renderer/cstyle.py @@ -1,7 +1,7 @@ from typing import Dict, List, Optional, NamedTuple, Tuple, Union, DefaultDict import math from collections import defaultdict -from tinygrad.codegen.linearizer import UOps, UOp, MemOp, ConstOp +from tinygrad.codegen.linearizer import UOps, UOp, MemOp from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps from tinygrad.helpers import ImageDType, dtypes, getenv, prod, DType from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable, sym_render @@ -167,22 +167,22 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T elif uop == UOps.ALU: assert dtype is not None r[u] = ssa('alu') - kk(f"{dtype.name} {r[u]} = {lang.code_for_op[args](*[r[x] for x in vin])};") - #kk(f"{lang.generic_var_prefix if newvar not in vin else ''}{newvar.render(newvar not in vin and lang.generic_var_prefix == '')} = {lang.code_for_op[args](*[x.render() for x in vin])};") + kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {lang.code_for_op[args](*[r[x] for x in vin])};") elif uop == UOps.DEFINE_ACC: assert dtype is not None r[u] = ssa('acc') - kk(f"{dtype.name} {r[u]} = {lang.render_const(args, dtype)};") + kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {lang.render_const(args, dtype)};") + elif uop == UOps.SPECIAL: + r[u] = args.expr + elif uop == UOps.CONST: + r[u] = lang.render_const(args, dtype) if args >= 0 else f"({lang.render_const(args, dtype)})" elif uop == UOps.LOAD: assert dtype is not None r[u] = ssa('val') # valids are handled here - if isinstance(args, ConstOp): - val = lang.render_const(args.value, dtype) - else: - val = lang.render_load(dtype, args.name, args.memory_dtype, args.idx, args.local) + val = lang.render_load(dtype, args.name, args.memory_dtype, args.idx, args.local) if args.valid.min == 0 and args.valid.max == 1: val = lang.render_conditional(args.valid.render(render_cl), val, lang.render_const(args.invalid_value, dtype)) - kk(f"{dtype.name} {r[u]} = {val};") + kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {val};") elif uop == UOps.STORE: if args is None: kk(f"{r[vin[0]]} = {r[vin[1]]};") @@ -193,7 +193,7 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> T kk(lang.render_store(args.name, args.memory_dtype, r[vin[0]], vin[0].dtype, args.idx, args.local)) elif uop == UOps.CAST and dtype is not None and dtype.sz > 1: r[u] = ssa('cast') - kk(f"{dtype.name} {r[u]} = {lang.render_cast([r[x] for x in vin], dtype)};") + kk(f"{lang.generic_var_prefix if lang.generic_var_prefix else dtype.name} {r[u]} = {lang.render_cast([r[x] for x in vin], dtype)};") elif uop == UOps.DEFINE_LOCAL: if lang.external_local_bufs: prekernel.append(lang.render_local(args[0], args[1])) diff --git a/tinygrad/renderer/llvmir.py b/tinygrad/renderer/llvmir.py index b1456ad1b2..af6867b768 100644 --- a/tinygrad/renderer/llvmir.py +++ b/tinygrad/renderer/llvmir.py @@ -1,12 +1,12 @@ from typing import Final, Dict, Callable, Any, List, Optional, Tuple import functools from llvmlite import ir # type: ignore -from tinygrad.codegen.linearizer import UOps, UOp, MemOp, ConstOp +from tinygrad.codegen.linearizer import UOps, UOp, MemOp from tinygrad.helpers import dtypes from tinygrad.ops import Op, UnaryOps, BinaryOps, TernaryOps from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode -def sym_render(a, ops=None, ctx=None): return ir.Constant(ir.IntType(64), a) if isinstance(a, int) else a.render(ops, ctx) +def sym_render(a, ops=None, ctx=None): return ir.Constant(ir.IntType(32), a) if isinstance(a, int) else a.render(ops, ctx) render_llvm = { NumNode: lambda self,ops,ctx: sym_render(self.b,ops,ctx), MulNode: lambda self,ops,ctx: ctx.mul(self.a.render(ops,ctx), sym_render(self.b,ops,ctx)), @@ -30,7 +30,7 @@ code_for_op: Final[Dict[Op, Callable]] = { BinaryOps.MAX: lambda builder,x,y: builder.select(builder.fcmp_unordered(">", x, y, flags=('fast',)), x, y, flags=('fast',)), BinaryOps.MOD: lambda builder,x,y: builder.srem(x,y), TernaryOps.MULACC: lambda builder,x,y,z: builder.fadd(builder.fmul(x,y, flags=('fast',)), z, flags=('fast',)), - TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=('fast',)), y, z, flags=('fast',)), + TernaryOps.WHERE: lambda builder,x,y,z: builder.select(builder.fcmp_unordered("!=", x, ir.Constant(ir.FloatType(), 0), flags=('fast',)) if isinstance(x.type, ir.FloatType) else builder.trunc(x, ir.IntType(1)), y, z, flags=('fast',)), } dtype_to_llvm_dtype = {dtypes.float64:ir.DoubleType(), dtypes.float16:ir.HalfType(), dtypes.bfloat16:ir.IntType(16), dtypes.float32:ir.FloatType(), dtypes.int8:ir.IntType(8), dtypes.uint8:ir.IntType(8), dtypes.bool: ir.IntType(1), dtypes.int64: ir.IntType(64), dtypes.int32: ir.IntType(32), dtypes._arg_int32: ir.IntType(32)} @@ -92,7 +92,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li render_llvm[Variable] = lambda self,ops,ctx: lvars[self.expr] for bufname,dtype in buf_to_dtype.items(): - if dtype == dtypes._arg_int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(64)) + if dtype == dtypes._arg_int32: lvars[bufname] = bb[-1].sext(func.args[buf_index[bufname]], ir.IntType(32)) for u in uops: uop,dtype,vin,args,_ = u @@ -110,7 +110,7 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li phis.append((rp, lvars[rp])) loop_blocks.append((bb[-1], phis)) - lvars[var.expr] = bb[-1].phi(ir.IntType(64), name=var.expr) + lvars[var.expr] = bb[-1].phi(ir.IntType(32), name=var.expr) lvars[var.expr].add_incoming(sym_render(var.min), bb[-2]._block) if uop == UOps.ENDLOOP: for var in args[0][::-1]: @@ -124,23 +124,21 @@ def uops_to_llvm_ir(function_name:str, uops:List[UOp]) -> Tuple[str, Optional[Li if uop == UOps.DEFINE_ACC: lvars[u] = ir.Constant(dtype_to_llvm_dtype[dtype], args) reduce_phis.append(u) + if uop == UOps.SPECIAL: + lvars[u] = lvars[args.expr] + if uop == UOps.CONST: + value = int(args) if dtypes.is_int(dtype) else bool(args) if dtype == dtypes.bool else args + lvars[u] = ir.Constant(dtype_to_llvm_dtype[dtype], value) if uop == UOps.LOAD: - assert dtype is not None and isinstance(args, (MemOp, ConstOp)) + assert dtype is not None valid = args.valid.render(render_llvm, bb[-1]) - if isinstance(args, ConstOp): - value, invalid_value = [int(args.value), int(args.invalid_value)] if dtypes.is_int(dtype) else ([bool(args.value), bool(args.invalid_value)] if dtype == dtypes.bool else [args.value, args.invalid_value]) # type: ignore - if args.valid.min == 0 and args.valid.max == 1: - val = bb[-1].select(valid, ir.Constant(dtype_to_llvm_dtype[dtype], value), ir.Constant(dtype_to_llvm_dtype[dtype], invalid_value)) - else: - val = ir.Constant(dtype_to_llvm_dtype[dtype], value if args.valid.min == 1 else invalid_value) + idx = args.idx.render(render_llvm, bb[-1]) + if args.valid.min == 0: + aug_idx = bb[-1].select(valid, idx, sym_render(0)) + val = bb[-1].select(valid, bb[-1].load(bb[-1].gep(func.args[buf_index[args.name]], [aug_idx], inbounds=True)), ir.Constant(dtype_to_llvm_dtype[args.memory_dtype], args.invalid_value)) else: - idx = args.idx.render(render_llvm, bb[-1]) - if args.valid.min == 0: - aug_idx = bb[-1].select(valid, idx, sym_render(0)) - val = bb[-1].select(valid, bb[-1].load(bb[-1].gep(func.args[buf_index[args.name]], [aug_idx], inbounds=True)), ir.Constant(dtype_to_llvm_dtype[args.memory_dtype], args.invalid_value)) - else: - val = bb[-1].load(bb[-1].gep(func.args[buf_index[args.name]], [idx], inbounds=True)) - val = cast(bb, val, args.memory_dtype, dtype) + val = bb[-1].load(bb[-1].gep(func.args[buf_index[args.name]], [idx], inbounds=True)) + val = cast(bb, val, args.memory_dtype, dtype) lvars[u] = val if uop == UOps.STORE: if args is None: diff --git a/tinygrad/renderer/wgsl.py b/tinygrad/renderer/wgsl.py index c5d7380bf6..a1203ae877 100644 --- a/tinygrad/renderer/wgsl.py +++ b/tinygrad/renderer/wgsl.py @@ -16,7 +16,8 @@ class WGSLLanguage(CStyleLanguage): external_local_bufs = True code_for_op = { UnaryOps.EXP2: lambda x: f"exp2({x})", UnaryOps.LOG2: lambda x: f"log2({x})", UnaryOps.SIN: lambda x: f"sin({x})", UnaryOps.SQRT: lambda x: f"sqrt({x})", - BinaryOps.ADD: lambda x,y: f"({x}+{y})", BinaryOps.SUB: lambda x,y: f"({x}-{y})", BinaryOps.MUL: lambda x,y: f"({x}*{y})", BinaryOps.DIV: lambda x,y: f"({x}/{y})", + BinaryOps.ADD: lambda x,y: f"({x}+{y})", BinaryOps.SUB: lambda x,y: f"({x}-{y})", BinaryOps.MUL: lambda x,y: f"({x}*{y})", + BinaryOps.DIV: lambda x,y: f"({x}/{y})", BinaryOps.MOD: lambda x,y: f"({x}%{y})", BinaryOps.MAX: lambda x,y: f"max({x},{y})", BinaryOps.CMPLT: lambda x,y: f"f32({x}<{y})", TernaryOps.MULACC: lambda x,y,z: f"fma({x},{y},{z})", TernaryOps.WHERE: lambda a,b,c: f"select({c},{b},{a}!=0.)" }