mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-23 05:48:08 -05:00
remove duplicated UOp in Tensor init types [pr] (#8177)
and a small comment
This commit is contained in:
@@ -12,8 +12,7 @@ from tinygrad.engine.schedule import ScheduleItem
|
||||
|
||||
logkerns, logkerns_level = open(getenv("LOGKERNS", ""), "a") if getenv("LOGKERNS", "") else None, getenv("LOGKERNS_LEVEL", 1)
|
||||
def get_kernel(renderer:Renderer, ast:UOp) -> Kernel:
|
||||
if DEBUG >= 5:
|
||||
print(ast)
|
||||
if DEBUG >= 5: print(ast)
|
||||
k = Kernel(ast, opts=renderer).required_optimizations()
|
||||
if not NOOPT:
|
||||
if not (used_tensor_cores:=k.apply_tensor_cores(getenv("TC", 1))): k.hand_coded_optimizations()
|
||||
|
||||
@@ -122,7 +122,7 @@ class Tensor(SimpleMathTrait):
|
||||
training: ClassVar[bool] = False
|
||||
no_grad: ClassVar[bool] = False
|
||||
|
||||
def __init__(self, data:Union[None, ConstType, UOp, bytes, List, Tuple, UOp, MultiLazyBuffer, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821
|
||||
def __init__(self, data:Union[None, ConstType, bytes, List, Tuple, UOp, MultiLazyBuffer, 'np.ndarray', pathlib.Path], # type: ignore [name-defined] # noqa: F821
|
||||
device:Optional[Union[str, tuple, list]]=None, dtype:Optional[DTypeLike]=None, requires_grad:Optional[bool]=None):
|
||||
if dtype is not None: dtype = to_dtype(dtype)
|
||||
assert dtype is None or isinstance(dtype, DType), f"invalid dtype {dtype}"
|
||||
@@ -150,7 +150,7 @@ class Tensor(SimpleMathTrait):
|
||||
elif isinstance(data, (list, tuple)):
|
||||
if dtype is None:
|
||||
if (d := fully_flatten(data)) and all(isinstance(s, bool) for s in d): dtype = dtypes.bool
|
||||
else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float
|
||||
else: dtype = dtypes.default_int if d and all_int(d) else dtypes.default_float # NOTE: this works because all_int([True, False]) is True
|
||||
if dtype == dtypes.bfloat16: data = Tensor(_frompy(data, dtypes.float32), device=device).cast(dtypes.bfloat16).lazydata
|
||||
else: data = _frompy(data, dtype)
|
||||
elif str(type(data)) == "<class 'numpy.ndarray'>":
|
||||
|
||||
Reference in New Issue
Block a user