more typing fixes

This commit is contained in:
Francis Lata
2025-02-28 15:42:11 +00:00
parent e9d1af26b2
commit 074e9f742b
6 changed files with 36 additions and 33 deletions

View File

@@ -1,5 +1,5 @@
import os, random, pickle, queue
from typing import List, Optional
from typing import List
from pathlib import Path
from multiprocessing import Queue, Process, shared_memory, connection, Lock, cpu_count
@@ -351,8 +351,8 @@ def batch_load_unet3d(preprocessed_dataset_dir:Path, batch_size:int=6, val:bool=
### RetinaNet
def load_retinanet_data(base_dir:Path, val:bool, queue_in:Queue, queue_out:Queue,
imgs:Tensor, boxes:Tensor, labels:Tensor, matches:Optional[Tensor]=None,
anchors:Optional[Tensor]=None, seed:Optional[int]=None):
imgs:Tensor, boxes:Tensor, labels:Tensor, matches:Tensor|None=None,
anchors:Tensor|None=None, seed:int|None=None):
from extra.datasets.openimages import image_load, random_horizontal_flip, resize
from examples.mlperf.helpers import box_iou, find_matches, generate_anchors
import torch
@@ -386,7 +386,7 @@ def load_retinanet_data(base_dir:Path, val:bool, queue_in:Queue, queue_out:Queue
queue_out.put(idx)
queue_out.put(None)
def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, shuffle:bool=True, seed:Optional[int]=None):
def batch_load_retinanet(dataset, val:bool, base_dir:Path, batch_size:int=32, shuffle:bool=True, seed:int|None=None):
def _enqueue_batch(bc):
from extra.datasets.openimages import prepare_target
for idx in range(bc * batch_size, (bc+1) * batch_size):

View File

@@ -68,13 +68,15 @@ class LayerNormBert:
return (xn * self.weight.cast(dtypes.default_float) + self.bias.cast(dtypes.default_float))
class FrozenBatchNorm2d(nn.BatchNorm2d):
def __init__(self, num_features):
def __init__(self, num_features:int):
super().__init__(num_features)
self.weight.requires_grad = False
self.bias.requires_grad = False
class Conv2dNormal(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True, prior_prob=None):
def __init__(self, in_channels:int, out_channels:int, kernel_size:int|tuple[int, ...],
stride:int=1, padding:int|tuple[int, ...]|str=0, dilation:int=1, groups:int=1,
bias:bool=True, prior_prob:float|None=None):
super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
self.weight = Tensor.normal(*self.weight.shape, std=0.01)
if bias:
@@ -84,7 +86,9 @@ class Conv2dNormal(nn.Conv2d):
else: self.bias = Tensor.zeros_like(self.bias)
class Conv2dKaimingUniform(nn.Conv2d):
def __init__(self, in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True):
def __init__(self, in_channels:int, out_channels:int, kernel_size:int|tuple[int, ...],
stride:int=1, padding:int|tuple[int, ...]|str=0, dilation:int=1, groups:int=1,
bias:bool=True):
super().__init__(in_channels, out_channels, kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
self.weight = Tensor.kaiming_uniform(*self.weight.shape, a=1)
if bias: self.bias = Tensor.zeros_like(self.bias)

View File

@@ -353,6 +353,8 @@ def train_retinanet():
from pycocotools.coco import COCO
from pycocotools.cocoeval import COCOeval
from tinygrad.helpers import colored
from tinygrad.nn.optim import Optimizer
from typing import Iterator
import numpy as np
@@ -361,18 +363,18 @@ def train_retinanet():
NUM_CLASSES = len(MLPERF_CLASSES)
BASE_DIR = getenv("BASE_DIR", BASEDIR)
BENCHMARK = getenv("BENCHMARK")
config["gpus"] = GPUS = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 1))]
config["gpus"] = GPUS = [f"{Device.DEFAULT}:{i}" for i in range(getenv("GPUS", 6))]
for x in GPUS: Device[x]
print(f"training on {GPUS}")
def _freeze_backbone_layers(backbone, trainable_layers):
def _freeze_backbone_layers(backbone:resnet.ResNet, trainable_layers:int):
layers_to_train = ["layer4", "layer3", "layer2", "layer1", "conv1"][:trainable_layers]
for k, v in get_state_dict(backbone).items():
if all([not k.startswith(layer) for layer in layers_to_train]):
v.requires_grad = False
def _data_get(it, val=False):
def _data_get(it:Iterator[tuple[Tensor, ...]], val:bool=False):
if val:
x, img_ids, img_sizes, cookie = next(it)
return x.shard(GPUS, axis=0).realize(), img_ids, img_sizes, cookie
@@ -380,7 +382,7 @@ def train_retinanet():
x, y_boxes, y_labels, matches, anchors, cookie = next(it)
return x.shard(GPUS, axis=0).realize(), y_boxes.shard(GPUS, axis=0), y_labels.shard(GPUS, axis=0), matches.shard(GPUS, axis=0), anchors.shard(GPUS, axis=0), cookie
def _create_lr_scheduler(optim, start_iter, warmup_iters, warmup_factor):
def _create_lr_scheduler(optim:Optimizer, start_iter:int, warmup_iters:int, warmup_factor:float):
def _lr_lambda(e):
e = e + start_iter
if e >= warmup_iters: return 1.0
@@ -446,7 +448,7 @@ def train_retinanet():
lr_scheduler = _create_lr_scheduler(optim, start_iter, warmup_iters, lr_warmup_factor)
# ** resume from checkpointing **
if ckpt:=getenv("RESUME", ""):
if (ckpt:=getenv("RESUME", "")):
load_training_state(model, optim, lr_scheduler, safe_load(ckpt))
start_epoch = int(lr_scheduler.epoch_counter.item() / steps_in_train_epoch)
print(f"resuming from {ckpt} at epoch {start_epoch}")
@@ -456,7 +458,7 @@ def train_retinanet():
import wandb
wandb_args = {"project": "MLPerf-RetinaNet"}
if (wandb_id := getenv("WANDB_RESUME", "")):
if (wandb_id:=getenv("WANDB_RESUME", "")):
wandb_args["id"] = wandb_id
wandb_args["resume"] = "must"

