mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
Merge branch 'master' into retinanet_mlperf
This commit is contained in:
@@ -4,12 +4,11 @@ from extra.mcts_search import mcts_search
|
||||
from examples.mlperf.helpers import get_mlperf_bert_model
|
||||
from tinygrad import Tensor, Device, dtypes, nn
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.ops import UOps
|
||||
from tinygrad.ops import UOps, sym_infer
|
||||
from tinygrad.device import Compiled
|
||||
from tinygrad.engine.schedule import create_schedule
|
||||
from tinygrad.engine.search import time_linearizer, beam_search, bufs_from_lin
|
||||
from tinygrad.helpers import DEBUG, ansilen, getenv, colored, TRACEMETA
|
||||
from tinygrad.shape.symbolic import sym_infer
|
||||
|
||||
def get_sched_resnet():
|
||||
mdl = ResNet50()
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
export PYTHONPATH="."
|
||||
export MODEL="bert"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=6
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=72 EVAL_BS=6
|
||||
|
||||
export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=64 BEAM_LOCAL_MAX=512
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
export PYTHONPATH="."
|
||||
export MODEL="bert"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=6
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=72 EVAL_BS=6
|
||||
|
||||
export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=64 BEAM_LOCAL_MAX=512
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
export PYTHONPATH="."
|
||||
export MODEL="bert"
|
||||
export SUBMISSION_PLATFORM="tinybox_green"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=6
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=72 EVAL_BS=6
|
||||
|
||||
export BEAM=4 BEAM_UOPS_MAX=2000 BEAM_UPCAST_MAX=64 BEAM_LOCAL_MAX=512
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
export PYTHONPATH="."
|
||||
export MODEL="bert"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=6
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=72 EVAL_BS=6
|
||||
|
||||
export BEAM=3
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
export PYTHONPATH="."
|
||||
export MODEL="bert"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=6
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=72 EVAL_BS=6
|
||||
|
||||
export BEAM=3
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
export PYTHONPATH="."
|
||||
export MODEL="bert"
|
||||
export SUBMISSION_PLATFORM="tinybox_red"
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=66 EVAL_BS=6
|
||||
export DEFAULT_FLOAT="HALF" GPUS=6 BS=72 EVAL_BS=6
|
||||
|
||||
export BEAM=3
|
||||
export IGNORE_JIT_FIRST_BEAM=1
|
||||
|
||||
@@ -3,7 +3,7 @@ from tinygrad.codegen.kernel import UOps, MemOp, UOp
|
||||
from tinygrad.ops import BinaryOps, UnaryOps
|
||||
from tinygrad.dtype import DType, dtypes
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.shape.symbolic import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
|
||||
from tinygrad.ops import Variable, NumNode, MulNode, DivNode, ModNode, LtNode, SumNode, AndNode
|
||||
import functools
|
||||
import math
|
||||
from collections import defaultdict
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import List, Any, Dict, cast, Optional, Tuple
|
||||
from tinygrad.helpers import init_c_var, round_up
|
||||
from tinygrad.device import Buffer, BufferOptions
|
||||
from tinygrad.device import Compiled, Device
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.ops import Variable
|
||||
from tinygrad.runtime.ops_hsa import HSADevice, PROFILE, Profiler
|
||||
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
||||
from tinygrad.engine.jit import MultiGraphRunner, GraphException
|
||||
|
||||
@@ -6,7 +6,7 @@ from dataclasses import dataclass
|
||||
from tinygrad.helpers import dedup, prod
|
||||
from tinygrad.ops import ReduceOps, UnaryOps, BinaryOps, TernaryOps, UOp, UOps, pretty_print
|
||||
from tinygrad.dtype import ImageDType, PtrDType, dtypes, DType, ConstType
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.ops import Variable, sint
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
||||
# these ops are deleted after AST is UOp
|
||||
|
||||
@@ -15,7 +15,7 @@ from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, Buf
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.ops import Variable
|
||||
inf, nan = float('inf'), float('nan')
|
||||
from tinygrad.codegen.kernel import Opt, OptOps
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, Buf
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.ops import Variable
|
||||
inf, nan = float('inf'), float('nan')
|
||||
from tinygrad.codegen.kernel import Opt, OptOps
|
||||
|
||||
|
||||
@@ -1,12 +1,13 @@
|
||||
# stuff needed to unpack a kernel
|
||||
from typing import Tuple
|
||||
from extra.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, BufferOps, MemBuffer, ConstBuffer, MetaOps
|
||||
from tinygrad import Variable
|
||||
from tinygrad.codegen.kernel import Opt, OptOps
|
||||
from tinygrad.ops import UOp, UOps, KernelInfo
|
||||
from tinygrad.dtype import dtypes, PtrDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.shape.symbolic import Variable, NumNode
|
||||
from tinygrad.ops import NumNode
|
||||
inf, nan = float('inf'), float('nan')
|
||||
|
||||
# kernel unpacker
|
||||
|
||||
@@ -12,7 +12,7 @@ from tinygrad.ops import LazyOp, TernaryOps, BinaryOps, UnaryOps, ReduceOps, Buf
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.ops import Variable
|
||||
inf, nan = float('inf'), float('nan')
|
||||
from tinygrad.codegen.kernel import Opt, OptOps
|
||||
|
||||
|
||||
@@ -2,8 +2,7 @@ import unittest
|
||||
import numpy as np
|
||||
|
||||
from tinygrad.helpers import BEAM, Timing, CI
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad import Variable, Tensor
|
||||
from tinygrad.nn import Conv2d
|
||||
|
||||
def rand(*shape):
|
||||
|
||||
@@ -6,7 +6,7 @@ from extra.optimization.helpers import load_worlds, ast_str_to_ast
|
||||
from tinygrad.helpers import prod, tqdm
|
||||
from tinygrad.ops import UOp, UOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import sym_infer, Node
|
||||
from tinygrad.ops import sym_infer, Node
|
||||
|
||||
class MovementOps(Enum): RESHAPE = auto(); PERMUTE = auto(); EXPAND = auto(); PAD = auto(); SHRINK = auto(); STRIDE = auto(); AS_STRIDED = auto() # noqa: E702
|
||||
|
||||
|
||||
3
test/external/external_benchmark_schedule.py
vendored
3
test/external/external_benchmark_schedule.py
vendored
@@ -1,6 +1,6 @@
|
||||
from typing import List
|
||||
from extra.models.resnet import ResNet50
|
||||
from tinygrad import Tensor, Device
|
||||
from tinygrad import Tensor, Device, nn
|
||||
from tinygrad.helpers import Profiling, Timing, getenv, BEAM, NOOPT, DEBUG, Context, ansilen
|
||||
from tinygrad.ops import UOps
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
@@ -11,6 +11,7 @@ from tinygrad.engine.search import beam_search, bufs_from_lin
|
||||
|
||||
if __name__ == "__main__":
|
||||
mdl = ResNet50()
|
||||
for p in nn.state.get_parameters(mdl): p.replace(Tensor.empty(p.shape))
|
||||
img = Tensor.empty(64, 3, 224, 224)
|
||||
|
||||
PROFILE = getenv("PROFILE", 0)
|
||||
|
||||
3
test/external/fuzz_symbolic.py
vendored
3
test/external/fuzz_symbolic.py
vendored
@@ -1,7 +1,8 @@
|
||||
import itertools
|
||||
import random
|
||||
from tinygrad import Variable
|
||||
from tinygrad.helpers import DEBUG
|
||||
from tinygrad.shape.symbolic import Variable, NumNode
|
||||
from tinygrad.ops import NumNode
|
||||
random.seed(42)
|
||||
|
||||
def add_v(expr, rng=None):
|
||||
|
||||
2
test/external/fuzz_uops.py
vendored
2
test/external/fuzz_uops.py
vendored
@@ -7,7 +7,7 @@ from tinygrad.ops import END_FOR_UOP, UOp, print_uops
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.helpers import DEBUG, colored
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.ops import Variable
|
||||
from tinygrad.tensor import _to_np_dtype
|
||||
from test.external.fuzz_schedule import FUZZ_SCHEDULE_MAX_PATHS, find_all_toposorts
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from tinygrad.engine.realize import Runner
|
||||
from tinygrad.dtype import ConstType, DType
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.helpers import Context, CI, OSX, getenv
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.ops import sint
|
||||
|
||||
def derandomize_model(model):
|
||||
with Context(GRAPH=0):
|
||||
@@ -32,7 +32,6 @@ def assert_jit_cache_len(fxn, expected_len):
|
||||
assert len(fxn.jit_cache[0].prg.jit_cache) == expected_len
|
||||
|
||||
def is_dtype_supported(dtype: DType, device: str = Device.DEFAULT):
|
||||
if dtype == dtypes.pyint and device != "PYTHON": return False
|
||||
if dtype == dtypes.bfloat16:
|
||||
# NOTE: this requires bf16 buffer support
|
||||
return device in {"AMD"} or (device in {"CUDA", "NV"} and not CI and not getenv("PTX"))
|
||||
|
||||
@@ -3,9 +3,8 @@ import numpy as np
|
||||
from tinygrad.nn import optim
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
from tinygrad import Tensor, Device, GlobalCounters, dtypes
|
||||
from tinygrad import Tensor, Device, GlobalCounters, dtypes, Variable
|
||||
from tinygrad.helpers import CI, Context
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from extra.lr_scheduler import OneCycleLR
|
||||
from test.helpers import derandomize_model, is_dtype_supported
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ from test.helpers import is_dtype_supported, rand_for_dtype
|
||||
settings.register_profile("my_profile", max_examples=200, deadline=None, derandomize=getenv("DERANDOMIZE_CI", False))
|
||||
settings.load_profile("my_profile")
|
||||
|
||||
core_dtypes = list([v for k,v in DTYPES_DICT.items() if k != 'pyint'])
|
||||
core_dtypes = list(DTYPES_DICT.values())
|
||||
if Device.DEFAULT == "CPU": core_dtypes.remove(dtypes.bfloat16) # NOTE: this is for teenygrad, don't remove
|
||||
dtype_ints = [dt for dt in core_dtypes if dtypes.is_int(dt) and is_dtype_supported(dt)]
|
||||
dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and is_dtype_supported(dt)]
|
||||
@@ -20,7 +20,7 @@ dtype_floats = [dt for dt in core_dtypes if dtypes.is_float(dt) and is_dtype_sup
|
||||
def get_available_cast_dtypes(dtype: DType) -> List[DType]:
|
||||
if not is_dtype_supported(dtype): return []
|
||||
# dont cast internal dtypes
|
||||
return [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_") and k != 'pyint']
|
||||
return [v for k, v in DTYPES_DICT.items() if v != dtype and is_dtype_supported(v) and not k.startswith("_")]
|
||||
|
||||
def _test_to_np(a:Tensor, np_dtype, target):
|
||||
if DEBUG >= 2: print(a)
|
||||
|
||||
@@ -10,7 +10,7 @@ from tinygrad.ops import UOp, UOps, BinaryOps, TernaryOps, UnaryOps
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
# from tinygrad.shape.symbolic import Variable
|
||||
# from tinygrad.ops import Variable
|
||||
from tinygrad.tensor import Tensor, _to_np_dtype
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule
|
||||
from tinygrad.engine.realize import run_schedule, lower_schedule, CompiledRunner
|
||||
|
||||
@@ -147,7 +147,7 @@ class TestProfiler(unittest.TestCase):
|
||||
|
||||
transfer_node_1 = helper_profile_filter_node(profile, name=f"{Device.DEFAULT} -> {Device.DEFAULT}:1")[0]
|
||||
helper_validate_node(transfer_node_1, profile=profile, pid_name=Device.DEFAULT, tid_name="DMA")
|
||||
assert 80 < transfer_node_1['dur'] < (16000 if CI else 1400), f"Duration is not in the range: {transfer_node_1['dur']}"
|
||||
assert 80 < transfer_node_1['dur'] < (20000 if CI else 1400), f"Duration is not in the range: {transfer_node_1['dur']}"
|
||||
|
||||
@unittest.skipIf(MOCKGPU and Device.DEFAULT == "AMD", "AMD mockgpu with indirect buffers does not support queue wait interrupts")
|
||||
def test_profile_deps(self):
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import unittest
|
||||
import numpy as np
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad import Tensor, Variable
|
||||
|
||||
class TestSample(unittest.TestCase):
|
||||
def test_sample(self):
|
||||
|
||||
@@ -7,19 +7,17 @@ import numpy as np
|
||||
import functools
|
||||
from typing import List, Optional, Union, cast
|
||||
|
||||
from tinygrad import nn, dtypes
|
||||
from tinygrad.device import Device
|
||||
from tinygrad import nn, dtypes, Device, Tensor
|
||||
from tinygrad.dtype import DType, PtrDType
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.ops import BinaryOps, MetaOps, UOp, UnaryOps, UOps, graph_rewrite
|
||||
from tinygrad.helpers import CI, DEBUG, FUSE_ARANGE, GlobalCounters, flatten, getenv, SPLIT_REDUCEOP, unwrap, prod
|
||||
from tinygrad.codegen.kernel import Kernel, verify_ast
|
||||
from tinygrad.engine.schedule import BUF_LIMIT, create_schedule, reduceop_fusor, st_fixup
|
||||
from tinygrad.engine.realize import CompiledRunner, run_schedule
|
||||
from test.helpers import ast_const, is_dtype_supported, Context, timeit
|
||||
from tinygrad.engine.lazy import LazyBuffer, view_supported_devices
|
||||
from test.helpers import ast_const, is_dtype_supported, Context, timeit
|
||||
from extra.models.llama import precompute_freqs_cis
|
||||
|
||||
class KernelCountException(Exception): pass
|
||||
@@ -876,11 +874,16 @@ class TestSchedule(unittest.TestCase):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(4, 12, 64, 64).realize()
|
||||
out = x.softmax()
|
||||
# run_schedule(check_schedule(out, 2))
|
||||
run_schedule(check_schedule(out, 3))
|
||||
expected = (x_exp:=np.exp(x.numpy()-x.numpy().max(-1, keepdims=True)))/x_exp.sum(-1, keepdims=True)
|
||||
np.testing.assert_allclose(out.numpy(), expected, atol=1e-4, rtol=1e-4)
|
||||
|
||||
def test_softmax_backward(self):
|
||||
Tensor.manual_seed(0)
|
||||
x = Tensor.randn(4, 12, 64, 64, requires_grad=True).realize()
|
||||
x.softmax().sum().backward()
|
||||
run_schedule(check_schedule(x.grad, 4))
|
||||
|
||||
# changed by: multireduce spec
|
||||
def test_layernorm_onelayer_fusion(self):
|
||||
Tensor.manual_seed(0)
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
import unittest
|
||||
|
||||
from test.helpers import assert_jit_cache_len
|
||||
from tinygrad.engine.jit import TinyJit
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad import Variable, Tensor, TinyJit
|
||||
import numpy as np
|
||||
|
||||
class TestSymbolicJit(unittest.TestCase):
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import unittest
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad import Variable
|
||||
from tinygrad.helpers import getenv
|
||||
from tinygrad.tensor import Tensor
|
||||
from examples.gpt2 import Attention
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import unittest
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from tinygrad.shape.symbolic import Variable, NumNode
|
||||
from tinygrad import Variable
|
||||
from tinygrad.ops import NumNode
|
||||
from tinygrad.tensor import Tensor
|
||||
|
||||
class TestSymbolic(unittest.TestCase):
|
||||
|
||||
@@ -172,10 +172,10 @@ class TestGraphRewrite(unittest.TestCase):
|
||||
self.assertEqual(nout.src[1].arg, 3.0)
|
||||
|
||||
def test_consts_go_last(self):
|
||||
a = UOp.define_var('a', dtypes.int, 0, 1)
|
||||
b = UOp.define_var('b', dtypes.int, 0, 1)
|
||||
c = UOp.define_var('c', dtypes.int, 0, 1)
|
||||
d = UOp.define_var('d', dtypes.int, 0, 1)
|
||||
a = UOp.variable('a', 0, 1)
|
||||
b = UOp.variable('b', 0, 1)
|
||||
c = UOp.variable('c', 0, 1)
|
||||
d = UOp.variable('d', 0, 1)
|
||||
outs = [2+a, 2+a+d+3+b+c+4, UOp(UOps.ALU, a.dtype, src=(a.const_like(2), a), arg=BinaryOps.ADD), (4+d)+c+(2+a)+b]
|
||||
for out in outs:
|
||||
sink = graph_rewrite(out, sym)
|
||||
@@ -196,7 +196,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
self.assertEqual(out.arg, 3.0)
|
||||
|
||||
def test_where_same_fold(self):
|
||||
v = UOp.define_var('tmp', dtypes.int, 0, 1)
|
||||
v = UOp.variable('tmp', 0, 1)
|
||||
c0 = UOp(UOps.CONST, dtypes.int, arg=0)
|
||||
vc = UOp(UOps.ALU, dtypes.bool, (v, c0), BinaryOps.CMPNE)
|
||||
c1 = UOp(UOps.CONST, dtypes.float, arg=1.0)
|
||||
@@ -290,7 +290,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
for i in [2, 4, 8]:
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
|
||||
acc = UOp.define_var('acc', dtypes.half.vec(i), 0, 1)
|
||||
acc = UOp.variable('acc', 0, 1, dtypes.half.vec(i))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (vec, var, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
assert_equiv_uops(uops[0], acc)
|
||||
@@ -299,7 +299,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
for i in [2, 4, 8]:
|
||||
var = UOp(UOps.DEFINE_VAR, dtypes.half.vec(i))
|
||||
vec = UOp(UOps.VECTORIZE, dtypes.half.vec(i), tuple(UOp.const(dtypes.half, 0.0) for _ in range(i)))
|
||||
acc = UOp.define_var('acc', dtypes.half.vec(i), 0, 1)
|
||||
acc = UOp.variable('acc', 0, 1, dtypes.half.vec(i))
|
||||
wmma = UOp(UOps.WMMA, dtypes.half.vec(i), (var, vec, acc))
|
||||
uops = to_uops_list([wmma])
|
||||
assert_equiv_uops(uops[0], acc)
|
||||
@@ -366,7 +366,7 @@ class TestUOpGraph(unittest.TestCase):
|
||||
self.assertEqual(len([x for x in uops if x.op is UOps.CAST]), 1)
|
||||
|
||||
def test_depth_2_const_fold(self):
|
||||
v = UOp.define_var("tmp", dtypes.int, 0, 1)
|
||||
v = UOp.variable("tmp", 0, 1)
|
||||
c2 = UOp(UOps.CONST, dtypes.int, arg=2)
|
||||
c4 = UOp(UOps.CONST, dtypes.int, arg=4)
|
||||
vc = UOp(UOps.ALU, dtypes.int, (v, c2), BinaryOps.ADD)
|
||||
@@ -620,8 +620,8 @@ class TestLoadStoreFolder(unittest.TestCase):
|
||||
|
||||
def test_simple_load_dont_fold_different_gated(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
|
||||
gate = UOp.define_var("g1", dtypes.bool, False, True)
|
||||
gate2 = UOp.define_var("g2", dtypes.bool, False, True)
|
||||
gate = UOp.variable("g1", False, True, dtypes.bool)
|
||||
gate2 = UOp.variable("g2", False, True, dtypes.bool)
|
||||
load = [UOp(UOps.LOAD, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
|
||||
sink = UOp(UOps.VECTORIZE, dtypes.float.vec(len(load)), tuple(load))
|
||||
sink = float4_rewrite(sink)
|
||||
@@ -636,7 +636,7 @@ class TestLoadStoreFolder(unittest.TestCase):
|
||||
|
||||
def test_simple_store_fold_gate(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
|
||||
gate = UOp.define_var("g1", dtypes.bool, False, True)
|
||||
gate = UOp.variable("g1", False, True, dtypes.bool)
|
||||
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate)) for i in range(4)]
|
||||
sink = UOp(UOps.SINK, dtypes.void, tuple(load))
|
||||
sink = float4_rewrite(sink)
|
||||
@@ -647,8 +647,8 @@ class TestLoadStoreFolder(unittest.TestCase):
|
||||
|
||||
def test_simple_store_dont_fold(self):
|
||||
buf = UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float))
|
||||
gate = UOp.define_var("g1", dtypes.bool, False, True)
|
||||
gate2 = UOp.define_var("g2", dtypes.bool, False, True)
|
||||
gate = UOp.variable("g1", False, True, dtypes.bool)
|
||||
gate2 = UOp.variable("g2", False, True, dtypes.bool)
|
||||
load = [UOp(UOps.STORE, dtypes.float, (buf, UOp.const(dtypes.int, i), UOp.const(dtypes.float, i), gate if i == 0 else gate2)) for i in range(4)]
|
||||
sink = UOp(UOps.SINK, dtypes.void, tuple(load))
|
||||
sink = float4_rewrite(sink)
|
||||
|
||||
@@ -12,7 +12,6 @@ from tinygrad.engine.schedule import create_schedule, reduceop_fusor
|
||||
from tinygrad.engine.realize import CompiledRunner, lower_schedule_item, get_kernel
|
||||
from tinygrad.codegen.linearize import linearize_uop
|
||||
from tinygrad.codegen.uopgraph import full_graph_rewrite, sym
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from test.helpers import is_dtype_supported, assert_equiv_uops
|
||||
|
||||
def to_uops_list(u:List[UOp], opts=None, skip_check=False) -> List[UOp]: return linearize_uop(full_graph_rewrite(UOp.sink(*u), opts), skip_check)
|
||||
@@ -360,7 +359,7 @@ class TestUOpMethod(unittest.TestCase):
|
||||
assert (add < mul) or (mul < add), "add and mul with same src should have an order"
|
||||
|
||||
def test_uop_variables(self):
|
||||
a = Variable("a", 1, 10)
|
||||
a = UOp.variable("a", 1, 10)
|
||||
uop_var = UOp.const(dtypes.int, a)
|
||||
st_var = UOp(UOps.LOAD, dtypes.float, (UOp(UOps.DEFINE_GLOBAL, PtrDType(dtypes.float), (), 0),
|
||||
ShapeTracker.from_shape((2, a)).to_uop()))
|
||||
|
||||
@@ -94,20 +94,20 @@ class TestFoldingAndReduction(unittest.TestCase):
|
||||
|
||||
class TestModuloAndDivisionFolding(unittest.TestCase):
|
||||
def test_full_graph_rewrite_modulo_folding_with_define_var(self):
|
||||
x_var_uop = UOp.define_var('x', dtypes.int32, 0, 100)
|
||||
x_var_uop = UOp.variable('x', 0, 100)
|
||||
optimized_mod_uop = apply_rewrite(((x_var_uop * 4) + 2) % 4)
|
||||
self.assertEqual(optimized_mod_uop.op, UOps.CONST)
|
||||
self.assertEqual(optimized_mod_uop.arg, 2)
|
||||
|
||||
def test_full_graph_rewrite_division_folding_with_define_var(self):
|
||||
n_var_uop = UOp.define_var('n', dtypes.int32, 1, 1000)
|
||||
n_var_uop = UOp.variable('n', 1, 1000)
|
||||
optimized_div_uop = apply_rewrite((n_var_uop * 6) // 3)
|
||||
self.assertEqual(optimized_div_uop.op, UOps.ALU)
|
||||
self.assertEqual(optimized_div_uop.arg, BinaryOps.MUL)
|
||||
self.assertEqual(optimized_div_uop.src[1].arg, 2)
|
||||
|
||||
def test_full_graph_rewrite_complex_mod_div_folding(self):
|
||||
k_var_uop = UOp.define_var('k', dtypes.int32, 0, 50)
|
||||
k_var_uop = UOp.variable('k', 0, 50)
|
||||
optimized_div_uop = apply_rewrite(((k_var_uop * 12 + 8) % 6) // 2)
|
||||
self.assertEqual(optimized_div_uop.op, UOps.CONST)
|
||||
self.assertEqual(optimized_div_uop.arg, 1)
|
||||
@@ -124,17 +124,17 @@ class TestModuloAndDivisionFolding(unittest.TestCase):
|
||||
if opt.op is UOps.VECTORIZE: self.assertFalse(all_same(opt.src))
|
||||
|
||||
def test_full_graph_rewrite_modulo_large_divisor(self):
|
||||
x_var_uop = UOp.define_var('x', dtypes.int32, 1, 5)
|
||||
x_var_uop = UOp.variable('x', 1, 5)
|
||||
self.assertIs(apply_rewrite(x_var_uop % 10), x_var_uop)
|
||||
|
||||
def test_full_graph_rewrite_division_with_remainder(self):
|
||||
x_var_uop = UOp.define_var('x', dtypes.int32, 7, 9)
|
||||
x_var_uop = UOp.variable('x', 7, 9)
|
||||
optimized_sink = apply_rewrite(x_var_uop // 2)
|
||||
for x_value in range(7, 10):
|
||||
self.assertEqual(x_value // 2, evaluate_uop(optimized_sink, {'x': x_value}))
|
||||
|
||||
def test_full_graph_rewrite_complex_mod_div_expression(self):
|
||||
x_var_uop = UOp.define_var('x', dtypes.int32, 1, 10)
|
||||
x_var_uop = UOp.variable('x', 1, 10)
|
||||
optimized_sink = apply_rewrite(((x_var_uop * 5) % 3) // 2)
|
||||
for x_value in range(1, 11):
|
||||
original_result = ((x_value * 5) % 3) // 2
|
||||
@@ -152,14 +152,14 @@ class TestEdgeCasesAndSpecialOperations(unittest.TestCase):
|
||||
|
||||
@unittest.skip("broken")
|
||||
def test_full_graph_rewrite_modulo_negative_dividend(self):
|
||||
x_var_uop = UOp.define_var('x', dtypes.int32, -5, -1)
|
||||
x_var_uop = UOp.variable('x', -5, -1)
|
||||
optimized_sink = full_graph_rewrite((x_var_uop % 3).sink())
|
||||
for x_value in range(-5, 0):
|
||||
self.assertEqual(x_value % 3, evaluate_uop(optimized_sink.src[0], {'x': x_value}))
|
||||
|
||||
@unittest.skip("broken")
|
||||
def test_full_graph_rewrite_division_negative_divisor(self):
|
||||
x_var_uop = UOp.define_var('x', dtypes.int32, 1, 5)
|
||||
x_var_uop = UOp.variable('x', 1, 5)
|
||||
optimized_sink = full_graph_rewrite((x_var_uop // -2).sink())
|
||||
for x_value in range(1, 6):
|
||||
self.assertEqual(x_value // -2, evaluate_uop(optimized_sink.src[0], {'x': x_value}))
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
import gzip, unittest
|
||||
from PIL import Image
|
||||
from tinygrad import Variable
|
||||
from tinygrad.helpers import Context, ContextVar
|
||||
from tinygrad.helpers import merge_dicts, strip_parens, prod, round_up, fetch, fully_flatten, from_mv, to_mv
|
||||
from tinygrad.tensor import get_shape
|
||||
from tinygrad.codegen.lowerer import get_contraction
|
||||
from tinygrad.shape.symbolic import Variable, NumNode
|
||||
from tinygrad.ops import NumNode
|
||||
import numpy as np
|
||||
|
||||
VARIABLE = ContextVar("VARIABLE", 0)
|
||||
|
||||
@@ -4,13 +4,14 @@ import numpy as np
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, View
|
||||
from tinygrad.shape.symbolic import Variable, NumNode
|
||||
from tinygrad import Variable
|
||||
from tinygrad.ops import NumNode
|
||||
from tinygrad.ops import UOp, UOps, graph_rewrite
|
||||
from tinygrad.codegen.uopgraph import sym
|
||||
from itertools import product
|
||||
|
||||
def shapetracker_getitem(st:ShapeTracker, val:int):
|
||||
idx, valid = st.reshape((st.size,)).to_indexed_uops([UOp.const(dtypes.pyint, val)])
|
||||
idx, valid = st.reshape((st.size,)).to_indexed_uops([UOp.const(dtypes.int, val)])
|
||||
idx, valid = graph_rewrite(idx, sym), graph_rewrite(valid, sym)
|
||||
assert idx.op is UOps.CONST and valid.op is UOps.CONST
|
||||
return idx.arg, valid.arg
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import List
|
||||
from tinygrad.helpers import prod
|
||||
from tinygrad.shape.view import View
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad import Variable
|
||||
from test.unit.test_shapetracker import shapetracker_getitem
|
||||
|
||||
class MultiShapeTracker:
|
||||
|
||||
@@ -32,7 +32,7 @@ def render(uop:UOp) -> str:
|
||||
return fxn.split("val0 = ")[1].split(";")[0]
|
||||
|
||||
def Special(expr, nmax): return UOp(UOps.SPECIAL, dtypes.int, (), (expr, nmax))
|
||||
def Variable(expr, nmin, nmax): return UOp.define_var(expr, dtypes.int, nmin, nmax)
|
||||
def Variable(expr, nmin, nmax): return UOp.variable(expr, nmin, nmax)
|
||||
def Range(n, nmax):
|
||||
return UOp(UOps.RANGE, dtypes.int, arg=(n, True), src=(UOp.const(dtypes.int, 0), UOp.const(dtypes.int, nmax),))
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ class TestUOpResolve(unittest.TestCase):
|
||||
self.assertEqual((8 * UOp.const(dtypes.int, 4)).ssimplify(), 32)
|
||||
|
||||
def test_ambiguous_less_than(self):
|
||||
u = UOp.define_var("i", dtypes.pyint, 1, 10)
|
||||
u = UOp.variable("i", 1, 10)
|
||||
self.assertTrue(resolve(u < 4))
|
||||
self.assertFalse(resolve(u < 4, False))
|
||||
self.assertTrue(resolve(u < 11, False))
|
||||
@@ -56,64 +56,64 @@ class TestUOpResolve(unittest.TestCase):
|
||||
self.assertEqual(float(u), 11.5)
|
||||
|
||||
def test_var_cmp_t(self):
|
||||
u = UOp.define_var("i", dtypes.pyint, 1, 10) < 20
|
||||
u = UOp.variable("i", 1, 10) < 20
|
||||
self.assertTrue(u)
|
||||
|
||||
def test_var_cmp_t2(self):
|
||||
u = UOp.define_var("i", dtypes.pyint, 1, 10)//2 < 20
|
||||
u = UOp.variable("i", 1, 10)//2 < 20
|
||||
self.assertTrue(u)
|
||||
|
||||
def test_var_cmp_f(self):
|
||||
u = UOp.define_var("i", dtypes.pyint, 1, 10) < 1
|
||||
u = UOp.variable("i", 1, 10) < 1
|
||||
self.assertFalse(u)
|
||||
|
||||
def test_var_cmp_f2(self):
|
||||
u = UOp.define_var("i", dtypes.pyint, 1, 10) > 11
|
||||
u = UOp.variable("i", 1, 10) > 11
|
||||
self.assertFalse(u)
|
||||
|
||||
def test_or_true(self):
|
||||
u = UOp.define_var("b", dtypes.bool, False, True) | True
|
||||
u = UOp.variable("b", False, True, dtypes.bool) | True
|
||||
self.assertTrue(u)
|
||||
|
||||
def test_or_false(self):
|
||||
with self.assertRaises(ValueError):
|
||||
u = UOp.define_var("b", dtypes.bool, False, True) | False
|
||||
u = UOp.variable("b", False, True, dtypes.bool) | False
|
||||
self.assertTrue(u)
|
||||
|
||||
def test_and_false(self):
|
||||
u = UOp.define_var("b", dtypes.bool, False, True) & False
|
||||
u = UOp.variable("b", False, True, dtypes.bool) & False
|
||||
self.assertFalse(u)
|
||||
|
||||
def test_max(self):
|
||||
x = UOp.define_var("x", dtypes.pyint, 1, 10)
|
||||
y = UOp.define_var("y", dtypes.pyint, 5, 10)
|
||||
x = UOp.variable("x", 1, 10)
|
||||
y = UOp.variable("y", 5, 10)
|
||||
u = x.max(y)
|
||||
self.assertTrue(u < 20)
|
||||
self.assertFalse(u < 3)
|
||||
|
||||
def test_x_lt_x(self):
|
||||
x = UOp.define_var("i", dtypes.pyint, 1, 10)
|
||||
x = UOp.variable("i", 1, 10)
|
||||
self.assertFalse(x < x)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_x_lt_xp1(self):
|
||||
x = UOp.define_var("i", dtypes.pyint, 1, 10)
|
||||
x = UOp.variable("i", 1, 10)
|
||||
self.assertTrue(x < (x+1))
|
||||
|
||||
def test_and_true(self):
|
||||
with self.assertRaises(ValueError):
|
||||
u = UOp.define_var("b", dtypes.bool, False, True) & True
|
||||
u = UOp.variable("b", False, True, dtypes.bool) & True
|
||||
self.assertFalse(u)
|
||||
|
||||
@unittest.expectedFailure
|
||||
def test_var_cmp_range(self):
|
||||
v = UOp.define_var("i", dtypes.pyint, 1, 10)
|
||||
v = UOp.variable("i", 1, 10)
|
||||
u = (v > 4) | (v < 6)
|
||||
self.assertTrue(u)
|
||||
|
||||
def test_var_cmp_assert(self):
|
||||
with self.assertRaises(ValueError):
|
||||
u = UOp.define_var("i", dtypes.pyint, 1, 10) < 5
|
||||
u = UOp.variable("i", 1, 10) < 5
|
||||
self.assertFalse(u)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
#!/usr/bin/env python
|
||||
import unittest, pickle
|
||||
from typing import Tuple
|
||||
#from tinygrad.shape.symbolic import MulNode, SumNode, Variable, NumNode, LtNode, ModNode, Node, sym_render, sym_infer, create_lt_node, create_ge_node
|
||||
#from tinygrad.ops import MulNode, SumNode, Variable, NumNode, LtNode, ModNode, Node, sym_render, sym_infer, create_lt_node, create_ge_node
|
||||
|
||||
# TODO: fix all the @unittest.expectedFailure
|
||||
|
||||
@@ -12,7 +12,7 @@ from tinygrad.dtype import dtypes, PtrDType, ConstType
|
||||
from tinygrad.codegen.linearize import linearize_uop
|
||||
from tinygrad.codegen.uopgraph import full_graph_rewrite
|
||||
from tinygrad.ops import BinaryOps, UOp, UOps, print_uops
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad import Variable
|
||||
import functools
|
||||
|
||||
def render(self) -> Tuple[str, ConstType, ConstType]:
|
||||
|
||||
@@ -19,36 +19,36 @@ class TestVminVmaxProperties(unittest.TestCase):
|
||||
|
||||
def test_vmin_vmax_addition_with_variable(self):
|
||||
# vmin and vmax for addition with a variable
|
||||
x = UOp.define_var('x', dtypes.int32, 10, 20)
|
||||
x = UOp.variable('x', 10, 20)
|
||||
uop = x + 5
|
||||
self.assertEqual(uop.vmin, 15)
|
||||
self.assertEqual(uop.vmax, 25)
|
||||
|
||||
def test_vmin_vmax_multiplication_with_variable(self):
|
||||
# vmin and vmax for multiplication with a variable
|
||||
x = UOp.define_var('x', dtypes.int32, -3, 4)
|
||||
x = UOp.variable('x', -3, 4)
|
||||
uop = x * 2
|
||||
self.assertEqual(uop.vmin, -6)
|
||||
self.assertEqual(uop.vmax, 8)
|
||||
|
||||
def test_vmin_vmax_with_negative_multiplication(self):
|
||||
# vmin and vmax when multiplying by a negative number
|
||||
x = UOp.define_var('x', dtypes.int32, 2, 5)
|
||||
x = UOp.variable('x', 2, 5)
|
||||
uop = x * -3
|
||||
self.assertEqual(uop.vmin, -15)
|
||||
self.assertEqual(uop.vmax, -6)
|
||||
|
||||
def test_vmin_vmax_nested_min_max(self):
|
||||
# vmin and vmax with nested min/max operations
|
||||
x = UOp.define_var('x', dtypes.int32, 0, 10)
|
||||
x = UOp.variable('x', 0, 10)
|
||||
uop = x.max(5).min(8)
|
||||
self.assertEqual(uop.vmin, 5)
|
||||
self.assertEqual(uop.vmax, 8)
|
||||
|
||||
def test_vmin_vmax_where(self):
|
||||
x = UOp.define_var('x', dtypes.int, 0, 10)
|
||||
y = UOp.define_var('y', dtypes.int, 1, 11)
|
||||
z = UOp.define_var('z', dtypes.int, 2, 12)
|
||||
x = UOp.variable('x', 0, 10)
|
||||
y = UOp.variable('y', 1, 11)
|
||||
z = UOp.variable('z', 2, 12)
|
||||
uop = x.lt(5).where(y, z)
|
||||
self.assertEqual(uop.vmin, 1)
|
||||
self.assertEqual(uop.vmax, 12)
|
||||
@@ -56,21 +56,21 @@ class TestVminVmaxProperties(unittest.TestCase):
|
||||
class TestVminVmaxDivMod(unittest.TestCase):
|
||||
def test_vmin_vmax_division_positive(self):
|
||||
# vmin and vmax for division of a variable by a positive constant
|
||||
x = UOp.define_var('x', dtypes.int32, 10, 20)
|
||||
x = UOp.variable('x', 10, 20)
|
||||
uop = x // 2
|
||||
self.assertEqual(uop.vmin, 5)
|
||||
self.assertEqual(uop.vmax, 10)
|
||||
|
||||
def test_vmin_vmax_division_negative(self):
|
||||
# vmin and vmax for division of a variable by a negative constant
|
||||
x = UOp.define_var('x', dtypes.int32, 10, 20)
|
||||
x = UOp.variable('x', 10, 20)
|
||||
uop = x // -2
|
||||
self.assertEqual(uop.vmin, -10)
|
||||
self.assertEqual(uop.vmax, -5)
|
||||
|
||||
def test_vmin_vmax_mod_positive(self):
|
||||
# vmin and vmax for modulo of a variable by a positive constant
|
||||
x = UOp.define_var('x', dtypes.int32, 10, 20)
|
||||
x = UOp.variable('x', 10, 20)
|
||||
uop = x % 3
|
||||
self.assertEqual(uop.vmin, 0)
|
||||
self.assertEqual(uop.vmax, 2)
|
||||
@@ -78,21 +78,21 @@ class TestVminVmaxDivMod(unittest.TestCase):
|
||||
@unittest.skip("broken")
|
||||
def test_vmin_vmax_mod_negative(self):
|
||||
# vmin and vmax for modulo of a variable by a negative constant
|
||||
x = UOp.define_var('x', dtypes.int32, 10, 20)
|
||||
x = UOp.variable('x', 10, 20)
|
||||
uop = x % -3
|
||||
self.assertEqual(uop.vmin, -2)
|
||||
self.assertEqual(uop.vmax, 0)
|
||||
|
||||
def test_vmin_vmax_division_with_mixed_range(self):
|
||||
# vmin and vmax for division of a variable with a range crossing zero
|
||||
x = UOp.define_var('x', dtypes.int32, -10, 10)
|
||||
x = UOp.variable('x', -10, 10)
|
||||
uop = x // 3
|
||||
self.assertEqual(uop.vmin, -4) # -10//3 = -4
|
||||
self.assertEqual(uop.vmax, 3) # 10//3 = 3
|
||||
|
||||
def test_vmin_vmax_mod_with_mixed_range(self):
|
||||
# vmin and vmax for modulo of a variable with a range crossing zero
|
||||
x = UOp.define_var('x', dtypes.int32, -10, 10)
|
||||
x = UOp.variable('x', -10, 10)
|
||||
uop = x % 4
|
||||
self.assertEqual(uop.vmin, 0) # modulo always positive or zero when divisor is positive
|
||||
self.assertEqual(uop.vmax, 3) # max possible mod is 3 when dividing by 4
|
||||
@@ -146,26 +146,26 @@ class TestConstFactor(unittest.TestCase):
|
||||
|
||||
def test_const_factor_with_variable(self):
|
||||
# const_factor for an expression involving a variable
|
||||
x = UOp.define_var('x', dtypes.int32, 10, 20)
|
||||
x = UOp.variable('x', 10, 20)
|
||||
uop = x * 3
|
||||
self.assertEqual(uop.const_factor(), 3)
|
||||
|
||||
def test_const_factor_division(self):
|
||||
# const_factor for an expression with division
|
||||
x = UOp.define_var('x', dtypes.int32, 10, 20)
|
||||
x = UOp.variable('x', 10, 20)
|
||||
uop = x // 4
|
||||
self.assertEqual(uop.const_factor(), 1) # Division reduces the const_factor to 1
|
||||
|
||||
def test_const_factor_multiplication_of_var_and_const(self):
|
||||
# const_factor for multiplication of a variable and a constant
|
||||
x = UOp.define_var('x', dtypes.int32, 6, 18)
|
||||
x = UOp.variable('x', 6, 18)
|
||||
uop = x * 4
|
||||
self.assertEqual(uop.const_factor(), 4) # Constant factor 4
|
||||
|
||||
@unittest.skip("broken")
|
||||
def test_const_factor_multiplication_of_consts_and_vars(self):
|
||||
# Multiplying constants and variables
|
||||
x = UOp.define_var('x', dtypes.int32, 10, 20)
|
||||
x = UOp.variable('x', 10, 20)
|
||||
uop = (x * 3) * 5
|
||||
self.assertEqual(uop.const_factor(), 15) # Constant multipliers are combined (3 * 5 = 15)
|
||||
|
||||
@@ -186,7 +186,7 @@ class TestDivides(unittest.TestCase):
|
||||
@unittest.skip("broken")
|
||||
def test_divides_variable_and_constant(self):
|
||||
# Multiplying a variable by a constant, then dividing by the same constant
|
||||
x = UOp.define_var('x', dtypes.int32, 10, 20)
|
||||
x = UOp.variable('x', 10, 20)
|
||||
uop = x * 6
|
||||
result = uop.divides(6)
|
||||
self.assertIsNotNone(result)
|
||||
@@ -194,7 +194,7 @@ class TestDivides(unittest.TestCase):
|
||||
|
||||
def test_divides_complex_expression(self):
|
||||
# Dividing a more complex expression
|
||||
x = UOp.define_var('x', dtypes.int32, 10, 20)
|
||||
x = UOp.variable('x', 10, 20)
|
||||
uop = (x * 6) + 18
|
||||
result = uop.divides(6)
|
||||
self.assertIsNotNone(result)
|
||||
@@ -202,7 +202,7 @@ class TestDivides(unittest.TestCase):
|
||||
|
||||
def test_divides_with_inexact_factors(self):
|
||||
# Multiplying by a constant but dividing by a non-exact divisor
|
||||
x = UOp.define_var('x', dtypes.int32, 15, 45)
|
||||
x = UOp.variable('x', 15, 45)
|
||||
uop = x * 4
|
||||
result = uop.divides(3)
|
||||
self.assertIsNone(result) # Cannot divide by 3, since 4 is not divisible by 3
|
||||
|
||||
@@ -4,7 +4,8 @@ if int(os.getenv("TYPED", "0")):
|
||||
install_import_hook(__name__)
|
||||
from tinygrad.tensor import Tensor # noqa: F401
|
||||
from tinygrad.engine.jit import TinyJit # noqa: F401
|
||||
from tinygrad.shape.symbolic import Variable # noqa: F401
|
||||
from tinygrad.ops import UOp
|
||||
Variable = UOp.variable
|
||||
from tinygrad.dtype import dtypes # noqa: F401
|
||||
from tinygrad.helpers import GlobalCounters, fetch, Context, getenv # noqa: F401
|
||||
from tinygrad.device import Device # noqa: F401
|
||||
|
||||
@@ -6,14 +6,13 @@ from typing import Optional, List, Tuple, cast, Dict, Final, DefaultDict
|
||||
from enum import Enum, auto
|
||||
|
||||
from tinygrad.ops import UNSAFE_PAD_OPS, BUFFER_UOPS, BinaryOps, KernelInfo, UOp, UOps, PatternMatcher, print_uops, type_verify, resolve, \
|
||||
graph_rewrite, track_rewrites
|
||||
graph_rewrite, track_rewrites, Variable, sint
|
||||
from tinygrad.device import Device
|
||||
from tinygrad.renderer import Renderer, TensorCore, Program
|
||||
from tinygrad.dtype import ImageDType, PtrDType
|
||||
from tinygrad.helpers import all_same, colored, ansilen, dedup, getenv, prod, round_up, all_int, to_function_name, diskcache_put
|
||||
from tinygrad.helpers import DEBUG, TC_OPT, USE_TC, AMX
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.shape.view import strides_for_shape
|
||||
from tinygrad.codegen.linearize import linearize_uop
|
||||
from tinygrad.codegen.uopgraph import full_graph_rewrite
|
||||
@@ -603,7 +602,8 @@ class Kernel:
|
||||
# kernel name (before late upcast)
|
||||
name = ("r" if self.reduceop is not None else ("C" if all(x.op in BUFFER_UOPS for x in self.ast.parents) else "E")) + \
|
||||
(f"{len(self.ast.src)}_" if len(self.ast.src) > 1 else "_") + \
|
||||
colored('_', 'BLACK').join([colored(str(x), c) for x,c in zip(self.full_shape, self.colors())])
|
||||
colored('_', 'BLACK').join([colored(str(x.render() if isinstance(x, UOp) else x), c) \
|
||||
for x,c in zip(self.full_shape, self.colors())])
|
||||
|
||||
# name the function something unique
|
||||
Kernel.kernel_cnt[(function_name := to_function_name(name))] += 1
|
||||
|
||||
@@ -3,10 +3,10 @@ from __future__ import annotations
|
||||
import functools, itertools, operator
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Tuple, cast, Optional
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, variable_to_uop
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.view import variable_to_uop
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import KernelInfo, BinaryOps, UOp, UOps, graph_rewrite, PatternMatcher, UPat, resolve
|
||||
from tinygrad.ops import KernelInfo, BinaryOps, UOp, UOps, graph_rewrite, PatternMatcher, UPat, resolve, sint
|
||||
from tinygrad.renderer import Renderer
|
||||
from tinygrad.helpers import all_int, prod, partition, flatten
|
||||
|
||||
@@ -33,7 +33,7 @@ def _limit_dims(dims:Tuple[sint, ...], max_sizes:Tuple[int, ...]):
|
||||
def get_grouped_dims(prefix, dims:Tuple[sint, ...], max_sizes:Optional[Tuple[int, ...]], reverse=False) -> List[UOp]:
|
||||
if reverse: dims = dims[::-1]
|
||||
limited = _limit_dims(dims, max_sizes) if max_sizes is not None else dims
|
||||
ret = raw_idxs = [UOp(UOps.SPECIAL, dtypes.pyint, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)]
|
||||
ret = raw_idxs = [UOp(UOps.SPECIAL, dtypes.int, (), (f"{prefix}{i}", s)) for i,s in enumerate(limited)]
|
||||
if limited != dims:
|
||||
ret = []
|
||||
# cast for mypy, get_contraction won't be None
|
||||
@@ -75,22 +75,22 @@ def get_index(ast:UOp, opts:Renderer) -> IndexContext:
|
||||
get_grouped_dims("lidx", full_shape[global_dims:first_reduce+group_for_reduces], opts.local_max)
|
||||
else:
|
||||
# all loops are RANGES
|
||||
idxs = [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(g)), (i, False))
|
||||
idxs = [UOp(UOps.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, False))
|
||||
for i,g in enumerate(full_shape[:first_reduce])]
|
||||
|
||||
# reduce loops
|
||||
idxs += [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(g)), (i, True))
|
||||
idxs += [UOp(UOps.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(g)), (i, True))
|
||||
for i,g in enumerate(full_shape[first_reduce+group_for_reduces:first_upcasted], start=first_reduce+group_for_reduces)]
|
||||
|
||||
# upcast loops
|
||||
for i,g in enumerate(full_shape[first_upcasted:], start=first_upcasted):
|
||||
assert isinstance(g, int), "needs to be int to upcast/unroll"
|
||||
idxs.append(UOp(UOps.EXPAND, dtypes.pyint, (UOp.const(dtypes.pyint.vec(g), tuple(range(g))),), ((i,g),)))
|
||||
idxs.append(UOp(UOps.EXPAND, dtypes.int, (UOp.const(dtypes.int.vec(g), tuple(range(g))),), ((i,g),)))
|
||||
|
||||
# late indexes (group for reduce)
|
||||
ridxs = idxs[:]
|
||||
for a in range(first_reduce, first_reduce+group_for_reduces):
|
||||
ridxs[a] = UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(full_shape[a])), (1000+a, True))
|
||||
ridxs[a] = UOp(UOps.RANGE, dtypes.int, (UOp.const(dtypes.int, 0), variable_to_uop(full_shape[a])), (1000+a, True))
|
||||
|
||||
return IndexContext(idxs, ridxs)
|
||||
|
||||
|
||||
@@ -118,9 +118,9 @@ def idx_given_valid(valid:UOp, idx:UOp) -> Optional[UOp]:
|
||||
candidates = []
|
||||
if uop.op is UOps.ALU and uop.arg is BinaryOps.ADD and all(is_irreducible(u) and u.vmin == 0 for u in _get_chain(uop, BinaryOps.ADD)):
|
||||
# if the constraint is a simplex: X0 + X1 + ... > 0, we can check if all Xi > 0 simplify into the same output
|
||||
candidates.append([(Xi, UOp.define_var("fake", Xi.dtype, 1, Xi.vmax)) for Xi in _get_chain(uop, BinaryOps.ADD)])
|
||||
candidates.append([(Xi, UOp.variable("fake", 1, Xi.vmax, Xi.dtype)) for Xi in _get_chain(uop, BinaryOps.ADD)])
|
||||
# try checking the whole clause
|
||||
candidates.append([(uop, UOp.define_var("fake", uop.dtype, uop.vmin if v[0] is None else v[0], uop.vmax if v[1] is None else v[1]))])
|
||||
candidates.append([(uop, UOp.variable("fake", uop.vmin if v[0] is None else v[0], uop.vmax if v[1] is None else v[1], uop.dtype))])
|
||||
|
||||
for candidate in candidates:
|
||||
newidxs:List[List[UOp]] = [[], []]
|
||||
@@ -538,9 +538,6 @@ reducer = PatternMatcher([
|
||||
(UPat(UOps.LOAD, src=(UPat.var("buf"), UPat()), allow_any_len=True, name="load"), simplify_valid_image_load),
|
||||
])
|
||||
|
||||
no_pyint = PatternMatcher([(UPat((UOps.CONST, UOps.VCONST, UOps.ALU, UOps.SPECIAL, UOps.RANGE, UOps.EXPAND, UOps.VECTORIZE, UOps.DEFINE_VAR),
|
||||
name="x"), lambda x: UOp(x.op, dtypes.int32.vec(x.dtype.count), x.src, x.arg) if x.dtype.scalar() == dtypes.pyint else None)])
|
||||
|
||||
# *** uop graph ***
|
||||
|
||||
linearize_cnt = 0
|
||||
@@ -552,9 +549,6 @@ def full_graph_rewrite(sink:UOp, opts:Optional[Renderer]=None) -> UOp:
|
||||
acc_number = 0
|
||||
sink = graph_rewrite(sink, sym)
|
||||
|
||||
# rewrite pyint to int32
|
||||
sink = graph_rewrite(sink, no_pyint)
|
||||
|
||||
# expand
|
||||
linearize_cnt += 1
|
||||
if linearize_cnt != (de:=getenv("DEBUG_EXPAND", 0)) and de != -1:
|
||||
|
||||
@@ -84,9 +84,7 @@ class dtypes:
|
||||
return {dtypes.float16: (5, 10), dtypes.bfloat16: (8, 7), dtypes.float32: (8, 23), dtypes.float64: (11, 52)}[dtype]
|
||||
@staticmethod
|
||||
def fields() -> Dict[str, DType]: return DTYPES_DICT
|
||||
# TODO: priority should be higher than bool
|
||||
void: Final[DType] = DType(-1, 0, "void", None, 1)
|
||||
pyint: Final[DType] = DType(-1, 8, "pyint", None, 1) # arbitrary precision integer, same itemsize to int64 so min/max works
|
||||
bool: Final[DType] = DType(0, 1, "bool", '?', 1)
|
||||
int8: Final[DType] = DType(1, 1, "char", 'b', 1)
|
||||
uint8: Final[DType] = DType(2, 1, "unsigned char", 'B', 1)
|
||||
@@ -118,7 +116,7 @@ class dtypes:
|
||||
|
||||
floats = (float16, bfloat16, float32, float64)
|
||||
uints = (uint8, uint16, uint32, uint64)
|
||||
sints = (int8, int16, int32, int64, pyint)
|
||||
sints = (int8, int16, int32, int64)
|
||||
ints = uints + sints
|
||||
|
||||
if (env_default_float := getenv("DEFAULT_FLOAT", "")):
|
||||
|
||||
@@ -6,9 +6,8 @@ from tinygrad.engine.lazy import LazyBuffer
|
||||
from tinygrad.helpers import flatten, merge_dicts, DEBUG, Context, GRAPH, BEAM, getenv, colored, JIT, dedup, partition
|
||||
from tinygrad.device import Buffer, Compiled, Device
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.ops import UOp, ssimplify
|
||||
from tinygrad.ops import UOp, ssimplify, Variable, sint, sym_infer
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.shape.symbolic import Variable, sint, sym_infer
|
||||
from tinygrad.engine.realize import ExecItem, capturing, EmptyOp, ViewOp, BufferXfer, CompiledRunner, Runner, _internal_memory_planner
|
||||
from tinygrad.nn.state import get_parameters
|
||||
from dataclasses import dataclass
|
||||
|
||||
@@ -3,8 +3,7 @@ from typing import Union, Optional, Any, Tuple, List, get_args
|
||||
from tinygrad.dtype import dtypes, DType, ConstType, to_dtype, ImageDType
|
||||
from tinygrad.helpers import prod, getenv, all_int, all_same, DEBUG, _METADATA, Metadata, SPLIT_REDUCEOP
|
||||
from tinygrad.ops import MetaOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, Op, exec_alu, python_alu, REDUCE_ALU
|
||||
from tinygrad.ops import identity_element, MathTrait, resolve, UOp
|
||||
from tinygrad.shape.symbolic import sint
|
||||
from tinygrad.ops import identity_element, MathTrait, resolve, UOp, sint
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
from tinygrad.device import Buffer
|
||||
from weakref import ref, ReferenceType, WeakValueDictionary
|
||||
|
||||
@@ -4,10 +4,9 @@ from collections import defaultdict
|
||||
from dataclasses import dataclass, replace
|
||||
from tinygrad.helpers import colored, getenv, DEBUG, GlobalCounters, ansilen, BEAM, NOOPT, all_int, CAPTURING, Metadata, Context, TRACEMETA, dedup
|
||||
from tinygrad.helpers import NO_MEMORY_PLANNER
|
||||
from tinygrad.ops import UOps, UOp
|
||||
from tinygrad.ops import UOps, UOp, Variable, sym_infer, sint
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.device import Device, Buffer
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer, sint
|
||||
from tinygrad.renderer import Renderer, Program
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.engine.schedule import ScheduleItem
|
||||
@@ -164,9 +163,10 @@ class ExecItem:
|
||||
prg: Runner
|
||||
bufs: List[Optional[Buffer]]
|
||||
metadata: Optional[Tuple[Metadata, ...]] = None
|
||||
def run(self, var_vals:Optional[Dict[Variable, int]]=None, wait=False, jit=False, do_update_stats=True) -> Optional[float]:
|
||||
def run(self, _var_vals:Optional[Dict[Variable, int]]=None, wait=False, jit=False, do_update_stats=True) -> Optional[float]:
|
||||
var_vals = {} if _var_vals is None else _var_vals
|
||||
bufs = [cast(Buffer, x) for x in self.bufs] if jit else [cast(Buffer, x).ensure_allocated() for x in self.bufs]
|
||||
et = self.prg(bufs, var_vals if var_vals is not None else {}, wait=wait or DEBUG >= 2)
|
||||
et = self.prg(bufs, var_vals, wait=wait or DEBUG >= 2)
|
||||
if do_update_stats:
|
||||
GlobalCounters.kernel_count += 1
|
||||
GlobalCounters.global_ops += (op_est:=sym_infer(self.prg.op_estimate, var_vals))
|
||||
|
||||
@@ -3,10 +3,9 @@ from collections import defaultdict, deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Callable, Tuple, List, Dict, Optional, DefaultDict, cast
|
||||
from tinygrad.ops import REDUCE_ALU, UNSAFE_PAD_OPS, MetaOps, ReduceOps, TernaryOps, UnaryOps, UOp, UOps, PatternMatcher, UPat, resolve, \
|
||||
graph_rewrite, track_rewrites
|
||||
graph_rewrite, track_rewrites, Variable, sint
|
||||
from tinygrad.helpers import GRAPH, DEBUG, MULTIOUTPUT, SAVE_SCHEDULE, FUSE_CONV_BW, FUSE_ARANGE, GlobalCounters, Metadata, all_same, \
|
||||
colored, diskcache_put, prod, dedup, all_int, merge_dicts, getenv, unwrap
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.dtype import ImageDType, dtypes
|
||||
from tinygrad.engine.lazy import LazyBuffer
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
@@ -162,20 +161,19 @@ def _recursive_uop(buf:LazyBuffer, st:ShapeTracker, outputs:Tuple[LazyBuffer, ..
|
||||
if buf not in assign_targets and buf not in inputs: inputs.append(buf)
|
||||
return UOp(UOps.LOAD, dtype, (ubuf, unbound_st.to_uop()))
|
||||
|
||||
# reduce ops change ShapeTracker
|
||||
if buf.op in ReduceOps:
|
||||
rsrc = _recursive_uop(buf.srcs[0], ShapeTracker.from_shape(buf.srcs[0].shape), outputs, var_vals, inputs, buf_uops, assign_targets, cache)
|
||||
return cache.setdefault((buf, st), UOp(UOps.REDUCE_AXIS, dtype, (rsrc,), (REDUCE_ALU[cast(ReduceOps, buf.op)], buf.arg)).view(st))
|
||||
|
||||
# elementwise ops pass shapetracker
|
||||
in_uops = tuple(_recursive_uop(x, st, outputs, var_vals, inputs, buf_uops, assign_targets, cache) for x in buf.srcs)
|
||||
if buf.op is MetaOps.CONTIGUOUS:
|
||||
# only reduceop changes shape
|
||||
src_st = ShapeTracker.from_shape(buf.srcs[0].shape) if buf.op in ReduceOps else st
|
||||
src: List[UOp] = [_recursive_uop(x, src_st, outputs, var_vals, inputs, buf_uops, assign_targets, cache) for x in buf.srcs]
|
||||
if buf.op in ReduceOps: ret = UOp(UOps.REDUCE_AXIS, dtype, tuple(src), (REDUCE_ALU[cast(ReduceOps, buf.op)], buf.arg)).view(st)
|
||||
elif buf.op is MetaOps.CONTIGUOUS:
|
||||
assert buf in outputs, f"{buf.op} must be writable"
|
||||
return in_uops[0]
|
||||
if buf.op is MetaOps.ASSIGN: return cache.setdefault((buf, st), UOp(UOps.ASSIGN, dtype, (in_uops[1].src[0], in_uops[0])))
|
||||
if buf.op is UnaryOps.CAST: return cache.setdefault((buf, st), UOp(UOps.CAST, dtype, in_uops))
|
||||
if buf.op is UnaryOps.BITCAST: return cache.setdefault((buf, st), UOp(UOps.BITCAST, dtype, in_uops))
|
||||
return cache.setdefault((buf, st), UOp(UOps.ALU, dtype, in_uops, buf.op))
|
||||
ret = src[0]
|
||||
elif buf.op is MetaOps.ASSIGN: ret = UOp(UOps.ASSIGN, dtype, (src[1].src[0], src[0]))
|
||||
elif buf.op is UnaryOps.CAST: ret = src[0].cast(dtype)
|
||||
elif buf.op is UnaryOps.BITCAST: ret = src[0].bitcast(dtype)
|
||||
else: ret = UOp(UOps.ALU, dtype, tuple(src), buf.op)
|
||||
cache[(buf, st)] = ret
|
||||
return ret
|
||||
|
||||
def _lower_lazybuffer(outs:List[LazyBuffer], buf_uops:Dict[Buffer, UOp]) -> Tuple[LBScheduleItem, Dict[Variable, int]]:
|
||||
"""describe the computation for a LazyBuffer with UOp + inputs + var_vals"""
|
||||
|
||||
@@ -2,14 +2,12 @@ from typing import Dict, List, cast, DefaultDict, Optional, Tuple, Callable
|
||||
import itertools, functools, random, math, time, multiprocessing, traceback, signal
|
||||
from collections import defaultdict
|
||||
from dataclasses import replace
|
||||
from tinygrad.ops import UOp, UOps
|
||||
from tinygrad.ops import UOp, UOps, Variable, sym_infer
|
||||
from tinygrad.device import Device, Buffer, Compiler
|
||||
from tinygrad.helpers import prod, flatten, DEBUG, CACHELEVEL, diskcache_get, diskcache_put, getenv, Context, colored, to_function_name
|
||||
from tinygrad.dtype import ImageDType
|
||||
from tinygrad.codegen.kernel import Kernel
|
||||
from tinygrad.codegen.kernel import Opt, OptOps, KernelOptError
|
||||
from tinygrad.codegen.kernel import Kernel, Opt, OptOps, KernelOptError
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.shape.symbolic import Variable, sym_infer
|
||||
from tinygrad.engine.realize import CompiledRunner
|
||||
from tinygrad.renderer import Program
|
||||
|
||||
|
||||
@@ -3,10 +3,9 @@ import math
|
||||
from typing import Tuple, Optional
|
||||
from tinygrad.helpers import argsort
|
||||
from tinygrad.dtype import dtypes, DType, sum_acc_dtype
|
||||
from tinygrad.ops import ReduceOps, resolve
|
||||
from tinygrad.ops import ReduceOps, resolve, sint
|
||||
from tinygrad.tensor import Function
|
||||
from tinygrad.engine.lazy import LazyBuffer
|
||||
from tinygrad.shape.symbolic import sint
|
||||
|
||||
class Contiguous(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer: return x.contiguous()
|
||||
@@ -40,7 +39,6 @@ class Sin(Function):
|
||||
|
||||
def backward(self, grad_output:LazyBuffer) -> LazyBuffer: return (math.pi/2 - self.x).sin() * grad_output
|
||||
|
||||
# NOTE: maximum(x, 0) behaves differently where x=0
|
||||
class Relu(Function):
|
||||
def forward(self, x:LazyBuffer) -> LazyBuffer:
|
||||
self.ret = x.max(0)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import math
|
||||
from typing import Optional, Union, Tuple
|
||||
from typing import Optional, Union, Tuple, List
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.helpers import prod, make_pair
|
||||
from tinygrad.nn import optim, state, datasets # noqa: F401
|
||||
@@ -39,7 +40,7 @@ class BatchNorm:
|
||||
if track_running_stats: self.running_mean, self.running_var = Tensor.zeros(sz, requires_grad=False), Tensor.ones(sz, requires_grad=False)
|
||||
|
||||
def calc_stats(self, x:Tensor) -> Tuple[Tensor, Tensor]:
|
||||
shape_mask = [1, -1, *([1]*(x.ndim-2))]
|
||||
shape_mask: List[int] = [1, -1, *([1]*(x.ndim-2))]
|
||||
if self.track_running_stats and not Tensor.training: return self.running_mean, self.running_var.reshape(shape=shape_mask).expand(x.shape)
|
||||
# This requires two full memory accesses to x
|
||||
# https://github.com/pytorch/pytorch/blob/c618dc13d2aa23625cb0d7ada694137532a4fa33/aten/src/ATen/native/cuda/Normalization.cuh
|
||||
@@ -49,7 +50,7 @@ class BatchNorm:
|
||||
batch_var = (y*y).mean(axis=reduce_axes)
|
||||
return batch_mean, batch_var
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
batch_mean, batch_var = self.calc_stats(x)
|
||||
# NOTE: wow, this is done all throughout training in most PyTorch models
|
||||
if self.track_running_stats and Tensor.training:
|
||||
@@ -59,7 +60,7 @@ class BatchNorm:
|
||||
return x.batchnorm(self.weight, self.bias, batch_mean, batch_var.add(self.eps).rsqrt())
|
||||
BatchNorm2d = BatchNorm3d = BatchNorm
|
||||
|
||||
def Conv1d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
||||
def Conv1d(in_channels:int, out_channels:int, kernel_size:int, stride=1, padding:Union[int, str]=0, dilation=1, groups=1, bias=True) -> Conv2d:
|
||||
"""
|
||||
Applies a 1D convolution over an input signal composed of several input planes.
|
||||
|
||||
@@ -93,22 +94,24 @@ class Conv2d:
|
||||
print(t.numpy())
|
||||
```
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
|
||||
self.kernel_size = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
|
||||
def __init__(self, in_channels:int, out_channels:int, kernel_size:Union[int, Tuple[int, ...]], stride=1, padding:Union[int, str]=0,
|
||||
dilation=1, groups=1, bias=True):
|
||||
self.kernel_size: Tuple[int, ...] = (kernel_size, kernel_size) if isinstance(kernel_size, int) else tuple(kernel_size)
|
||||
if isinstance(padding, str):
|
||||
if padding.lower() != 'same': raise ValueError(f"Invalid padding string {padding!r}, only 'same' is supported")
|
||||
if stride != 1: raise ValueError("padding='same' is not supported for strided convolutions")
|
||||
self.padding = [p for d,k in zip(make_pair(dilation,len(self.kernel_size)), self.kernel_size[::-1]) for p in (d*(k-1)//2, d*(k-1) - d*(k-1)//2)]
|
||||
self.padding: Union[int, List[int]] = [p for d,k in zip(make_pair(dilation,len(self.kernel_size)), self.kernel_size[::-1]) for p in (d*(k-1)//2, d*(k-1) - d*(k-1)//2)] #noqa:E501
|
||||
else: self.padding = padding
|
||||
self.stride, self.dilation, self.groups = stride, dilation, groups
|
||||
scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
|
||||
self.weight = Tensor.uniform(out_channels, in_channels//groups, *self.kernel_size, low=-scale, high=scale)
|
||||
self.bias = Tensor.uniform(out_channels, low=-scale, high=scale) if bias else None
|
||||
self.bias: Optional[Tensor] = Tensor.uniform(out_channels, low=-scale, high=scale) if bias else None
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return x.conv2d(self.weight, self.bias, padding=self.padding, stride=self.stride, dilation=self.dilation, groups=self.groups)
|
||||
|
||||
def ConvTranspose1d(in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
|
||||
def ConvTranspose1d(in_channels:int, out_channels:int, kernel_size:int, stride=1, padding=0, output_padding=0, dilation=1,
|
||||
groups=1, bias=True) -> ConvTranspose2d:
|
||||
"""
|
||||
Applies a 1D transposed convolution operator over an input signal composed of several input planes.
|
||||
|
||||
@@ -142,13 +145,14 @@ class ConvTranspose2d(Conv2d):
|
||||
print(t.numpy())
|
||||
```
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, output_padding=0, dilation=1, groups=1, bias=True):
|
||||
def __init__(self, in_channels:int, out_channels:int, kernel_size:Union[int, Tuple[int, ...]], stride=1, padding=0, output_padding=0,
|
||||
dilation=1, groups=1, bias=True):
|
||||
super().__init__(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias)
|
||||
scale = 1 / math.sqrt(in_channels * prod(self.kernel_size))
|
||||
self.weight = Tensor.uniform(in_channels, out_channels//groups, *self.kernel_size, low=-scale, high=scale)
|
||||
self.output_padding = output_padding
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return x.conv_transpose2d(self.weight, self.bias, padding=self.padding, output_padding=self.output_padding, stride=self.stride,
|
||||
dilation=self.dilation, groups=self.groups)
|
||||
|
||||
@@ -168,12 +172,12 @@ class Linear:
|
||||
print(t.numpy())
|
||||
```
|
||||
"""
|
||||
def __init__(self, in_features, out_features, bias=True):
|
||||
def __init__(self, in_features:int, out_features:int, bias=True):
|
||||
bound = 1 / math.sqrt(in_features)
|
||||
self.weight = Tensor.uniform(out_features, in_features, low=-bound, high=bound)
|
||||
self.bias = Tensor.uniform(out_features, low=-bound, high=bound) if bias else None
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
return x.linear(self.weight.transpose(), self.bias)
|
||||
|
||||
class GroupNorm:
|
||||
@@ -193,12 +197,12 @@ class GroupNorm:
|
||||
print(t.mean().item(), t.std().item())
|
||||
```
|
||||
"""
|
||||
def __init__(self, num_groups:int, num_channels:int, eps:float=1e-5, affine:bool=True):
|
||||
def __init__(self, num_groups:int, num_channels:int, eps=1e-5, affine=True):
|
||||
self.num_groups, self.num_channels, self.eps = num_groups, num_channels, eps
|
||||
self.weight: Optional[Tensor] = Tensor.ones(num_channels) if affine else None
|
||||
self.bias: Optional[Tensor] = Tensor.zeros(num_channels) if affine else None
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
# reshape for layernorm to work as group norm
|
||||
# subtract mean and divide stddev
|
||||
x = x.reshape(x.shape[0], self.num_groups, -1).layernorm(eps=self.eps).reshape(x.shape)
|
||||
@@ -224,12 +228,12 @@ class InstanceNorm:
|
||||
print(t.mean().item(), t.std().item())
|
||||
```
|
||||
"""
|
||||
def __init__(self, num_features:int, eps:float=1e-5, affine:bool=True):
|
||||
def __init__(self, num_features:int, eps=1e-5, affine=True):
|
||||
self.num_features, self.eps = num_features, eps
|
||||
self.weight: Optional[Tensor] = Tensor.ones(num_features) if affine else None
|
||||
self.bias: Optional[Tensor] = Tensor.zeros(num_features) if affine else None
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
x = x.reshape(x.shape[0], self.num_features, -1).layernorm(eps=self.eps).reshape(x.shape)
|
||||
if self.weight is None or self.bias is None: return x
|
||||
return x * self.weight.reshape(1, -1, *[1] * (len(x.shape)-2)) + self.bias.reshape(1, -1, *[1] * (len(x.shape)-2))
|
||||
@@ -251,12 +255,12 @@ class LayerNorm:
|
||||
print(t.mean().item(), t.std().item())
|
||||
```
|
||||
"""
|
||||
def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps:float=1e-5, elementwise_affine:bool=True):
|
||||
self.normalized_shape = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
|
||||
def __init__(self, normalized_shape:Union[int, Tuple[int, ...]], eps=1e-5, elementwise_affine=True):
|
||||
self.normalized_shape: Tuple[int, ...] = (normalized_shape,) if isinstance(normalized_shape, int) else tuple(normalized_shape)
|
||||
self.axis, self.eps, self.elementwise_affine = tuple(-1-i for i in range(len(self.normalized_shape))), eps, elementwise_affine
|
||||
self.weight, self.bias = (Tensor.ones(*self.normalized_shape), Tensor.zeros(*self.normalized_shape)) if elementwise_affine else (None, None)
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
def __call__(self, x:Tensor) -> Tensor:
|
||||
assert self.normalized_shape == x.shape[-len(self.normalized_shape):], f"last dimensions of {x.shape} must match {self.normalized_shape}"
|
||||
x = x.layernorm(eps=self.eps, axis=self.axis)
|
||||
if not self.elementwise_affine: return x
|
||||
@@ -278,7 +282,7 @@ class LayerNorm2d(LayerNorm):
|
||||
print(t.mean().item(), t.std().item())
|
||||
```
|
||||
"""
|
||||
def __call__(self, x): return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
def __call__(self, x: Tensor) -> Tensor: return super().__call__(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||
|
||||
class RMSNorm:
|
||||
"""
|
||||
@@ -296,9 +300,9 @@ class RMSNorm:
|
||||
print(norm(t).numpy())
|
||||
```
|
||||
"""
|
||||
def __init__(self, dim, eps=1e-6): self.eps, self.weight = eps, Tensor.ones(dim)
|
||||
def __init__(self, dim:int, eps=1e-6): self.eps, self.weight = eps, Tensor.ones(dim)
|
||||
|
||||
def _norm(self, x:Tensor): return x * (x.square().mean(-1, keepdim=True) + self.eps).rsqrt()
|
||||
def _norm(self, x:Tensor) -> Tensor: return x * (x.square().mean(-1, keepdim=True) + self.eps).rsqrt()
|
||||
|
||||
def __call__(self, x:Tensor) -> Tensor: return self._norm(x.float()).cast(x.dtype) * self.weight
|
||||
|
||||
|
||||
107
tinygrad/ops.py
107
tinygrad/ops.py
@@ -7,7 +7,6 @@ from weakref import WeakValueDictionary
|
||||
from tinygrad.dtype import ConstType, ImageDType, PtrDType, dtypes, DType, truncate
|
||||
from tinygrad.helpers import ContextVar, prod, getenv, all_same, Context
|
||||
if TYPE_CHECKING:
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
|
||||
# wrapper around IntEnum that preserves Enum.__str__ and makes auto() unique across all FastEnum subclasses
|
||||
@@ -154,7 +153,6 @@ COMMUTATIVE = {BinaryOps.ADD, BinaryOps.MUL, BinaryOps.MAX, BinaryOps.CMPNE, Bin
|
||||
END_FOR_UOP = {UOps.IF:(UOps.STORE, UOps.ENDIF), UOps.RANGE:(UOps.ASSIGN, UOps.ENDRANGE)}
|
||||
|
||||
# With True as the default, this matches the old symbolic behavior
|
||||
# python3 -c 'from tinygrad.shape.symbolic import Variable; print(bool(Variable("a", 1, 10) < 4))' -> True
|
||||
def resolve(x, default:bool=True):
|
||||
if not isinstance(x, UOp): return bool(x)
|
||||
assert x.dtype is dtypes.bool, "UOp in resolve must be bool"
|
||||
@@ -185,6 +183,7 @@ class UOp(MathTrait):
|
||||
__slots__ = ["op", "dtype", "src", "arg"]
|
||||
def __init__(self, op: UOps, dtype:DType=dtypes.void, src: Tuple[UOp,...]=tuple(), arg:Any=None):
|
||||
# TODO: instant check rules here make debugging easier
|
||||
#assert op in UOps and isinstance(dtype, DType), f"bad UOp creation with {op} {dtype}"
|
||||
#if op is UOps.ALU and arg is BinaryOps.CMPNE: assert dtype.scalar() == dtypes.bool
|
||||
#if op is UOps.VECTORIZE and dtype != dtypes.void: assert len(src) == dtype.count, f"{len(src)} invalid for {dtype}"
|
||||
#if op is UOps.ALU and arg not in (BinaryOps.CMPNE, BinaryOps.CMPLT, TernaryOps.WHERE): assert all_same([dtype] + [x.dtype for x in src])
|
||||
@@ -195,6 +194,18 @@ class UOp(MathTrait):
|
||||
new_args = (kwargs.get("op", self.op), kwargs.get("dtype", self.dtype), kwargs.get("src", self.src), kwargs.get("arg", self.arg))
|
||||
if (self.op, self.dtype, self.src, self.arg) == new_args: return self
|
||||
return UOp(*new_args)
|
||||
@functools.cached_property
|
||||
def key(self) -> bytes:
|
||||
return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest()
|
||||
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
|
||||
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is UOps.REDUCE_AXIS else self.arg
|
||||
@functools.cached_property
|
||||
def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents}}
|
||||
@property # parents with self
|
||||
def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None}
|
||||
|
||||
# *** uop shape stuff ***
|
||||
|
||||
@property
|
||||
def has_st(self) -> bool: return self.op not in {UOps.DEFINE_LOCAL, UOps.DEFINE_GLOBAL, UOps.BUFFER, UOps.CONST, UOps.DEFINE_VAR}
|
||||
@functools.cached_property
|
||||
@@ -207,11 +218,11 @@ class UOp(MathTrait):
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
return ShapeTracker.from_shape(src_sts[0].reduce(self.axis_arg)) if self.op is UOps.REDUCE_AXIS else src_sts[0]
|
||||
@functools.cached_property
|
||||
def key(self) -> bytes:
|
||||
return hashlib.sha256(str((self.op, self.dtype, self.arg)).encode() + b"".join([s.key for s in self.src])).digest()
|
||||
def __repr__(self): return pretty_print(self, lambda x: f"{type(self).__name__}({x.op}, {x.dtype}, arg={x.argstr()}, src=(%s))")
|
||||
def argstr(self): return f'({", ".join(map(str, self.arg))})' if self.op is UOps.REDUCE_AXIS else self.arg
|
||||
def full_shape(self) -> Tuple[sint, ...]:
|
||||
return self.arg.shape if self.op is UOps.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
|
||||
|
||||
# *** uop evaluation ***
|
||||
|
||||
def simplify(self):
|
||||
with Context(TRACK_MATCH_STATS=0):
|
||||
return graph_rewrite(self, symbolic)
|
||||
@@ -226,7 +237,10 @@ class UOp(MathTrait):
|
||||
def __bool__(self): return self._eval((dtypes.bool,), bool)
|
||||
def __int__(self): return self._eval(dtypes.ints, int)
|
||||
def __float__(self): return self._eval(dtypes.floats, float)
|
||||
# *** uop syntactic sugar
|
||||
def substitute(self, dvars:Dict[UOp, UOp]): return graph_rewrite(self, _substitute, dvars)
|
||||
|
||||
# *** uop syntactic sugar ***
|
||||
|
||||
@property
|
||||
def st_arg(self) -> ShapeTracker:
|
||||
assert self.op in BUFFER_UOPS, f"st_arg called on {self.op}"
|
||||
@@ -268,37 +282,34 @@ class UOp(MathTrait):
|
||||
out_dtype = dtypes.bool.vec(out_dtype.count) if out_dtype.count > 1 else dtypes.bool
|
||||
return UOp(UOps.ALU, out_dtype, (self,)+src, arg)
|
||||
@staticmethod
|
||||
def const(dtype:DType, b:Tuple[ConstType, ...]|ConstType|Variable): return UOp._const(dtype, b)
|
||||
@staticmethod
|
||||
def _const(dtype:DType, b:Tuple[ConstType, ...]|ConstType|Variable):
|
||||
# TODO: fix dtype of b.max after Variable is just an UOp
|
||||
#if isinstance(b, Variable): return UOp.define_var(b.expr, dtype, b.min, cast(int, b.max))
|
||||
def const(dtype:DType, b:Tuple[ConstType, ...]|ConstType|Variable):
|
||||
if isinstance(b, UOp): return b.unbind()[0] if b.op is UOps.BIND else b
|
||||
if isinstance(b, tuple) and all_same(b): b = b[0] # doesn't have to be a VCONST if they are all the same
|
||||
return UOp(UOps.VCONST if isinstance(b, tuple) else UOps.CONST, dtype, arg=dtypes.as_const(b, dtype) if dtype is not None else b) # type: ignore
|
||||
@staticmethod
|
||||
def define_var(name:str, dtype:DType, min_val:ConstType, max_val:ConstType): return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
|
||||
def range(dtype:DType, start:ConstType|UOp, end:ConstType|UOp, idx:int):
|
||||
return UOp(UOps.RANGE, dtype=dtype, src=(UOp.const(dtype, start) if not isinstance(start, UOp) else start,
|
||||
UOp.const(dtype, end) if not isinstance(end, UOp) else end), arg=idx)
|
||||
def reduce(self, op:BinaryOps, *rng:UOp): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op)
|
||||
|
||||
# *** uop Variable stuff ***
|
||||
|
||||
@staticmethod
|
||||
def variable(name:str, min_val:ConstType, max_val:ConstType, dtype:DType=dtypes.int):
|
||||
return UOp(UOps.DEFINE_VAR, dtype, arg=(name, min_val, max_val))
|
||||
@property
|
||||
def expr(self):
|
||||
assert self.op is UOps.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
|
||||
return self.arg[0]
|
||||
def bind(self, val:int):
|
||||
assert self.op is UOps.DEFINE_VAR, f"op is {self.op}, need DEFINE_VAR"
|
||||
assert self.arg[1] <= val and val <= self.arg[2], f"bind {val} not in range {self.arg[1]}-{self.arg[2]}"
|
||||
return UOp(UOps.BIND, self.dtype, (self, self.const_like(val)))
|
||||
def unbind(self) -> Tuple[Variable, int]:
|
||||
assert self.op is UOps.BIND and self.src[0].op is UOps.DEFINE_VAR and self.src[1].op is UOps.CONST, f"can't unbind {self}"
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
return cast(Variable, self.src[0]), self.src[1].arg
|
||||
return self.src[0], self.src[1].arg
|
||||
@property
|
||||
def val(self) -> int: return self.unbind()[1]
|
||||
# TODO: this is context rewrite
|
||||
def substitute(self, dvars:Dict[UOp, UOp]):
|
||||
if self in dvars: return dvars[self]
|
||||
return self.replace(src=tuple(x.substitute(dvars) for x in self.src))
|
||||
@staticmethod
|
||||
def range(dtype:DType, start:ConstType, end:ConstType, idx:int):
|
||||
return UOp(UOps.RANGE, dtype=dtype, src=(UOp.const(dtype, start), UOp.const(dtype, end)), arg=(idx,))
|
||||
def reduce(self, op:BinaryOps, *rng:UOp): return UOp(UOps.REDUCE, self.dtype, (self,) + rng, op)
|
||||
@functools.cached_property
|
||||
def parents(self) -> Dict[UOp, None]: return {**{x:None for x in self.src}, **{k:None for x in self.src for k in x.parents}}
|
||||
@property # parents with self
|
||||
def sparents(self) -> Dict[UOp, None]: return {**self.parents, self:None}
|
||||
@functools.cached_property
|
||||
def full_shape(self) -> Tuple[sint, ...]:
|
||||
return self.arg.shape if self.op is UOps.VIEW else tuple(smax(x) for x in zip(*[x.full_shape for x in self.src if x.has_st]))
|
||||
def vars(self) -> Set[UOp]:
|
||||
bound_vars = set([x for x in self.sparents if x.op is UOps.BIND and x.src[0].op is UOps.DEFINE_VAR])
|
||||
bound_var_base = set(x.src[0] for x in bound_vars)
|
||||
@@ -306,8 +317,10 @@ class UOp(MathTrait):
|
||||
return bound_vars.union(set([x for x in all_vars if x not in bound_var_base]))
|
||||
def variables(self) -> List[Variable]:
|
||||
st_vars: List[Set[Variable]] = [x.st_arg.vars() for x in self.sparents if x.op in BUFFER_UOPS]
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
return sorted(set.union(*st_vars, [x.unbind()[0] if not isinstance(x, Variable) else x for x in self.vars()]), key=lambda v: v.arg)
|
||||
return sorted(set.union(*st_vars, [x.unbind()[0] if x.op is not UOps.DEFINE_VAR else x for x in self.vars()]), key=lambda v: v.arg)
|
||||
|
||||
# *** uop symbolic stuff ***
|
||||
|
||||
def const_factor(self) -> int:
|
||||
"""largest known int that divides self"""
|
||||
if self.op is UOps.CONST: return self.arg
|
||||
@@ -341,17 +354,20 @@ class UOp(MathTrait):
|
||||
if self.op is UOps.SPECIAL: return 0, self.arg[1]-1 if isinstance(self.arg[1], int) else dtypes.max(self.dtype)
|
||||
if self.op is UOps.CONST: return self.arg, self.arg
|
||||
if self.op is UOps.VCONST: return (min(self.arg), max(self.arg))
|
||||
if self.op is UOps.ALU and self.dtype.count == 1:
|
||||
if self.op is UOps.ALU:
|
||||
s0,s1,s2 = [cast(UOp, self.src[i] if i < len(self.src) else None) for i in range(3)]
|
||||
if self.arg is BinaryOps.ADD: return s0.vmin+s1.vmin, s0.vmax+s1.vmax
|
||||
if self.arg is BinaryOps.MUL:
|
||||
# both are non-positive
|
||||
if (s0.vmax <= 0 and s1.vmax <= 0): return s0.vmax*s1.vmax, s0.vmin*s1.vmin
|
||||
# at lease one is non-negative
|
||||
# at least one is non-negative
|
||||
if (s0.vmin >= 0 or s1.vmin >= 0):
|
||||
Lmin, Lmax = (s0.vmin, s0.vmax) if s1.vmin >= 0 else (s0.vmax, s0.vmin)
|
||||
Rmin, Rmax = (s1.vmin, s1.vmax) if s0.vmin >= 0 else (s1.vmax, s1.vmin)
|
||||
return Lmin*Rmin, Lmax*Rmax
|
||||
# arbitrary
|
||||
products = [s0.vmin * s1.vmin, s0.vmin * s1.vmax, s0.vmax * s1.vmin, s0.vmax * s1.vmax]
|
||||
return min(products), max(products)
|
||||
if self.arg is BinaryOps.MOD and s1.vmin > 0: return 0, s1.vmax-1
|
||||
if self.arg is BinaryOps.IDIV and s1.op is UOps.CONST:
|
||||
if s1.arg > 0: return s0.vmin//s1.arg, s0.vmax//s1.arg
|
||||
@@ -402,13 +418,6 @@ def exec_alu(op:Op, dtype:DType, operands):
|
||||
return tuple([exec_alu(op, dtype.scalar(), [x[i] if isinstance(x, tuple) else x for x in operands]) for i in range(dtype.count)])
|
||||
return truncate.get(dtype, lambda x: x)(python_alu[op](*operands))
|
||||
|
||||
def uop_alu_resolve(u:UOp) -> sint:
|
||||
if u.op is UOps.CONST: return u.arg
|
||||
if u.op is UOps.DEFINE_VAR: return u
|
||||
#if u.op is UOps.DEFINE_VAR: return Variable(u.arg[0], u.arg[1], u.arg[2])
|
||||
if u.op is UOps.ALU: return exec_alu(u.arg, u.dtype, tuple(map(uop_alu_resolve, u.src)))
|
||||
raise RuntimeError(f"ALU resolve fail @ {u.op}")
|
||||
|
||||
# ***** uop helpers *****
|
||||
|
||||
def print_uops(uops:List[UOp]):
|
||||
@@ -435,7 +444,7 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
|
||||
for u in uops:
|
||||
if u.op is UOps.RANGE:
|
||||
mult_stack.append(mults)
|
||||
mults *= uop_alu_resolve(u.src[1] - u.src[0])
|
||||
mults *= (u.src[1] - u.src[0]).ssimplify()
|
||||
elif u.op is UOps.ENDRANGE:
|
||||
mults = mult_stack.pop(-1)
|
||||
elif u.op is UOps.SPECIAL:
|
||||
@@ -683,9 +692,6 @@ spec = PatternMatcher([
|
||||
(UPat(UOps.RANGE, src=(UPat(name="x"), UPat(name="y")), name="rng"), lambda rng,x,y: rng.dtype == x.dtype == y.dtype),
|
||||
(UPat(UOps.SPECIAL, src=()), lambda: True),
|
||||
|
||||
# no pyint allowed here!
|
||||
(UPat(UOps.ALU, dtype=dtypes.pyint), lambda: False),
|
||||
|
||||
# TODO: confirm the args of both of these are shapetrackers
|
||||
(UPat(UOps.VIEW, src=()), lambda: True),
|
||||
(UPat(UOps.VIEW, src=(UPat(),)), lambda: True),
|
||||
@@ -953,6 +959,8 @@ symbolic_flat = symbolic+PatternMatcher([
|
||||
((UPat.var("x", dtypes.ints) + UPat.var("y")) * UPat.cvar("c"), lambda x,y,c: x*c+y*c),
|
||||
])
|
||||
|
||||
_substitute = PatternMatcher([(UPat(tuple(UOps), name="x"), lambda ctx,x: ctx.get(x,None))])
|
||||
|
||||
# for debug
|
||||
renderer = PatternMatcher([
|
||||
(UPat(UOps.DEFINE_VAR, name="x"), lambda x: UOp(UOps.NOOP, arg=x.arg[0])),
|
||||
@@ -965,3 +973,12 @@ renderer = PatternMatcher([
|
||||
(UPat(UOps.ALU, src=UPat(UOps.NOOP), arg=BinaryOps.CMPLT, name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}<{x.src[1].arg})")),
|
||||
(UPat(UOps.ALU, src=UPat(UOps.NOOP), arg=BinaryOps.CMPNE, name="x"), lambda x: UOp(UOps.NOOP, arg=f"({x.src[0].arg}!={x.src[1].arg})")),
|
||||
])
|
||||
|
||||
# *** what was symbolic.py ***
|
||||
|
||||
sint = Union[int, UOp]
|
||||
Variable = UOp
|
||||
|
||||
def NumNode(val:int): return UOp.const(dtypes.int, val)
|
||||
def sym_infer(uop: Union[UOp, int], var_vals: Dict[UOp, int]) -> int:
|
||||
return int(uop.substitute({k:k.const_like(v) for k,v in var_vals.items()})) if isinstance(uop, UOp) else uop
|
||||
|
||||
@@ -2,8 +2,7 @@ from typing import Optional, List, Tuple, Dict, Callable, Any
|
||||
import functools
|
||||
from dataclasses import dataclass, field
|
||||
from tinygrad.helpers import to_function_name, dedup, prod
|
||||
from tinygrad.ops import Op, UOps, UOp, flops_mem
|
||||
from tinygrad.shape.symbolic import sym_infer, sint, Variable
|
||||
from tinygrad.ops import Op, UOps, UOp, flops_mem, sym_infer, sint, Variable
|
||||
from tinygrad.dtype import DType
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -43,7 +42,7 @@ class Program:
|
||||
if not self._ran_post_init and self.uops is not None:
|
||||
# single pass through the uops
|
||||
for u in self.uops:
|
||||
if u.op is UOps.DEFINE_VAR: self.vars.append(Variable(u.arg[0], u.arg[1], u.arg[2]))
|
||||
if u.op is UOps.DEFINE_VAR: self.vars.append(u)
|
||||
if u.op is UOps.DEFINE_GLOBAL: self.globals.append(u.arg)
|
||||
if u.op is UOps.STORE: self.outs.extend([x.arg for x in u.src[0].sparents if x.op is UOps.DEFINE_GLOBAL])
|
||||
if u.op is UOps.SPECIAL:
|
||||
|
||||
@@ -294,7 +294,7 @@ class MetalRenderer(CStyleLanguage):
|
||||
return {arg[3].name}2(c.thread_elements()[0], c.thread_elements()[1]);\n}}""")
|
||||
return super().render_kernel(function_name, kernel, bufs, uops, prefix)
|
||||
|
||||
code_for_op_half = {UnaryOps.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"1/{x}",
|
||||
code_for_op_half = {UnaryOps.RECIP: lambda x,dtype: f"hrcp({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"(1/{x})",
|
||||
BinaryOps.MAX: lambda a,b,dtype: f"__hmax({a},{b})" if dtype in (dtypes.half, dtypes.bfloat16) else f"max({a},{b})",
|
||||
UnaryOps.SQRT: lambda x,dtype: f"hsqrt({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sqrt({x})",
|
||||
UnaryOps.SIN: lambda x,dtype: f"hsin({x})" if dtype in (dtypes.half, dtypes.bfloat16) else f"sin({x})",
|
||||
@@ -369,16 +369,6 @@ code_for_op_hip = { UnaryOps.SQRT: lambda x,dtype: f"__ocml_sqrt_f{ {dtypes.half
|
||||
# TODO: MAX with int uses fmax_f32?
|
||||
BinaryOps.MAX: lambda a,b,dtype: f"__ocml_fmax_f{ {dtypes.half:16, dtypes.double:64}.get(dtype, 32) }({a},{b})",}
|
||||
|
||||
def _make_hip_code_for_op():
|
||||
def wrapper(key, func):
|
||||
def cast_bf16(*args):
|
||||
if args[-1] == dtypes.bfloat16:
|
||||
operands = tuple(f"(float)({arg})" for arg in (args[1:-1] if key is TernaryOps.WHERE else args[:-1]))
|
||||
return f"(hip_bfloat16)({func(*(((args[0],) if key is TernaryOps.WHERE else ()) + operands), dtypes.float)})"
|
||||
return func(*args)
|
||||
return cast_bf16
|
||||
return { k:wrapper(k,v) for k,v in {**CStyleLanguage.code_for_op, **code_for_op_hip}.items() }
|
||||
|
||||
class AMDRenderer(CStyleLanguage):
|
||||
device = "AMD"
|
||||
shared_max = 65536
|
||||
@@ -397,13 +387,18 @@ class AMDRenderer(CStyleLanguage):
|
||||
kernel_prefix += '\nextern "C" __attribute__((global))'
|
||||
code_for_workitem = {"g": lambda x: f"__ockl_get_group_id({x})", "l": lambda x: f"__ockl_get_local_id({x})",
|
||||
"i": lambda x: f"(__ockl_get_group_id({x})*__ockl_get_local_size({x})+__ockl_get_local_id({x}))"}
|
||||
code_for_op = _make_hip_code_for_op()
|
||||
code_for_op = { **CStyleLanguage.code_for_op, **code_for_op_hip }
|
||||
smem_prefix = "__attribute__((shared))"
|
||||
barrier = '__builtin_amdgcn_fence(__ATOMIC_RELEASE, "workgroup");' + '__builtin_amdgcn_s_barrier();' + \
|
||||
'__builtin_amdgcn_fence(__ATOMIC_ACQUIRE, "workgroup");'
|
||||
float4 = "make_float4"
|
||||
uses_ptr_arithmetic = False # NOTE: this fixes TestLinearizerOverflowAlt
|
||||
type_map = {dtypes.bfloat16: "hip_bfloat16"}
|
||||
extra_matcher = PatternMatcher([
|
||||
(UPat(UOps.ALU, arg=TernaryOps.WHERE, src=(UPat.var("b"), UPat.var("x", dtype=dtypes.bfloat16), UPat.var("y", dtype=dtypes.bfloat16))),
|
||||
lambda b,x,y: UOp(UOps.ALU, arg=TernaryOps.WHERE, dtype=dtypes.float, src=(b,x.cast(dtypes.float),y.cast(dtypes.float))).cast(dtypes.bfloat16)),
|
||||
*[(UPat(UOps.ALU, dtype=dtypes.bfloat16, name="x"),
|
||||
lambda x: (UOp(x.op, dtypes.float, tuple(vv.cast(dtypes.float) for vv in x.src), x.arg).cast(dtypes.bfloat16)))]]) + extra_pm
|
||||
|
||||
def render_vector_prefix(self, dtype:DType) -> str:
|
||||
vec, scal = self.render_dtype(dtype), self.render_dtype(dtype.scalar())
|
||||
|
||||
@@ -4,7 +4,7 @@ from tinygrad.helpers import dedup, cpu_time_execution, DEBUG
|
||||
from tinygrad.engine.jit import GraphRunner, GraphException
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.ops import Variable
|
||||
from tinygrad.runtime.ops_clang import ClangProgram
|
||||
from tinygrad.renderer.cstyle import ClangRenderer
|
||||
render_dtype = ClangRenderer().render_dtype
|
||||
|
||||
@@ -4,7 +4,7 @@ import tinygrad.runtime.autogen.cuda as cuda
|
||||
from tinygrad.helpers import init_c_var, dedup
|
||||
from tinygrad.device import Buffer, Device
|
||||
from tinygrad.runtime.ops_cuda import CUDADevice, check, encode_args, cu_time_execution
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.ops import Variable
|
||||
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
||||
from tinygrad.engine.jit import MultiGraphRunner, GraphException
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import List, Any, Dict, cast, Optional, Tuple, Set
|
||||
from tinygrad.helpers import round_up, PROFILE, memsize_to_str
|
||||
from tinygrad.runtime.support.hcq import HCQCompiled, HCQAllocator, HCQSignal, HCQBuffer, HWCommandQueue, HWComputeQueue, HWCopyQueue, HCQArgsState
|
||||
from tinygrad.device import Buffer, BufferOptions, Compiled, Device
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.ops import Variable
|
||||
from tinygrad.engine.realize import ExecItem, BufferXfer, CompiledRunner
|
||||
from tinygrad.engine.jit import MultiGraphRunner
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from tinygrad.helpers import dedup, getenv
|
||||
from tinygrad.device import Buffer
|
||||
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
||||
from tinygrad.engine.jit import GraphRunner, GraphException
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.ops import Variable
|
||||
from tinygrad.runtime.ops_metal import wait_check, msg, libobjc, to_struct, objc_instance,\
|
||||
MTLResourceOptions, elapsed_time, objc_id
|
||||
|
||||
|
||||
@@ -281,8 +281,11 @@ class QCOMProgram(HCQProgram):
|
||||
if hasattr(self, 'lib_gpu'): self.device.allocator.free(self.lib_gpu, self.lib_gpu.size, options=BufferOptions(cpu_access=True, nolru=True))
|
||||
|
||||
class QCOMBuffer(HCQBuffer):
|
||||
def __init__(self, va_addr:int, size:int, desc=None, ibo=None, pitch=None, real_stride=None):
|
||||
self.va_addr, self.size, self.desc, self.ibo, self.pitch, self.real_stride = va_addr, size, desc, ibo, pitch, real_stride
|
||||
def __init__(self, va_addr:int, size:int, info=None, mapped=False, desc=None, ibo=None, pitch=None, real_stride=None, **kwargs):
|
||||
self.va_addr, self.size, self.info, self.mapped = va_addr, size, info, mapped
|
||||
|
||||
# Texture specific definitions
|
||||
self.desc, self.ibo, self.pitch, self.real_stride = [0] * 16, [0] * 16, pitch, real_stride
|
||||
|
||||
class QCOMAllocator(HCQAllocator):
|
||||
def _alloc(self, size:int, options:BufferOptions) -> HCQBuffer:
|
||||
@@ -298,8 +301,7 @@ class QCOMAllocator(HCQAllocator):
|
||||
if options.external_ptr: texture = QCOMBuffer(options.external_ptr, size)
|
||||
else: texture = self.device._gpu_alloc(pitch * imgh, kgsl.KGSL_MEMTYPE_TEXTURE, map_to_cpu=True)
|
||||
|
||||
# Extend HCQBuffer with texture-related info.
|
||||
texture.pitch, texture.real_stride, texture.desc, texture.ibo = pitch, real_stride, [0] * 16, [0] * 16
|
||||
texture.pitch, texture.real_stride = pitch, real_stride
|
||||
|
||||
tex_fmt = adreno.FMT6_32_32_32_32_FLOAT if options.image.itemsize == 4 else adreno.FMT6_16_16_16_16_FLOAT
|
||||
texture.desc[0] = qreg.a6xx_tex_const_0(swiz_x=0, swiz_y=1, swiz_z=2, swiz_w=3, fmt=tex_fmt)
|
||||
@@ -318,12 +320,12 @@ class QCOMAllocator(HCQAllocator):
|
||||
src_off, dest_off = src_off+src_stride, dest_off+dest_stride
|
||||
|
||||
def copyin(self, dest:HCQBuffer, src:memoryview):
|
||||
if hasattr(qd:=cast(QCOMBuffer, dest), 'pitch'): self._do_copy(mv_address(src), qd.va_addr, len(src), qd.real_stride, qd.real_stride, qd.pitch)
|
||||
if (qd:=cast(QCOMBuffer, dest)).pitch is not None: self._do_copy(mv_address(src), qd.va_addr, len(src), qd.real_stride, qd.real_stride, qd.pitch)
|
||||
else: ctypes.memmove(dest.va_addr, mv_address(src), src.nbytes)
|
||||
|
||||
def copyout(self, dest:memoryview, src:HCQBuffer):
|
||||
self.device.synchronize()
|
||||
if hasattr(qs:=cast(QCOMBuffer, src), 'pitch'): self._do_copy(qs.va_addr, mv_address(dest), qs.size, qs.real_stride, qs.pitch, qs.real_stride)
|
||||
if (qs:=cast(QCOMBuffer, src)).pitch is not None: self._do_copy(qs.va_addr, mv_address(dest), qs.size, qs.real_stride, qs.pitch, qs.real_stride)
|
||||
else: ctypes.memmove(from_mv(dest), src.va_addr, dest.nbytes)
|
||||
|
||||
def as_buffer(self, src:HCQBuffer) -> memoryview:
|
||||
@@ -379,7 +381,7 @@ class QCOMDevice(HCQCompiled):
|
||||
va_addr = libc.mmap(va_addr, va_len, mmap.PROT_READ|mmap.PROT_WRITE, mmap.MAP_SHARED|MAP_FIXED, self.fd, alloc.id * 0x1000)
|
||||
if fill_zeroes: ctypes.memset(va_addr, 0, va_len)
|
||||
|
||||
return SimpleNamespace(va_addr=va_addr, size=size, mapped=map_to_cpu, info=alloc)
|
||||
return QCOMBuffer(va_addr=va_addr, size=size, mapped=map_to_cpu, info=alloc)
|
||||
|
||||
def _gpu_free(self, mem):
|
||||
kgsl.IOCTL_KGSL_GPUOBJ_FREE(self.fd, id=mem.info.id)
|
||||
|
||||
@@ -3,21 +3,9 @@ from __future__ import annotations
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Optional, Dict, Set
|
||||
from tinygrad.helpers import merge_dicts, getenv
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
from tinygrad.shape.view import View, strides_for_shape
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import UOp, UOps, BinaryOps, graph_rewrite, resolve, _get_chain, symbolic_flat
|
||||
|
||||
def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.pyint, x) if isinstance(x, int) else x
|
||||
def _uop_view(view:View, idxs:List[UOp], vexpr:UOp) -> Tuple[UOp, UOp]:
|
||||
# TODO: dtypes.realint
|
||||
iexpr = variable_to_uop(view.offset)
|
||||
for idx,sh,st,m in zip(idxs, view.shape, view.strides, view.mask if view.mask is not None else [None]*len(view.shape)):
|
||||
if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*variable_to_uop(st)
|
||||
if m is not None:
|
||||
if resolve(m[0] != 0): vexpr = vexpr * idx.ge(variable_to_uop(m[0]))
|
||||
if resolve(m[1] != sh): vexpr = vexpr * idx.lt(variable_to_uop(m[1]))
|
||||
return iexpr, vexpr
|
||||
from tinygrad.ops import UOp, UOps, BinaryOps, graph_rewrite, _get_chain, symbolic_flat, Variable, sint
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ShapeTracker:
|
||||
@@ -55,17 +43,14 @@ class ShapeTracker:
|
||||
def to_uop(self) -> UOp: return UOp(UOps.VIEW, dtypes.void, (), self)
|
||||
|
||||
def to_indexed_uops(self, _idxs:Optional[List[UOp]]=None) -> Tuple[UOp, UOp]:
|
||||
idxs = [UOp(UOps.RANGE, dtypes.pyint, (UOp.const(dtypes.pyint, 0), variable_to_uop(s)), i) for i,s in enumerate(self.shape)] \
|
||||
if _idxs is None else _idxs
|
||||
idx, valid = _uop_view(self.views[-1], idxs, UOp.const(dtypes.bool, True))
|
||||
idx, valid = self.views[-1].to_indexed_uops(_idxs)
|
||||
for view in reversed(self.views[0:-1]):
|
||||
view = view.minify()
|
||||
acc, idxs = 1, []
|
||||
for _d in reversed(view.shape):
|
||||
d = variable_to_uop(_d)
|
||||
for d in reversed(view.shape):
|
||||
idxs.append((idx//acc)%d)
|
||||
acc *= d
|
||||
idx, valid = _uop_view(view, idxs[::-1], valid)
|
||||
idx, valid = view.to_indexed_uops(idxs[::-1], valid)
|
||||
return idx, valid
|
||||
|
||||
def real_size(self) -> int:
|
||||
|
||||
@@ -1,30 +0,0 @@
|
||||
from __future__ import annotations
|
||||
from typing import Union, Optional, Dict, cast
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import UOp, UOps, exec_alu, ConstType
|
||||
|
||||
sint = Union[int, UOp]
|
||||
|
||||
def NumNode(val:int): return UOp.const(dtypes.int, val)
|
||||
class Variable(UOp):
|
||||
def __reduce__(self): return Variable, self.arg
|
||||
def __new__(cls, expr:str, nmin:ConstType, nmax:ConstType): # pylint: disable=signature-differs
|
||||
return super().__new__(cls, UOps.DEFINE_VAR, dtypes.int, arg=(expr, nmin, nmax))
|
||||
def __init__(self, expr:str, nmin:ConstType, nmax:ConstType):
|
||||
super().__init__(UOps.DEFINE_VAR, dtypes.int, arg=(expr, nmin, nmax))
|
||||
def bind(self, val:int):
|
||||
assert self.op is UOps.DEFINE_VAR, f"op is {self.op}"
|
||||
assert self.arg[1] <= val and val <= self.arg[2], f"bind {val} not in range {self.arg[1]}-{self.arg[2]}"
|
||||
return UOp(UOps.BIND, self.dtype, (self, self.const_like(val)))
|
||||
@property
|
||||
def expr(self): return self.arg[0]
|
||||
|
||||
def sym_infer(uop: Union[UOp, int], var_vals: Optional[Dict[Variable, int]]) -> int:
|
||||
if isinstance(uop, (int, float)): return uop # TODO: ugh, the float is a hack for qcom
|
||||
if uop.op == UOps.CONST: return uop.arg
|
||||
if uop.op == UOps.DEFINE_VAR and var_vals is not None: return var_vals[cast(Variable, uop)]
|
||||
if uop.op == UOps.BIND: return uop.src[1].arg # bound variable returns bound value
|
||||
if uop.op == UOps.ALU:
|
||||
src_values = [sym_infer(src, var_vals) for src in uop.src]
|
||||
return exec_alu(uop.arg, uop.dtype, src_values)
|
||||
raise NotImplementedError(f"Unsupported UOp {uop.op}")
|
||||
@@ -2,9 +2,9 @@ from __future__ import annotations
|
||||
import functools, operator, itertools, math
|
||||
from dataclasses import dataclass
|
||||
from typing import Tuple, List, Optional, Dict, Set, cast, Union
|
||||
from tinygrad.ops import resolve, UOp
|
||||
from tinygrad.dtype import dtypes
|
||||
from tinygrad.ops import resolve, UOp, NumNode, Variable, sint, sym_infer
|
||||
from tinygrad.helpers import prod, all_int, argsort
|
||||
from tinygrad.shape.symbolic import NumNode, Variable, sint, sym_infer
|
||||
|
||||
@functools.lru_cache(maxsize=None)
|
||||
def canonicalize_strides(shape:Tuple[sint, ...], strides:Tuple[sint, ...]) -> Tuple[sint, ...]:
|
||||
@@ -82,6 +82,8 @@ def un1d(shape:Tuple[sint, ...], offs:sint) -> List[sint]:
|
||||
offs -= here * stride
|
||||
return result
|
||||
|
||||
def variable_to_uop(x, ctx=None) -> UOp: return UOp.const(dtypes.int, x) if isinstance(x, int) else x
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class View:
|
||||
shape:Tuple[sint, ...]
|
||||
@@ -90,6 +92,16 @@ class View:
|
||||
mask:Optional[Tuple[Tuple[sint, sint], ...]]
|
||||
contiguous:bool
|
||||
|
||||
def to_indexed_uops(self:View, _idxs:Optional[List[UOp]]=None, vexpr:UOp=UOp.const(dtypes.bool, True)) -> Tuple[UOp, UOp]:
|
||||
idxs = [UOp.range(dtypes.int, 0, s, i) for i,s in enumerate(self.shape)] if _idxs is None else _idxs
|
||||
iexpr = variable_to_uop(self.offset)
|
||||
for idx,sh,st,m in zip(idxs, self.shape, self.strides, self.mask if self.mask is not None else [None]*len(self.shape)):
|
||||
if resolve(sh != 1) and resolve(st != 0): iexpr = iexpr + idx*st
|
||||
if m is not None:
|
||||
if resolve(m[0] != 0): vexpr = vexpr * idx.ge(m[0])
|
||||
if resolve(m[1] != sh): vexpr = vexpr * idx.lt(m[1])
|
||||
return iexpr, vexpr
|
||||
|
||||
@functools.lru_cache(maxsize=None) # pylint: disable=method-cache-max-size-none
|
||||
def size(self) -> int:
|
||||
# NOTE: Variable and the Node derived from it in symbolic shapes can only have int as max.
|
||||
@@ -164,7 +176,7 @@ class View:
|
||||
|
||||
# Merge dimensions in vm2 if required.
|
||||
# NB: Merging too many dimensions can make it difficult to project vm2's mask, hence only combining when required.
|
||||
idxs: List[UOp] = [Variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
|
||||
idxs: List[UOp] = [UOp.variable(f"idx{i}", 0, s-1) for i,s in enumerate(vm1.shape)]
|
||||
merged_size, merged_term = 1, NumNode(0)
|
||||
extents: List[Tuple[sint, UOp]] = []
|
||||
for term, s, o in zip(reversed(terms), reversed(vm2.shape), reversed(origin)):
|
||||
|
||||
@@ -9,9 +9,8 @@ from tinygrad.dtype import DType, DTypeLike, dtypes, ImageDType, ConstType, leas
|
||||
from tinygrad.helpers import argfix, make_pair, flatten, prod, all_int, round_up, merge_dicts, argsort, getenv, all_same, fully_flatten, dedup
|
||||
from tinygrad.helpers import IMAGE, DEBUG, WINO, _METADATA, Metadata, TRACEMETA, ceildiv
|
||||
from tinygrad.multi import MultiLazyBuffer
|
||||
from tinygrad.ops import MetaOps, truncate, smax, resolve, UOp, UOps, BinaryOps
|
||||
from tinygrad.ops import MetaOps, truncate, smax, resolve, UOp, UOps, BinaryOps, sint, Variable
|
||||
from tinygrad.device import Device, Buffer, BufferOptions
|
||||
from tinygrad.shape.symbolic import sint, Variable
|
||||
from tinygrad.engine.lazy import LazyBuffer
|
||||
from tinygrad.engine.realize import run_schedule, memory_planner
|
||||
from tinygrad.engine.schedule import ScheduleItem, create_schedule_with_vars
|
||||
@@ -1682,7 +1681,7 @@ class Tensor:
|
||||
|
||||
def _softmax(self, axis, dtype:Optional[DTypeLike]=None):
|
||||
x = self.cast(dtype) if dtype is not None else self
|
||||
m = x - x.max(axis=axis, keepdim=True)
|
||||
m = x - x.max(axis=axis, keepdim=True).detach()
|
||||
e = m.exp()
|
||||
return m, e, e.sum(axis=axis, keepdim=True)
|
||||
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" viewBox="-10 -10 150 70" shape-rendering="crispEdges">
|
||||
<g id="logo">
|
||||
<!-- t -->
|
||||
<polygon points="10,40 10,20 0,20 0,10 10,10 10,0 20,0 20,10 30,10 30,20 20,20 20,30 30,30 30,40" />
|
||||
<!-- i -->
|
||||
<polygon points="40,40 40,20 50,20 50,40" />
|
||||
<polygon points="40,10 40,0 50,0 50,10" />
|
||||
<!-- n -->
|
||||
<polygon points="60,40 60,10 80,10 80,40 90,40 90,20 70,20 70,40" />
|
||||
<!-- y -->
|
||||
<polygon points="100,50 100,40 130,40 130,10 120,10 120,20 110,20 110,10 100,10 100,30 120,30 120,50" />
|
||||
</g>
|
||||
<style>
|
||||
@media (prefers-color-scheme: dark) {
|
||||
#logo {
|
||||
fill: #fff;
|
||||
}
|
||||
}
|
||||
@media (prefers-color-scheme: light) {
|
||||
#logo {
|
||||
fill: #000;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
</svg>
|
||||
|
Before Width: | Height: | Size: 750 B |
@@ -3,7 +3,7 @@
|
||||
<head>
|
||||
<title>tinygrad viz</title>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<link rel="icon" href="favicon.svg" type="image/svg+xml">
|
||||
<link rel="icon" href="data:;base64,iVBORw0KGgo=">
|
||||
<link rel="preconnect" href="https://fonts.googleapis.com">
|
||||
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
||||
<link href="https://fonts.googleapis.com/css2?family=Noto+Sans+HK:wght@100..900&display=swap" rel="stylesheet">
|
||||
|
||||
@@ -89,13 +89,7 @@ def get_details(k:Any, ctx:TrackedRewriteContext, metadata:GraphRewriteMetadata)
|
||||
|
||||
class Handler(BaseHTTPRequestHandler):
|
||||
def do_GET(self):
|
||||
if (url:=urlparse(self.path)).path == "/favicon.svg":
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "image/svg+xml")
|
||||
self.end_headers()
|
||||
with open(os.path.join(os.path.dirname(__file__), "favicon.svg"), "rb") as f:
|
||||
ret = f.read()
|
||||
if url.path == "/":
|
||||
if (url:=urlparse(self.path)).path == "/":
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "text/html")
|
||||
self.end_headers()
|
||||
|
||||
Reference in New Issue
Block a user