mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
const cleanup with dtype.Scalar (#3257)
moved Scalar to dtype.py. assert in _broadcasted when y is a Scalar and fix some tests
This commit is contained in:
@@ -250,8 +250,8 @@ class TestJit(unittest.TestCase):
|
||||
# save [2] in the caches
|
||||
cache.good(zero, two)
|
||||
cache.bad(zero, two)
|
||||
np.testing.assert_equal([2], cache.good_cache)
|
||||
np.testing.assert_equal([2], cache.bad_cache)
|
||||
np.testing.assert_equal([2], cache.good_cache.numpy())
|
||||
np.testing.assert_equal([2], cache.bad_cache.numpy())
|
||||
|
||||
# verify the jitted calls read 2 from the cache
|
||||
np.testing.assert_equal([2], cache.good_jitted(zero).numpy())
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from typing import Final, Optional, ClassVar, Set, Tuple, Dict
|
||||
from typing import Final, Optional, ClassVar, Set, Tuple, Dict, Union
|
||||
from dataclasses import dataclass
|
||||
import numpy as np # TODO: remove numpy
|
||||
import functools
|
||||
|
||||
Scalar = Union[float, int, bool]
|
||||
|
||||
@dataclass(frozen=True, order=True)
|
||||
class DType:
|
||||
priority: int # this determines when things get upcasted
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
from typing import Optional, Union, Any, Tuple, List
|
||||
import functools
|
||||
from tinygrad.helpers import all_same, dedup, round_up, DEBUG
|
||||
from tinygrad.dtype import DType
|
||||
from tinygrad.dtype import DType, Scalar
|
||||
from tinygrad.ops import BinaryOps, LoadOps, UnaryOps, TernaryOps, ReduceOps
|
||||
from tinygrad.lazy import LazyBuffer, create_schedule
|
||||
from tinygrad.shape.shapetracker import ShapeTracker, sint
|
||||
@@ -50,7 +50,7 @@ class MultiLazyBuffer:
|
||||
# passthroughs
|
||||
def schedule(self, seen=None): return create_schedule(self.lbs, seen)
|
||||
def cast(self, dtype:DType, bitcast:bool=False): return MultiLazyBuffer([x.cast(dtype, bitcast) for x in self.lbs], self.axis)
|
||||
def const(self, val:Union[float, int]) -> MultiLazyBuffer: return MultiLazyBuffer([x.const(val) for x in self.lbs], self.axis)
|
||||
def const(self, val:Scalar) -> MultiLazyBuffer: return MultiLazyBuffer([x.const(val) for x in self.lbs], self.axis)
|
||||
def contiguous(self): return MultiLazyBuffer([x.contiguous() for x in self.lbs], self.axis)
|
||||
|
||||
# elementwise is simple
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
import sys, math
|
||||
from collections import defaultdict
|
||||
from typing import Union, Optional, Any, Tuple, List, Set, Dict, DefaultDict, cast
|
||||
from tinygrad.dtype import dtypes, DType, ImageDType
|
||||
from tinygrad.dtype import dtypes, DType, ImageDType, Scalar
|
||||
from tinygrad.helpers import prod, flatten, getenv, dedup, DEBUG, all_int, all_same, GRAPH
|
||||
from tinygrad.ops import LoadOps, UnaryOps, BinaryOps, TernaryOps, ReduceOps, BufferOps, Op, LazyOp, ConstBuffer, MemBuffer, ScheduleItem
|
||||
from tinygrad.shape.symbolic import sint, Variable
|
||||
@@ -56,7 +56,7 @@ class LazyBuffer:
|
||||
def loadop(op, shape:Tuple[sint,...], dtype:DType, device:str, arg=None, src:Optional[LazyBuffer]=None, enable_cache=False) -> LazyBuffer:
|
||||
return create_lazybuffer(device, ShapeTracker.from_shape(shape), dtype, op, arg, (src,) if src is not None else (), enable_cache=enable_cache)
|
||||
|
||||
def const(self, val:Union[float, int], shape:Optional[Tuple[sint,...]]=None) -> LazyBuffer:
|
||||
def const(self, val:Scalar, shape:Optional[Tuple[sint,...]]=None) -> LazyBuffer:
|
||||
shape = self.shape if shape is None else shape
|
||||
return LazyBuffer.loadop(LoadOps.CONST, tuple(), self.dtype, self.device, arg=val).reshape((1,)*len(shape)).expand(shape)
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ from collections import defaultdict
|
||||
from functools import partialmethod, reduce
|
||||
import numpy as np
|
||||
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType, least_upper_float, least_upper_dtype
|
||||
from tinygrad.dtype import DType, dtypes, ImageDType, Scalar, least_upper_float, least_upper_dtype
|
||||
from tinygrad.helpers import argfix, make_pair, getenv, IMAGE, DEBUG, WINO, flatten, prod, all_int, round_up, merge_dicts, fully_flatten
|
||||
from tinygrad.lazy import LazyBuffer, create_schedule
|
||||
from tinygrad.features.multi import MultiLazyBuffer
|
||||
@@ -46,8 +46,6 @@ def _fromcpu(x: np.ndarray) -> LazyBuffer:
|
||||
ret.realized = Buffer("CPU", prod(x.shape), dtypes.from_np(x.dtype), x.flatten())
|
||||
return ret
|
||||
|
||||
Scalar = Union[float, int, bool]
|
||||
|
||||
class Tensor:
|
||||
__slots__ = "lazydata", "requires_grad", "grad", "_ctx"
|
||||
__deletable__ = ('_ctx',)
|
||||
@@ -779,7 +777,7 @@ class Tensor:
|
||||
x: Tensor = self
|
||||
if not isinstance(y, Tensor):
|
||||
# make y a Tensor
|
||||
if 0 in self.shape: return self, self.full_like(y)
|
||||
assert isinstance(y, (float, int, bool)), f"{type(y)=}, {y=}"
|
||||
if isinstance(self.dtype, ImageDType) or dtypes.is_float(x.dtype) or (dtypes.is_int(x.dtype) and isinstance(y, int)): y_dtype = x.dtype
|
||||
else: y_dtype = dtypes.from_py(y)
|
||||
y = Tensor(y, self.device, y_dtype, requires_grad=False)
|
||||
|
||||
Reference in New Issue
Block a user