good changes from tensor_cores branch (#1005)

* good changes from tensor_cores branch

* touchups

* real_strides fixup

* refactor merge_views
This commit is contained in:
George Hotz
2023-06-18 20:28:06 -07:00
committed by GitHub
parent ccb51ff5b0
commit 5428b5d774
6 changed files with 118 additions and 53 deletions

View File

@@ -143,16 +143,17 @@ class Linearizer:
def shape_offsets(self, i): return itertools.product(*[list(range(s)) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()]
def float4_axis(self, i): return [x-(self.shape_len-self.upcasted) for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x]%4 == 0]
# TODO: this stride is only on the last view, and may not be real
def upcasted_axis(self, i):
return list(zip(self.sts[i].shape[self.shape_len-self.upcasted:],
self.sts[i].views[-1].strides[self.shape_len-self.upcasted:], # WRONG
self.sts[i].real_strides()[self.shape_len-self.upcasted:],
[x!=y for x,y in zip(self.sts[0].shape[self.shape_len-self.upcasted:], self.full_shape[self.shape_len-self.upcasted:])]))
# TODO: is there a better way to write this?
def acc_offsets(self, i):
if self.upcasted == 0: return [0]
acc_strides = [x*(1-self.upcasted_axis(i)[::-1][i][2]) for i,x in enumerate(strides_for_shape(tuple(1 if r else s for s,_,r in self.upcasted_axis(i)[::-1])))]
return [sum(t) for t in itertools.product(*[[y*acc_strides[i] for y in range(x[0])] for i,x in enumerate(self.upcasted_axis(i)[::-1])])]
upcasted_i = self.upcasted_axis(i)
acc_strides = [x*(1-upcasted_i[::-1][i][2]) for i,x in enumerate(strides_for_shape(tuple(1 if r else s for s,_,r in upcasted_i[::-1])))]
return [sum(t) for t in itertools.product(*[[y*acc_strides[i] for y in range(x[0])] for i,x in enumerate(upcasted_i[::-1])])]
def _group_float4(self, i, store_offset):
store_offset_float4 = {}
@@ -506,12 +507,12 @@ class Linearizer:
# **** below this line need to be optional and benchmarked ****
# potentially do more upcasts of non reduce axes based on a heuristic
upcasted_axis = set()
while prod(self.sts[0].shape[:self.first_reduce]) >= 1024:
xb_choices = []
for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
# if it mods, and some buffer has stride 0 on axis while having no stride 0 in the buftoken
# NOTE: this is using views[-1]
if self.full_shape[axis]%upcast_amount == 0 and any(self.sts[buf_index].views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index in range(len(self.sts))):
# if we haven't upcasted it, it mods, and some buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
if axis not in upcasted_axis and self.full_shape[axis]%upcast_amount == 0 and any(self.sts[buf_index].views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index in range(len(self.sts))):
xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in self.sts), sum(st.views[-1].strides[axis] for st in self.sts), axis, upcast_amount))
if len(xb_choices):
xb_choices = sorted(xb_choices)
@@ -519,6 +520,7 @@ class Linearizer:
self.shift_to(xb_choices[0][2], amount=xb_choices[0][3])
self.upcast()
self.simplify_ones()
upcasted_axis.add(xb_choices[0][2])
else:
break

View File

@@ -28,7 +28,6 @@ class _CL:
for q in self.cl_queue: q.finish()
CL = _CL()
# TODO: merge CLImage in here
class CLBuffer(RawBufferCopyInOut):
def __init__(self, size, dtype, device='0'):
assert not OSX or dtype != dtypes.float64, "OpenCL on Mac doesn't support float64"

View File

