mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
if support (#2076)
* if support * bugfix * fix wgsl if * more correct wgsl fix
This commit is contained in:
@@ -16,7 +16,7 @@ from tinygrad.features.image import to_image_idx
|
||||
|
||||
# bottom ones are asm only
|
||||
class UOps(Enum):
|
||||
LOOP = auto(); END = auto(); SPECIAL = auto() # loops can be global, local, or other # noqa: E702
|
||||
LOOP = auto(); IF = auto(); END = auto(); SPECIAL = auto() # loops can be global, local, or other # noqa: E702
|
||||
DEFINE_GLOBAL = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # this defines buffers # noqa: E702
|
||||
LOAD = auto(); STORE = auto(); CONST = auto(); BARRIER = auto() # noqa: E702
|
||||
ALU = auto(); WMMA = auto(); CAST = auto(); GEP = auto() # noqa: E702
|
||||
@@ -219,6 +219,7 @@ class Linearizer(OptimizedKernel):
|
||||
loaded_buffers = {}
|
||||
acc = []
|
||||
self.load_cache: Dict[str, UOp] = {}
|
||||
if_gate: Optional[UOp] = None
|
||||
|
||||
# reduce op
|
||||
fake_reduce_idxs = []
|
||||
@@ -314,7 +315,10 @@ class Linearizer(OptimizedKernel):
|
||||
fake_global_idxs = [x*0 for x in global_idxs]
|
||||
self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators
|
||||
self.uop(UOps.BARRIER, None, (), cachable=False)
|
||||
end_loop(loop_local_idxs)
|
||||
end_loop(loop_local_idxs) # TODO: this is ending too much, should only end what's in the if?
|
||||
if self.opts.has_local:
|
||||
if_cond: UOp = Variable.ands([x<1 for x in local_idxs[self.local_dims:]]).render(self.render_ops, self)
|
||||
if_gate = self.uop(UOps.IF, None, (if_cond,), cachable=False)
|
||||
|
||||
# create new late reduce local loops and replace local_idxs that have been used
|
||||
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))]
|
||||
@@ -358,6 +362,7 @@ class Linearizer(OptimizedKernel):
|
||||
self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)
|
||||
|
||||
# end the global (and maybe local) loop
|
||||
if if_gate: self.uop(UOps.END, None, (if_gate,))
|
||||
end_loop(loop_global_idxs+loop_local_idxs if not self.group_for_reduce else loop_global_idxs)
|
||||
|
||||
# (recursively) remove childless uops
|
||||
|
||||
@@ -76,6 +76,9 @@ class CStyleLanguage(NamedTuple):
|
||||
def render_for(self, expr: str, _min:Union[int,str], _max:Union[int,str]) -> str:
|
||||
return f"for (int {expr} = {_min}; {expr} < {_max}; ++{expr}) {{"
|
||||
|
||||
def render_if(self, cond: str):
|
||||
return f"if ({cond}) {{"
|
||||
|
||||
def render_conditional(self, cond: str, x:str, y:str) -> str:
|
||||
return f"({cond})?({x}):{y}"
|
||||
|
||||
@@ -126,6 +129,9 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st
|
||||
if uop == UOps.LOOP:
|
||||
kk(lang.render_for(ssa(u,'ridx'), r[vin[0]], r[vin[1]]))
|
||||
depth += 1
|
||||
elif uop == UOps.IF:
|
||||
kk(lang.render_if(r[vin[0]]))
|
||||
depth += 1
|
||||
elif uop == UOps.BARRIER:
|
||||
kk(lang.barrier)
|
||||
elif uop == UOps.END:
|
||||
|
||||
@@ -43,6 +43,9 @@ class WGSLLanguage(CStyleLanguage):
|
||||
def render_for(self, expr:str, _min:Union[int,str], _max:Union[int,str]) -> str:
|
||||
return f"for(var {expr} = {_min}; {expr} < {_max}; {expr}++) {{"
|
||||
|
||||
def render_if(self, cond: str):
|
||||
return f"if (bool({cond})) {{"
|
||||
|
||||
def render_conditional(self, cond:str, x:str, y:str) -> str:
|
||||
return f"select(f32({y}), {x}, bool({cond}))"
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ class Node:
|
||||
b: Union[Node, int]
|
||||
min: int
|
||||
max: int
|
||||
def render(self, ops=None, ctx=None) -> str:
|
||||
def render(self, ops=None, ctx=None) -> Any:
|
||||
if ops is None: ops = render_python
|
||||
assert self.__class__ in (Variable, NumNode) or self.min != self.max
|
||||
return ops[type(self)](self, ops, ctx)
|
||||
|
||||
Reference in New Issue
Block a user