mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-02-18 18:35:12 -05:00
minor == to is touchups
This commit is contained in:
@@ -103,11 +103,11 @@ def merge_views(vm2:View, vm1:View) -> Optional[View]:
|
||||
this_dim = View(vm2.shape, vm2.strides).expr_node(Variable('idx', 0, s-1)*st)
|
||||
if s == 1:
|
||||
new_strides.append(0) # all shape 1 can have stride 0
|
||||
elif this_dim.__class__ == NumNode and this_dim.b == 0:
|
||||
elif this_dim.__class__ is NumNode and this_dim.b == 0:
|
||||
new_strides.append(0)
|
||||
elif this_dim.__class__ == Variable:
|
||||
elif this_dim.__class__ is Variable:
|
||||
new_strides.append(1)
|
||||
elif this_dim.__class__ == MulNode and cast(MulNode, this_dim).a.__class__ == Variable:
|
||||
elif this_dim.__class__ is MulNode and cast(MulNode, this_dim).a.__class__ is Variable:
|
||||
new_strides.append(this_dim.b)
|
||||
else:
|
||||
if DEBUG >= 4: print("can't simplify", s, this_dim.render())
|
||||
@@ -152,7 +152,7 @@ def get_unsafe_resize_offset(strides, arg):
|
||||
class ShapeTracker:
|
||||
__slots__ = "views"
|
||||
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], views:Optional[List[View]]=None):
|
||||
self.views: List[View] = views if views is not None else ([*cast(ShapeTracker, shape).views] if shape.__class__ == ShapeTracker else [view_from_shape(shape)])
|
||||
self.views: List[View] = views if views is not None else ([*cast(ShapeTracker, shape).views] if shape.__class__ is ShapeTracker else [view_from_shape(shape)])
|
||||
def __repr__(self): return f"ShapeTracker(shape={self.views[-1].shape}, views={self.views})"
|
||||
def copy(self) -> ShapeTracker: return ShapeTracker(self.views[-1].shape, [*self.views])
|
||||
|
||||
@@ -199,7 +199,7 @@ class ShapeTracker:
|
||||
return self._expr_idx(idx, valid)
|
||||
|
||||
def expr_node(self, idx='idx'):
|
||||
if idx.__class__ == str: idx = Variable(idx, 0, prod(self.shape)-1)
|
||||
if idx.__class__ is str: idx = Variable(idx, 0, prod(self.shape)-1)
|
||||
return self._expr_idx(self.views[-1].expr_node(idx), self.views[-1].expr_node_mask(idx))
|
||||
|
||||
def needs_valid(self) -> bool:
|
||||
@@ -221,7 +221,6 @@ class ShapeTracker:
|
||||
if any([b or e for b, e in arg]):
|
||||
zvarg, mask = get_pad_args(self.shape, arg)
|
||||
self.__unsafe_resize(zvarg, mask=mask)
|
||||
return self
|
||||
return self
|
||||
|
||||
def shrink(self, arg: Tuple[Tuple[int, int], ...]):
|
||||
|
||||
Reference in New Issue
Block a user