LLVM backend uses shapetracker

This commit is contained in:
George Hotz
2023-02-10 13:53:33 -06:00
parent c3cf17c6d0
commit 87a7717222
3 changed files with 36 additions and 68 deletions

View File

@@ -9,7 +9,7 @@ from tinygrad.shape import ShapeTracker
from tinygrad.shape.symbolic import ModNode, DivNode, render_python # this will go away when VALIDHACKS does
# div is different in cl than python
render_cl = render_python.copy()
render_cl[DivNode] = lambda self,ops: f"({self.a.render(ops)}/{self.b})"
render_cl[DivNode] = lambda self,ops,ctx: f"({self.a.render(ops)}/{self.b})"
from tinygrad.helpers import getenv
CUDA = getenv("CUDA", 0)

View File

@@ -1,69 +1,31 @@
from __future__ import annotations
import math
import functools
from typing import Tuple, Union, Dict, Any, List, ClassVar, Optional
from tinygrad.helpers import prod
from tinygrad.shape import ShapeTracker, ZeroView
from tinygrad.shape import ShapeTracker
from tinygrad.ops import LazyOp
from tinygrad.ast import ASTKernel
import ctypes
import numpy as np
from tinygrad.ops import DEBUG, UnaryOps, BinaryOps, ReduceOps, ExplicitExecAST
from tinygrad.runtime.llvm import LLVM, ir
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, GeNode, LtNode, SumNode, AndNode
def int_const(x): return ir.Constant(ir.IntType(64), x)
# this is only used on the crappy path
def idx_deref(builder, buf, ptr, idx):
if DEBUG >= 1:
print("viewcount:", len(buf.st.views), buf.st.expr(), ptr, "on", buf.shape)
# TODO: unify this with expr in ShapeTracker
valid = None
for v in buf.st.views[0:-1][::-1]:
if isinstance(v, ZeroView):
if valid is None:
valid = ir.Constant(ir.IntType(1), 1)
acc = 1
for s,(x,y) in list(zip(v.old_shape, v.arg))[::-1]:
if x < 0 or y > s:
lr = idx
if acc != 1:
lr = builder.sdiv(lr, int_const(acc))
lr = builder.srem(lr, int_const(y-x))
if x != 0:
lr = builder.add(lr, int_const(x))
if x < 0:
valid = builder.and_(valid, builder.icmp_signed(">=", lr, int_const(0)))
if y > s:
valid = builder.and_(valid, builder.icmp_signed("<", lr, int_const(s)))
acc *= y-x
else:
acc = 1
ret = int_const(v.offset)
if DEBUG >= 2:
print(f"expanding index {v.shape_strides}")
for i,(d,s) in enumerate(v.shape_strides[::-1]):
if d != 1 and s != 0:
# slow path
lr = idx
if acc != 1:
lr = builder.sdiv(lr, int_const(acc))
if acc*d != prod(buf.shape):
lr = builder.srem(lr, int_const(d))
if s != 1:
lr = builder.mul(lr, int_const(s))
ret = builder.add(ret, lr)
acc *= d
idx = ret
if valid is not None:
# this always does the load, so we have it load *0 if the arg won't be used
# TODO: would control flow be faster?
aug_idx = builder.select(valid, idx, int_const(0))
return builder.select(valid, builder.load(builder.gep(ptr, [aug_idx], inbounds=True)), ir.Constant(ir.FloatType(), 0))
else:
return builder.load(builder.gep(ptr, [idx], inbounds=True))
render_llvm = {
Variable: lambda self,ops,ctx: self.expr,
NumNode: lambda self,ops,ctx: int_const(self.b),
MulNode: lambda self,ops,ctx: ctx.mul(self.a.render(ops,ctx), int_const(self.b)),
DivNode: lambda self,ops,ctx: ctx.sdiv(self.a.render(ops,ctx), int_const(self.b)),
ModNode: lambda self,ops,ctx: ctx.srem(self.a.render(ops,ctx), int_const(self.b)),
GeNode: lambda self,ops,ctx: ctx.icmp_signed(">=", self.a.render(ops,ctx), int_const(self.b)),
LtNode: lambda self,ops,ctx: ctx.icmp_signed("<", self.a.render(ops,ctx), int_const(self.b)),
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.add(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.and_(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx))
}
# TODO: Refactor LLVMBuffer and GPUBuffer into ShapeTrackedBuffer
class LLVMBuffer(ExplicitExecAST):
op_lookup : ClassVar = {
UnaryOps.NOOP: lambda builder,x: x,
@@ -213,9 +175,15 @@ class LLVMBuffer(ExplicitExecAST):
m = kernel_output_type(ir.Undefined)
buf_index = k.bufs.index(x)
for i, idx in enumerate(get_idxs(builder, idx_level[buf_index][level], buf_index)):
if len(x.st.views) > 1:
if DEBUG >= 1: print(f"WARNING: {x} has buffers with more than 1 view, can't optimize")
element = idx_deref(builder, x, func.args[buf_index], idx)
# first view is already implictly handled
idx, valid = x.st._expr_idx(Variable(idx, 0, prod(x.st.shape)))
idx = idx.render(render_llvm, builder)
if valid.min == 0:
valid = valid.render(render_llvm, builder)
# this always does the load, so we have it load *0 if the arg won't be used
# TODO: would control flow be faster?
aug_idx = builder.select(valid, idx, int_const(0))
element = builder.select(valid, builder.load(builder.gep(func.args[buf_index], [aug_idx], inbounds=True)), ir.Constant(ir.FloatType(), 0))
else:
element = builder.load(builder.gep(func.args[buf_index], [idx], inbounds=True))
m = element if kernel_output_dim == 1 else builder.insert_element(m, element, int_const(i))

View File

@@ -7,10 +7,10 @@ class Node:
b: int
min: int
max: int
def render(self, ops=None):
def render(self, ops=None, ctx=None):
if ops is None: ops = render_python
if self.min == self.max and type(self) != NumNode: return NumNode(self.min).render(ops)
return ops[type(self)](self, ops)
if self.min == self.max and type(self) != NumNode: return NumNode(self.min).render(ops, ctx)
return ops[type(self)](self, ops, ctx)
def __add__(self, b:int): return Variable.sum([self, Variable.num(b)])
def __mul__(self, b:int):
if b == 0: return NumNode(0)
@@ -119,13 +119,13 @@ class SumNode(RedNode): minmax = staticmethod(lambda nodes: (sum([x.min for x in
class AndNode(RedNode): minmax = staticmethod(lambda nodes: (min([x.min for x in nodes]), max([x.max for x in nodes])))
render_python : Dict[Type, Callable] = {
Variable: lambda self,ops: f"{self.expr}",
NumNode: lambda self,ops: f"{self.b}",
MulNode: lambda self,ops: f"({self.a.render(ops)}*{self.b})",
DivNode: lambda self,ops: f"({self.a.render(ops)}//{self.b})",
ModNode: lambda self,ops: f"({self.a.render(ops)}%{self.b})",
GeNode: lambda self,ops: f"({self.a.render(ops)}>={self.b})",
LtNode: lambda self,ops: f"({self.a.render(ops)}<{self.b})",
SumNode: lambda self,ops: f"({'+'.join([x.render(ops) for x in self.nodes])})",
AndNode: lambda self,ops: f"({'&&'.join([x.render(ops) for x in self.nodes])})"
Variable: lambda self,ops,ctx: f"{self.expr}",
NumNode: lambda self,ops,ctx: f"{self.b}",
MulNode: lambda self,ops,ctx: f"({self.a.render(ops)}*{self.b})",
DivNode: lambda self,ops,ctx: f"({self.a.render(ops)}//{self.b})",
ModNode: lambda self,ops,ctx: f"({self.a.render(ops)}%{self.b})",
GeNode: lambda self,ops,ctx: f"({self.a.render(ops)}>={self.b})",
LtNode: lambda self,ops,ctx: f"({self.a.render(ops)}<{self.b})",
SumNode: lambda self,ops,ctx: f"({'+'.join([x.render(ops) for x in self.nodes])})",
AndNode: lambda self,ops,ctx: f"({'&&'.join([x.render(ops) for x in self.nodes])})"
}