mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-07 03:00:26 -04:00
restructure simplify_valid_image_load [run_process_replay] (#6581)
* restructure simplify_valid_image_load [run_process_replay] separated parsing valid / idx and simplification * space * type
This commit is contained in:
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
from typing import Optional, Tuple, Dict, List, Set, cast, TYPE_CHECKING, Any, DefaultDict, Callable
|
||||
import functools, itertools, heapq, math, operator
|
||||
from collections import defaultdict
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType
|
||||
from tinygrad.dtype import dtypes, PtrDType, ImageDType, ConstType
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, exec_alu, UOp, UOps, END_FOR_UOP, type_verify, print_uops, identity_element
|
||||
from tinygrad.ops import UPat, PatternMatcher, graph_rewrite
|
||||
from tinygrad.helpers import DEBUG, getenv, flatten, dedup, TRANSCENDENTAL, AMX, prod, CI, partition, all_same
|
||||
@@ -174,41 +174,51 @@ def simplify_valid_image_load(load:UOp, buf:UOp):
|
||||
new_valid = functools.reduce(operator.and_, ss) if (ss:=[s for s in _get_chain(valid, BinaryOps.AND) if s is not stmt]) else None
|
||||
return UOp(UOps.LOAD, load.dtype, (buf, idx, invalid_val, new_valid)) if new_valid else UOp(UOps.LOAD, load.dtype, (buf, idx))
|
||||
|
||||
# We want to simplify expressions like (X*c+d)%m in the idx, with optional *c and +d. m is the total length of the row.
|
||||
# If the contraints in valid implies that it "spans" the whole row, and we can rewrite it to X*c+k for some k, and drop the valid.
|
||||
# first, parse valid into {expr: ((lower_bound, statement), (upper_bound, statement))}
|
||||
bounds:DefaultDict[UOp, List[Optional[Tuple[ConstType, UOp]]]] = defaultdict(lambda: [None, None])
|
||||
for stmt in _get_chain(valid, BinaryOps.AND):
|
||||
if stmt.op is UOps.ALU and stmt.arg is BinaryOps.CMPLT and stmt.src[1].op is UOps.CONST:
|
||||
if (s:=stmt.src[0]).op is UOps.ALU and s.arg is BinaryOps.MUL and s.src[1].op is UOps.CONST and s.src[1].arg == -1:
|
||||
bounds[s.src[0]][0] = (-stmt.src[1].arg+1, stmt)
|
||||
else: bounds[s][1] = (stmt.src[1].arg-1, stmt)
|
||||
|
||||
for v in bounds.values():
|
||||
# some expr has lower bound > upper bound -> valid is an empty set
|
||||
if v[0] is not None and v[1] is not None and v[0][0] > v[1][0]:
|
||||
return UOp(UOps.LOAD, load.dtype, (buf, idx, invalid_val, valid.const_like(False)))
|
||||
|
||||
# next, parse idx by the form ((X*c+d)%m, ((X*c+d)//m+e))
|
||||
# parse m
|
||||
m = mod.src[1].arg if (mod:=idx.src[0]).op is UOps.ALU and mod.arg is BinaryOps.MOD and mod.src[1].op is UOps.CONST else None
|
||||
if not m or m != buf_dtype.shape[1]: return None
|
||||
# parse idx.src[0]
|
||||
d = add.src[1].arg if (add:=mod.src[0]).op is UOps.ALU and add.arg is BinaryOps.ADD and add.src[1].op is UOps.CONST else 0
|
||||
mul = add.src[0] if d else add # + d is optional
|
||||
c = mul.src[1].arg if mul.op is UOps.ALU and mul.arg is BinaryOps.MUL and mul.src[1].op is UOps.CONST else 1
|
||||
X = mul.src[0] if c != 1 else mul # * c is optional
|
||||
# parse idx.src[1]
|
||||
e = add1.src[1].arg if (add1:=idx.src[1]).op is UOps.ALU and add1.arg is BinaryOps.ADD and add1.src[1].op is UOps.CONST else 0
|
||||
div = add1.src[0] if e else add1
|
||||
m_ = div.src[1].arg if div.op is UOps.ALU and div.arg is BinaryOps.IDIV and div.src[1].op is UOps.CONST else None
|
||||
if m_ != m or div.src[0] != add: return None
|
||||
|
||||
lower, upper = X.vmin, X.vmax
|
||||
# from valid, find the bound of X
|
||||
drop_stmt = []
|
||||
if X in bounds and (b0:=bounds[X][0]) is not None:
|
||||
lower = b0[0]
|
||||
drop_stmt.append(b0[1])
|
||||
else: lower = X.vmin
|
||||
if X in bounds and (b1:=bounds[X][1]) is not None:
|
||||
upper = b1[0]
|
||||
drop_stmt.append(b1[1])
|
||||
else: upper = X.vmax
|
||||
|
||||
for stmt in _get_chain(valid, BinaryOps.AND):
|
||||
if stmt.op is UOps.ALU and stmt.arg is BinaryOps.CMPLT and stmt.src[1].op is UOps.CONST:
|
||||
if stmt.src[0].key == X.key: # X < c
|
||||
upper = stmt.src[1].arg-1
|
||||
drop_stmt.append(stmt)
|
||||
elif stmt.src[0].key == (-X).key: # -X < -c -> X > c
|
||||
lower = -stmt.src[1].arg+1
|
||||
drop_stmt.append(stmt)
|
||||
|
||||
# valid is an empty set
|
||||
if upper < lower: return UOp(UOps.LOAD, load.dtype, (buf, idx, invalid_val, valid.const_like(False)))
|
||||
|
||||
# If the contraints in valid implies that it "spans" the whole row, and we can rewrite it to X*c+k for some k, and drop the valid.
|
||||
new_indx0, new_indx1 = None, None
|
||||
if (L:=(lower * c + d)) // m == (U:=(upper * c + d)) // m: # in the same row
|
||||
if (L % m - c < 0) and (U % m + c >= m): # spans the whole row
|
||||
new_indx0 = graph_rewrite(mul - ((L // m) * m - d), constant_folder)
|
||||
|
||||
# Because (X * c + d) % m spans the whole row, (X * c + d) // m has a fixed value.
|
||||
# check if idx1 is a div that can be simplified. idx1 = (add // m + e)
|
||||
e = add1.src[1].arg if (add1:=idx.src[1]).op is UOps.ALU and add1.arg is BinaryOps.ADD and add1.src[1].op is UOps.CONST else 0
|
||||
div = add1.src[0] if e else add1
|
||||
m_ = div.src[1].arg if div.op is UOps.ALU and div.arg is BinaryOps.IDIV and div.src[1].op is UOps.CONST else None
|
||||
if m_ == m and div.src[0] == add: new_indx1 = idx.src[1].const_like(L // m + e)
|
||||
new_indx1 = idx.src[1].const_like(L // m + e)
|
||||
|
||||
if new_indx0 and new_indx1:
|
||||
new_idx = UOp(UOps.VECTORIZE, dtypes.int.vec(2), (new_indx0, new_indx1))
|
||||
|
||||
Reference in New Issue
Block a user