mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
move symbolic and transcendental to uop [pr] (#10771)
This commit is contained in:
@@ -28,7 +28,7 @@ repos:
|
||||
pass_filenames: false
|
||||
- id: tests
|
||||
name: subset of tests
|
||||
entry: env MAX_BUFFER_SIZE=300000000 PYTHONPATH="." python3 -m pytest -n=4 --ignore=test/unit/test_keccak.py --ignore=test/unit/test_indexing.py test/unit/ test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_assign.py test/test_symbolic_shapetracker.py
|
||||
entry: env PYTHONPATH="." python3 -m pytest -n=4 test/test_ops.py test/test_dtype.py test/test_schedule.py test/test_assign.py
|
||||
language: system
|
||||
always_run: true
|
||||
pass_filenames: false
|
||||
|
||||
@@ -13,7 +13,7 @@ from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.dtype import DType, ImageDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.uop.ops import PatternMatcher, UOp, Ops, GroupOp, UPat, graph_rewrite, track_rewrites
|
||||
from tinygrad.codegen.symbolic import symbolic_simple
|
||||
from tinygrad.uop.symbolic import symbolic_simple
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, SPLIT_REDUCEOP, GlobalCounters, Context, getenv, all_same, temp
|
||||
from tinygrad.engine.grouper import view_left, view_right, sym, get_kernelize_map, Kernel, create_ast, merge_views, create_kernels
|
||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
|
||||
|
||||
@@ -3,7 +3,7 @@ import unittest, pytest
|
||||
from tinygrad import dtypes, Variable
|
||||
from tinygrad.helpers import DEBUG, Context
|
||||
from tinygrad.uop.ops import Ops, UOp, UPat, PatternMatcher, track_rewrites, graph_rewrite, GroupOp
|
||||
from tinygrad.codegen.symbolic import sym
|
||||
from tinygrad.uop.symbolic import sym
|
||||
from tinygrad.codegen import full_rewrite, full_rewrite_to_sink
|
||||
from tinygrad.codegen.expander import expander
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ from tinygrad.renderer import ProgramSpec
|
||||
from tinygrad.engine.grouper import fix_kernel_ops
|
||||
from tinygrad.engine.realize import CompiledRunner, get_kernel
|
||||
from tinygrad.codegen import full_rewrite
|
||||
from tinygrad.codegen.symbolic import sym
|
||||
from tinygrad.uop.symbolic import sym
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.codegen.kernel import Kernel, Opt, OptOps
|
||||
|
||||
|
||||
@@ -204,7 +204,7 @@ class TestGEPAndVectorizeRewrite(unittest.TestCase):
|
||||
|
||||
import inspect
|
||||
from tinygrad.uop.ops import graph_rewrite, _substitute, track_rewrites
|
||||
from tinygrad.codegen.symbolic import symbolic_simple
|
||||
from tinygrad.uop.symbolic import symbolic_simple
|
||||
|
||||
class TestBottomUpRewrite(unittest.TestCase):
|
||||
def test_const_folding(self):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import unittest
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.uop.ops import UOp, graph_rewrite_map, _substitute
|
||||
from tinygrad.codegen.symbolic import symbolic
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
|
||||
class TestRewriteMap(unittest.TestCase):
|
||||
def test_substitute(self):
|
||||
|
||||
@@ -3,7 +3,7 @@ import unittest, itertools
|
||||
from tinygrad.codegen import full_rewrite_to_sink
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.codegen.symbolic import simplify_valid
|
||||
from tinygrad.uop.symbolic import simplify_valid
|
||||
|
||||
def get_gated_load_uop(valid:UOp, idx:UOp):
|
||||
return UOp(Ops.LOAD, dtypes.float, (
|
||||
|
||||
@@ -2,8 +2,8 @@ import unittest, math
|
||||
import numpy as np
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.uop.ops import UOp, Ops
|
||||
from tinygrad.codegen.transcendental import TRANSCENDENTAL_SUPPORTED_DTYPES, payne_hanek_reduction, cody_waite_reduction
|
||||
from tinygrad.codegen.transcendental import frexp, rintk, xpow, xexp2, xlog2, trig_poly, pow2if
|
||||
from tinygrad.uop.transcendental import TRANSCENDENTAL_SUPPORTED_DTYPES, payne_hanek_reduction, cody_waite_reduction
|
||||
from tinygrad.uop.transcendental import frexp, rintk, xpow, xexp2, xlog2, trig_poly, pow2if
|
||||
from test.helpers import eval_uop
|
||||
|
||||
class TestTranscendentalFunctions(unittest.TestCase):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import unittest, decimal, json
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import TRACK_MATCH_STATS, TrackedPatternMatcher, UOp, graph_rewrite, track_rewrites, UPat, Ops
|
||||
from tinygrad.codegen.symbolic import symbolic
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
from tinygrad.uop.ops import tracked_ctxs as contexts, tracked_keys as keys, _name_cnt, _substitute
|
||||
from tinygrad.device import ProfileDeviceEvent, ProfileRangeEvent, ProfileGraphEvent, ProfileGraphEntry
|
||||
from tinygrad.viz.serve import get_metadata, get_details, uop_to_json, to_perfetto
|
||||
|
||||
@@ -7,7 +7,7 @@ from tinygrad.renderer import Renderer
|
||||
|
||||
# import all pattern matchers here
|
||||
from tinygrad.codegen.lowerer import pm_quant, pm_lowerer, get_index
|
||||
from tinygrad.codegen.symbolic import sym, symbolic_simple, gep_pushing
|
||||
from tinygrad.uop.symbolic import sym, symbolic_simple, gep_pushing
|
||||
from tinygrad.codegen.expander import migrate_indexing, pm_store_ignore, pm_move_ignore, pm_delete_ignore, expander
|
||||
from tinygrad.codegen.devectorizer import load_store_folding, load_store_indexing, devectorize, \
|
||||
pm_reduce, ReduceContext, correct_load_store, pm_render, get_late_rewrite_patterns
|
||||
|
||||
@@ -5,9 +5,9 @@ from dataclasses import dataclass
|
||||
from tinygrad.device import is_dtype_supported
|
||||
from tinygrad.dtype import dtypes, ImageDType, PtrDType, promo_lattice, DType
|
||||
from tinygrad.uop.ops import UOp, Ops, UPat, PatternMatcher, resolve, graph_rewrite, GroupOp, identity_element
|
||||
from tinygrad.codegen.symbolic import split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat
|
||||
from tinygrad.uop.symbolic import split_uop, uop_given_valid, parse_valid, simplify_valid, sym, symbolic_flat
|
||||
from tinygrad.helpers import getenv, flatten, AMX, prod, partition
|
||||
from tinygrad.codegen.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
from tinygrad.uop.transcendental import xexp2, xlog2, xsin, xpow, TRANSCENDENTAL_SUPPORTED_DTYPES
|
||||
from tinygrad.renderer import Renderer
|
||||
|
||||
# ***** image load valid simplification *****
|
||||
|
||||
@@ -6,7 +6,7 @@ from tinygrad.dtype import dtypes, PtrDType, least_upper_dtype
|
||||
from tinygrad.uop.ops import KernelInfo, UOp, Ops, PatternMatcher, UPat, sint, sint_to_uop
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.helpers import all_int, prod, partition, flatten, unwrap
|
||||
from tinygrad.codegen.symbolic import symbolic
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
|
||||
# returns the axes to create new_shape if new_shape can be created by combining axis from old_shape
|
||||
def get_contraction(old_shape:tuple[sint, ...], new_shape:tuple[sint, ...]) -> list[list[int]]|None:
|
||||
|
||||
@@ -3,7 +3,7 @@ from tinygrad.uop.ops import UOp, Ops, GroupOp, PatternMatcher, UPat, graph_rewr
|
||||
from tinygrad.uop.ops import track_rewrites, _substitute
|
||||
from tinygrad.uop.spec import type_verify, tensor_uop_spec
|
||||
from tinygrad.codegen.lowerer import get_contraction_with_reduce
|
||||
from tinygrad.codegen.symbolic import symbolic_simple
|
||||
from tinygrad.uop.symbolic import symbolic_simple
|
||||
from tinygrad.helpers import Metadata, all_int, all_same, colored, prod, dedup, unwrap, getenv, pluralize
|
||||
from tinygrad.helpers import FUSE_CONV_BW, FUSE_ARANGE, DEBUG, DONT_REALIZE_EXPAND, DONT_GROUP_REDUCES, SPLIT_REDUCEOP
|
||||
from tinygrad.dtype import ImageDType
|
||||
|
||||
@@ -7,7 +7,7 @@ from tinygrad.helpers import merge_dicts, getenv
|
||||
from tinygrad.shape.view import View, strides_for_shape, unravel
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.uop.ops import UOp, Ops, graph_rewrite, Variable, sint, sint_to_uop, Context, PatternMatcher, UPat, GroupOp
|
||||
from tinygrad.codegen.symbolic import split_uop, symbolic_flat, uop_given_valid, simplify_valid
|
||||
from tinygrad.uop.symbolic import split_uop, symbolic_flat, uop_given_valid, simplify_valid
|
||||
|
||||
# If a node overflow, its srcs need to be checked to see if this overflow is the result of an ALU operation,
|
||||
# or that the node simply inherits the dtype from srcs. Upcast is either `Ops.CAST`+`replace` or just `replace`.
|
||||
|
||||
@@ -171,7 +171,7 @@ class UOp(MathTrait, metaclass=UOpMetaClass):
|
||||
|
||||
def simplify(self):
|
||||
# late import!
|
||||
from tinygrad.codegen.symbolic import symbolic
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
with Context(TRACK_MATCH_STATS=0):
|
||||
return graph_rewrite(self, symbolic)
|
||||
def ssimplify(self) -> Union[UOp, ConstType]: return ret.arg if (ret:=self.simplify()).op is Ops.CONST else ret
|
||||
|
||||
@@ -5,7 +5,7 @@ from collections import defaultdict
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UPat, UOp, GroupOp, exec_alu
|
||||
from tinygrad.dtype import ConstType, dtypes, PtrDType
|
||||
from tinygrad.helpers import partition, all_same, prod, flatten, get_single_element, cdiv, cmod, CORRECT_DIVMOD_FOLDING
|
||||
from tinygrad.codegen.transcendental import xpow
|
||||
from tinygrad.uop.transcendental import xpow
|
||||
|
||||
# ******** phase 1 of symbolic used to live in ops, it's the most generic folding rules ********
|
||||
|
||||
Reference in New Issue
Block a user