compact the global dimensions using the shapetracker (#897)

This commit is contained in:
George Hotz
2023-06-01 13:09:54 -07:00
committed by GitHub
parent ef129bcb85
commit dd41f3ee40

View File

@@ -2,7 +2,7 @@ from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple,
import math, collections
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer, LocalTypes
from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps
from tinygrad.helpers import getenv, partition, ImageDType, DEBUG, dtypes, colored
from tinygrad.helpers import getenv, partition, ImageDType, DEBUG, dtypes, colored, prod
from tinygrad.runtime.lib import RawConst
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable, Node, SumNode, MulNode
from tinygrad.lazy import LazyBuffer
@@ -71,7 +71,6 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
for uop,newvar,vin,args in uops:
if uop == UOps.LOOP:
root = None
for i,var in enumerate(args[0]):
if isinstance(var, NumNode):
if args[1] == "global" and lang.gid: global_size.append(1)
@@ -80,19 +79,9 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
kk("{")
else:
if args[1] == "global" and lang.gid:
if len(args[0]) >= 4 and len(args[0])-i > 2:
# sometimes, there's more dimensions. compact all the dimensions into the last CL dimension
# TODO: these compactions should be searchable (they sort of are with reshapes and permutes)
if i == 0:
kk(f"{{ int {var.expr} = {lang.gid[-1]}; /* {var.max+1} */")
root = var.expr
global_size.append(var.max+1)
else:
kk(f"{{ int {var.expr} = {root} % {var.max+1}; {root} /= {var.max+1};")
global_size[-1] *= var.max+1
else:
kk(f"{{ int {var.expr} = {lang.gid[len(args[0])-1-i]}; /* {var.max+1} */")
global_size.append(var.max+1)
assert len(args[0]) <= len(lang.gid), f"too many global dimensions, has {len(args[0])} and {len(lang.gid)} are supported"
kk(f"{{ int {var.expr} = {lang.gid[len(args[0])-1-i]}; /* {var.max+1} */")
global_size.append(var.max+1)
elif args[1] == "local" and lang.lid:
assert len(args[0]) <= len(lang.lid)
kk(f"{{ int {var.expr} = {lang.lid[len(args[0])-1-i]}; /* {var.max+1} */")
@@ -203,13 +192,10 @@ class CStyleCodegen(Linearizer):
# sometimes, there's more dimensions than len(self.lang.gid).
# compact all the dimensions into the first
# NOTE: this might make multiview shapetrackers
# TODO: this exposes bugs in the optimizers assuming the strides are on a single view
"""
if len(self.lang.gid) and self.first_reduce > len(self.lang.gid):
num_to_merge = (self.first_reduce - len(self.lang.gid))+1
self.reshape_and_permute(lambda x: (prod(x[0:num_to_merge]),)+x[num_to_merge:], None)
if DEBUG >= 4: print("reshaped to", self.full_shape, "due to too many global dimensions")
"""
self.linearize()