mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
LLVM backend uses shapetracker
This commit is contained in:
@@ -9,7 +9,7 @@ from tinygrad.shape import ShapeTracker
|
||||
from tinygrad.shape.symbolic import ModNode, DivNode, render_python # this will go away when VALIDHACKS does
|
||||
# div is different in cl than python
|
||||
render_cl = render_python.copy()
|
||||
render_cl[DivNode] = lambda self,ops: f"({self.a.render(ops)}/{self.b})"
|
||||
render_cl[DivNode] = lambda self,ops,ctx: f"({self.a.render(ops)}/{self.b})"
|
||||
from tinygrad.helpers import getenv
|
||||
|
||||
CUDA = getenv("CUDA", 0)
|
||||
|
||||
@@ -1,69 +1,31 @@
|
||||
from __future__ import annotations
|
||||
import math
|
||||
import functools
|
||||
from typing import Tuple, Union, Dict, Any, List, ClassVar, Optional
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.shape import ShapeTracker, ZeroView
|
||||
from tinygrad.shape import ShapeTracker
|
||||
from tinygrad.ops import LazyOp
|
||||
from tinygrad.ast import ASTKernel
|
||||
import ctypes
|
||||
import numpy as np
|
||||
from tinygrad.ops import DEBUG, UnaryOps, BinaryOps, ReduceOps, ExplicitExecAST
|
||||
from tinygrad.runtime.llvm import LLVM, ir
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, GeNode, LtNode, SumNode, AndNode
|
||||
|
||||
def int_const(x): return ir.Constant(ir.IntType(64), x)
|
||||
|
||||
# this is only used on the crappy path
|
||||
def idx_deref(builder, buf, ptr, idx):
|
||||
if DEBUG >= 1:
|
||||
print("viewcount:", len(buf.st.views), buf.st.expr(), ptr, "on", buf.shape)
|
||||
# TODO: unify this with expr in ShapeTracker
|
||||
valid = None
|
||||
for v in buf.st.views[0:-1][::-1]:
|
||||
if isinstance(v, ZeroView):
|
||||
if valid is None:
|
||||
valid = ir.Constant(ir.IntType(1), 1)
|
||||
acc = 1
|
||||
for s,(x,y) in list(zip(v.old_shape, v.arg))[::-1]:
|
||||
if x < 0 or y > s:
|
||||
lr = idx
|
||||
if acc != 1:
|
||||
lr = builder.sdiv(lr, int_const(acc))
|
||||
lr = builder.srem(lr, int_const(y-x))
|
||||
if x != 0:
|
||||
lr = builder.add(lr, int_const(x))
|
||||
if x < 0:
|
||||
valid = builder.and_(valid, builder.icmp_signed(">=", lr, int_const(0)))
|
||||
if y > s:
|
||||
valid = builder.and_(valid, builder.icmp_signed("<", lr, int_const(s)))
|
||||
acc *= y-x
|
||||
else:
|
||||
acc = 1
|
||||
ret = int_const(v.offset)
|
||||
if DEBUG >= 2:
|
||||
print(f"expanding index {v.shape_strides}")
|
||||
for i,(d,s) in enumerate(v.shape_strides[::-1]):
|
||||
if d != 1 and s != 0:
|
||||
# slow path
|
||||
lr = idx
|
||||
if acc != 1:
|
||||
lr = builder.sdiv(lr, int_const(acc))
|
||||
if acc*d != prod(buf.shape):
|
||||
lr = builder.srem(lr, int_const(d))
|
||||
if s != 1:
|
||||
lr = builder.mul(lr, int_const(s))
|
||||
ret = builder.add(ret, lr)
|
||||
acc *= d
|
||||
idx = ret
|
||||
if valid is not None:
|
||||
# this always does the load, so we have it load *0 if the arg won't be used
|
||||
# TODO: would control flow be faster?
|
||||
aug_idx = builder.select(valid, idx, int_const(0))
|
||||
return builder.select(valid, builder.load(builder.gep(ptr, [aug_idx], inbounds=True)), ir.Constant(ir.FloatType(), 0))
|
||||
else:
|
||||
return builder.load(builder.gep(ptr, [idx], inbounds=True))
|
||||
render_llvm = {
|
||||
Variable: lambda self,ops,ctx: self.expr,
|
||||
NumNode: lambda self,ops,ctx: int_const(self.b),
|
||||
MulNode: lambda self,ops,ctx: ctx.mul(self.a.render(ops,ctx), int_const(self.b)),
|
||||
DivNode: lambda self,ops,ctx: ctx.sdiv(self.a.render(ops,ctx), int_const(self.b)),
|
||||
ModNode: lambda self,ops,ctx: ctx.srem(self.a.render(ops,ctx), int_const(self.b)),
|
||||
GeNode: lambda self,ops,ctx: ctx.icmp_signed(">=", self.a.render(ops,ctx), int_const(self.b)),
|
||||
LtNode: lambda self,ops,ctx: ctx.icmp_signed("<", self.a.render(ops,ctx), int_const(self.b)),
|
||||
SumNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.add(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx)),
|
||||
AndNode: lambda self,ops,ctx: functools.reduce(lambda a,b: ctx.and_(a,b.render(ops,ctx)), self.nodes[1:], self.nodes[0].render(ops,ctx))
|
||||
}
|
||||
|
||||
|
||||
# TODO: Refactor LLVMBuffer and GPUBuffer into ShapeTrackedBuffer
|
||||
class LLVMBuffer(ExplicitExecAST):
|
||||
op_lookup : ClassVar = {
|
||||
UnaryOps.NOOP: lambda builder,x: x,
|
||||
@@ -213,9 +175,15 @@ class LLVMBuffer(ExplicitExecAST):
|
||||
m = kernel_output_type(ir.Undefined)
|
||||
buf_index = k.bufs.index(x)
|
||||
for i, idx in enumerate(get_idxs(builder, idx_level[buf_index][level], buf_index)):
|
||||
if len(x.st.views) > 1:
|
||||
if DEBUG >= 1: print(f"WARNING: {x} has buffers with more than 1 view, can't optimize")
|
||||
element = idx_deref(builder, x, func.args[buf_index], idx)
|
||||
# first view is already implictly handled
|
||||
idx, valid = x.st._expr_idx(Variable(idx, 0, prod(x.st.shape)))
|
||||
idx = idx.render(render_llvm, builder)
|
||||
if valid.min == 0:
|
||||
valid = valid.render(render_llvm, builder)
|
||||
# this always does the load, so we have it load *0 if the arg won't be used
|
||||
# TODO: would control flow be faster?
|
||||
aug_idx = builder.select(valid, idx, int_const(0))
|
||||
element = builder.select(valid, builder.load(builder.gep(func.args[buf_index], [aug_idx], inbounds=True)), ir.Constant(ir.FloatType(), 0))
|
||||
else:
|
||||
element = builder.load(builder.gep(func.args[buf_index], [idx], inbounds=True))
|
||||
m = element if kernel_output_dim == 1 else builder.insert_element(m, element, int_const(i))
|
||||
|
||||
@@ -7,10 +7,10 @@ class Node:
|
||||
b: int
|
||||
min: int
|
||||
max: int
|
||||
def render(self, ops=None):
|
||||
def render(self, ops=None, ctx=None):
|
||||
if ops is None: ops = render_python
|
||||
if self.min == self.max and type(self) != NumNode: return NumNode(self.min).render(ops)
|
||||
return ops[type(self)](self, ops)
|
||||
if self.min == self.max and type(self) != NumNode: return NumNode(self.min).render(ops, ctx)
|
||||
return ops[type(self)](self, ops, ctx)
|
||||
def __add__(self, b:int): return Variable.sum([self, Variable.num(b)])
|
||||
def __mul__(self, b:int):
|
||||
if b == 0: return NumNode(0)
|
||||
@@ -119,13 +119,13 @@ class SumNode(RedNode): minmax = staticmethod(lambda nodes: (sum([x.min for x in
|
||||
class AndNode(RedNode): minmax = staticmethod(lambda nodes: (min([x.min for x in nodes]), max([x.max for x in nodes])))
|
||||
|
||||
render_python : Dict[Type, Callable] = {
|
||||
Variable: lambda self,ops: f"{self.expr}",
|
||||
NumNode: lambda self,ops: f"{self.b}",
|
||||
MulNode: lambda self,ops: f"({self.a.render(ops)}*{self.b})",
|
||||
DivNode: lambda self,ops: f"({self.a.render(ops)}//{self.b})",
|
||||
ModNode: lambda self,ops: f"({self.a.render(ops)}%{self.b})",
|
||||
GeNode: lambda self,ops: f"({self.a.render(ops)}>={self.b})",
|
||||
LtNode: lambda self,ops: f"({self.a.render(ops)}<{self.b})",
|
||||
SumNode: lambda self,ops: f"({'+'.join([x.render(ops) for x in self.nodes])})",
|
||||
AndNode: lambda self,ops: f"({'&&'.join([x.render(ops) for x in self.nodes])})"
|
||||
Variable: lambda self,ops,ctx: f"{self.expr}",
|
||||
NumNode: lambda self,ops,ctx: f"{self.b}",
|
||||
MulNode: lambda self,ops,ctx: f"({self.a.render(ops)}*{self.b})",
|
||||
DivNode: lambda self,ops,ctx: f"({self.a.render(ops)}//{self.b})",
|
||||
ModNode: lambda self,ops,ctx: f"({self.a.render(ops)}%{self.b})",
|
||||
GeNode: lambda self,ops,ctx: f"({self.a.render(ops)}>={self.b})",
|
||||
LtNode: lambda self,ops,ctx: f"({self.a.render(ops)}<{self.b})",
|
||||
SumNode: lambda self,ops,ctx: f"({'+'.join([x.render(ops) for x in self.nodes])})",
|
||||
AndNode: lambda self,ops,ctx: f"({'&&'.join([x.render(ops) for x in self.nodes])})"
|
||||
}
|
||||
Reference in New Issue
Block a user