View File

@@ -6,7 +6,6 @@ from pathlib import Path
import boto3, botocore
from tinygrad import Tensor, dtypes
from tinygrad.helpers import fetch, tqdm, getenv
from typing import Optional, Dict, Tuple, Union, List
import pandas as pd
import concurrent.futures
@@ -186,7 +185,7 @@ def random_horizontal_flip(img, tgt, prob=0.5):
tgt["boxes"][:, [0, 2]] = w - tgt["boxes"][:, [2, 0]]
return img, tgt
def resize(img:Image, tgt:Optional[Dict[str, Union[np.ndarray, Tuple]]]=None, size:Tuple[int, int]=(800, 800)) -> Union[Tuple[np.ndarray, np.ndarray, Tuple], Tuple[np.ndarray, Tuple]]:
def resize(img:Image, tgt:dict[str, np.ndarray|tuple]|None=None, size:tuple[int, int]=(800, 800)) -> tuple[np.ndarray, np.ndarray, tuple]|tuple[np.ndarray, tuple]:
import torchvision.transforms.functional as F
img_size = img.size[::-1]
img = F.resize(img, size=size)
@@ -206,7 +205,7 @@ def resize(img:Image, tgt:Optional[Dict[str, Union[np.ndarray, Tuple]]]=None, si
return img, img_size
def normalize(img:Tensor, device:Optional[List[str]] = None):
def normalize(img:Tensor, device:list[str]|None = None):
mean = Tensor([0.485, 0.456, 0.406], device=device, dtype=dtypes.float32).reshape(1, -1, 1, 1)
std = Tensor([0.229, 0.224, 0.225], device=device, dtype=dtypes.float32).reshape(1, -1, 1, 1)
img = ((img.permute([0, 3, 1, 2]) / 255.0) - mean) / std

View File

@@ -92,7 +92,7 @@ class OneCycleLR(LR_Scheduler):
).cast(self.optimizer.lr.dtype)
class LambdaLR(LR_Scheduler):
def __init__(self, optimizer: Optimizer, lr_lambda: Callable[[int],float]):
def __init__(self, optimizer:Optimizer, lr_lambda:Callable[[int], float]):
super().__init__(optimizer)
self.lr_lambda = lr_lambda
self.initial_lr = self.optimizer.lr.numpy()[0]

View File

