if support (#2076)

* if support

* bugfix

* fix wgsl if

* more correct wgsl fix
This commit is contained in:
George Hotz
2023-10-15 07:17:37 -07:00
committed by GitHub
parent cb9309bee6
commit 30933d5bd0
4 changed files with 17 additions and 3 deletions

View File

@@ -16,7 +16,7 @@ from tinygrad.features.image import to_image_idx
# bottom ones are asm only
class UOps(Enum):
LOOP = auto(); END = auto(); SPECIAL = auto() # loops can be global, local, or other # noqa: E702
LOOP = auto(); IF = auto(); END = auto(); SPECIAL = auto() # loops can be global, local, or other # noqa: E702
DEFINE_GLOBAL = auto(); DEFINE_LOCAL = auto(); DEFINE_ACC = auto() # this defines buffers # noqa: E702
LOAD = auto(); STORE = auto(); CONST = auto(); BARRIER = auto() # noqa: E702
ALU = auto(); WMMA = auto(); CAST = auto(); GEP = auto() # noqa: E702
@@ -219,6 +219,7 @@ class Linearizer(OptimizedKernel):
loaded_buffers = {}
acc = []
self.load_cache: Dict[str, UOp] = {}
if_gate: Optional[UOp] = None
# reduce op
fake_reduce_idxs = []
@@ -314,7 +315,10 @@ class Linearizer(OptimizedKernel):
fake_global_idxs = [x*0 for x in global_idxs]
self.global_store(-1, fake_global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, acc) # store accumulators
self.uop(UOps.BARRIER, None, (), cachable=False)
end_loop(loop_local_idxs)
end_loop(loop_local_idxs) # TODO: this is ending too much, should only end what's in the if?
if self.opts.has_local:
if_cond: UOp = Variable.ands([x<1 for x in local_idxs[self.local_dims:]]).render(self.render_ops, self)
if_gate = self.uop(UOps.IF, None, (if_cond,), cachable=False)
# create new late reduce local loops and replace local_idxs that have been used
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce and i not in self.upcast_in_mid_reduce_axes else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))]
@@ -358,6 +362,7 @@ class Linearizer(OptimizedKernel):
self.global_store(0, global_idxs+local_idxs+fake_reduce_idxs+upcast_idxs, val)
# end the global (and maybe local) loop
if if_gate: self.uop(UOps.END, None, (if_gate,))
end_loop(loop_global_idxs+loop_local_idxs if not self.group_for_reduce else loop_global_idxs)
# (recursively) remove childless uops

View File

@@ -76,6 +76,9 @@ class CStyleLanguage(NamedTuple):
def render_for(self, expr: str, _min:Union[int,str], _max:Union[int,str]) -> str:
return f"for (int {expr} = {_min}; {expr} < {_max}; ++{expr}) {{"
def render_if(self, cond: str):
return f"if ({cond}) {{"
def render_conditional(self, cond: str, x:str, y:str) -> str:
return f"({cond})?({x}):{y}"
@@ -126,6 +129,9 @@ def uops_to_cstyle(lang:CStyleLanguage, function_name:str, uops:List[UOp]) -> st
if uop == UOps.LOOP:
kk(lang.render_for(ssa(u,'ridx'), r[vin[0]], r[vin[1]]))
depth += 1
elif uop == UOps.IF:
kk(lang.render_if(r[vin[0]]))
depth += 1
elif uop == UOps.BARRIER:
kk(lang.barrier)
elif uop == UOps.END:

View File

@@ -43,6 +43,9 @@ class WGSLLanguage(CStyleLanguage):
def render_for(self, expr:str, _min:Union[int,str], _max:Union[int,str]) -> str:
return f"for(var {expr} = {_min}; {expr} < {_max}; {expr}++) {{"
def render_if(self, cond: str):
return f"if (bool({cond})) {{"
def render_conditional(self, cond:str, x:str, y:str) -> str:
return f"select(f32({y}), {x}, bool({cond}))"

View File

@@ -15,7 +15,7 @@ class Node:
b: Union[Node, int]
min: int
max: int
def render(self, ops=None, ctx=None) -> str:
def render(self, ops=None, ctx=None) -> Any:
if ops is None: ops = render_python
assert self.__class__ in (Variable, NumNode) or self.min != self.max
return ops[type(self)](self, ops, ctx)