mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 18:35:12 -05:00
compact the global dimensions using the shapetracker (#897)
This commit is contained in:
@@ -2,7 +2,7 @@ from typing import Final, Dict, Callable, ClassVar, List, Optional, NamedTuple,
|
||||
import math, collections
|
||||
from tinygrad.codegen.linearizer import Linearizer, UOps, UOp, LocalBuffer, LocalTypes
|
||||
from tinygrad.ops import ASTRunner, Op, UnaryOps, BinaryOps, FusedOps
|
||||
from tinygrad.helpers import getenv, partition, ImageDType, DEBUG, dtypes, colored
|
||||
from tinygrad.helpers import getenv, partition, ImageDType, DEBUG, dtypes, colored, prod
|
||||
from tinygrad.runtime.lib import RawConst
|
||||
from tinygrad.shape.symbolic import DivNode, AndNode, render_python, NumNode, Variable, Node, SumNode, MulNode
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
@@ -71,7 +71,6 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
|
||||
|
||||
for uop,newvar,vin,args in uops:
|
||||
if uop == UOps.LOOP:
|
||||
root = None
|
||||
for i,var in enumerate(args[0]):
|
||||
if isinstance(var, NumNode):
|
||||
if args[1] == "global" and lang.gid: global_size.append(1)
|
||||
@@ -80,19 +79,9 @@ def uops_to_cstyle(uops:List[UOp], bufs:List[Union[LocalBuffer,LazyBuffer]], lan
|
||||
kk("{")
|
||||
else:
|
||||
if args[1] == "global" and lang.gid:
|
||||
if len(args[0]) >= 4 and len(args[0])-i > 2:
|
||||
# sometimes, there's more dimensions. compact all the dimensions into the last CL dimension
|
||||
# TODO: these compactions should be searchable (they sort of are with reshapes and permutes)
|
||||
if i == 0:
|
||||
kk(f"{{ int {var.expr} = {lang.gid[-1]}; /* {var.max+1} */")
|
||||
root = var.expr
|
||||
global_size.append(var.max+1)
|
||||
else:
|
||||
kk(f"{{ int {var.expr} = {root} % {var.max+1}; {root} /= {var.max+1};")
|
||||
global_size[-1] *= var.max+1
|
||||
else:
|
||||
kk(f"{{ int {var.expr} = {lang.gid[len(args[0])-1-i]}; /* {var.max+1} */")
|
||||
global_size.append(var.max+1)
|
||||
assert len(args[0]) <= len(lang.gid), f"too many global dimensions, has {len(args[0])} and {len(lang.gid)} are supported"
|
||||
kk(f"{{ int {var.expr} = {lang.gid[len(args[0])-1-i]}; /* {var.max+1} */")
|
||||
global_size.append(var.max+1)
|
||||
elif args[1] == "local" and lang.lid:
|
||||
assert len(args[0]) <= len(lang.lid)
|
||||
kk(f"{{ int {var.expr} = {lang.lid[len(args[0])-1-i]}; /* {var.max+1} */")
|
||||
@@ -203,13 +192,10 @@ class CStyleCodegen(Linearizer):
|
||||
# sometimes, there's more dimensions than len(self.lang.gid).
|
||||
# compact all the dimensions into the first
|
||||
# NOTE: this might make multiview shapetrackers
|
||||
# TODO: this exposes bugs in the optimizers assuming the strides are on a single view
|
||||
"""
|
||||
if len(self.lang.gid) and self.first_reduce > len(self.lang.gid):
|
||||
num_to_merge = (self.first_reduce - len(self.lang.gid))+1
|
||||
self.reshape_and_permute(lambda x: (prod(x[0:num_to_merge]),)+x[num_to_merge:], None)
|
||||
if DEBUG >= 4: print("reshaped to", self.full_shape, "due to too many global dimensions")
|
||||
"""
|
||||
|
||||
self.linearize()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user