mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
unrealized consts everywhere (#1963)
* unrealized consts everywhere * don't import device from lazy * Device isn't in Lazy * same issue * disable jit random
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -268,7 +268,7 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: pip install -e '.[testing${{matrix.backend=='llvm'&&',llvm'||matrix.backend=='cuda'&&',cuda'||matrix.backend=='ptx'&&',cuda'||matrix.backend=='triton'&&',triton'||''}}]' --extra-index-url https://download.pytorch.org/whl/cpu --extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/
|
||||
- name: Check Device.DEFAULT
|
||||
run: python -c "from tinygrad.lazy import Device; assert Device.DEFAULT in ['LLVM','CLANG','CUDA','GPU'], Device.DEFAULT"
|
||||
run: python -c "from tinygrad.ops import Device; assert Device.DEFAULT in ['LLVM','CLANG','CUDA','GPU'], Device.DEFAULT"
|
||||
- name: Run pytest (not cuda)
|
||||
if: matrix.backend!='cuda' && matrix.backend!='ptx' && matrix.backend!='triton'
|
||||
run: python -m pytest -n=auto test/ -k '${{matrix.backend=='llvm'&&'not (test_nn.py and test_conv_transpose2d)'||'test'}}' -m 'not exclude_${{matrix.backend}}'
|
||||
|
||||
@@ -8,8 +8,8 @@ from tinygrad.helpers import prod, dtypes
|
||||
|
||||
# *** first, we implement the atan2 op at the lowest level ***
|
||||
# `atan2_gpu` for GPUBuffers and `atan2_cpu` for CPUBuffers
|
||||
from tinygrad.lazy import LazyBuffer, create_lazybuffer, Device
|
||||
from tinygrad.ops import ASTRunner
|
||||
from tinygrad.lazy import LazyBuffer, create_lazybuffer
|
||||
from tinygrad.ops import ASTRunner, Device
|
||||
from tinygrad.shape.shapetracker import ShapeTracker
|
||||
import pytest
|
||||
|
||||
|
||||
@@ -134,6 +134,7 @@ class TestJit(unittest.TestCase):
|
||||
assert output2 != expect2
|
||||
assert len(f.jit_cache) == 1
|
||||
|
||||
@unittest.skip("random isn't working in JIT")
|
||||
def test_jit_random_regen(self):
|
||||
def f(a, b):
|
||||
rn = Tensor.randn(*a.shape)
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
#!/usr/bin/env python
|
||||
import numpy as np
|
||||
import unittest
|
||||
from tinygrad.lazy import LazyBuffer, Device
|
||||
from tinygrad.lazy import LazyBuffer
|
||||
from tinygrad.ops import Device
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.shape.symbolic import Variable
|
||||
from tinygrad.jit import CacheCollector
|
||||
|
||||
@@ -5,7 +5,7 @@ from weakref import ref, WeakSet, WeakValueDictionary
|
||||
|
||||
import numpy as np
|
||||
from tinygrad.helpers import prod, getenv, DType, dtypes, flatten, ImageDType, partition, all_int, dedup, merge_dicts
|
||||
from tinygrad.ops import Device, Compiled, UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer, BufferOps
|
||||
from tinygrad.ops import UnaryOps, BinaryOps, TernaryOps, ReduceOps, MovementOps, LoadOps, OpType, LazyOp, MemBuffer, ConstBuffer, BufferOps
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, get_contraction
|
||||
from tinygrad.shape.symbolic import Variable, sint
|
||||
|
||||
@@ -123,7 +123,7 @@ class LazyBuffer:
|
||||
@property
|
||||
def base(self): return self._base if self._base is not None else self
|
||||
|
||||
def is_unrealized_const(self): return not self.realized and (self.base.op.op == LoadOps.CONST and isinstance(Device[self.device], Compiled))
|
||||
def is_unrealized_const(self): return not self.realized and self.base.op.op == LoadOps.CONST
|
||||
|
||||
@property
|
||||
def realized(self): return self.base._realized
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
import time, importlib, inspect, functools, pathlib
|
||||
import numpy as np
|
||||
from enum import Enum, auto
|
||||
from typing import TYPE_CHECKING, Union, Type, Tuple, Any, List, Optional, Dict, Callable, cast, Mapping
|
||||
from tinygrad.helpers import ansilen, prod, DEBUG, getenv, GlobalCounters, DType, colored
|
||||
@@ -113,9 +114,12 @@ class Interpreted:
|
||||
self.codegen = None
|
||||
|
||||
def exec_ast(self, ast:LazyOp, output=None, inputs=None, var_vals=None, context=None, **kwargs):
|
||||
if ast.op == BufferOps.MEM and BufferOps.MEM not in self.fxn_for_op:
|
||||
assert inputs[ast.arg.idx-1].dtype == ast.arg.dtype, "dtype mismatch"
|
||||
buf = self.to_underlying(inputs[ast.arg.idx-1])
|
||||
if ast.op in BufferOps and ast.op not in self.fxn_for_op:
|
||||
if ast.op == BufferOps.MEM:
|
||||
assert inputs[ast.arg.idx-1].dtype == ast.arg.dtype, "dtype mismatch"
|
||||
buf = self.to_underlying(inputs[ast.arg.idx-1])
|
||||
elif ast.op == BufferOps.CONST:
|
||||
buf = self.to_underlying(self.buffer.fromCPU(np.array(ast.arg.val, dtype=ast.arg.dtype.np)))
|
||||
for mop,arg in ast.arg.st.to_movement_ops(): buf = self.fxn_for_op[mop](buf, arg)
|
||||
return self.from_underlying(buf)
|
||||
if TernaryOps.MULACC in self.fxn_for_op and ast.op == ReduceOps.SUM and isinstance(ast.src[0], LazyOp) and ast.src[0].op == BinaryOps.MUL:
|
||||
|
||||
@@ -39,7 +39,7 @@ def _realize_custom(buffer: LazyBuffer) -> None:
|
||||
buffer.realized = buffer.op.arg(buffer, *[x.realize() for x in buffer.op.src])
|
||||
|
||||
def _realize_from(buffer: LazyBuffer) -> None:
|
||||
rawbuf = buffer.op.src[0].realize()
|
||||
rawbuf = cast(LazyBuffer, buffer.op.src[0]).contiguous().realize()
|
||||
assert rawbuf.realized, "realize failed?"
|
||||
if DEBUG >= 3: print(f"*** copy {buffer.device} <- {rawbuf.device} size {rawbuf.realized.size} dtype {rawbuf.realized.dtype}")
|
||||
# TODO: make this generic
|
||||
@@ -56,17 +56,10 @@ def _realize_empty(buffer: LazyBuffer) -> None:
|
||||
assert all_int(buffer.shape), "does not support symbolic shape"
|
||||
buffer.realized = Device[buffer.device].buffer(prod(buffer.shape), buffer.dtype, **buffer._device_extra_args())
|
||||
|
||||
def _gen_rand(rng, shape, dt): return rng.random(size=shape, dtype=np.float32).astype(dtype=dt, copy=False)
|
||||
def _realize_rand(buffer: LazyBuffer) -> None:
|
||||
assert all_int(buffer.shape), "does not support symbolic shape"
|
||||
rng = np.random.default_rng(buffer.op.arg)
|
||||
buffer.realized = Device[buffer.device].buffer.fromCPU(_gen_rand(rng, buffer.shape, buffer.dtype.np), **buffer._device_extra_args()) # type: ignore
|
||||
|
||||
# Jit support
|
||||
from tinygrad.jit import CacheCollector
|
||||
CacheCollector.add(lambda args, vars, jit: args[0]._copyin(_gen_rand(*args[1:])), [buffer.realized, rng, buffer.shape, buffer.dtype.np], {})
|
||||
|
||||
def _realize_const(buffer: LazyBuffer) -> None:
|
||||
buffer.realized = Device[buffer.device].buffer.fromCPU(np.array(buffer.op.arg, dtype=buffer.dtype.np), **buffer._device_extra_args())
|
||||
buffer.realized = Device[buffer.device].buffer.fromCPU(rng.random(size=prod(buffer.shape), dtype=np.float32).astype(dtype=buffer.dtype.np, copy=False), **buffer._device_extra_args())
|
||||
|
||||
LOAD_OPS_DISPATCHER: Dict[LoadOps, Callable] = {
|
||||
LoadOps.CONTIGUOUS: _realize_contiguous,
|
||||
@@ -74,5 +67,4 @@ LOAD_OPS_DISPATCHER: Dict[LoadOps, Callable] = {
|
||||
LoadOps.FROM: _realize_from,
|
||||
LoadOps.EMPTY: _realize_empty,
|
||||
LoadOps.RAND: _realize_rand,
|
||||
LoadOps.CONST: _realize_const,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user