Files
tinygrad/tinygrad/codegen/linearizer.py
2023-06-17 16:47:55 -07:00

542 lines
27 KiB
Python

from typing import List, Tuple, Any, Optional, cast, DefaultDict, NamedTuple, TypeVar, Dict
import itertools, math
from collections import defaultdict
from enum import Enum, auto
from tinygrad.helpers import dedup, colored, ImageDType, DEBUG, prod, dtypes, mnum, DType, all_same
from tinygrad.ops import LazyOp, get_lazyops, get_buffers, FlopCounter, get_lazyop_info, map_buffers, UnaryOps
from tinygrad.lazy import LazyBuffer
from tinygrad.ops import MovementOps, ReduceOps, BinaryOps, FusedOps
from tinygrad.shape.shapetracker import ShapeTracker, strides_for_shape
from tinygrad.shape.symbolic import Variable
# bottom ones are asm only
class UOps(Enum): LOOP = auto(); DEFINE_LOCAL = auto(); LOAD = auto(); ALU = auto(); CONST = auto(); ENDLOOP = auto(); STORE = auto(); CAST = auto(); \
SPECIAL = auto(); DEFINE_REGISTER = auto(); LABEL = auto(); COND_BRANCH = auto() # noqa: E702
class LocalBuffer(NamedTuple):
dtype: DType = dtypes.float32
realized: None = None
class Token(NamedTuple):
name: str
dtype: DType
offset: Optional[int] = None
def render(self, with_type=False):
if with_type:
assert self.offset is None
return f"{self.dtype.name} {self.name}"
if self.offset is None: return self.name
assert self.dtype == dtypes._float4
return self.name+"."+"xyzw"[int(self.offset)]
def __repr__(self): return f"<{self.name}>" if self.offset is None and self.dtype == dtypes.float32 else f"<{self.name}:{self.dtype.name}:{self.offset}>"
# TODO: the next three functions are poorly written
def get_grouped_float4_idxs(acc:List[Token]) -> Optional[List[int]]:
idxs: Optional[List[int]] = []
for i,a in enumerate(acc):
if idxs is None: break
if i in idxs: continue
if a.dtype.sz > 1 and a.offset == 0:
idxs.append(i)
friends: List[int] = []
for j,b in enumerate(acc):
if len(friends) == 3: break
if j in idxs: continue
if a.name == b.name and b.dtype.sz > 1 and b.offset == len(friends)+1:
friends.append(j)
if len(friends) == 3: idxs += friends
else: idxs = None
else:
idxs = None
return idxs
def to_float4(x:List[Token]) -> Optional[Token]:
if all_same(x): return x[0]
if all_same([y.name for y in x]) and all([y.dtype == dtypes._float4 and y.offset == i for i,y in enumerate(x)]):
return Token(x[0].name, dtypes._float4)
return None
def get_grouped_maybe_float4(*values:List[Token], grouping_allowed=True):
assert all_same([len(x) for x in values]), f"all values are not the same length {values}"
# these use accumulators, we can only fold if the acc is a float4
idxs = get_grouped_float4_idxs(values[-1]) if grouping_allowed else None
if idxs is not None:
new_idxs = []
new_values = []
for i in range(0, len(idxs), 4):
nv = [to_float4([v[j] for j in idxs[i:i+4]]) for v in values]
if any([x is None for x in nv]): break
new_idxs.append(idxs[i:i+4])
new_values.append(nv)
if len(new_values) == len(idxs)//4:
return zip(new_idxs, new_values)
return zip([[i] for i in range(len(values[0]))], zip(*values))
class MemOp(NamedTuple):
i: int
idx: Variable
valid: Variable
class UOp(NamedTuple):
uop: UOps
out: Optional[Token]
vin: List[Token]
arg: Any
def __repr__(self): return f"{str(self.uop):20s}: {str(self.out) if self.out is not None else '':25s} {str(self.vin):32s} {self.arg}"
class Linearizer:
supports_float4: bool = False
supports_float4_alu: bool = False
def __init__(self, ast:LazyOp, output_buffer:LazyBuffer):
# NOTE: if there's a RESHAPE, we skip it. the output shape is set from the reduce op or a latebuf
self.ast = ast.src[0] if ast.op == MovementOps.RESHAPE else ast
# get the output buffers
self.bufs = [output_buffer] + dedup(get_buffers(ast))
# key for lookup in cache (can change, str might not be right)
# bufs are needed because kernels like f(x) = x + x and f(x, y) = x + y have the same str(ast), but are different kernels.
# mapping the buffers to integers is required because a-b != b-a (and how would you tell a and b apart?)
self.key = f"ASTKernelKey ast={str(map_buffers({x:i for i,x in enumerate(self.bufs)}, ast))} bufs={self.bufs}"
def process(self) -> None:
if hasattr(self, "sts"): return # already processed
# fetch lazyop info
self.info: FlopCounter = get_lazyop_info(self.ast)
self.mem_estimate: int = sum(x.dtype.itemsize*(x.realized.size if x.realized is not None else prod(x.shape)) for x in self.bufs if x is not None)
# there's only allowed to be one reduceop
reduceops = [x for x in get_lazyops(self.ast) if x.op in ReduceOps]
assert len(dedup(reduceops)) <= 1, "max one reduce op in an ast"
self.reduceop = reduceops[0] if reduceops else None
# get earlybufs, before the one reduce op
self.earlybufs = dedup(get_buffers(self.reduceop)) if self.reduceop else []
# create new shapetrackers inside this kernel, we will permute them
self.sts: List[ShapeTracker] = [x.st.copy() for x in self.bufs]
for st in self.sts: st.simplify()
# make the output buffer shape correct in here
self.sts[0].reshape(self.info.shape)
self.full_buf_index: int = self.bufs.index(self.earlybufs[0]) if len(self.earlybufs) > 0 else 0
# move all reduce axes to the end
reduce = list(enumerate(zip(self.full_shape, self.sts[0].shape)))
permute = tuple([i for i,(s,n) in reduce if s == n] + [i for i,(s,n) in reduce if s != n])
self.reshape_and_permute(None, permute)
# parameters
self.group_for_reduce: List[int] = []
self.upcasted: int = 0
# group simplifies
self.simplify_ones()
self.simplify_merge_adjacent()
# print early
if DEBUG >= 5: self.printbufs("early")
def shape_offsets(self, i): return itertools.product(*[list(range(s)) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()]
def float4_axis(self, i): return [x-(self.shape_len-self.upcasted) for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x]%4 == 0]
# TODO: this stride is only on the last view, and may not be real
def upcasted_axis(self, i):
return list(zip(self.sts[i].shape[self.shape_len-self.upcasted:],
self.sts[i].views[-1].strides[self.shape_len-self.upcasted:], # WRONG
[x!=y for x,y in zip(self.sts[0].shape[self.shape_len-self.upcasted:], self.full_shape[self.shape_len-self.upcasted:])]))
# TODO: is there a better way to write this?
def acc_offsets(self, i):
if self.upcasted == 0: return [0]
acc_strides = [x*(1-self.upcasted_axis(i)[::-1][i][2]) for i,x in enumerate(strides_for_shape(tuple(1 if r else s for s,_,r in self.upcasted_axis(i)[::-1])))]
return [sum(t) for t in itertools.product(*[[y*acc_strides[i] for y in range(x[0])] for i,x in enumerate(self.upcasted_axis(i)[::-1])])]
def _group_float4(self, i, store_offset):
store_offset_float4 = {}
float4_axis = (self.upcasted-1) - self.float4_axis(i)[0]
for uidxs, var in store_offset.items():
if uidxs[float4_axis]%4 == 0:
store_offset_float4[uidxs] = [var]
else:
uidxs2 = list(uidxs)
uidxs2[float4_axis] -= uidxs2[float4_axis]%4
store_offset_float4[tuple(uidxs2)].append(var)
return store_offset_float4
def global_load(self, i, idxs:List[Variable], const=None) -> List[Token]:
load_offset: Dict[Tuple[int, ...], Any] = {uidxs:(dtypes.float,uidxs)+self.sts[i].expr_idxs(idxs+[Variable.num(x) for x in uidxs[::-1]]) for uidxs in self.shape_offsets(i)}
# float4 grouping (optional)
should_upcast = self.supports_float4 and (self.bufs[i].dtype in [dtypes.float32, dtypes.float16] or isinstance(self.bufs[i].dtype, ImageDType)) and len(self.float4_axis(i)) == 1
if should_upcast:
load_offset_new = {}
for k,out_tokens in self._group_float4(i, load_offset).items():
idxs = [x[2]-out_tokens[0][2] for x in out_tokens]
valids_okay = all_same([x[3] for x in out_tokens]) or (all_same([x[3]//4 for x in out_tokens]) and (out_tokens[0][3]//4)*4 == out_tokens[0][3])
if any(idx.min != idx.max or idx.min != val for idx,val in zip(idxs, range(4))) or (out_tokens[0][2]//4)*4 != out_tokens[0][2] or not valids_okay:
# idxs not in order, valids don't match, or idx doesn't evenly divide 4. use normal float
for x in out_tokens: load_offset_new[x[1]] = x
else:
load_offset_new[k] = (dtypes._float4, [x[1] for x in out_tokens], out_tokens[0][2], out_tokens[0][3])
load_offset = load_offset_new
# do loads
cache: Dict[str, Token] = {}
loaded = {}
for uidxs, (localtype, uidx_list, idx, valid) in load_offset.items():
key = f"{localtype}{idx.render()}{valid.render()}"
if key not in cache:
cache[key] = self.uop(UOps.LOAD, Token(f"val{mnum(i)}_{len(cache)}", localtype), [], MemOp(i, idx, valid)) if const is None else self.uop(UOps.CONST, Token(f"acc{mnum(i)}_{len(cache)}", localtype), [], const)
if localtype == dtypes._float4:
for j,uidx in enumerate(uidx_list):
loaded[uidx] = Token(cache[key].name, dtypes._float4, j)
else:
loaded[uidxs] = cache[key]
return [loaded[uidxs] for uidxs in self.shape_offsets(i)]
def global_store(self, i, idxs:List[Variable], store:List[Token], ssa) -> None:
store_offset: Dict[Tuple[int, ...], Token] = dict(zip(self.shape_offsets(i), store))
# float4 grouping (optional)
# TODO: why does this not work for float16?
should_upcast = self.supports_float4 and (self.bufs[i].dtype == dtypes.float32 or isinstance(self.bufs[i].dtype, ImageDType)) and len(self.float4_axis(i)) == 1
if should_upcast:
store_offset_new = {}
for k,out_tokens in self._group_float4(i, store_offset).items():
if all_same([x.name for x in out_tokens]) and tuple(range(4)) == tuple(x.offset for x in out_tokens):
store_offset_new[k] = Token(out_tokens[0].name, dtypes._float4)
else:
store_offset_new[k] = self.uop(UOps.CAST, ssa("alu", dtypes._float4), out_tokens)
store_offset = store_offset_new
# do stores
for uidxs, var in store_offset.items():
self.uop(UOps.STORE, None, [var], MemOp(i, *self.sts[i].expr_idxs(idxs+[Variable.num(x) for x in uidxs[::-1]])))
def linearize(self):
# uops
self.uops: List[UOp] = []
# add a local buffer for multistage reduce
if len(self.group_for_reduce):
self.bufs.append(LocalBuffer())
# TODO: the strides of this can be controlled
self.sts.append(ShapeTracker(tuple([1] * self.first_reduce + self.group_for_reduce + [1] * (self.shape_len - self.upcasted - len(self.group_for_reduce) - self.first_reduce) + [x[0] for x in self.upcasted_axis(0)])))
self.uop(UOps.DEFINE_LOCAL, None, [], ("temp", self.sts[-1].size()))
# print
if DEBUG >= 3: self.printbufs()
# kernel name (before late upcast)
self.function_name = ("r_" if self.reduceop else "E_") + '_'.join([str(x) for x in self.full_shape])
self.display_name = ("r_" if self.reduceop else "E_") + colored('_', 'black', bright=True).join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
# parse AST
loaded_buffers = {}
acc = []
# ssa
_ssa:DefaultDict[str,int] = defaultdict(int)
def ssa(name, ltype=dtypes.float) -> Token:
_ssa[name] += 1
return Token(f"{name}{_ssa[name]-1}", ltype)
# global loop
global_idxs = [Variable(f"gidx{i}", 0, self.full_shape[i]-1 if i < self.first_reduce else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))]
self.uop(UOps.LOOP, None, [], (global_idxs, "global"))
# local loop
if self.group_for_reduce:
# NOTE: this is assuming the global size = the local size in these dims. in general, this doesn't have to be true
local_idxs = [Variable(f"lidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))]
self.uop(UOps.LOOP, None, [], (local_idxs, "local"))
gl_idxs = [x*(y.max+1)+y for x,y in zip(global_idxs, local_idxs)]
else:
# without local idxs, it's just the global idxs
gl_idxs = global_idxs
# reduce op
fake_reduce_idxs = []
removed = len(global_idxs)
if self.reduceop is not None:
# define indexes
reduce_idxs = [Variable(f"ridx{i}", 0, self.full_shape[i]-1) for i in range(self.first_reduce+len(self.group_for_reduce), self.shape_len-self.upcasted)]
fake_reduce_idxs = [x*0 for x in reduce_idxs]
# define accumulator
acc = self.global_load(0, gl_idxs+fake_reduce_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
# reduce loop
self.uop(UOps.LOOP, None, [], (reduce_idxs, "reduce"))
# load earlybufs
loaded_buffers.update({b:self.global_load(i, gl_idxs+reduce_idxs) for i,b in enumerate(self.bufs) if b in self.earlybufs and i != 0})
# run early AST (with reduce)
self.ast_parse(self.reduceop, [acc[off] for off in self.acc_offsets(self.full_buf_index)], loaded_buffers, ssa, do_reduce=True)
# end the reduce loop
self.uop(UOps.ENDLOOP, None, [], (reduce_idxs, "reduce"))
# end the local loop, do the local reduce
if self.group_for_reduce:
self.global_store(-1, local_idxs+fake_reduce_idxs, acc, ssa) # store accumulators
self.uop(UOps.ENDLOOP, None, [], (local_idxs, "local")) # this is a barrier on GPUs
# if any group_for_reduce items aren't reduces, upcast them here
for j in self.upcast_in_mid_reduce_axes:
self.reshape_and_permute(None, [i for i in range(self.shape_len) if i != j] + [j])
self.upcast()
self.group_for_reduce.pop()
removed -= 1
# NOTE: this structure is the same as the reduce op above
# define late accumulator
acc = self.global_load(-1, local_idxs[:removed]+fake_reduce_idxs, {ReduceOps.SUM: 0.0, ReduceOps.MAX: -math.inf}[cast(ReduceOps, self.reduceop.op)])
# late reduce loop
end_local_idxs = [Variable(f"tidx{i}", 0, self.full_shape[i]-1 if i >= self.first_reduce else 0) for i in range(0, self.first_reduce+len(self.group_for_reduce))]
self.uop(UOps.LOOP, None, [], (end_local_idxs, "late_reduce"))
# load localbufs
loaded_buffers["LOCAL_BUFFER"] = self.global_load(-1, end_local_idxs+fake_reduce_idxs)
# there's no AST here (and there's no shape for the reduce LazyOp)
self.ast_parse(LazyOp(self.reduceop.op, ("LOCAL_BUFFER",)), [acc[off] for off in self.acc_offsets(-1)], loaded_buffers, ssa, do_reduce=True)
# end the late reduce loop
self.uop(UOps.ENDLOOP, None, [], (end_local_idxs, "late_reduce"))
# load latebufs
loaded_buffers.update({b:self.global_load(i, global_idxs[:removed]+fake_reduce_idxs) for i,b in enumerate(self.bufs) if b not in self.earlybufs and i != 0 and not isinstance(b, LocalBuffer)})
# run late AST
val = self.ast_parse(self.ast, acc, loaded_buffers, ssa)
# store
self.global_store(0, global_idxs[:removed]+fake_reduce_idxs, val, ssa)
# end the global loop
self.uop(UOps.ENDLOOP, None, [], (global_idxs, "global"))
_OT = TypeVar("_OT")
def uop(self, uop:UOps, out:_OT, vin:List[Token], arg:Any=None) -> _OT:
self.uops.append(UOp(uop, cast(Optional[Token], out), vin, arg))
if DEBUG >= 4: print(self.uops[-1])
return out
def ast_parse(self, x, acc, loaded_buffers, ssa, do_reduce=False) -> List[Token]:
if not isinstance(x, LazyOp): return loaded_buffers[x]
if x.op in [UnaryOps.NOOP, UnaryOps.CAST]: return self.ast_parse(x.src[0], acc, loaded_buffers, ssa) # cast isn't an ALU op
if x.op in ReduceOps and not do_reduce: return acc
# MULACC fusion. TODO: this is copied from Interpreted
if x.op == ReduceOps.SUM and isinstance(x.src[0], LazyOp) and x.src[0].op == BinaryOps.MUL:
x = LazyOp(FusedOps.MULACC, x.src[0].src, x.arg)
if x.op == ReduceOps.SUM and isinstance(x.src[0], LazyOp) and x.src[0].op == UnaryOps.CAST and isinstance(x.src[0].src[0], LazyOp) and x.src[0].src[0].op == BinaryOps.MUL:
x = LazyOp(FusedOps.MULACC, x.src[0].src[0].src, x.arg)
values = [self.ast_parse(v, acc, loaded_buffers, ssa) for v in x.src]
# TODO: fold float4 into a single uop when possible.
if isinstance(x.op, (ReduceOps, FusedOps)):
ret = [(idx, self.uop(UOps.ALU, val[-1], list(val), {ReduceOps.SUM:BinaryOps.ADD, ReduceOps.MAX:BinaryOps.MAX, FusedOps.MULACC:FusedOps.MULACC}[x.op])) for idx, val in get_grouped_maybe_float4(*values, acc, grouping_allowed=self.supports_float4_alu)]
else:
ret = [(idx, self.uop(UOps.ALU, ssa('alu', dtypes._float4) if any(x.dtype == dtypes._float4 and x.offset is None for x in val) else ssa('alu'), list(val), x.op)) for idx, val in get_grouped_maybe_float4(*values, grouping_allowed=self.supports_float4_alu and x.op!=BinaryOps.CMPEQ)]
ordered_ret: List[Optional[Token]] = [None]*len(values[0])
# scatter
for i,j in ret:
for o,k in enumerate(i):
ordered_ret[k] = Token(j.name, j.dtype, o) if j.dtype == dtypes._float4 else j
assert all(isinstance(x, Token) for x in ordered_ret), "some tokens didn't get scattered?"
return cast(List[Token], ordered_ret)
@property
def first_reduce(self) -> int: return [x!=y for x,y in zip(self.sts[0].shape[:self.shape_len-self.upcasted]+(0,), self.full_shape[:self.shape_len-self.upcasted]+(1,))].index(True)
@property
def full_shape(self) -> Tuple[int, ...]: return self.sts[self.full_buf_index].shape
@property
def full_unupcasted_shape(self) -> Tuple[int, ...]: return self.full_shape[:self.shape_len-self.upcasted]
@property
def shape_len(self) -> int: return len(self.sts[0].shape)
@property
def upcast_in_mid_reduce_axes(self) -> List[int]: return [j for j in range(self.first_reduce, self.first_reduce+len(self.group_for_reduce)) if self.full_shape[j] == self.sts[0].shape[j]]
def colors(self) -> List[str]:
# up to first_reduce, they are all global (blue)
colors = ["blue"] * self.first_reduce
# between first_reduce and first_reduce + group_for_reduce, they are either local (cyan), or late upcasted (green)
colors += ["green" if i in self.upcast_in_mid_reduce_axes else "cyan" for i in range(self.first_reduce, self.first_reduce + len(self.group_for_reduce))]
# between first_reduce + group_for_reduce and upcasted, they are reduce (red)
colors += ["red"] * ((self.shape_len-self.upcasted) - (self.first_reduce + len(self.group_for_reduce)))
# upcasted dimensions are reduce (magenta) or normal (yellow)
colors += ["magenta" if self.full_shape[i] != self.sts[0].shape[i] else "yellow" for i in range(self.shape_len-self.upcasted, self.shape_len)]
assert len(colors) == self.shape_len, "colors size mismatch"
return colors
def printbufs(self, prefix=""):
for i in range(len(self.sts)):
print(prefix, f"{i:3d} {str(self.bufs[i].realized) if self.bufs[i] is not None else 'FAKE':47s}", self.sts[i].views)
print(' '.join(colored(f"{s:4d}", color) for s,color in zip(self.full_shape, self.colors())))
# ******************** base simplifiers ********************
# apply reshape and permute to all shapetrackers
def reshape_and_permute(self, new_shape_fxn, axis):
for st in self.sts:
if new_shape_fxn is not None: st.reshape(tuple(new_shape_fxn(st.shape)))
if axis is not None: st.permute(tuple(axis))
# drops the final dimension
def upcast(self):
assert self.full_shape[-1] != 1, "can't upcast a dimension with size 1"
self.upcasted += 1
# axis : the axis to pull from
# amount : the amount to take
# top : if you want to pull that amount from the top
# insert_before : place to insert the new stuff
def shift_to(self, axis, amount, top=False, insert_before=None):
if insert_before is None: insert_before = self.shape_len
move_axis = axis if top else axis+1
if move_axis < insert_before: insert_before += 1
self.reshape_and_permute(
lambda x: list(x[0:axis]) + (([amount, x[axis]//amount] if top else [x[axis]//amount, amount]) if x[axis] > 1 else [1,1]) + list(x[axis+1:]),
[i for i in range(insert_before) if i != move_axis] + [move_axis] + [i for i in range(insert_before, self.shape_len+1) if i != move_axis])
# ******************** complex simplifiers ********************
def simplify_ones(self):
# remove places where the shape is all ones
# TODO: this should be factored in to multi shape stride
if self.shape_len == 0: return
all_ones = [all(st.shape[i]==1 for st in self.sts) for i in range(self.shape_len)]
# keep at least 1 one
if all(all_ones): all_ones[-1] = False
self.reshape_and_permute(lambda shape: [x for i,x in enumerate(shape) if not all_ones[i]], None)
def simplify_merge_adjacent(self):
if self.shape_len == 0: return
shapes, strides = [x.shape for x in self.sts], [x.views[-1].strides for x in self.sts]
# merge dimensions if we can, multi get_shape_strides
# TODO: does this always preserve the reduce dimension, NO
# TODO: move this into shapetracker, with tests!
rets = [[(shapes[j][0], strides[j][0])] for j in range(len(shapes))]
for i in range(1, len(shapes[0])):
can_merge = []
for j in range(len(shapes)):
# TODO: added the always mergeability of 1s, is this right? if so, add to shapetracker in the 1 case
can_merge.append((strides[j][i] != 0 and rets[j][-1][1] == shapes[j][i]*strides[j][i]) or (strides[j][i] == 0 and rets[j][-1][1] == 0))
# more can merge than this
mergeable = all(can_merge) and i != self.first_reduce
for j in range(len(shapes)):
if mergeable: rets[j][-1] = (rets[j][-1][0] * shapes[j][i], strides[j][i])
else: rets[j].append((shapes[j][i], strides[j][i]))
# do the reshapes
for i,x in enumerate(rets): self.sts[i].reshape(tuple(y[0] for y in x))
# ******************** GPU simplifiers ********************
def required_optimizations(self, early_only=False):
for buf_index,buf in enumerate(self.bufs):
unit_stride_axes_mul_4 = [i for i in self.sts[buf_index].unit_stride_axes() if self.sts[buf_index].shape[i]%4 == 0]
if (not early_only or buf in self.earlybufs) and isinstance(self.bufs[buf_index].dtype, ImageDType):
assert len(unit_stride_axes_mul_4) >= 1, f"needs a unit stride axis in {self.bufs[buf_index]}"
if all(x < (self.shape_len-self.upcasted) for x in unit_stride_axes_mul_4) and unit_stride_axes_mul_4[0] not in self.upcast_in_mid_reduce_axes:
self.shift_to(unit_stride_axes_mul_4[0], 4)
self.upcast()
def limit_global_dims(self, limit):
# sometimes, there's more dimensions than len(self.lang.gid).
# compact all the dimensions into the first
# NOTE: this might make multiview shapetrackers
if limit and self.first_reduce > limit:
num_to_merge = (self.first_reduce - limit)+1
self.reshape_and_permute(lambda x: (prod(x[0:num_to_merge]),)+x[num_to_merge:], None)
if DEBUG >= 4: print("reshaped to", self.full_shape, "due to too many global dimensions")
def hand_coded_optimizations(self):
# if there's images in the earlybufs, we have to make an axis the 4 loading one
self.required_optimizations(early_only=True)
# simplify (sets first_reduce)
self.simplify_ones()
# are we grouping? (requires local shape support)
if not self.float4_axis(0) and self.first_reduce <= 2 and self.first_reduce + 1 <= self.shape_len and prod(self.sts[0].shape[:self.first_reduce]) <= 2048:
# TODO: use 1024 if it's allowed in a smarter way
for sz in (([256, 16]) if prod(self.sts[0].shape[:self.first_reduce]) <= 32 else [16]):
if all([st.shape[self.first_reduce] % sz == 0 or st.shape[self.first_reduce] == 1 for st in self.sts]):
self.shift_to(self.first_reduce, sz, top=True, insert_before=self.first_reduce + len(self.group_for_reduce))
self.group_for_reduce.append(sz)
break
# are we upcasting in mid reduce? (only for images)
if self.bufs[0].dtype.name.startswith('image') and not self.float4_axis(0) and self.group_for_reduce and self.first_reduce <= 2 and prod(self.sts[0].shape) > 1:
axes = self.sts[0].unit_stride_axes()
assert len(axes) == 1, f"wrong number of stride 1 axis : {axes}"
if self.sts[0].shape[axes[0]]%4 == 0:
self.shift_to(axes[0], 4, insert_before=self.first_reduce + len(self.group_for_reduce)) # insert at the end of the grouped axis
self.group_for_reduce.append(4)
# now do everything required
self.required_optimizations()
# simplify (sets first_reduce)
self.simplify_ones()
# use more opencl indexing if the output buffer is an image and we have room
if self.bufs[0].dtype.name.startswith('image') and self.first_reduce+len(self.group_for_reduce) < 3:
base_shape = self.bufs[0].dtype.shape
if (base_shape[0]*base_shape[1]) % self.sts[0].shape[0] == 0 and self.sts[0].shape[0]//base_shape[0] != 0:
if DEBUG >= 4: print("split opencl", base_shape, self.sts[0].shape)
self.reshape_and_permute(lambda x: [base_shape[0], x[0]//base_shape[0]]+list(x[1:]), None)
self.simplify_ones()
# no more opt if we are grouping
if self.group_for_reduce: return
# **** below this line need to be optional and benchmarked ****
# potentially do more upcasts of non reduce axes based on a heuristic
while prod(self.sts[0].shape[:self.first_reduce]) >= 1024:
xb_choices = []
for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
# if it mods, and some buffer has stride 0 on axis while having no stride 0 in the buftoken
# NOTE: this is using views[-1]
if self.full_shape[axis]%upcast_amount == 0 and any(self.sts[buf_index].views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index in range(len(self.sts))):
xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in self.sts), sum(st.views[-1].strides[axis] for st in self.sts), axis, upcast_amount))
if len(xb_choices):
xb_choices = sorted(xb_choices)
if DEBUG >= 4: print(f"float4 merging axis : {xb_choices}")
self.shift_to(xb_choices[0][2], amount=xb_choices[0][3])
self.upcast()
self.simplify_ones()
else:
break
# if last dim <= 16 and it's a reduce dim, upcast the reduce (loop unrolling). no simplify needed since it's just an upcast. NOTE: careful, this has broken VALIDHACKS
if self.first_reduce < (self.shape_len-self.upcasted) and (len(list(self.shape_offsets(self.full_buf_index))) <= 4 or not any(r for _,_,r in self.upcasted_axis(self.full_buf_index))):
if self.full_unupcasted_shape[-1] <= 16:
self.upcast()
else:
for splits in [4]:
if self.full_unupcasted_shape[-1]%splits == 0:
self.shift_to(len(self.full_unupcasted_shape)-1, splits, insert_before=len(self.full_unupcasted_shape))
self.upcast()
break
# if nothing at all is upcasted and it's easy to, do an upcast
# TODO: this is breaking the tests
for splits in [4]:
if self.upcasted == 0 and len(self.full_unupcasted_shape) > 0 and self.full_unupcasted_shape[-1] % splits == 0:
self.shift_to(len(self.full_unupcasted_shape)-1, splits, insert_before=len(self.full_unupcasted_shape))
self.upcast()