mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 10:31:41 -05:00
good changes from tensor_cores branch (#1005)
* good changes from tensor_cores branch * touchups * real_strides fixup * refactor merge_views
This commit is contained in:
@@ -143,16 +143,17 @@ class Linearizer:
|
||||
def shape_offsets(self, i): return itertools.product(*[list(range(s)) for s in self.sts[i].shape[self.shape_len-self.upcasted:][::-1]]) if self.upcasted > 0 else [tuple()]
|
||||
def float4_axis(self, i): return [x-(self.shape_len-self.upcasted) for x in self.sts[i].unit_stride_axes() if x >= self.shape_len-self.upcasted and self.sts[i].shape[x]%4 == 0]
|
||||
|
||||
# TODO: this stride is only on the last view, and may not be real
|
||||
def upcasted_axis(self, i):
|
||||
return list(zip(self.sts[i].shape[self.shape_len-self.upcasted:],
|
||||
self.sts[i].views[-1].strides[self.shape_len-self.upcasted:], # WRONG
|
||||
self.sts[i].real_strides()[self.shape_len-self.upcasted:],
|
||||
[x!=y for x,y in zip(self.sts[0].shape[self.shape_len-self.upcasted:], self.full_shape[self.shape_len-self.upcasted:])]))
|
||||
|
||||
# TODO: is there a better way to write this?
|
||||
def acc_offsets(self, i):
|
||||
if self.upcasted == 0: return [0]
|
||||
acc_strides = [x*(1-self.upcasted_axis(i)[::-1][i][2]) for i,x in enumerate(strides_for_shape(tuple(1 if r else s for s,_,r in self.upcasted_axis(i)[::-1])))]
|
||||
return [sum(t) for t in itertools.product(*[[y*acc_strides[i] for y in range(x[0])] for i,x in enumerate(self.upcasted_axis(i)[::-1])])]
|
||||
upcasted_i = self.upcasted_axis(i)
|
||||
acc_strides = [x*(1-upcasted_i[::-1][i][2]) for i,x in enumerate(strides_for_shape(tuple(1 if r else s for s,_,r in upcasted_i[::-1])))]
|
||||
return [sum(t) for t in itertools.product(*[[y*acc_strides[i] for y in range(x[0])] for i,x in enumerate(upcasted_i[::-1])])]
|
||||
|
||||
def _group_float4(self, i, store_offset):
|
||||
store_offset_float4 = {}
|
||||
@@ -506,12 +507,12 @@ class Linearizer:
|
||||
# **** below this line need to be optional and benchmarked ****
|
||||
|
||||
# potentially do more upcasts of non reduce axes based on a heuristic
|
||||
upcasted_axis = set()
|
||||
while prod(self.sts[0].shape[:self.first_reduce]) >= 1024:
|
||||
xb_choices = []
|
||||
for axis, upcast_amount in itertools.product(range(self.first_reduce), [3,4]): # consider all the non reduce axes, and a 3 or 4 reduce
|
||||
# if it mods, and some buffer has stride 0 on axis while having no stride 0 in the buftoken
|
||||
# NOTE: this is using views[-1]
|
||||
if self.full_shape[axis]%upcast_amount == 0 and any(self.sts[buf_index].views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index in range(len(self.sts))):
|
||||
# if we haven't upcasted it, it mods, and some buffer has stride 0 on axis while having no stride 0 in the upcasted axis already
|
||||
if axis not in upcasted_axis and self.full_shape[axis]%upcast_amount == 0 and any(self.sts[buf_index].views[-1].strides[axis] == 0 and not any(x[1] == 0 for x in self.upcasted_axis(buf_index)) for buf_index in range(len(self.sts))):
|
||||
xb_choices.append((sum(st.views[-1].strides[axis]>0 for st in self.sts), sum(st.views[-1].strides[axis] for st in self.sts), axis, upcast_amount))
|
||||
if len(xb_choices):
|
||||
xb_choices = sorted(xb_choices)
|
||||
@@ -519,6 +520,7 @@ class Linearizer:
|
||||
self.shift_to(xb_choices[0][2], amount=xb_choices[0][3])
|
||||
self.upcast()
|
||||
self.simplify_ones()
|
||||
upcasted_axis.add(xb_choices[0][2])
|
||||
else:
|
||||
break
|
||||
|
||||
|
||||
@@ -28,7 +28,6 @@ class _CL:
|
||||
for q in self.cl_queue: q.finish()
|
||||
CL = _CL()
|
||||
|
||||
# TODO: merge CLImage in here
|
||||
class CLBuffer(RawBufferCopyInOut):
|
||||
def __init__(self, size, dtype, device='0'):
|
||||
assert not OSX or dtype != dtypes.float64, "OpenCL on Mac doesn't support float64"
|
||||
|
||||
@@ -4,17 +4,11 @@ from enum import Enum, auto
|
||||
import functools
|
||||
from typing import Dict, Tuple, Union, List, Optional, Callable, cast
|
||||
from tinygrad.helpers import prod, DEBUG
|
||||
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node, SumNode, ModNode
|
||||
from tinygrad.shape.symbolic import Variable, MulNode, NumNode, Node
|
||||
|
||||
# these ops live here
|
||||
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto() # noqa: E702
|
||||
|
||||
def check_no_mul(test, var):
|
||||
if test == var: return True
|
||||
if test.__class__ is SumNode: return any(check_no_mul(x, var) for x in test.nodes) # in a sum is okay
|
||||
if test.__class__ is ModNode and test.b%4 == 0: return check_no_mul(test.a, var) # removing a mod is okay
|
||||
return False
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def to_shape_strides(shape:Tuple[int, ...], strides:Tuple[int, ...]) -> List[Tuple[int, int]]:
|
||||
assert len(shape) == len(strides)
|
||||
@@ -97,22 +91,10 @@ def view_from_shape(shape:Tuple[int, ...]) -> View:
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def merge_views(vm2:View, vm1:View) -> Optional[View]:
|
||||
if vm2.mask: return None # this isn't supported yet
|
||||
new_strides, new_offset = [], vm2.expr_node(Variable.num(vm1.offset))
|
||||
assert isinstance(new_offset, NumNode), "new_offset wasn't a number?!?"
|
||||
for s,st in zip(vm1.shape, vm1.strides):
|
||||
this_dim = View(vm2.shape, vm2.strides).expr_node(Variable('idx', 0, s-1)*st)
|
||||
if s == 1:
|
||||
new_strides.append(0) # all shape 1 can have stride 0
|
||||
elif this_dim.__class__ is NumNode and this_dim.b == 0:
|
||||
new_strides.append(0)
|
||||
elif this_dim.__class__ is Variable:
|
||||
new_strides.append(1)
|
||||
elif this_dim.__class__ is MulNode and cast(MulNode, this_dim).a.__class__ is Variable:
|
||||
new_strides.append(this_dim.b)
|
||||
else:
|
||||
if DEBUG >= 4: print("can't simplify", s, this_dim.render())
|
||||
break
|
||||
return View(vm1.shape, tuple(new_strides), new_offset.b, vm1.mask) if len(new_strides) == len(vm1.strides) else None
|
||||
mst = ShapeTracker(vm1.shape, [vm2, vm1])
|
||||
strides = mst.real_strides()
|
||||
if None in strides: return None
|
||||
return View(vm1.shape, cast(Tuple[int, ...], strides), mst.real_offset(), vm1.mask)
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def _reshape(view: View, new_shape: Tuple[int, ...]) -> Tuple[View, bool]:
|
||||
@@ -168,15 +150,33 @@ class ShapeTracker:
|
||||
# this is the real size (ish)
|
||||
def size(self): return prod([s for s,st in zip(self.views[-1].shape, self.views[-1].strides) if st != 0])
|
||||
|
||||
def unit_stride_axes(self) -> List[int]:
|
||||
ret, acc = [], 1
|
||||
for j,s in reversed(list(enumerate(self.shape))):
|
||||
if s == 1: continue
|
||||
# these are multiview strides, value is None if it's not a simple strided dimension
|
||||
# TODO: this can be shared code between simplify and merge_views
|
||||
def real_offset(self) -> int:
|
||||
real_offset, mask = self.expr_node(Variable('zero', 0, 0))
|
||||
assert real_offset.__class__ is NumNode, f"how is the offset not a number? {real_offset} {mask}"
|
||||
return real_offset.b
|
||||
|
||||
def real_strides(self) -> Tuple[Optional[int], ...]:
|
||||
if len(self.views) == 1: return self.views[-1].strides
|
||||
ret: List[Optional[int]] = []
|
||||
acc, real_offset = 1, self.real_offset()
|
||||
for s in reversed(self.shape):
|
||||
if s == 1: # fast path, all shape 1 have stride 0
|
||||
ret.append(0)
|
||||
continue
|
||||
var = Variable('idx', 0, s-1)
|
||||
this_dim = self.expr_node(var*acc)
|
||||
this_dim, _ = self.expr_node(var*acc)
|
||||
this_dim -= real_offset
|
||||
acc *= s
|
||||
if check_no_mul(this_dim[0], var): ret.append(j)
|
||||
return ret
|
||||
# TODO: sometimes a mod here is okay if you are say, reading a float4, since you only care %4
|
||||
# if test.__class__ is ModNode and test.b%4 == 0: return check_no_mul(test.a, var) # removing a mod is okay
|
||||
if this_dim.__class__ is MulNode and cast(MulNode, this_dim).a.__class__ is Variable: ret.append(this_dim.b)
|
||||
elif this_dim.__class__ is NumNode and this_dim.b == 0: ret.append(0)
|
||||
elif this_dim.__class__ is Variable: ret.append(1)
|
||||
else: ret.append(None)
|
||||
return tuple(ret[::-1])
|
||||
def unit_stride_axes(self) -> List[int]: return [i for i,st in enumerate(self.real_strides()) if st == 1]
|
||||
|
||||
def _expr_idx(self, idx, valid):
|
||||
for v in reversed(self.views[0:-1]):
|
||||
|
||||
Reference in New Issue
Block a user