mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
* openpilot image warp test * 0.4 ms on metal, 1 ms on CPU * new inputs each time * reshape
56 lines
1.8 KiB
Python
56 lines
1.8 KiB
Python
import time
|
|
from tinygrad.tensor import Tensor, Device
|
|
|
|
MODEL_WIDTH = 512
|
|
MODEL_HEIGHT = 256
|
|
MODEL_FRAME_SIZE = MODEL_WIDTH * MODEL_HEIGHT * 3 // 2
|
|
IMG_INPUT_SHAPE = (1, 12, 128, 256)
|
|
|
|
def tensor_arange(end): return Tensor([float(i) for i in range(end)])
|
|
def tensor_round(tensor:Tensor): return (tensor + 0.5).floor()
|
|
|
|
h_src, w_src = 1208, 1928
|
|
h_dst, w_dst = MODEL_HEIGHT, MODEL_WIDTH
|
|
x = tensor_arange(w_dst).reshape(1, w_dst).expand(h_dst, w_dst)
|
|
y = tensor_arange(h_dst).reshape(h_dst, 1).expand(h_dst, w_dst)
|
|
ones = Tensor.ones_like(x)
|
|
dst_coords = x.reshape((1,-1)).cat(y.reshape((1,-1))).cat(ones.reshape((1,-1)))
|
|
|
|
def warp_perspective_tinygrad(src:Tensor, M_inv:Tensor) -> Tensor:
|
|
src_coords = M_inv @ dst_coords
|
|
src_coords = src_coords / src_coords[2:3, :]
|
|
|
|
x_src = src_coords[0].reshape(h_dst, w_dst)
|
|
y_src = src_coords[1].reshape(h_dst, w_dst)
|
|
|
|
x_nearest = tensor_round(x_src).clip(0, w_src - 1).cast('int')
|
|
y_nearest = tensor_round(y_src).clip(0, h_src - 1).cast('int')
|
|
|
|
# TODO: make 2d indexing fast
|
|
idx = y_nearest*src.shape[1] + x_nearest
|
|
dst = src.flatten()[idx]
|
|
return dst.reshape(h_dst, w_dst)
|
|
|
|
if __name__ == "__main__":
|
|
from tinygrad.engine.jit import TinyJit
|
|
update_img_jit = TinyJit(warp_perspective_tinygrad, prune=True)
|
|
|
|
step_times = []
|
|
for _ in range(10):
|
|
# regenerate inputs
|
|
inputs = [Tensor.randn(1928,1208), Tensor.randn(3,3)]
|
|
Tensor.realize(*inputs)
|
|
Device.default.synchronize()
|
|
|
|
# do the warp
|
|
st = time.perf_counter()
|
|
out = update_img_jit(*inputs)
|
|
mt = time.perf_counter()
|
|
val = out.contiguous().realize()
|
|
Device.default.synchronize()
|
|
et = time.perf_counter()
|
|
|
|
# measure the time
|
|
step_times.append((et-st)*1e3)
|
|
print(f"enqueue {(mt-st)*1e3:6.2f} ms -- total run {step_times[-1]:6.2f} ms")
|