mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
ruff check whole examples/mlperf/ (#10979)
This commit is contained in:
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@@ -336,7 +336,7 @@ jobs:
|
||||
run: |
|
||||
pip3 install --upgrade --force-reinstall ruff==0.11.0
|
||||
python3 -m ruff check .
|
||||
python3 -m ruff check examples/mlperf/model_train.py --ignore E501
|
||||
python3 -m ruff check examples/mlperf/ --ignore E501
|
||||
- name: Lint tinygrad with pylint
|
||||
run: python -m pylint tinygrad/
|
||||
- name: Run mypy
|
||||
|
||||
@@ -212,7 +212,7 @@ def get_mlperf_bert_model():
|
||||
from examples.mlperf.initializers import LinearBert, EmbeddingBert, LayerNormBert
|
||||
|
||||
bert.Linear = LinearBert
|
||||
bert.Embedding = EmbeddingBert
|
||||
bert.Embedding = EmbeddingBert
|
||||
bert.LayerNorm = LayerNormBert
|
||||
|
||||
from extra.models.bert import BertForPretraining
|
||||
|
||||
@@ -39,7 +39,7 @@ class LinearBert(nn.Linear):
|
||||
def __init__(self, in_features, out_features, bias=True, std=0.02):
|
||||
self.weight = std * rand_truncn(out_features, in_features, dtype=dtypes.float32)
|
||||
self.bias = Tensor.zeros(out_features, dtype=dtypes.float32) if bias else None
|
||||
|
||||
|
||||
def __call__(self, x:Tensor):
|
||||
return x.cast(dtypes.default_float).linear(self.weight.cast(dtypes.default_float).transpose(), self.bias.cast(dtypes.default_float) if self.bias is not None else None)
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from tinygrad import Tensor, dtypes
|
||||
from tinygrad import dtypes
|
||||
from tinygrad.nn.optim import Optimizer
|
||||
|
||||
from extra.lr_scheduler import LR_Scheduler
|
||||
|
||||
Reference in New Issue
Block a user