@@ -1,7 +1,5 @@
from typing import Optional, Union, Dict
import math
from tinygrad import Tensor, dtypes
from tinygrad import Tensor
from tinygrad.helpers import flatten, get_child
import tinygrad.nn as nn
from examples.mlperf.helpers import generate_anchors, BoxCoder
@@ -37,7 +35,7 @@ def decode_bbox(offsets, anchors):
return np.stack([pred_x1, pred_y1, pred_x2, pred_y2], axis=1, dtype=np.float32)
class RetinaNet:
def __init__(self, backbone: ResNet, num_classes=264, num_anchors=9, scales=None, aspect_ratios=None):
def __init__(self, backbone:ResNet, num_classes:int=264, num_anchors:int=9, scales:list[int]|None=None, aspect_ratios:list[float]|None=None):
assert isinstance(backbone, ResNet)
scales = tuple((i, int(i*2**(1/3)), int(i*2**(2/3))) for i in 2**np.arange(5, 10)) if scales is None else scales
aspect_ratios = ((0.5, 1.0, 2.0),) * len(scales) if aspect_ratios is None else aspect_ratios
@@ -133,12 +131,12 @@ class RetinaNet:
return detections
class ClassificationHead:
def __init__(self, in_channels, num_anchors, num_classes, prior_prob=0.01):
def __init__(self, in_channels:int, num_anchors:int, num_classes:int, prior_prob:float=0.01):
self.num_classes = num_classes
self.conv = flatten([(Conv2dNormal(in_channels, in_channels, kernel_size=3, padding=1), lambda x: x.relu()) for _ in range(4)])
self.cls_logits = Conv2dNormal(in_channels, num_anchors * num_classes, kernel_size=3, padding=1, prior_prob=prior_prob)
def __call__(self, x:Tensor, labels:Optional[Tensor] = None, matches:Optional[Tensor] = None):
def __call__(self, x:Tensor, labels:Tensor|None=None, matches:Tensor|None=None):
out = [self.cls_logits(feat.sequential(self.conv)).permute(0, 2, 3, 1).reshape(feat.shape[0], -1, self.num_classes) for feat in x]
out = out[0].cat(*out[1:], dim=1)
@@ -156,7 +154,7 @@ class ClassificationHead:
return loss
class RegressionHead:
def __init__(self, in_channels, num_anchors, box_coder:Optional[BoxCoder] = None):
def __init__(self, in_channels:int, num_anchors:int, box_coder:BoxCoder|None=None):
self.conv = flatten([(Conv2dNormal(in_channels, in_channels, kernel_size=3, padding=1), lambda x: x.relu()) for _ in range(4)])
self.bbox_reg = Conv2dNormal(in_channels, num_anchors * 4, kernel_size=3, padding=1)
@@ -164,7 +162,7 @@ class RegressionHead:
box_coder = BoxCoder((1.0, 1.0, 1.0, 1.0), apply_to_remove=False)
self.box_coder = box_coder
def __call__(self, x:Tensor, bboxes:Optional[Tensor] = None, matches:Optional[Tensor] = None, anchors:Optional[Tensor] = None):
def __call__(self, x:Tensor, bboxes:Tensor|None=None, matches:Tensor|None=None, anchors:Tensor|None=None):
out = [self.bbox_reg(feat.sequential(self.conv)).permute(0, 2, 3, 1).reshape(feat.shape[0], -1, 4) for feat in x]
out = out[0].cat(*out[1:], dim=1)
@@ -183,11 +181,11 @@ class RegressionHead:
return loss
class RetinaHead:
def __init__(self, in_channels, num_anchors, num_classes):
def __init__(self, in_channels:int, num_anchors:int, num_classes:int):
self.classification_head = ClassificationHead(in_channels, num_anchors, num_classes)
self.regression_head = RegressionHead(in_channels, num_anchors)
def __call__(self, x:Tensor, **kwargs) -> Union[Tensor, Dict[str, Tensor]]:
def __call__(self, x:Tensor, **kwargs) -> Tensor|dict[str, Tensor]:
if Tensor.training:
return {
"classification_loss": self.classification_head(x, labels=kwargs["labels"], matches=kwargs["matches"]),
@@ -199,7 +197,7 @@ class RetinaHead:
return out
class ResNetFPN:
def __init__(self, resnet, out_channels=256, returned_layers=[2, 3, 4]):
def __init__(self, resnet:ResNet, out_channels:int=256, returned_layers:list[int]=[2, 3, 4]):
self.out_channels = out_channels
self.body = resnet
in_channels_list = [(self.body.in_planes // 8) * 2 ** (i - 1) for i in returned_layers]
@@ -209,7 +207,7 @@ class ResNetFPN:
def compute_grid_sizes(self, input_size):
return np.ceil(np.array(input_size)[None, :] / 2 ** np.arange(3, 8)[:, None])
def __call__(self, x):
def __call__(self, x:Tensor):
out = self.body.bn1(self.body.conv1(x)).relu()
out = out.pad([1,1,1,1]).max_pool2d((3,3), 2)
out = out.sequential(self.body.layer1)
@@ -219,12 +217,12 @@ class ResNetFPN:
return self.fpn([p3, p4, p5])
class ExtraFPNBlock:
def __init__(self, in_channels, out_channels):
def __init__(self, in_channels:int, out_channels:int):
self.p6 = Conv2dKaimingUniform(in_channels, out_channels, kernel_size=3, stride=2, padding=1)
self.p7 = Conv2dKaimingUniform(out_channels, out_channels, kernel_size=3, stride=2, padding=1)
self.use_P5 = in_channels == out_channels
def __call__(self, p, c):
def __call__(self, p:Tensor, c:Tensor):
p5, c5 = p[-1], c[-1]
x = p5 if self.use_P5 else c5
p6 = self.p6(x)
@@ -233,14 +231,14 @@ class ExtraFPNBlock:
return p
class FPN:
def __init__(self, in_channels_list, out_channels, extra_blocks=None):
def __init__(self, in_channels_list:list[int], out_channels:int, extra_blocks:ExtraFPNBlock|None=None):
self.inner_blocks, self.layer_blocks = [], []
for in_channels in in_channels_list:
self.inner_blocks.append(Conv2dKaimingUniform(in_channels, out_channels, kernel_size=1))
self.layer_blocks.append(Conv2dKaimingUniform(out_channels, out_channels, kernel_size=3, padding=1))
self.extra_blocks = ExtraFPNBlock(256, 256) if extra_blocks is None else extra_blocks
def __call__(self, x):
def __call__(self, x:Tensor):
last_inner = self.inner_blocks[-1](x[-1])
results = [self.layer_blocks[-1](last_inner)]
for idx in range(len(x) - 2, -1, -1):