minor == to is touchups

This commit is contained in:
George Hotz
2023-06-15 17:11:12 -07:00
parent 041d96083c
commit dca084f227

View File

@@ -103,11 +103,11 @@ def merge_views(vm2:View, vm1:View) -> Optional[View]:
this_dim = View(vm2.shape, vm2.strides).expr_node(Variable('idx', 0, s-1)*st) this_dim = View(vm2.shape, vm2.strides).expr_node(Variable('idx', 0, s-1)*st)
if s == 1: if s == 1:
new_strides.append(0) # all shape 1 can have stride 0 new_strides.append(0) # all shape 1 can have stride 0
elif this_dim.__class__ == NumNode and this_dim.b == 0: elif this_dim.__class__ is NumNode and this_dim.b == 0:
new_strides.append(0) new_strides.append(0)
elif this_dim.__class__ == Variable: elif this_dim.__class__ is Variable:
new_strides.append(1) new_strides.append(1)
elif this_dim.__class__ == MulNode and cast(MulNode, this_dim).a.__class__ == Variable: elif this_dim.__class__ is MulNode and cast(MulNode, this_dim).a.__class__ is Variable:
new_strides.append(this_dim.b) new_strides.append(this_dim.b)
else: else:
if DEBUG >= 4: print("can't simplify", s, this_dim.render()) if DEBUG >= 4: print("can't simplify", s, this_dim.render())
@@ -152,7 +152,7 @@ def get_unsafe_resize_offset(strides, arg):
class ShapeTracker: class ShapeTracker:
__slots__ = "views" __slots__ = "views"
def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], views:Optional[List[View]]=None): def __init__(self, shape:Union[ShapeTracker, Tuple[int, ...]], views:Optional[List[View]]=None):
self.views: List[View] = views if views is not None else ([*cast(ShapeTracker, shape).views] if shape.__class__ == ShapeTracker else [view_from_shape(shape)]) self.views: List[View] = views if views is not None else ([*cast(ShapeTracker, shape).views] if shape.__class__ is ShapeTracker else [view_from_shape(shape)])
def __repr__(self): return f"ShapeTracker(shape={self.views[-1].shape}, views={self.views})" def __repr__(self): return f"ShapeTracker(shape={self.views[-1].shape}, views={self.views})"
def copy(self) -> ShapeTracker: return ShapeTracker(self.views[-1].shape, [*self.views]) def copy(self) -> ShapeTracker: return ShapeTracker(self.views[-1].shape, [*self.views])
@@ -199,7 +199,7 @@ class ShapeTracker:
return self._expr_idx(idx, valid) return self._expr_idx(idx, valid)
def expr_node(self, idx='idx'): def expr_node(self, idx='idx'):
if idx.__class__ == str: idx = Variable(idx, 0, prod(self.shape)-1) if idx.__class__ is str: idx = Variable(idx, 0, prod(self.shape)-1)
return self._expr_idx(self.views[-1].expr_node(idx), self.views[-1].expr_node_mask(idx)) return self._expr_idx(self.views[-1].expr_node(idx), self.views[-1].expr_node_mask(idx))
def needs_valid(self) -> bool: def needs_valid(self) -> bool:
@@ -221,7 +221,6 @@ class ShapeTracker:
if any([b or e for b, e in arg]): if any([b or e for b, e in arg]):
zvarg, mask = get_pad_args(self.shape, arg) zvarg, mask = get_pad_args(self.shape, arg)
self.__unsafe_resize(zvarg, mask=mask) self.__unsafe_resize(zvarg, mask=mask)
return self
return self return self
def shrink(self, arg: Tuple[Tuple[int, int], ...]): def shrink(self, arg: Tuple[Tuple[int, int], ...]):