From af5d77f6844fe9d002e880df51e1073cbba33e9b Mon Sep 17 00:00:00 2001 From: chenyu Date: Fri, 22 Nov 2024 11:15:02 -0500 Subject: [PATCH] move sint_to_uop from view.py to ops.py [pr] (#7848) both sint and uop are in ops.py --- tinygrad/codegen/lowerer.py | 3 +-- tinygrad/ops.py | 2 ++ tinygrad/shape/view.py | 4 +--- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/tinygrad/codegen/lowerer.py b/tinygrad/codegen/lowerer.py index d0e5db725a..12ef528ebe 100644 --- a/tinygrad/codegen/lowerer.py +++ b/tinygrad/codegen/lowerer.py @@ -4,9 +4,8 @@ import functools, itertools, operator from dataclasses import dataclass from typing import List, Tuple, cast, Optional from tinygrad.shape.shapetracker import ShapeTracker -from tinygrad.shape.view import sint_to_uop from tinygrad.dtype import dtypes, PtrDType -from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element +from tinygrad.ops import KernelInfo, UOp, Ops, graph_rewrite, PatternMatcher, UPat, sint, identity_element, sint_to_uop from tinygrad.renderer import Renderer from tinygrad.helpers import all_int, prod, partition, flatten diff --git a/tinygrad/ops.py b/tinygrad/ops.py index e5d539f25f..fb61289615 100644 --- a/tinygrad/ops.py +++ b/tinygrad/ops.py @@ -1036,6 +1036,8 @@ def max_var_const(x:UOp, c1:UOp, c2:UOp): if x.vmin >= 0: return x*c1 if c1.arg >= c2.arg else x*c2 if x.vmax <= 0: return x*c2 if c1.arg >= c2.arg else x*c1 +def sint_to_uop(x:sint) -> UOp: return UOp.const(dtypes.int, x) if isinstance(x, int) else x + symbolic_simple = PatternMatcher([ # ** self folding ** (UPat.var("x") + 0, lambda x: x), # x+0 -> x diff --git a/tinygrad/shape/view.py b/tinygrad/shape/view.py index 1034147548..301061421a 100644 --- a/tinygrad/shape/view.py +++ b/tinygrad/shape/view.py @@ -3,7 +3,7 @@ import functools, operator, itertools, math from dataclasses import dataclass from typing import Tuple, List, Optional, Dict, Set, cast from tinygrad.dtype import dtypes -from tinygrad.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin +from tinygrad.ops import resolve, UOp, Variable, sint, sym_infer, smax, smin, sint_to_uop from tinygrad.helpers import prod, all_int, argsort, flatten, ceildiv @functools.lru_cache(maxsize=None) @@ -81,8 +81,6 @@ def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]: offs -= here * stride return result -def sint_to_uop(x:sint) -> UOp: return UOp.const(dtypes.int, x) if isinstance(x, int) else x - @dataclass(frozen=True) class View: shape:Tuple[sint, ...]