@@ -4,17 +4,11 @@ from enum import Enum, auto
import functools
from typing import Dict, Tuple, Union, List, Optional, Callable, cast
from tinygrad.helpers import prod, DEBUG
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode, ModNode
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node
# these ops live here
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto() # noqa: E702
def check_no_mul(test, var):
if test == var: return True
if test.__class__ is SumNode: return any(check_no_mul(x, var) for x in test.nodes) # in a sum is okay
if test.__class__ is ModNode and test.b%4 == 0: return check_no_mul(test.a, var) # removing a mod is okay
return False
@functools.lru_cache(maxsize=None)
def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tuple[int, int]]:
assert len(shape) == len(strides)
@@ -97,22 +91,10 @@ def view_from_shape(shape:Tuple[int, ...]) -> View:
@functools.lru_cache(maxsize=None)
def merge_views(vm2:View, vm1:View) -> Optional[View]:
if vm2.mask: return None # this isn't supported yet
new_strides, new_offset = [], vm2.expr_node(Variable.num(vm1.offset))
assert isinstance(new_offset, NumNode), "new_offset wasn't a number?!?"
for s,st in zip(vm1.shape, vm1.strides):
this_dim = View(vm2.shape, vm2.strides).expr_node(Variable('idx', 0, s-1)*st)
if s == 1:
new_strides.append(0) # all shape 1 can have stride 0
elif this_dim.__class__ is NumNode and this_dim.b == 0:
new_strides.append(0)
elif this_dim.__class__ is Variable:
new_strides.append(1)
elif this_dim.__class__ is MulNode and cast(MulNode, this_dim).a.__class__ is Variable:
new_strides.append(this_dim.b)
else:
if DEBUG >= 4: print("can't simplify", s, this_dim.render())
break
return View(vm1.shape, tuple(new_strides), new_offset.b, vm1.mask) if len(new_strides) == len(vm1.strides) else None
mst = ShapeTracker(vm1.shape, [vm2, vm1])
strides = mst.real_strides()
if None in strides: return None
return View(vm1.shape, cast(Tuple[int, ...], strides), mst.real_offset(), vm1.mask)
@functools.lru_cache(maxsize=None)
def _reshape(view: View, new_shape: Tuple[int, ...]) -> Tuple[View, bool]:
@@ -168,15 +150,33 @@ class ShapeTracker:
# this is the real size (ish)
def size(self): return prod([s for s,st in zip(self.views[-1].shape, self.views[-1].strides) if st != 0])
def unit_stride_axes(self) -> List[int]:
ret, acc = [], 1
for j,s in reversed(list(enumerate(self.shape))):
if s == 1: continue
# these are multiview strides, value is None if it's not a simple strided dimension
# TODO: this can be shared code between simplify and merge_views
def real_offset(self) -> int:
real_offset, mask = self.expr_node(Variable('zero', 0, 0))
assert real_offset.__class__ is NumNode, f"how is the offset not a number? {real_offset} {mask}"
return real_offset.b
def real_strides(self) -> Tuple[Optional[int], ...]:
if len(self.views) == 1: return self.views[-1].strides
ret: List[Optional[int]] = []
acc, real_offset = 1, self.real_offset()
for s in reversed(self.shape):
if s == 1: # fast path, all shape 1 have stride 0
ret.append(0)
continue
var = Variable('idx', 0, s-1)
this_dim = self.expr_node(var*acc)
this_dim, _ = self.expr_node(var*acc)
this_dim -= real_offset
acc *= s
if check_no_mul(this_dim[0], var): ret.append(j)
return ret
# TODO: sometimes a mod here is okay if you are say, reading a float4, since you only care %4
# if test.__class__ is ModNode and test.b%4 == 0: return check_no_mul(test.a, var) # removing a mod is okay
if this_dim.__class__ is MulNode and cast(MulNode, this_dim).a.__class__ is Variable: ret.append(this_dim.b)
elif this_dim.__class__ is NumNode and this_dim.b == 0: ret.append(0)
elif this_dim.__class__ is Variable: ret.append(1)
else: ret.append(None)
return tuple(ret[::-1])
def unit_stride_axes(self) -> List[int]: return [i for i,st in enumerate(self.real_strides()) if st == 1]
def _expr_idx(self, idx, valid):
for v in reversed(self.views[0:-1]):