mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
fixup to_movement_ops and add back to CI (#3881)
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -99,6 +99,8 @@ jobs:
|
||||
run: |
|
||||
PYTHONPATH="." python test/external/fuzz_shapetracker.py
|
||||
PYTHONPATH="." python test/external/fuzz_shapetracker_math.py
|
||||
- name: Test to_movement_ops
|
||||
run: PYTHONPATH="." python extra/to_movement_ops.py
|
||||
- name: Use as an external package
|
||||
run: |
|
||||
mkdir $HOME/test_external_dir
|
||||
|
||||
@@ -135,17 +135,16 @@ def test_rebuild(st: ShapeTracker):
|
||||
last_v2 = rebuilt_st.views[-1]
|
||||
assert last_v1.shape == last_v2.shape, f"{last_v1.shape} != {last_v2.shape}"
|
||||
|
||||
def test_interpret_ast(ast:LazyOp):
|
||||
def test_rebuild_bufferop_st(ast:LazyOp):
|
||||
if ast.op in BufferOps:
|
||||
test_rebuild(ast.arg.st)
|
||||
else:
|
||||
for src in ast.src: test_interpret_ast(src)
|
||||
for src in ast.src: test_rebuild_bufferop_st(src)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ast_strs = load_worlds(False, False, True)[:4000]
|
||||
ast_strs = load_worlds(False, False, True)[:2000]
|
||||
for ast_str in tqdm(ast_strs):
|
||||
for op in ast_str_to_ast(ast_str):
|
||||
test_interpret_ast(op)
|
||||
for ast in ast_str_to_ast(ast_str):
|
||||
test_rebuild_bufferop_st(ast)
|
||||
|
||||
print(f"avg length of mop = {sum(k*v for k,v in c.items()) / sum(c.values()):.2f}")
|
||||
Reference in New Issue
Block a user