mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
This reverts commit 3b777a9e05.
This commit is contained in:
@@ -892,13 +892,13 @@ class TestIdxUpcast(unittest.TestCase):
|
||||
|
||||
@unittest.skipUnless(is_dtype_supported(dtypes.long), "int64 is supported")
|
||||
def test_overflow_sym(self):
|
||||
self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 1, 2048).bind(32))
|
||||
self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 0, 2048).bind(32))
|
||||
|
||||
def test_regular(self):
|
||||
self.do_op_then_assert(dtypes.int, 64, 64, 64)
|
||||
|
||||
def test_regular_sym(self):
|
||||
self.do_op_then_assert(dtypes.int, 2048, 2048, UOp.variable("dim3", 1, 64).bind(32))
|
||||
self.do_op_then_assert(dtypes.int, 2048, 2048, UOp.variable("dim3", 0, 64).bind(32))
|
||||
|
||||
@unittest.skipIf(PTX, "PTX always convert Ops.INDEX to int64")
|
||||
def test_symfold(self):
|
||||
@@ -910,7 +910,7 @@ class TestIdxUpcast(unittest.TestCase):
|
||||
@unittest.skipIf(is_dtype_supported(dtypes.long), "int64 is supported")
|
||||
def test_int64_unsupported_overflow_sym(self):
|
||||
with self.assertRaises(KeyError):
|
||||
self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 1, 2048).bind(32))
|
||||
self.do_op_then_assert(dtypes.long, 2048, 2048, UOp.variable("dim3", 0, 2048).bind(32))
|
||||
|
||||
@unittest.skipIf(is_dtype_supported(dtypes.long), "int64 is supported")
|
||||
def test_int64_unsupported_overflow(self):
|
||||
|
||||
@@ -16,7 +16,6 @@ from tinygrad.codegen.devectorizer import load_store_folding, load_store_indexin
|
||||
ReduceContext, correct_load_store, pm_render
|
||||
from tinygrad.codegen.optional import get_late_rewrite_patterns
|
||||
from tinygrad.codegen.linearize import block_create, pm_blockend_merge, block_merge, pm_finalize, BlockContext
|
||||
from tinygrad.opt import pm_optimize
|
||||
|
||||
@dataclass
|
||||
class RewriteStep:
|
||||
@@ -43,10 +42,6 @@ def get_rewrites_for_renderer(opts:Renderer, linearizer:bool=True) -> list[Rewri
|
||||
def _get_rewrites_for_renderer(opts:Renderer, linearizer:bool, _QUANTIZE, _DEVECTORIZE, _TRANSCENDENTAL) -> list[RewriteStep]:
|
||||
# ** lowerer (rewrite_shapetracker_with_index) **
|
||||
ret: list[RewriteStep] = []
|
||||
|
||||
# this is kernel.py
|
||||
ret.append(RewriteStep(pm_optimize, ctx=lambda _: opts, name="optimize ast"))
|
||||
|
||||
if _QUANTIZE and opts.device in {"CPU", "DSP"}: ret.append(RewriteStep(pm_quant, name="quantize"))
|
||||
ret.append(RewriteStep(pm_lowerer, get_index, name="lowerer", bottom_up=True))
|
||||
|
||||
|
||||
@@ -7,7 +7,9 @@ from tinygrad.uop.ops import Ops, PatternMatcher, UOp, UPat, Variable, sym_infer
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.renderer import Renderer, ProgramSpec, Estimates
|
||||
from tinygrad.engine.schedule import ScheduleItem
|
||||
from tinygrad.opt import get_optimized_ast
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.uop.spec import type_verify
|
||||
|
||||
# **************** Program Creation ****************
|
||||
|
||||
@@ -25,13 +27,16 @@ def get_program(ast:UOp, renderer:Renderer) -> ProgramSpec:
|
||||
"""
|
||||
|
||||
if getenv("VIZ"): graph_rewrite(ast, PatternMatcher([]), name="View Base AST")
|
||||
modified_ast = get_optimized_ast(ast, renderer) if ast.arg is None or ast.arg.opts_to_apply is not None else ast
|
||||
if __debug__: type_verify(list(modified_ast.toposort()))
|
||||
|
||||
# linearize
|
||||
try:
|
||||
uops = full_rewrite(ast, renderer)
|
||||
uops = full_rewrite(modified_ast, renderer)
|
||||
except RuntimeError:
|
||||
print("***** LINEARIZE FAILURE *****")
|
||||
print(f"ast = {ast}")
|
||||
print(f"opts = {modified_ast.arg.applied_opts}")
|
||||
raise
|
||||
assert uops[-1].op is Ops.SINK, "last uop must be sink"
|
||||
|
||||
|
||||
@@ -2,10 +2,9 @@
|
||||
|
||||
from tinygrad.opt.kernel import Kernel
|
||||
from tinygrad.opt.heuristic import hand_coded_optimizations
|
||||
from tinygrad.uop.ops import UOp, PatternMatcher, UPat, Ops
|
||||
from tinygrad.uop.ops import UOp
|
||||
from tinygrad.helpers import NOOPT, BEAM, USE_TC, getenv
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.uop.spec import type_verify
|
||||
|
||||
def get_optimized_ast(ast:UOp, renderer:Renderer) -> UOp:
|
||||
"""
|
||||
@@ -28,11 +27,4 @@ def get_optimized_ast(ast:UOp, renderer:Renderer) -> UOp:
|
||||
kb = Kernel(ast, opts=renderer)
|
||||
rawbufs = bufs_from_lin(kb, allocate=False)
|
||||
k = beam_search(kb, rawbufs, BEAM.value, bool(getenv("BEAM_ESTIMATE", 1)))
|
||||
ret = k.get_optimized_ast()
|
||||
if __debug__: type_verify(list(ret.toposort()))
|
||||
return ret
|
||||
|
||||
pm_optimize = PatternMatcher([
|
||||
(UPat(Ops.SINK, name="ast"), lambda ctx,ast:
|
||||
get_optimized_ast(ast, ctx) if (ast.arg is None or ast.arg.opts_to_apply is not None) and ast.src[0].st is not None else None),
|
||||
])
|
||||
return k.get_optimized_ast()
|
||||
|
||||
@@ -440,7 +440,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
all_vars = set([x for x in self.toposort() if x.op is Ops.DEFINE_VAR])
|
||||
return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
|
||||
def variables(self) -> list[Variable]:
|
||||
st_vars: list[set[Variable]] = [x.arg.vars() for x in self.toposort() if x.op is Ops.VIEW]
|
||||
st_vars: list[set[Variable]] = [x.st_arg.vars() for x in self.toposort() if x.op in GroupOp.Buffer]
|
||||
return sorted(set.union(*st_vars, set([x.unbind()[0] if x.op is not Ops.DEFINE_VAR else x for x in self.vars()])), key=lambda v: v.arg)
|
||||
|
||||
# *** uop symbolic stuff ***
|
||||
|
||||
Reference in New Issue
Block a user