mirror of
https://github.com/ROCm/ROCm.git
synced 2026-04-05 03:01:17 -04:00
Merge remote-tracking branch 'upstream/main' into triton-mlir-IFU-03122023
This commit is contained in:
5
.github/workflows/integration-tests.yml
vendored
5
.github/workflows/integration-tests.yml
vendored
@@ -42,7 +42,7 @@ jobs:
|
||||
|
||||
- name: Clear cache
|
||||
run: |
|
||||
rm -rf ~/.triton/cache/
|
||||
rm -rf ~/.triton/
|
||||
|
||||
- name: Update path
|
||||
run: |
|
||||
@@ -93,7 +93,6 @@ jobs:
|
||||
cd python/test/unit/
|
||||
pytest
|
||||
|
||||
|
||||
- name: Run CXX unittests
|
||||
run: |
|
||||
cd python/
|
||||
@@ -107,4 +106,4 @@ jobs:
|
||||
sudo nvidia-smi -i 0 -pm 1
|
||||
sudo nvidia-smi -i 0 --lock-gpu-clocks=1350,1350
|
||||
pytest -vs .
|
||||
sudo nvidia-smi -i 0 -rgc
|
||||
sudo nvidia-smi -i 0 -rgc
|
||||
|
||||
@@ -1,37 +1,37 @@
|
||||
dev,name,batch_size,speedup,abs_latency,compilation_latency,compression_ratio
|
||||
cuda,AlbertForMaskedLM,4,1.5637,196.1976,36.7633,1.2637
|
||||
cuda,AlbertForQuestionAnswering,4,1.5697,193.9925,22.3314,1.3133
|
||||
cuda,BartForCausalLM,4,1.4704,88.2529,36.2264,0.9750
|
||||
cuda,BertForMaskedLM,16,1.5146,83.3326,42.1532,1.0497
|
||||
cuda,BertForQuestionAnswering,16,1.6353,65.9640,30.3219,1.1693
|
||||
cuda,BlenderbotSmallForCausalLM,64,1.1694,62.6029,28.2282,0.9127
|
||||
cuda,BlenderbotSmallForConditionalGeneration,64,1.3153,116.4360,51.6582,0.9804
|
||||
cuda,CamemBert,16,1.4672,93.1047,33.8399,1.0462
|
||||
cuda,DebertaForMaskedLM,4,0.8610,94.2568,41.9968,1.0406
|
||||
cuda,DebertaForQuestionAnswering,8,0.9550,94.5806,42.3991,1.1506
|
||||
cuda,DebertaV2ForMaskedLM,1,0.7829,214.8802,64.2595,0.9772
|
||||
cuda,DistilBertForMaskedLM,128,1.2346,80.7496,30.1995,0.9625
|
||||
cuda,DistilBertForQuestionAnswering,256,1.3784,86.2125,32.3333,1.1469
|
||||
cuda,DistillGPT2,16,1.6558,72.8982,21.3507,1.0639
|
||||
cuda,ElectraForCausalLM,32,1.5242,67.9513,57.7771,0.9719
|
||||
cuda,ElectraForQuestionAnswering,64,1.9230,67.2341,45.5223,1.1624
|
||||
cuda,GPT2ForSequenceClassification,4,2.0511,53.7014,31.2262,1.2305
|
||||
cuda,LayoutLMForMaskedLM,16,1.5055,84.8592,40.6248,1.0491
|
||||
cuda,LayoutLMForSequenceClassification,16,1.6464,66.9082,35.1418,1.1401
|
||||
cuda,MBartForCausalLM,4,1.4704,88.7321,29.1925,0.9831
|
||||
cuda,MegatronBertForCausalLM,4,1.1061,136.3322,79.4501,1.0946
|
||||
cuda,MegatronBertForQuestionAnswering,8,1.2551,133.9124,75.9488,1.1147
|
||||
cuda,MobileBertForMaskedLM,64,0.9569,333.3552,130.7601,1.0135
|
||||
cuda,MobileBertForQuestionAnswering,128,0.9634,331.8111,126.0109,0.8400
|
||||
cuda,PLBartForCausalLM,8,1.5155,83.1783,24.3849,0.9886
|
||||
cuda,PLBartForConditionalGeneration,4,1.4414,93.6038,52.3630,1.0496
|
||||
cuda,PegasusForCausalLM,32,1.1225,79.1829,36.5687,0.9737
|
||||
cuda,PegasusForConditionalGeneration,32,1.1506,175.4371,59.1006,1.0686
|
||||
cuda,RobertaForCausalLM,16,1.5780,83.6565,33.7543,1.0491
|
||||
cuda,RobertaForQuestionAnswering,16,1.6336,66.3454,29.9597,1.1698
|
||||
cuda,Speech2Text2ForCausalLM,256,1.5464,41.4059,25.6908,0.8768
|
||||
cuda,T5ForConditionalGeneration,4,1.2736,96.9787,54.8479,1.1802
|
||||
cuda,T5Small,4,1.2861,98.4766,32.1507,1.1802
|
||||
cuda,TrOCRForCausalLM,32,1.2573,127.5731,36.5153,0.9584
|
||||
cuda,XLNetLMHeadModel,8,1.6924,177.0149,83.7423,1.1026
|
||||
cuda,YituTechConvBert,16,1.4142,107.2519,68.8073,1.0363
|
||||
cuda,AlbertForMaskedLM,4,1.5511,164.3373,26.8523,1.2647
|
||||
cuda,AlbertForQuestionAnswering,4,1.5501,163.5580,25.7983,1.3145
|
||||
cuda,BartForCausalLM,4,1.5080,71.7230,32.8907,0.9749
|
||||
cuda,BertForMaskedLM,16,1.5350,67.9451,35.3286,1.0494
|
||||
cuda,BertForQuestionAnswering,16,1.6735,53.2963,34.3754,1.1710
|
||||
cuda,BlenderbotSmallForCausalLM,64,1.2106,46.6466,23.8058,0.9120
|
||||
cuda,BlenderbotSmallForConditionalGeneration,64,1.3616,77.3013,55.3546,0.9803
|
||||
cuda,CamemBert,16,1.4779,76.1809,35.3883,1.0469
|
||||
cuda,DebertaForMaskedLM,4,0.8415,62.3395,35.9657,1.0418
|
||||
cuda,DebertaForQuestionAnswering,8,1.0609,67.5151,35.7728,1.1528
|
||||
cuda,DebertaV2ForMaskedLM,1,0.6026,134.6517,66.1783,0.9773
|
||||
cuda,DistilBertForMaskedLM,128,1.2460,66.9382,18.3089,0.9624
|
||||
cuda,DistilBertForQuestionAnswering,256,1.3997,72.4126,18.1956,1.1486
|
||||
cuda,DistillGPT2,16,1.6656,60.5455,17.2280,1.0641
|
||||
cuda,ElectraForCausalLM,32,1.8299,45.4841,37.0944,0.9717
|
||||
cuda,ElectraForQuestionAnswering,64,2.0289,52.6890,35.9632,1.1928
|
||||
cuda,GPT2ForSequenceClassification,4,2.2567,38.2969,30.0527,1.2323
|
||||
cuda,LayoutLMForMaskedLM,16,1.5423,68.8018,36.5562,1.0495
|
||||
cuda,LayoutLMForSequenceClassification,16,1.7058,53.9355,35.2225,1.1659
|
||||
cuda,MBartForCausalLM,4,1.4945,71.4649,32.8653,0.9830
|
||||
cuda,MegatronBertForCausalLM,4,1.4328,58.4404,70.6226,1.0951
|
||||
cuda,MegatronBertForQuestionAnswering,8,1.5886,85.2533,69.1219,1.1152
|
||||
cuda,MobileBertForMaskedLM,64,0.9007,131.7379,107.5275,1.0136
|
||||
cuda,MobileBertForQuestionAnswering,128,0.8435,167.9066,106.7049,0.8579
|
||||
cuda,PLBartForCausalLM,8,1.5261,68.9224,19.5826,0.9887
|
||||
cuda,PLBartForConditionalGeneration,4,1.5298,71.2811,45.6902,1.0495
|
||||
cuda,PegasusForCausalLM,32,1.2212,57.5436,33.3863,0.9736
|
||||
cuda,PegasusForConditionalGeneration,32,1.2822,106.4678,69.8825,1.0689
|
||||
cuda,RobertaForCausalLM,16,1.6128,67.5706,34.7355,1.0496
|
||||
cuda,RobertaForQuestionAnswering,16,1.6800,53.6267,33.8527,1.1704
|
||||
cuda,Speech2Text2ForCausalLM,256,1.8230,32.9145,18.7201,0.8760
|
||||
cuda,T5ForConditionalGeneration,4,1.6592,59.5324,39.4406,1.1814
|
||||
cuda,T5Small,4,1.6581,59.5930,37.0471,1.1814
|
||||
cuda,TrOCRForCausalLM,32,1.2586,106.2633,32.5330,0.9583
|
||||
cuda,XLNetLMHeadModel,8,1.8108,142.8795,84.8197,1.1240
|
||||
cuda,YituTechConvBert,16,1.5207,81.4595,53.1565,1.0362
|
||||
|
||||
|
106
.github/workflows/torchinductor/data/timm_models.csv
vendored
106
.github/workflows/torchinductor/data/timm_models.csv
vendored
@@ -1,54 +1,54 @@
|
||||
dev,name,batch_size,speedup,abs_latency,compilation_latency,compression_ratio
|
||||
cuda,adv_inception_v3,128,1.5315,126.8003,157.5622,1.0179
|
||||
cuda,beit_base_patch16_224,64,1.3308,91.0230,35.2707,0.9891
|
||||
cuda,coat_lite_mini,128,2.0396,63.6890,87.8844,1.0199
|
||||
cuda,convmixer_768_32,32,1.0443,336.7657,32.3711,0.9999
|
||||
cuda,convnext_base,64,1.4920,97.2042,76.9539,1.0388
|
||||
cuda,crossvit_9_240,128,1.2019,87.2138,86.0260,1.0032
|
||||
cuda,cspdarknet53,64,1.4388,90.4760,88.7691,1.0140
|
||||
cuda,deit_base_distilled_patch16_224,64,1.2471,81.1739,45.5359,0.9527
|
||||
cuda,dla102,128,1.5027,147.5660,82.0541,1.0331
|
||||
cuda,dm_nfnet_f0,128,1.4283,98.4167,67.5993,1.0713
|
||||
cuda,dpn107,32,1.2356,100.9084,91.1994,0.9651
|
||||
cuda,eca_botnext26ts_256,128,1.5012,86.6883,79.4803,1.0043
|
||||
cuda,ese_vovnet19b_dw,128,1.3848,52.7837,47.5720,0.9915
|
||||
cuda,fbnetc_100,128,1.5364,64.6275,121.7965,0.9548
|
||||
cuda,gernet_l,128,1.1577,76.1686,56.8719,0.9712
|
||||
cuda,ghostnet_100,128,1.7220,65.4748,209.2949,1.0223
|
||||
cuda,gluon_inception_v3,128,1.5343,126.3718,52.4958,1.0466
|
||||
cuda,gmixer_24_224,128,1.6468,84.2051,51.3789,1.1584
|
||||
cuda,gmlp_s16_224,128,1.5873,94.6600,58.2645,1.2023
|
||||
cuda,hrnet_w18,128,1.3546,263.2673,198.3349,0.9923
|
||||
cuda,inception_v3,128,1.5238,126.7710,51.6400,1.0466
|
||||
cuda,jx_nest_base,32,1.2384,103.6627,78.7322,0.9607
|
||||
cuda,lcnet_050,128,1.7771,21.2319,34.1647,0.9458
|
||||
cuda,mixer_b16_224,128,1.2902,108.2678,27.6445,0.9948
|
||||
cuda,mixnet_l,128,1.2122,182.1990,118.8907,0.9908
|
||||
cuda,mnasnet_100,128,1.6199,47.7919,48.7633,0.9408
|
||||
cuda,mobilenetv2_100,128,1.5661,50.4129,42.9730,1.1166
|
||||
cuda,mobilenetv3_large_100,128,1.5888,46.3295,51.9647,0.9704
|
||||
cuda,mobilevit_s,64,1.3031,82.1195,111.3226,1.0065
|
||||
cuda,nfnet_l0,128,1.4895,81.4742,58.5768,0.9691
|
||||
cuda,pit_b_224,64,1.4120,97.3046,43.7878,1.0241
|
||||
cuda,pnasnet5large,16,1.0523,239.1102,145.2293,1.2797
|
||||
cuda,poolformer_m36,64,1.2154,138.0360,93.0292,1.1927
|
||||
cuda,regnety_002,128,1.2659,38.8745,68.1799,0.8660
|
||||
cuda,repvgg_a2,128,1.2185,73.6415,32.3085,0.9735
|
||||
cuda,res2net101_26w_4s,64,1.0443,116.1409,144.6286,0.9491
|
||||
cuda,res2net50_14w_8s,128,1.3212,130.1624,102.7642,0.9609
|
||||
cuda,res2next50,128,1.2159,157.3657,46.9827,0.9756
|
||||
cuda,resmlp_12_224,128,1.2970,54.6888,40.0312,1.0342
|
||||
cuda,resnest101e,64,1.4079,134.5610,119.5467,1.0831
|
||||
cuda,rexnet_100,128,1.4427,65.3909,222.1865,1.0439
|
||||
cuda,selecsls42b,128,1.4015,53.4159,31.3161,0.9731
|
||||
cuda,spnasnet_100,128,1.5507,54.3208,34.4102,1.0045
|
||||
cuda,swin_base_patch4_window7_224,64,1.5038,115.4018,104.8326,0.9043
|
||||
cuda,swsl_resnext101_32x16d,32,0.9981,136.5238,49.8939,0.9833
|
||||
cuda,tf_efficientnet_b0,128,1.4894,67.2972,57.0583,1.0725
|
||||
cuda,tf_mixnet_l,128,1.2179,189.8781,68.3717,1.0676
|
||||
cuda,tinynet_a,128,1.3548,64.3571,117.1189,1.0718
|
||||
cuda,tnt_s_patch16_224,128,3.0069,126.5317,67.8712,1.0505
|
||||
cuda,twins_pcpvt_base,64,1.2016,154.8390,144.5083,1.0541
|
||||
cuda,visformer_small,128,1.1935,87.3201,42.3853,1.0220
|
||||
cuda,vit_base_patch16_224,64,1.2207,85.2031,39.3641,0.9551
|
||||
cuda,volo_d1_224,0,0.0000
|
||||
cuda,adv_inception_v3,128,1.5923,102.5292,51.6032,1.0472
|
||||
cuda,beit_base_patch16_224,64,1.3390,75.3027,29.7471,1.0156
|
||||
cuda,coat_lite_mini,128,2.0579,53.3689,37.1856,1.0437
|
||||
cuda,convmixer_768_32,32,1.0470,275.5328,23.8037,0.9999
|
||||
cuda,convnext_base,64,1.5084,80.1811,42.5659,1.0373
|
||||
cuda,crossvit_9_240,128,1.5392,37.1806,44.9986,0.9193
|
||||
cuda,cspdarknet53,64,1.4721,75.0403,35.2882,1.0547
|
||||
cuda,deit_base_distilled_patch16_224,64,1.1432,55.9737,23.4038,0.9816
|
||||
cuda,dla102,128,1.5282,123.7284,49.3612,1.0430
|
||||
cuda,dm_nfnet_f0,128,1.4354,79.7518,34.8994,1.1038
|
||||
cuda,dpn107,32,1.2412,83.8921,58.9111,0.9952
|
||||
cuda,eca_botnext26ts_256,128,1.5425,71.2406,28.8920,1.0270
|
||||
cuda,ese_vovnet19b_dw,128,1.4647,42.4837,18.0285,1.0135
|
||||
cuda,fbnetc_100,128,1.5795,53.8033,33.0222,1.0082
|
||||
cuda,gernet_l,128,1.1684,63.4230,26.8687,1.0053
|
||||
cuda,ghostnet_100,128,1.7812,54.4211,47.6168,1.0484
|
||||
cuda,gluon_inception_v3,128,1.5952,102.5018,50.0857,1.0469
|
||||
cuda,gmixer_24_224,128,1.6749,69.2430,42.0841,1.1921
|
||||
cuda,gmlp_s16_224,128,1.5886,79.2132,43.0142,1.2343
|
||||
cuda,hrnet_w18,128,1.3743,221.5304,134.2573,1.0100
|
||||
cuda,inception_v3,128,1.5847,102.8333,49.7648,1.0472
|
||||
cuda,jx_nest_base,32,1.3747,71.4190,61.4053,0.9905
|
||||
cuda,lcnet_050,128,1.8159,18.0047,18.8249,1.0005
|
||||
cuda,mixer_b16_224,128,1.2795,90.9229,21.0438,1.0133
|
||||
cuda,mixnet_l,128,1.2273,149.9722,47.7482,1.0129
|
||||
cuda,mnasnet_100,128,1.6594,40.0512,26.5165,1.0047
|
||||
cuda,mobilenetv2_100,128,1.6085,41.1217,27.4450,1.1731
|
||||
cuda,mobilenetv3_large_100,128,1.6610,37.9995,29.8185,1.0052
|
||||
cuda,mobilevit_s,64,1.5212,55.4152,53.6475,1.0258
|
||||
cuda,nfnet_l0,128,1.4927,65.7078,32.4067,0.9980
|
||||
cuda,pit_b_224,64,1.2286,57.9484,26.5321,0.9606
|
||||
cuda,pnasnet5large,16,1.0000,198.2494,93.4641,1.3184
|
||||
cuda,poolformer_m36,64,1.3486,103.9235,62.3196,1.1942
|
||||
cuda,regnety_002,128,1.3030,32.4968,27.2439,1.0014
|
||||
cuda,repvgg_a2,128,1.2485,59.7729,26.9209,1.0185
|
||||
cuda,res2net101_26w_4s,64,1.0813,94.1773,86.6520,0.9655
|
||||
cuda,res2net50_14w_8s,128,1.3251,109.5258,79.9578,0.9830
|
||||
cuda,res2next50,128,1.2518,125.5008,43.9754,0.9756
|
||||
cuda,resmlp_12_224,128,1.3060,45.2373,19.3709,1.1048
|
||||
cuda,resnest101e,64,1.4346,108.1945,78.1993,1.1037
|
||||
cuda,rexnet_100,128,1.4637,55.0121,41.2075,1.0862
|
||||
cuda,selecsls42b,128,1.4284,44.6645,23.3892,1.0139
|
||||
cuda,spnasnet_100,128,1.5908,45.3189,32.0148,1.0048
|
||||
cuda,swin_base_patch4_window7_224,64,1.6164,89.5854,75.5848,0.9299
|
||||
cuda,swsl_resnext101_32x16d,32,1.0175,110.0041,45.7853,1.0003
|
||||
cuda,tf_efficientnet_b0,128,1.5271,55.7361,34.5551,1.1079
|
||||
cuda,tf_mixnet_l,128,1.2369,155.9027,48.6695,1.0921
|
||||
cuda,tinynet_a,128,1.3792,53.0640,40.6346,1.1108
|
||||
cuda,tnt_s_patch16_224,128,3.1078,104.8486,59.6028,1.0660
|
||||
cuda,twins_pcpvt_base,64,1.5921,67.4600,84.4977,1.0909
|
||||
cuda,visformer_small,128,1.1952,72.8705,23.7303,1.0410
|
||||
cuda,vit_base_patch16_224,64,1.1309,56.4866,22.0208,0.9804
|
||||
cuda,volo_d1_224,64,1.6868,72.0957,65.3011,0.9729
|
||||
|
||||
|
@@ -1,51 +1,53 @@
|
||||
dev,name,batch_size,speedup,abs_latency,compilation_latency,compression_ratio
|
||||
cuda,BERT_pytorch,16,1.3446,65.5196,59.4176,1.1679
|
||||
cuda,LearningToPaint,96,1.0376,12.7398,35.9770,0.7613
|
||||
cuda,Super_SloMo,6,1.3132,73.6570,39.9019,1.2390
|
||||
cuda,alexnet,128,1.1653,10.1614,10.5925,0.9408
|
||||
cuda,attention_is_all_you_need_pytorch,256,1.2514,82.7756,66.7768,1.1459
|
||||
cuda,dcgan,32,0.8947,2.6217,5.9728,1.0082
|
||||
cuda,densenet121,4,0.8777,65.7564,123.8633,0.8292
|
||||
cuda,drq,1,1.0291,4.9372,7.8125,0.9849
|
||||
cuda,fastNLP_Bert,6,1.5073,73.2808,42.6538,1.1547
|
||||
cuda,functorch_dp_cifar10,64,1.4043,9.3820,51.7668,0.4986
|
||||
cuda,functorch_maml_omniglot,1,1.0998,3.0214,12.6407,0.2181
|
||||
cuda,hf_Albert,8,1.3968,56.3755,38.3806,1.2603
|
||||
cuda,hf_Bart,4,1.1020,91.5587,53.6117,1.0087
|
||||
cuda,hf_Bert,4,1.1458,65.0555,43.0593,1.0261
|
||||
cuda,hf_Bert_large,4,1.1683,131.1134,59.7228,1.0909
|
||||
cuda,hf_DistilBert,8,1.2007,34.2069,22.8917,1.0228
|
||||
cuda,hf_GPT2,4,1.2689,52.0180,31.7956,1.1540
|
||||
cuda,BERT_pytorch,16,1.7111,24.2741,35.7065,1.3212
|
||||
cuda,LearningToPaint,96,1.0513,10.7557,11.1879,0.9896
|
||||
cuda,Super_SloMo,6,1.3267,60.4328,28.2097,1.2392
|
||||
cuda,alexnet,128,1.1754,8.3246,5.3319,1.0003
|
||||
cuda,attention_is_all_you_need_pytorch,256,1.3416,36.4401,39.5927,1.1774
|
||||
cuda,dcgan,32,0.9151,2.6249,3.2964,1.0082
|
||||
cuda,densenet121,4,0.9225,51.3747,68.5841,0.9930
|
||||
cuda,doctr_det_predictor,0,0.0000
|
||||
cuda,doctr_reco_predictor,0,0.0000
|
||||
cuda,drq,1,0.9500,3.4884,4.8028,0.9687
|
||||
cuda,fastNLP_Bert,6,1.4328,34.7753,35.4863,1.2368
|
||||
cuda,functorch_dp_cifar10,64,1.2015,8.1625,12.9040,1.0609
|
||||
cuda,functorch_maml_omniglot,1,0.9322,2.5844,3.8640,1.0000
|
||||
cuda,hf_Albert,8,2.1228,30.3377,26.8282,1.2676
|
||||
cuda,hf_Bart,4,1.2899,39.1935,47.2373,1.0080
|
||||
cuda,hf_Bert,4,1.3262,26.1063,35.0281,1.0656
|
||||
cuda,hf_Bert_large,4,1.4163,55.1021,67.2825,1.0915
|
||||
cuda,hf_DistilBert,8,1.4051,21.7191,18.0399,1.0242
|
||||
cuda,hf_GPT2,4,1.6661,26.9039,29.9473,1.1555
|
||||
cuda,hf_Longformer,0,0.0000
|
||||
cuda,hf_Reformer,4,1.0903,82.9519,28.0343,0.9289
|
||||
cuda,hf_T5_large,2,1.3534,332.3302,172.6140,1.1666
|
||||
cuda,lennard_jones,1000,0.9952,3.8690,4.8521,1.0000
|
||||
cuda,maml_omniglot,32,1.0328,3.3367,8.2772,0.2181
|
||||
cuda,mnasnet1_0,32,1.0162,25.4638,69.8684,0.8356
|
||||
cuda,mobilenet_v2,96,1.5212,38.4276,100.4918,1.1011
|
||||
cuda,nvidia_deeprecommender,256,1.0517,11.1245,7.3804,0.9715
|
||||
cuda,phlippe_densenet,128,1.0043,33.4096,108.0736,0.8774
|
||||
cuda,phlippe_resnet,128,1.0229,14.0998,21.7420,0.4147
|
||||
cuda,pytorch_CycleGAN_and_pix2pix,1,1.3815,9.3944,32.3602,0.6135
|
||||
cuda,pytorch_stargan,16,1.1625,14.4103,41.3705,0.8893
|
||||
cuda,pytorch_unet,1,1.3638,35.7120,51.2342,0.9525
|
||||
cuda,resnet152,32,0.9568,76.3876,70.2073,0.9997
|
||||
cuda,resnet18,16,0.9193,12.1360,23.4287,0.6492
|
||||
cuda,resnet50,32,1.0230,29.6914,26.1574,1.0010
|
||||
cuda,resnext50_32x4d,8,0.8679,25.7775,39.3170,0.8524
|
||||
cuda,shufflenet_v2_x1_0,128,1.1374,31.2127,62.0057,0.9590
|
||||
cuda,soft_actor_critic,256,0.9754,3.1737,5.5626,0.9998
|
||||
cuda,speech_transformer,32,1.1390,94.3465,74.5561,0.8732
|
||||
cuda,squeezenet1_1,32,1.1572,9.2585,19.0393,0.9243
|
||||
cuda,timm_efficientdet,1,1.3338,95.3918,255.9148,1.0310
|
||||
cuda,timm_efficientnet,32,1.1237,34.3466,80.1230,0.9445
|
||||
cuda,timm_nfnet,128,1.4441,95.5148,36.3090,1.1050
|
||||
cuda,timm_regnet,32,1.0374,65.3419,57.6930,0.9528
|
||||
cuda,timm_resnest,32,1.5878,18.2585,54.0304,0.9636
|
||||
cuda,timm_vision_transformer,8,1.0850,51.7360,50.3927,0.7429
|
||||
cuda,timm_vision_transformer_large,0,0.0000
|
||||
cuda,timm_vovnet,32,1.1318,27.3068,27.8668,0.8884
|
||||
cuda,hf_Reformer,4,1.1709,64.6979,15.7035,0.9267
|
||||
cuda,hf_T5_large,2,1.7215,107.0798,148.8805,1.1684
|
||||
cuda,lennard_jones,1000,0.8428,1.8488,3.0609,1.0001
|
||||
cuda,maml_omniglot,32,0.9648,2.6869,3.9775,0.9999
|
||||
cuda,mnasnet1_0,32,1.0469,21.6251,25.8232,0.9996
|
||||
cuda,mobilenet_v2,96,1.5604,31.9572,27.0225,1.1734
|
||||
cuda,nvidia_deeprecommender,256,1.0605,9.2080,4.1318,0.9711
|
||||
cuda,phlippe_densenet,128,1.0237,27.5988,28.0400,1.0023
|
||||
cuda,phlippe_resnet,128,1.0493,10.9751,10.2485,1.0092
|
||||
cuda,pytorch_CycleGAN_and_pix2pix,1,1.3724,8.2225,11.9561,1.0219
|
||||
cuda,pytorch_stargan,16,1.1835,11.9178,10.0507,1.0868
|
||||
cuda,pytorch_unet,1,1.3787,29.7543,13.7711,1.0100
|
||||
cuda,resnet152,32,0.9834,63.2446,67.7935,0.9991
|
||||
cuda,resnet18,16,0.9451,9.4977,11.7663,0.9948
|
||||
cuda,resnet50,32,1.0513,24.5141,24.6629,1.0021
|
||||
cuda,resnext50_32x4d,8,0.9216,22.2460,24.3420,0.9984
|
||||
cuda,shufflenet_v2_x1_0,128,1.1943,25.4520,28.8611,1.0951
|
||||
cuda,soft_actor_critic,256,0.8691,1.9637,3.3716,0.9996
|
||||
cuda,speech_transformer,32,1.2718,35.2922,46.9957,1.0897
|
||||
cuda,squeezenet1_1,32,1.1302,8.4540,7.9625,1.0771
|
||||
cuda,timm_efficientdet,1,1.3370,80.0377,120.1814,1.2713
|
||||
cuda,timm_efficientnet,32,1.1874,27.6302,33.9059,1.0971
|
||||
cuda,timm_nfnet,128,1.4525,77.3461,34.3270,1.1056
|
||||
cuda,timm_regnet,32,1.0644,50.6953,35.7562,1.0000
|
||||
cuda,timm_resnest,32,1.6200,14.7763,17.2245,1.0906
|
||||
cuda,timm_vision_transformer,32,1.0800,19.4188,22.0255,0.9966
|
||||
cuda,timm_vision_transformer_large,32,1.0081,393.1742,127.8083,0.9735
|
||||
cuda,timm_vovnet,32,1.1472,22.4727,22.7328,1.0120
|
||||
cuda,torchrec_dlrm,0,0.0000
|
||||
cuda,tts_angular,64,0.8185,10.2896,5.1774,1.0015
|
||||
cuda,vgg16,64,1.2931,61.1714,10.9558,0.9828
|
||||
cuda,yolov3,16,1.2202,68.8346,86.5149,1.0437
|
||||
cuda,tts_angular,64,0.8974,6.5057,2.5555,0.9973
|
||||
cuda,vgg16,64,1.2909,50.7405,6.1510,0.9828
|
||||
cuda,yolov3,16,1.2930,54.8069,41.9269,1.0563
|
||||
|
||||
|
@@ -4,7 +4,7 @@ from collections import namedtuple
|
||||
|
||||
# Create a named tuple for the output of the benchmark
|
||||
BenchmarkOutput = namedtuple(
|
||||
'BenchmarkOutput', ['dev', 'name', 'batch_size', 'speedup'])
|
||||
'BenchmarkOutput', ['dev', 'name', 'batch_size', 'speedup', 'latency'])
|
||||
|
||||
|
||||
def parse_output(file_path: str) -> dict:
|
||||
@@ -12,39 +12,52 @@ def parse_output(file_path: str) -> dict:
|
||||
with open(file_path) as f:
|
||||
reader = csv.reader(f)
|
||||
for i, row in enumerate(reader):
|
||||
if i == 0:
|
||||
if i == 0 or len(row) < 5:
|
||||
continue
|
||||
dev = row[0]
|
||||
name = row[1]
|
||||
batch_size = row[2]
|
||||
speedup = float(row[3])
|
||||
entries[name] = BenchmarkOutput(dev, name, batch_size, speedup)
|
||||
latency = float(row[4])
|
||||
entries[name] = BenchmarkOutput(dev, name, batch_size, speedup, latency)
|
||||
return entries
|
||||
|
||||
|
||||
def compare(baseline: dict, new: dict, threshold: float) -> bool:
|
||||
def compare(baseline: dict, new: dict, threshold: float, geomean_threshold: float) -> bool:
|
||||
baseline_geomean = 1.0
|
||||
new_geomean = 1.0
|
||||
for key in new:
|
||||
if key not in baseline:
|
||||
print(f"New benchmark {key} not found in baseline")
|
||||
baseline_speedup = baseline[key].speedup
|
||||
new_speedup = new[key].speedup
|
||||
if new_speedup < baseline_speedup * (1 - threshold):
|
||||
baseline_latency = baseline[key].latency
|
||||
new_latency = new[key].latency
|
||||
if new_latency < baseline_latency * (1 - threshold):
|
||||
print(
|
||||
f"New benchmark {key} is slower than baseline: {new_speedup} vs {baseline_speedup}")
|
||||
elif new_speedup > baseline_speedup * (1 + threshold):
|
||||
f"New benchmark {key} is faster than baseline: {new_latency} vs {baseline_latency}")
|
||||
elif new_latency > baseline_latency * (1 + threshold):
|
||||
print(
|
||||
f"New benchmark {key} is faster than baseline: {new_speedup} vs {baseline_speedup}")
|
||||
f"New benchmark {key} is slower than baseline: {new_latency} vs {baseline_latency}")
|
||||
baseline_geomean *= baseline[key].speedup
|
||||
new_geomean *= new[key].speedup
|
||||
|
||||
baseline_geomean = baseline_geomean ** (1 / len(baseline))
|
||||
new_geomean = new_geomean ** (1 / len(new))
|
||||
print(f"Baseline geomean: {baseline_geomean}")
|
||||
print(f"New geomean: {new_geomean}")
|
||||
assert new_geomean > baseline_geomean * (1 - geomean_threshold), \
|
||||
f"New geomean is slower than baseline: {new_geomean} vs {baseline_geomean}"
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--baseline', required=True)
|
||||
parser.add_argument('--new', required=True)
|
||||
parser.add_argument('--threshold', type=float, default=0.02)
|
||||
parser.add_argument('--threshold', type=float, default=0.1)
|
||||
parser.add_argument('--geomean-threshold', type=float, default=0.02)
|
||||
args = parser.parse_args()
|
||||
baseline = parse_output(args.baseline)
|
||||
new = parse_output(args.new)
|
||||
compare(baseline, new, args.threshold)
|
||||
compare(baseline, new, args.threshold, args.geomean_threshold)
|
||||
|
||||
|
||||
main()
|
||||
|
||||
@@ -27,6 +27,11 @@ cd "$ROOT" || exit
|
||||
for model in "${MODELS[@]}"; do
|
||||
echo "Checking performance test for $model"
|
||||
python3 "$INDUCTOR"/scripts/check_perf.py --new "$TEST_REPORTS_DIR"/"$model".csv --baseline "$INDUCTOR"/data/"$model".csv
|
||||
EXIT_STATUS=$?
|
||||
if [ "$EXIT_STATUS" -ne 0 ]; then
|
||||
echo "Performance test for $model failed"
|
||||
exit "$EXIT_STATUS"
|
||||
fi
|
||||
done
|
||||
|
||||
# unlock GPU clocks
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
cmake_minimum_required(VERSION 3.6)
|
||||
cmake_minimum_required(VERSION 3.18)
|
||||
|
||||
if(POLICY CMP0116)
|
||||
# Introduced in cmake 3.20
|
||||
# https://cmake.org/cmake/help/latest/policy/CMP0116.html
|
||||
cmake_policy(SET CMP0116 OLD)
|
||||
endif()
|
||||
|
||||
@@ -12,6 +14,7 @@ set(CMAKE_INCLUDE_CURRENT_DIR ON)
|
||||
|
||||
project(triton)
|
||||
include(CTest)
|
||||
|
||||
if(NOT WIN32)
|
||||
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
|
||||
endif()
|
||||
@@ -24,7 +27,7 @@ option(TRITON_BUILD_PYTHON_MODULE "Build Python Triton bindings" OFF)
|
||||
set(TRITON_USE_ROCM ON)
|
||||
|
||||
# Ensure Python3 vars are set correctly
|
||||
# used conditionally in this file and by lit tests
|
||||
# used conditionally in this file and by lit tests
|
||||
|
||||
# Customized release build type with assertions: TritonRelBuildWithAsserts
|
||||
set(CMAKE_C_FLAGS_TRITONRELBUILDWITHASSERTS "-O2 -g")
|
||||
@@ -37,7 +40,7 @@ if(NOT CMAKE_BUILD_TYPE)
|
||||
endif()
|
||||
|
||||
if(NOT WIN32)
|
||||
find_library(TERMINFO_LIBRARY tinfo)
|
||||
find_library(TERMINFO_LIBRARY tinfo)
|
||||
endif()
|
||||
|
||||
# Compiler flags
|
||||
@@ -47,9 +50,9 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/include)
|
||||
include_directories(${PYBIND11_INCLUDE_DIR})
|
||||
|
||||
if(WIN32)
|
||||
SET(BUILD_SHARED_LIBS OFF)
|
||||
find_package(dlfcn-win32 REQUIRED)
|
||||
set(CMAKE_DL_LIBS dlfcn-win32::dl)
|
||||
SET(BUILD_SHARED_LIBS OFF)
|
||||
find_package(dlfcn-win32 REQUIRED)
|
||||
set(CMAKE_DL_LIBS dlfcn-win32::dl)
|
||||
endif()
|
||||
|
||||
set(CMAKE_CXX_FLAGS "${CMAKE_C_FLAGS} -D__STDC_FORMAT_MACROS -fPIC -std=gnu++17 -fvisibility=hidden -fvisibility-inlines-hidden")
|
||||
@@ -62,12 +65,10 @@ if(APPLE)
|
||||
set(CMAKE_OSX_DEPLOYMENT_TARGET 11.6)
|
||||
endif()
|
||||
|
||||
|
||||
|
||||
##########
|
||||
# #########
|
||||
# LLVM
|
||||
##########
|
||||
if (NOT MLIR_DIR)
|
||||
# #########
|
||||
if(NOT MLIR_DIR)
|
||||
if(NOT LLVM_LIBRARY_DIR)
|
||||
if(WIN32)
|
||||
find_package(LLVM 13 REQUIRED COMPONENTS nvptx amdgpu)
|
||||
@@ -83,12 +84,16 @@ if (NOT MLIR_DIR)
|
||||
else()
|
||||
find_package(LLVM 11 REQUIRED COMPONENTS "nvptx;amdgpu")
|
||||
endif()
|
||||
|
||||
message(STATUS "Found LLVM ${LLVM_PACKAGE_VERSION}")
|
||||
|
||||
# FindLLVM outputs LLVM_LIBRARY_DIRS but we expect LLVM_LIBRARY_DIR here
|
||||
set(LLVM_LIBRARY_DIR ${LLVM_LIBRARY_DIRS})
|
||||
|
||||
if(APPLE)
|
||||
set(CMAKE_OSX_DEPLOYMENT_TARGET "10.14")
|
||||
endif()
|
||||
|
||||
# sometimes we don't want to use llvm-config, since it may have been downloaded for some specific linux distros
|
||||
else()
|
||||
set(LLVM_LDFLAGS "-L${LLVM_LIBRARY_DIR}")
|
||||
@@ -148,37 +153,38 @@ if (NOT MLIR_DIR)
|
||||
libLLVMAnalysis.a
|
||||
)
|
||||
endif()
|
||||
set (MLIR_DIR ${LLVM_LIBRARY_DIR}/cmake/mlir)
|
||||
|
||||
set(MLIR_DIR ${LLVM_LIBRARY_DIR}/cmake/mlir)
|
||||
endif()
|
||||
|
||||
# Python module
|
||||
if(TRITON_BUILD_PYTHON_MODULE)
|
||||
message(STATUS "Adding Python module")
|
||||
set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src)
|
||||
set(PYTHON_SRC ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/triton.cc)
|
||||
include_directories("." ${PYTHON_SRC_PATH})
|
||||
if (PYTHON_INCLUDE_DIRS)
|
||||
include_directories(${PYTHON_INCLUDE_DIRS})
|
||||
else()
|
||||
find_package(Python3 REQUIRED COMPONENTS Development Interpreter)
|
||||
include_directories(${Python3_INCLUDE_DIRS})
|
||||
link_directories(${Python3_LIBRARY_DIRS})
|
||||
link_libraries(${Python3_LIBRARIES})
|
||||
add_link_options(${Python3_LINK_OPTIONS})
|
||||
endif()
|
||||
message(STATUS "Adding Python module")
|
||||
set(PYTHON_SRC_PATH ${CMAKE_CURRENT_SOURCE_DIR}/python/src)
|
||||
set(PYTHON_SRC ${PYTHON_SRC_PATH}/main.cc ${PYTHON_SRC_PATH}/triton.cc)
|
||||
include_directories("." ${PYTHON_SRC_PATH})
|
||||
|
||||
if(PYTHON_INCLUDE_DIRS)
|
||||
include_directories(${PYTHON_INCLUDE_DIRS})
|
||||
else()
|
||||
find_package(Python3 REQUIRED COMPONENTS Development Interpreter)
|
||||
include_directories(${Python3_INCLUDE_DIRS})
|
||||
link_directories(${Python3_LIBRARY_DIRS})
|
||||
link_libraries(${Python3_LIBRARIES})
|
||||
add_link_options(${Python3_LINK_OPTIONS})
|
||||
endif()
|
||||
endif()
|
||||
|
||||
# # Triton
|
||||
# file(GLOB_RECURSE LIBTRITON_SRC lib/*.cc)
|
||||
# if (WIN32 AND TRITON_BUILD_PYTHON_MODULE)
|
||||
# Python3_add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
||||
# set_target_properties(triton PROPERTIES SUFFIX ".pyd")
|
||||
# set_target_properties(triton PROPERTIES PREFIX "lib")
|
||||
# Python3_add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
||||
# set_target_properties(triton PROPERTIES SUFFIX ".pyd")
|
||||
# set_target_properties(triton PROPERTIES PREFIX "lib")
|
||||
# else()
|
||||
# add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
||||
# add_library(triton SHARED ${LIBTRITON_SRC} ${PYTHON_SRC})
|
||||
# endif()
|
||||
|
||||
|
||||
# MLIR
|
||||
find_package(MLIR REQUIRED CONFIG PATHS ${MLIR_DIR})
|
||||
|
||||
@@ -196,14 +202,13 @@ include_directories(${MLIR_INCLUDE_DIRS})
|
||||
include_directories(${LLVM_INCLUDE_DIRS})
|
||||
include_directories(${PROJECT_SOURCE_DIR}/include)
|
||||
include_directories(${PROJECT_BINARY_DIR}/include) # Tablegen'd files
|
||||
# link_directories(${LLVM_LIBRARY_DIR})
|
||||
|
||||
# link_directories(${LLVM_LIBRARY_DIR})
|
||||
add_subdirectory(include)
|
||||
add_subdirectory(lib)
|
||||
add_subdirectory(bin)
|
||||
|
||||
# find_package(PythonLibs REQUIRED)
|
||||
|
||||
set(TRITON_SOURCE_DIR "${CMAKE_CURRENT_SOURCE_DIR}")
|
||||
set(TRITON_BINARY_DIR "${CMAKE_CURRENT_BINARY_DIR}")
|
||||
|
||||
@@ -221,6 +226,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
|
||||
TritonHSACO
|
||||
${dialect_libs}
|
||||
${conversion_libs}
|
||||
|
||||
# optimizations
|
||||
MLIRPass
|
||||
MLIRTransforms
|
||||
@@ -233,6 +239,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
|
||||
MLIRROCDLToLLVMIRTranslation
|
||||
MLIRIR
|
||||
)
|
||||
|
||||
if(WIN32)
|
||||
target_link_libraries(triton PRIVATE ${LLVM_LIBRARIES} ${CMAKE_DL_LIBS}
|
||||
${TRITON_LIBRARIES}
|
||||
@@ -246,21 +253,23 @@ if(TRITON_BUILD_PYTHON_MODULE)
|
||||
${TRITON_LIBRARIES}
|
||||
)
|
||||
endif()
|
||||
|
||||
|
||||
target_link_options(triton PRIVATE ${LLVM_LDFLAGS})
|
||||
endif()
|
||||
|
||||
if (UNIX AND NOT APPLE)
|
||||
if(UNIX AND NOT APPLE)
|
||||
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -Wl,--exclude-libs,ALL")
|
||||
endif()
|
||||
|
||||
if(TRITON_BUILD_PYTHON_MODULE AND NOT WIN32)
|
||||
set(CMAKE_SHARED_LIBRARY_SUFFIX ".so")
|
||||
# Check if the platform is MacOS
|
||||
if(APPLE)
|
||||
set(PYTHON_LDFLAGS "-undefined dynamic_lookup -flto")
|
||||
endif()
|
||||
target_link_libraries(triton ${CUTLASS_LIBRARIES} ${PYTHON_LDFLAGS})
|
||||
set(CMAKE_SHARED_LIBRARY_SUFFIX ".so")
|
||||
|
||||
# Check if the platform is MacOS
|
||||
if(APPLE)
|
||||
set(PYTHON_LDFLAGS "-undefined dynamic_lookup -flto")
|
||||
endif()
|
||||
|
||||
target_link_libraries(triton ${CUTLASS_LIBRARIES} ${PYTHON_LDFLAGS})
|
||||
endif()
|
||||
|
||||
add_subdirectory(test)
|
||||
|
||||
15
README.md
15
README.md
@@ -25,7 +25,7 @@ You can install the latest stable release of Triton from pip:
|
||||
```bash
|
||||
pip install triton
|
||||
```
|
||||
Binary wheels are available for CPython 3.6-3.9 and PyPy 3.6-3.7.
|
||||
Binary wheels are available for CPython 3.6-3.11 and PyPy 3.7-3.9.
|
||||
|
||||
And the latest nightly release:
|
||||
|
||||
@@ -66,12 +66,11 @@ pip install -e .
|
||||
|
||||
# Changelog
|
||||
|
||||
Version 1.1 is out! New features include:
|
||||
Version 2.0 is out! New features include:
|
||||
- Many, many bugfixes
|
||||
- More documentation
|
||||
- Automatic on-disk caching of compiled binary objects
|
||||
- Random Number Generation
|
||||
- Faster (up to 2x on A100), cleaner blocksparse ops
|
||||
- Performance improvements
|
||||
- Backend rewritten to use MLIR
|
||||
- Support for kernels that contain back-to-back matmuls (e.g., flash attention)
|
||||
|
||||
# Contributing
|
||||
|
||||
@@ -88,7 +87,3 @@ Supported Platforms:
|
||||
Supported Hardware:
|
||||
* NVIDIA GPUs (Compute Capability 7.0+)
|
||||
* Under development: AMD GPUs, CPUs
|
||||
|
||||
# Disclaimer
|
||||
|
||||
Triton is a fairly recent project, and it is under active development. We expect it to be pretty useful in a wide variety of cases, but don't be surprised if it's a bit rough around the edges :)
|
||||
|
||||
@@ -1,7 +1,3 @@
|
||||
add_subdirectory(FileCheck)
|
||||
# add_llvm_executable(FileCheck FileCheck/FileCheck.cpp)
|
||||
# target_link_libraries(FileCheck PRIVATE LLVMFileCheck LLVMSupport)
|
||||
|
||||
get_property(dialect_libs GLOBAL PROPERTY MLIR_DIALECT_LIBS)
|
||||
get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS)
|
||||
|
||||
|
||||
@@ -1,2 +0,0 @@
|
||||
add_llvm_executable(FileCheck FileCheck.cpp)
|
||||
target_link_libraries(FileCheck PRIVATE LLVMFileCheck LLVMSupport)
|
||||
@@ -1,885 +0,0 @@
|
||||
//===- FileCheck.cpp - Check that File's Contents match what is expected --===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// FileCheck does a line-by line check of a file that validates whether it
|
||||
// contains the expected content. This is useful for regression tests etc.
|
||||
//
|
||||
// This program exits with an exit status of 2 on error, exit status of 0 if
|
||||
// the file matched the expected contents, and exit status of 1 if it did not
|
||||
// contain the expected contents.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "llvm/FileCheck/FileCheck.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
#include "llvm/Support/InitLLVM.h"
|
||||
#include "llvm/Support/Process.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "llvm/Support/WithColor.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include <cmath>
|
||||
#include <map>
|
||||
using namespace llvm;
|
||||
|
||||
static cl::extrahelp FileCheckOptsEnv(
|
||||
"\nOptions are parsed from the environment variable FILECHECK_OPTS and\n"
|
||||
"from the command line.\n");
|
||||
|
||||
static cl::opt<std::string>
|
||||
CheckFilename(cl::Positional, cl::desc("<check-file>"), cl::Optional);
|
||||
|
||||
static cl::opt<std::string>
|
||||
InputFilename("input-file", cl::desc("File to check (defaults to stdin)"),
|
||||
cl::init("-"), cl::value_desc("filename"));
|
||||
|
||||
static cl::list<std::string> CheckPrefixes(
|
||||
"check-prefix",
|
||||
cl::desc("Prefix to use from check file (defaults to 'CHECK')"));
|
||||
static cl::alias CheckPrefixesAlias(
|
||||
"check-prefixes", cl::aliasopt(CheckPrefixes), cl::CommaSeparated,
|
||||
cl::NotHidden,
|
||||
cl::desc(
|
||||
"Alias for -check-prefix permitting multiple comma separated values"));
|
||||
|
||||
static cl::list<std::string> CommentPrefixes(
|
||||
"comment-prefixes", cl::CommaSeparated, cl::Hidden,
|
||||
cl::desc("Comma-separated list of comment prefixes to use from check file\n"
|
||||
"(defaults to 'COM,RUN'). Please avoid using this feature in\n"
|
||||
"LLVM's LIT-based test suites, which should be easier to\n"
|
||||
"maintain if they all follow a consistent comment style. This\n"
|
||||
"feature is meant for non-LIT test suites using FileCheck."));
|
||||
|
||||
static cl::opt<bool> NoCanonicalizeWhiteSpace(
|
||||
"strict-whitespace",
|
||||
cl::desc("Do not treat all horizontal whitespace as equivalent"));
|
||||
|
||||
static cl::opt<bool> IgnoreCase("ignore-case",
|
||||
cl::desc("Use case-insensitive matching"));
|
||||
|
||||
static cl::list<std::string> ImplicitCheckNot(
|
||||
"implicit-check-not",
|
||||
cl::desc("Add an implicit negative check with this pattern to every\n"
|
||||
"positive check. This can be used to ensure that no instances of\n"
|
||||
"this pattern occur which are not matched by a positive pattern"),
|
||||
cl::value_desc("pattern"));
|
||||
|
||||
static cl::list<std::string>
|
||||
GlobalDefines("D", cl::AlwaysPrefix,
|
||||
cl::desc("Define a variable to be used in capture patterns."),
|
||||
cl::value_desc("VAR=VALUE"));
|
||||
|
||||
static cl::opt<bool> AllowEmptyInput(
|
||||
"allow-empty", cl::init(false),
|
||||
cl::desc("Allow the input file to be empty. This is useful when making\n"
|
||||
"checks that some error message does not occur, for example."));
|
||||
|
||||
static cl::opt<bool> AllowUnusedPrefixes(
|
||||
"allow-unused-prefixes", cl::init(false), cl::ZeroOrMore,
|
||||
cl::desc("Allow prefixes to be specified but not appear in the test."));
|
||||
|
||||
static cl::opt<bool> MatchFullLines(
|
||||
"match-full-lines", cl::init(false),
|
||||
cl::desc("Require all positive matches to cover an entire input line.\n"
|
||||
"Allows leading and trailing whitespace if --strict-whitespace\n"
|
||||
"is not also passed."));
|
||||
|
||||
static cl::opt<bool> EnableVarScope(
|
||||
"enable-var-scope", cl::init(false),
|
||||
cl::desc("Enables scope for regex variables. Variables with names that\n"
|
||||
"do not start with '$' will be reset at the beginning of\n"
|
||||
"each CHECK-LABEL block."));
|
||||
|
||||
static cl::opt<bool> AllowDeprecatedDagOverlap(
|
||||
"allow-deprecated-dag-overlap", cl::init(false),
|
||||
cl::desc("Enable overlapping among matches in a group of consecutive\n"
|
||||
"CHECK-DAG directives. This option is deprecated and is only\n"
|
||||
"provided for convenience as old tests are migrated to the new\n"
|
||||
"non-overlapping CHECK-DAG implementation.\n"));
|
||||
|
||||
static cl::opt<bool> Verbose(
|
||||
"v", cl::init(false), cl::ZeroOrMore,
|
||||
cl::desc("Print directive pattern matches, or add them to the input dump\n"
|
||||
"if enabled.\n"));
|
||||
|
||||
static cl::opt<bool> VerboseVerbose(
|
||||
"vv", cl::init(false), cl::ZeroOrMore,
|
||||
cl::desc("Print information helpful in diagnosing internal FileCheck\n"
|
||||
"issues, or add it to the input dump if enabled. Implies\n"
|
||||
"-v.\n"));
|
||||
|
||||
// The order of DumpInputValue members affects their precedence, as documented
|
||||
// for -dump-input below.
|
||||
enum DumpInputValue {
|
||||
DumpInputNever,
|
||||
DumpInputFail,
|
||||
DumpInputAlways,
|
||||
DumpInputHelp
|
||||
};
|
||||
|
||||
static cl::list<DumpInputValue> DumpInputs(
|
||||
"dump-input",
|
||||
cl::desc("Dump input to stderr, adding annotations representing\n"
|
||||
"currently enabled diagnostics. When there are multiple\n"
|
||||
"occurrences of this option, the <value> that appears earliest\n"
|
||||
"in the list below has precedence. The default is 'fail'.\n"),
|
||||
cl::value_desc("mode"),
|
||||
cl::values(clEnumValN(DumpInputHelp, "help", "Explain input dump and quit"),
|
||||
clEnumValN(DumpInputAlways, "always", "Always dump input"),
|
||||
clEnumValN(DumpInputFail, "fail", "Dump input on failure"),
|
||||
clEnumValN(DumpInputNever, "never", "Never dump input")));
|
||||
|
||||
// The order of DumpInputFilterValue members affects their precedence, as
|
||||
// documented for -dump-input-filter below.
|
||||
enum DumpInputFilterValue {
|
||||
DumpInputFilterError,
|
||||
DumpInputFilterAnnotation,
|
||||
DumpInputFilterAnnotationFull,
|
||||
DumpInputFilterAll
|
||||
};
|
||||
|
||||
static cl::list<DumpInputFilterValue> DumpInputFilters(
|
||||
"dump-input-filter",
|
||||
cl::desc("In the dump requested by -dump-input, print only input lines of\n"
|
||||
"kind <value> plus any context specified by -dump-input-context.\n"
|
||||
"When there are multiple occurrences of this option, the <value>\n"
|
||||
"that appears earliest in the list below has precedence. The\n"
|
||||
"default is 'error' when -dump-input=fail, and it's 'all' when\n"
|
||||
"-dump-input=always.\n"),
|
||||
cl::values(clEnumValN(DumpInputFilterAll, "all", "All input lines"),
|
||||
clEnumValN(DumpInputFilterAnnotationFull, "annotation-full",
|
||||
"Input lines with annotations"),
|
||||
clEnumValN(DumpInputFilterAnnotation, "annotation",
|
||||
"Input lines with starting points of annotations"),
|
||||
clEnumValN(DumpInputFilterError, "error",
|
||||
"Input lines with starting points of error "
|
||||
"annotations")));
|
||||
|
||||
static cl::list<unsigned> DumpInputContexts(
|
||||
"dump-input-context", cl::value_desc("N"),
|
||||
cl::desc("In the dump requested by -dump-input, print <N> input lines\n"
|
||||
"before and <N> input lines after any lines specified by\n"
|
||||
"-dump-input-filter. When there are multiple occurrences of\n"
|
||||
"this option, the largest specified <N> has precedence. The\n"
|
||||
"default is 5.\n"));
|
||||
|
||||
typedef cl::list<std::string>::const_iterator prefix_iterator;
|
||||
|
||||
static void DumpCommandLine(int argc, char **argv) {
|
||||
errs() << "FileCheck command line: ";
|
||||
for (int I = 0; I < argc; I++)
|
||||
errs() << " " << argv[I];
|
||||
errs() << "\n";
|
||||
}
|
||||
|
||||
struct MarkerStyle {
|
||||
/// The starting char (before tildes) for marking the line.
|
||||
char Lead;
|
||||
/// What color to use for this annotation.
|
||||
raw_ostream::Colors Color;
|
||||
/// A note to follow the marker, or empty string if none.
|
||||
std::string Note;
|
||||
/// Does this marker indicate inclusion by -dump-input-filter=error?
|
||||
bool FiltersAsError;
|
||||
MarkerStyle() {}
|
||||
MarkerStyle(char Lead, raw_ostream::Colors Color,
|
||||
const std::string &Note = "", bool FiltersAsError = false)
|
||||
: Lead(Lead), Color(Color), Note(Note), FiltersAsError(FiltersAsError) {
|
||||
assert((!FiltersAsError || !Note.empty()) &&
|
||||
"expected error diagnostic to have note");
|
||||
}
|
||||
};
|
||||
|
||||
static MarkerStyle GetMarker(FileCheckDiag::MatchType MatchTy) {
|
||||
switch (MatchTy) {
|
||||
case FileCheckDiag::MatchFoundAndExpected:
|
||||
return MarkerStyle('^', raw_ostream::GREEN);
|
||||
case FileCheckDiag::MatchFoundButExcluded:
|
||||
return MarkerStyle('!', raw_ostream::RED, "error: no match expected",
|
||||
/*FiltersAsError=*/true);
|
||||
case FileCheckDiag::MatchFoundButWrongLine:
|
||||
return MarkerStyle('!', raw_ostream::RED, "error: match on wrong line",
|
||||
/*FiltersAsError=*/true);
|
||||
case FileCheckDiag::MatchFoundButDiscarded:
|
||||
return MarkerStyle('!', raw_ostream::CYAN,
|
||||
"discard: overlaps earlier match");
|
||||
case FileCheckDiag::MatchFoundErrorNote:
|
||||
// Note should always be overridden within the FileCheckDiag.
|
||||
return MarkerStyle('!', raw_ostream::RED,
|
||||
"error: unknown error after match",
|
||||
/*FiltersAsError=*/true);
|
||||
case FileCheckDiag::MatchNoneAndExcluded:
|
||||
return MarkerStyle('X', raw_ostream::GREEN);
|
||||
case FileCheckDiag::MatchNoneButExpected:
|
||||
return MarkerStyle('X', raw_ostream::RED, "error: no match found",
|
||||
/*FiltersAsError=*/true);
|
||||
case FileCheckDiag::MatchNoneForInvalidPattern:
|
||||
return MarkerStyle('X', raw_ostream::RED,
|
||||
"error: match failed for invalid pattern",
|
||||
/*FiltersAsError=*/true);
|
||||
case FileCheckDiag::MatchFuzzy:
|
||||
return MarkerStyle('?', raw_ostream::MAGENTA, "possible intended match",
|
||||
/*FiltersAsError=*/true);
|
||||
}
|
||||
llvm_unreachable_internal("unexpected match type");
|
||||
}
|
||||
|
||||
static void DumpInputAnnotationHelp(raw_ostream &OS) {
|
||||
OS << "The following description was requested by -dump-input=help to\n"
|
||||
<< "explain the input dump printed by FileCheck.\n"
|
||||
<< "\n"
|
||||
<< "Related command-line options:\n"
|
||||
<< "\n"
|
||||
<< " - -dump-input=<value> enables or disables the input dump\n"
|
||||
<< " - -dump-input-filter=<value> filters the input lines\n"
|
||||
<< " - -dump-input-context=<N> adjusts the context of filtered lines\n"
|
||||
<< " - -v and -vv add more annotations\n"
|
||||
<< " - -color forces colors to be enabled both in the dump and below\n"
|
||||
<< " - -help documents the above options in more detail\n"
|
||||
<< "\n"
|
||||
<< "These options can also be set via FILECHECK_OPTS. For example, for\n"
|
||||
<< "maximum debugging output on failures:\n"
|
||||
<< "\n"
|
||||
<< " $ FILECHECK_OPTS='-dump-input-filter=all -vv -color' ninja check\n"
|
||||
<< "\n"
|
||||
<< "Input dump annotation format:\n"
|
||||
<< "\n";
|
||||
|
||||
// Labels for input lines.
|
||||
OS << " - ";
|
||||
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "L:";
|
||||
OS << " labels line number L of the input file\n"
|
||||
<< " An extra space is added after each input line to represent"
|
||||
<< " the\n"
|
||||
<< " newline character\n";
|
||||
|
||||
// Labels for annotation lines.
|
||||
OS << " - ";
|
||||
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "T:L";
|
||||
OS << " labels the only match result for either (1) a pattern of type T"
|
||||
<< " from\n"
|
||||
<< " line L of the check file if L is an integer or (2) the"
|
||||
<< " I-th implicit\n"
|
||||
<< " pattern if L is \"imp\" followed by an integer "
|
||||
<< "I (index origin one)\n";
|
||||
OS << " - ";
|
||||
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "T:L'N";
|
||||
OS << " labels the Nth match result for such a pattern\n";
|
||||
|
||||
// Markers on annotation lines.
|
||||
OS << " - ";
|
||||
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "^~~";
|
||||
OS << " marks good match (reported if -v)\n"
|
||||
<< " - ";
|
||||
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "!~~";
|
||||
OS << " marks bad match, such as:\n"
|
||||
<< " - CHECK-NEXT on same line as previous match (error)\n"
|
||||
<< " - CHECK-NOT found (error)\n"
|
||||
<< " - CHECK-DAG overlapping match (discarded, reported if "
|
||||
<< "-vv)\n"
|
||||
<< " - ";
|
||||
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "X~~";
|
||||
OS << " marks search range when no match is found, such as:\n"
|
||||
<< " - CHECK-NEXT not found (error)\n"
|
||||
<< " - CHECK-NOT not found (success, reported if -vv)\n"
|
||||
<< " - CHECK-DAG not found after discarded matches (error)\n"
|
||||
<< " - ";
|
||||
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "?";
|
||||
OS << " marks fuzzy match when no match is found\n";
|
||||
|
||||
// Elided lines.
|
||||
OS << " - ";
|
||||
WithColor(OS, raw_ostream::SAVEDCOLOR, true) << "...";
|
||||
OS << " indicates elided input lines and annotations, as specified by\n"
|
||||
<< " -dump-input-filter and -dump-input-context\n";
|
||||
|
||||
// Colors.
|
||||
OS << " - colors ";
|
||||
WithColor(OS, raw_ostream::GREEN, true) << "success";
|
||||
OS << ", ";
|
||||
WithColor(OS, raw_ostream::RED, true) << "error";
|
||||
OS << ", ";
|
||||
WithColor(OS, raw_ostream::MAGENTA, true) << "fuzzy match";
|
||||
OS << ", ";
|
||||
WithColor(OS, raw_ostream::CYAN, true, false) << "discarded match";
|
||||
OS << ", ";
|
||||
WithColor(OS, raw_ostream::CYAN, true, true) << "unmatched input";
|
||||
OS << "\n";
|
||||
}
|
||||
|
||||
/// An annotation for a single input line.
|
||||
struct InputAnnotation {
|
||||
/// The index of the match result across all checks
|
||||
unsigned DiagIndex;
|
||||
/// The label for this annotation.
|
||||
std::string Label;
|
||||
/// Is this the initial fragment of a diagnostic that has been broken across
|
||||
/// multiple lines?
|
||||
bool IsFirstLine;
|
||||
/// What input line (one-origin indexing) this annotation marks. This might
|
||||
/// be different from the starting line of the original diagnostic if
|
||||
/// !IsFirstLine.
|
||||
unsigned InputLine;
|
||||
/// The column range (one-origin indexing, open end) in which to mark the
|
||||
/// input line. If InputEndCol is UINT_MAX, treat it as the last column
|
||||
/// before the newline.
|
||||
unsigned InputStartCol, InputEndCol;
|
||||
/// The marker to use.
|
||||
MarkerStyle Marker;
|
||||
/// Whether this annotation represents a good match for an expected pattern.
|
||||
bool FoundAndExpectedMatch;
|
||||
};
|
||||
|
||||
/// Get an abbreviation for the check type.
|
||||
static std::string GetCheckTypeAbbreviation(Check::FileCheckType Ty) {
|
||||
switch (Ty) {
|
||||
case Check::CheckPlain:
|
||||
if (Ty.getCount() > 1)
|
||||
return "count";
|
||||
return "check";
|
||||
case Check::CheckNext:
|
||||
return "next";
|
||||
case Check::CheckSame:
|
||||
return "same";
|
||||
case Check::CheckNot:
|
||||
return "not";
|
||||
case Check::CheckDAG:
|
||||
return "dag";
|
||||
case Check::CheckLabel:
|
||||
return "label";
|
||||
case Check::CheckEmpty:
|
||||
return "empty";
|
||||
case Check::CheckComment:
|
||||
return "com";
|
||||
case Check::CheckEOF:
|
||||
return "eof";
|
||||
case Check::CheckBadNot:
|
||||
return "bad-not";
|
||||
case Check::CheckBadCount:
|
||||
return "bad-count";
|
||||
case Check::CheckMisspelled:
|
||||
return "misspelled";
|
||||
case Check::CheckNone:
|
||||
llvm_unreachable("invalid FileCheckType");
|
||||
}
|
||||
llvm_unreachable("unknown FileCheckType");
|
||||
}
|
||||
|
||||
static void
|
||||
BuildInputAnnotations(const SourceMgr &SM, unsigned CheckFileBufferID,
|
||||
const std::pair<unsigned, unsigned> &ImpPatBufferIDRange,
|
||||
const std::vector<FileCheckDiag> &Diags,
|
||||
std::vector<InputAnnotation> &Annotations,
|
||||
unsigned &LabelWidth) {
|
||||
struct CompareSMLoc {
|
||||
bool operator()(const SMLoc &LHS, const SMLoc &RHS) const {
|
||||
return LHS.getPointer() < RHS.getPointer();
|
||||
}
|
||||
};
|
||||
// How many diagnostics does each pattern have?
|
||||
std::map<SMLoc, unsigned, CompareSMLoc> DiagCountPerPattern;
|
||||
for (auto Diag : Diags)
|
||||
++DiagCountPerPattern[Diag.CheckLoc];
|
||||
// How many diagnostics have we seen so far per pattern?
|
||||
std::map<SMLoc, unsigned, CompareSMLoc> DiagIndexPerPattern;
|
||||
// How many total diagnostics have we seen so far?
|
||||
unsigned DiagIndex = 0;
|
||||
// What's the widest label?
|
||||
LabelWidth = 0;
|
||||
for (auto DiagItr = Diags.begin(), DiagEnd = Diags.end(); DiagItr != DiagEnd;
|
||||
++DiagItr) {
|
||||
InputAnnotation A;
|
||||
A.DiagIndex = DiagIndex++;
|
||||
|
||||
// Build label, which uniquely identifies this check result.
|
||||
unsigned CheckBufferID = SM.FindBufferContainingLoc(DiagItr->CheckLoc);
|
||||
auto CheckLineAndCol =
|
||||
SM.getLineAndColumn(DiagItr->CheckLoc, CheckBufferID);
|
||||
llvm::raw_string_ostream Label(A.Label);
|
||||
Label << GetCheckTypeAbbreviation(DiagItr->CheckTy) << ":";
|
||||
if (CheckBufferID == CheckFileBufferID)
|
||||
Label << CheckLineAndCol.first;
|
||||
else if (ImpPatBufferIDRange.first <= CheckBufferID &&
|
||||
CheckBufferID < ImpPatBufferIDRange.second)
|
||||
Label << "imp" << (CheckBufferID - ImpPatBufferIDRange.first + 1);
|
||||
else
|
||||
llvm_unreachable("expected diagnostic's check location to be either in "
|
||||
"the check file or for an implicit pattern");
|
||||
if (DiagCountPerPattern[DiagItr->CheckLoc] > 1)
|
||||
Label << "'" << DiagIndexPerPattern[DiagItr->CheckLoc]++;
|
||||
LabelWidth = std::max((std::string::size_type)LabelWidth, A.Label.size());
|
||||
|
||||
A.Marker = GetMarker(DiagItr->MatchTy);
|
||||
if (!DiagItr->Note.empty()) {
|
||||
A.Marker.Note = DiagItr->Note;
|
||||
// It's less confusing if notes that don't actually have ranges don't have
|
||||
// markers. For example, a marker for 'with "VAR" equal to "5"' would
|
||||
// seem to indicate where "VAR" matches, but the location we actually have
|
||||
// for the marker simply points to the start of the match/search range for
|
||||
// the full pattern of which the substitution is potentially just one
|
||||
// component.
|
||||
if (DiagItr->InputStartLine == DiagItr->InputEndLine &&
|
||||
DiagItr->InputStartCol == DiagItr->InputEndCol)
|
||||
A.Marker.Lead = ' ';
|
||||
}
|
||||
if (DiagItr->MatchTy == FileCheckDiag::MatchFoundErrorNote) {
|
||||
assert(!DiagItr->Note.empty() &&
|
||||
"expected custom note for MatchFoundErrorNote");
|
||||
A.Marker.Note = "error: " + A.Marker.Note;
|
||||
}
|
||||
A.FoundAndExpectedMatch =
|
||||
DiagItr->MatchTy == FileCheckDiag::MatchFoundAndExpected;
|
||||
|
||||
// Compute the mark location, and break annotation into multiple
|
||||
// annotations if it spans multiple lines.
|
||||
A.IsFirstLine = true;
|
||||
A.InputLine = DiagItr->InputStartLine;
|
||||
A.InputStartCol = DiagItr->InputStartCol;
|
||||
if (DiagItr->InputStartLine == DiagItr->InputEndLine) {
|
||||
// Sometimes ranges are empty in order to indicate a specific point, but
|
||||
// that would mean nothing would be marked, so adjust the range to
|
||||
// include the following character.
|
||||
A.InputEndCol =
|
||||
std::max(DiagItr->InputStartCol + 1, DiagItr->InputEndCol);
|
||||
Annotations.push_back(A);
|
||||
} else {
|
||||
assert(DiagItr->InputStartLine < DiagItr->InputEndLine &&
|
||||
"expected input range not to be inverted");
|
||||
A.InputEndCol = UINT_MAX;
|
||||
Annotations.push_back(A);
|
||||
for (unsigned L = DiagItr->InputStartLine + 1, E = DiagItr->InputEndLine;
|
||||
L <= E; ++L) {
|
||||
// If a range ends before the first column on a line, then it has no
|
||||
// characters on that line, so there's nothing to render.
|
||||
if (DiagItr->InputEndCol == 1 && L == E)
|
||||
break;
|
||||
InputAnnotation B;
|
||||
B.DiagIndex = A.DiagIndex;
|
||||
B.Label = A.Label;
|
||||
B.IsFirstLine = false;
|
||||
B.InputLine = L;
|
||||
B.Marker = A.Marker;
|
||||
B.Marker.Lead = '~';
|
||||
B.Marker.Note = "";
|
||||
B.InputStartCol = 1;
|
||||
if (L != E)
|
||||
B.InputEndCol = UINT_MAX;
|
||||
else
|
||||
B.InputEndCol = DiagItr->InputEndCol;
|
||||
B.FoundAndExpectedMatch = A.FoundAndExpectedMatch;
|
||||
Annotations.push_back(B);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static unsigned FindInputLineInFilter(
|
||||
DumpInputFilterValue DumpInputFilter, unsigned CurInputLine,
|
||||
const std::vector<InputAnnotation>::iterator &AnnotationBeg,
|
||||
const std::vector<InputAnnotation>::iterator &AnnotationEnd) {
|
||||
if (DumpInputFilter == DumpInputFilterAll)
|
||||
return CurInputLine;
|
||||
for (auto AnnotationItr = AnnotationBeg; AnnotationItr != AnnotationEnd;
|
||||
++AnnotationItr) {
|
||||
switch (DumpInputFilter) {
|
||||
case DumpInputFilterAll:
|
||||
llvm_unreachable("unexpected DumpInputFilterAll");
|
||||
break;
|
||||
case DumpInputFilterAnnotationFull:
|
||||
return AnnotationItr->InputLine;
|
||||
case DumpInputFilterAnnotation:
|
||||
if (AnnotationItr->IsFirstLine)
|
||||
return AnnotationItr->InputLine;
|
||||
break;
|
||||
case DumpInputFilterError:
|
||||
if (AnnotationItr->IsFirstLine && AnnotationItr->Marker.FiltersAsError)
|
||||
return AnnotationItr->InputLine;
|
||||
break;
|
||||
}
|
||||
}
|
||||
return UINT_MAX;
|
||||
}
|
||||
|
||||
/// To OS, print a vertical ellipsis (right-justified at LabelWidth) if it would
|
||||
/// occupy less lines than ElidedLines, but print ElidedLines otherwise. Either
|
||||
/// way, clear ElidedLines. Thus, if ElidedLines is empty, do nothing.
|
||||
static void DumpEllipsisOrElidedLines(raw_ostream &OS, std::string &ElidedLines,
|
||||
unsigned LabelWidth) {
|
||||
if (ElidedLines.empty())
|
||||
return;
|
||||
unsigned EllipsisLines = 3;
|
||||
if (EllipsisLines < StringRef(ElidedLines).count('\n')) {
|
||||
for (unsigned i = 0; i < EllipsisLines; ++i) {
|
||||
WithColor(OS, raw_ostream::BLACK, /*Bold=*/true)
|
||||
<< right_justify(".", LabelWidth);
|
||||
OS << '\n';
|
||||
}
|
||||
} else
|
||||
OS << ElidedLines;
|
||||
ElidedLines.clear();
|
||||
}
|
||||
|
||||
static void DumpAnnotatedInput(raw_ostream &OS, const FileCheckRequest &Req,
|
||||
DumpInputFilterValue DumpInputFilter,
|
||||
unsigned DumpInputContext,
|
||||
StringRef InputFileText,
|
||||
std::vector<InputAnnotation> &Annotations,
|
||||
unsigned LabelWidth) {
|
||||
OS << "Input was:\n<<<<<<\n";
|
||||
|
||||
// Sort annotations.
|
||||
llvm::sort(Annotations,
|
||||
[](const InputAnnotation &A, const InputAnnotation &B) {
|
||||
// 1. Sort annotations in the order of the input lines.
|
||||
//
|
||||
// This makes it easier to find relevant annotations while
|
||||
// iterating input lines in the implementation below. FileCheck
|
||||
// does not always produce diagnostics in the order of input
|
||||
// lines due to, for example, CHECK-DAG and CHECK-NOT.
|
||||
if (A.InputLine != B.InputLine)
|
||||
return A.InputLine < B.InputLine;
|
||||
// 2. Sort annotations in the temporal order FileCheck produced
|
||||
// their associated diagnostics.
|
||||
//
|
||||
// This sort offers several benefits:
|
||||
//
|
||||
// A. On a single input line, the order of annotations reflects
|
||||
// the FileCheck logic for processing directives/patterns.
|
||||
// This can be helpful in understanding cases in which the
|
||||
// order of the associated directives/patterns in the check
|
||||
// file or on the command line either (i) does not match the
|
||||
// temporal order in which FileCheck looks for matches for the
|
||||
// directives/patterns (due to, for example, CHECK-LABEL,
|
||||
// CHECK-NOT, or `--implicit-check-not`) or (ii) does match
|
||||
// that order but does not match the order of those
|
||||
// diagnostics along an input line (due to, for example,
|
||||
// CHECK-DAG).
|
||||
//
|
||||
// On the other hand, because our presentation format presents
|
||||
// input lines in order, there's no clear way to offer the
|
||||
// same benefit across input lines. For consistency, it might
|
||||
// then seem worthwhile to have annotations on a single line
|
||||
// also sorted in input order (that is, by input column).
|
||||
// However, in practice, this appears to be more confusing
|
||||
// than helpful. Perhaps it's intuitive to expect annotations
|
||||
// to be listed in the temporal order in which they were
|
||||
// produced except in cases the presentation format obviously
|
||||
// and inherently cannot support it (that is, across input
|
||||
// lines).
|
||||
//
|
||||
// B. When diagnostics' annotations are split among multiple
|
||||
// input lines, the user must track them from one input line
|
||||
// to the next. One property of the sort chosen here is that
|
||||
// it facilitates the user in this regard by ensuring the
|
||||
// following: when comparing any two input lines, a
|
||||
// diagnostic's annotations are sorted in the same position
|
||||
// relative to all other diagnostics' annotations.
|
||||
return A.DiagIndex < B.DiagIndex;
|
||||
});
|
||||
|
||||
// Compute the width of the label column.
|
||||
const unsigned char *InputFilePtr = InputFileText.bytes_begin(),
|
||||
*InputFileEnd = InputFileText.bytes_end();
|
||||
unsigned LineCount = InputFileText.count('\n');
|
||||
if (InputFileEnd[-1] != '\n')
|
||||
++LineCount;
|
||||
unsigned LineNoWidth = std::log10(LineCount) + 1;
|
||||
// +3 below adds spaces (1) to the left of the (right-aligned) line numbers
|
||||
// on input lines and (2) to the right of the (left-aligned) labels on
|
||||
// annotation lines so that input lines and annotation lines are more
|
||||
// visually distinct. For example, the spaces on the annotation lines ensure
|
||||
// that input line numbers and check directive line numbers never align
|
||||
// horizontally. Those line numbers might not even be for the same file.
|
||||
// One space would be enough to achieve that, but more makes it even easier
|
||||
// to see.
|
||||
LabelWidth = std::max(LabelWidth, LineNoWidth) + 3;
|
||||
|
||||
// Print annotated input lines.
|
||||
unsigned PrevLineInFilter = 0; // 0 means none so far
|
||||
unsigned NextLineInFilter = 0; // 0 means uncomputed, UINT_MAX means none
|
||||
std::string ElidedLines;
|
||||
raw_string_ostream ElidedLinesOS(ElidedLines);
|
||||
ColorMode TheColorMode =
|
||||
WithColor(OS).colorsEnabled() ? ColorMode::Enable : ColorMode::Disable;
|
||||
if (TheColorMode == ColorMode::Enable)
|
||||
ElidedLinesOS.enable_colors(true);
|
||||
auto AnnotationItr = Annotations.begin(), AnnotationEnd = Annotations.end();
|
||||
for (unsigned Line = 1;
|
||||
InputFilePtr != InputFileEnd || AnnotationItr != AnnotationEnd; ++Line) {
|
||||
const unsigned char *InputFileLine = InputFilePtr;
|
||||
|
||||
// Compute the previous and next line included by the filter.
|
||||
if (NextLineInFilter < Line)
|
||||
NextLineInFilter = FindInputLineInFilter(DumpInputFilter, Line,
|
||||
AnnotationItr, AnnotationEnd);
|
||||
assert(NextLineInFilter && "expected NextLineInFilter to be computed");
|
||||
if (NextLineInFilter == Line)
|
||||
PrevLineInFilter = Line;
|
||||
|
||||
// Elide this input line and its annotations if it's not within the
|
||||
// context specified by -dump-input-context of an input line included by
|
||||
// -dump-input-filter. However, in case the resulting ellipsis would occupy
|
||||
// more lines than the input lines and annotations it elides, buffer the
|
||||
// elided lines and annotations so we can print them instead.
|
||||
raw_ostream *LineOS = &OS;
|
||||
if ((!PrevLineInFilter || PrevLineInFilter + DumpInputContext < Line) &&
|
||||
(NextLineInFilter == UINT_MAX ||
|
||||
Line + DumpInputContext < NextLineInFilter))
|
||||
LineOS = &ElidedLinesOS;
|
||||
else {
|
||||
LineOS = &OS;
|
||||
DumpEllipsisOrElidedLines(OS, ElidedLinesOS.str(), LabelWidth);
|
||||
}
|
||||
|
||||
// Print right-aligned line number.
|
||||
WithColor(*LineOS, raw_ostream::BLACK, /*Bold=*/true, /*BF=*/false,
|
||||
TheColorMode)
|
||||
<< format_decimal(Line, LabelWidth) << ": ";
|
||||
|
||||
// For the case where -v and colors are enabled, find the annotations for
|
||||
// good matches for expected patterns in order to highlight everything
|
||||
// else in the line. There are no such annotations if -v is disabled.
|
||||
std::vector<InputAnnotation> FoundAndExpectedMatches;
|
||||
if (Req.Verbose && TheColorMode == ColorMode::Enable) {
|
||||
for (auto I = AnnotationItr; I != AnnotationEnd && I->InputLine == Line;
|
||||
++I) {
|
||||
if (I->FoundAndExpectedMatch)
|
||||
FoundAndExpectedMatches.push_back(*I);
|
||||
}
|
||||
}
|
||||
|
||||
// Print numbered line with highlighting where there are no matches for
|
||||
// expected patterns.
|
||||
bool Newline = false;
|
||||
{
|
||||
WithColor COS(*LineOS, raw_ostream::SAVEDCOLOR, /*Bold=*/false,
|
||||
/*BG=*/false, TheColorMode);
|
||||
bool InMatch = false;
|
||||
if (Req.Verbose)
|
||||
COS.changeColor(raw_ostream::CYAN, true, true);
|
||||
for (unsigned Col = 1; InputFilePtr != InputFileEnd && !Newline; ++Col) {
|
||||
bool WasInMatch = InMatch;
|
||||
InMatch = false;
|
||||
for (auto M : FoundAndExpectedMatches) {
|
||||
if (M.InputStartCol <= Col && Col < M.InputEndCol) {
|
||||
InMatch = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if (!WasInMatch && InMatch)
|
||||
COS.resetColor();
|
||||
else if (WasInMatch && !InMatch)
|
||||
COS.changeColor(raw_ostream::CYAN, true, true);
|
||||
if (*InputFilePtr == '\n') {
|
||||
Newline = true;
|
||||
COS << ' ';
|
||||
} else
|
||||
COS << *InputFilePtr;
|
||||
++InputFilePtr;
|
||||
}
|
||||
}
|
||||
*LineOS << '\n';
|
||||
unsigned InputLineWidth = InputFilePtr - InputFileLine;
|
||||
|
||||
// Print any annotations.
|
||||
while (AnnotationItr != AnnotationEnd && AnnotationItr->InputLine == Line) {
|
||||
WithColor COS(*LineOS, AnnotationItr->Marker.Color, /*Bold=*/true,
|
||||
/*BG=*/false, TheColorMode);
|
||||
// The two spaces below are where the ": " appears on input lines.
|
||||
COS << left_justify(AnnotationItr->Label, LabelWidth) << " ";
|
||||
unsigned Col;
|
||||
for (Col = 1; Col < AnnotationItr->InputStartCol; ++Col)
|
||||
COS << ' ';
|
||||
COS << AnnotationItr->Marker.Lead;
|
||||
// If InputEndCol=UINT_MAX, stop at InputLineWidth.
|
||||
for (++Col; Col < AnnotationItr->InputEndCol && Col <= InputLineWidth;
|
||||
++Col)
|
||||
COS << '~';
|
||||
const std::string &Note = AnnotationItr->Marker.Note;
|
||||
if (!Note.empty()) {
|
||||
// Put the note at the end of the input line. If we were to instead
|
||||
// put the note right after the marker, subsequent annotations for the
|
||||
// same input line might appear to mark this note instead of the input
|
||||
// line.
|
||||
for (; Col <= InputLineWidth; ++Col)
|
||||
COS << ' ';
|
||||
COS << ' ' << Note;
|
||||
}
|
||||
COS << '\n';
|
||||
++AnnotationItr;
|
||||
}
|
||||
}
|
||||
DumpEllipsisOrElidedLines(OS, ElidedLinesOS.str(), LabelWidth);
|
||||
|
||||
OS << ">>>>>>\n";
|
||||
}
|
||||
|
||||
int main(int argc, char **argv) {
|
||||
// Enable use of ANSI color codes because FileCheck is using them to
|
||||
// highlight text.
|
||||
llvm::sys::Process::UseANSIEscapeCodes(true);
|
||||
|
||||
InitLLVM X(argc, argv);
|
||||
cl::ParseCommandLineOptions(argc, argv, /*Overview*/ "", /*Errs*/ nullptr,
|
||||
"FILECHECK_OPTS");
|
||||
|
||||
// Select -dump-input* values. The -help documentation specifies the default
|
||||
// value and which value to choose if an option is specified multiple times.
|
||||
// In the latter case, the general rule of thumb is to choose the value that
|
||||
// provides the most information.
|
||||
DumpInputValue DumpInput =
|
||||
DumpInputs.empty()
|
||||
? DumpInputFail
|
||||
: *std::max_element(DumpInputs.begin(), DumpInputs.end());
|
||||
DumpInputFilterValue DumpInputFilter;
|
||||
if (DumpInputFilters.empty())
|
||||
DumpInputFilter = DumpInput == DumpInputAlways ? DumpInputFilterAll
|
||||
: DumpInputFilterError;
|
||||
else
|
||||
DumpInputFilter =
|
||||
*std::max_element(DumpInputFilters.begin(), DumpInputFilters.end());
|
||||
unsigned DumpInputContext = DumpInputContexts.empty()
|
||||
? 5
|
||||
: *std::max_element(DumpInputContexts.begin(),
|
||||
DumpInputContexts.end());
|
||||
|
||||
if (DumpInput == DumpInputHelp) {
|
||||
DumpInputAnnotationHelp(outs());
|
||||
return 0;
|
||||
}
|
||||
if (CheckFilename.empty()) {
|
||||
errs() << "<check-file> not specified\n";
|
||||
return 2;
|
||||
}
|
||||
|
||||
FileCheckRequest Req;
|
||||
append_range(Req.CheckPrefixes, CheckPrefixes);
|
||||
|
||||
append_range(Req.CommentPrefixes, CommentPrefixes);
|
||||
|
||||
append_range(Req.ImplicitCheckNot, ImplicitCheckNot);
|
||||
|
||||
bool GlobalDefineError = false;
|
||||
for (StringRef G : GlobalDefines) {
|
||||
size_t EqIdx = G.find('=');
|
||||
if (EqIdx == std::string::npos) {
|
||||
errs() << "Missing equal sign in command-line definition '-D" << G
|
||||
<< "'\n";
|
||||
GlobalDefineError = true;
|
||||
continue;
|
||||
}
|
||||
if (EqIdx == 0) {
|
||||
errs() << "Missing variable name in command-line definition '-D" << G
|
||||
<< "'\n";
|
||||
GlobalDefineError = true;
|
||||
continue;
|
||||
}
|
||||
Req.GlobalDefines.push_back(G);
|
||||
}
|
||||
if (GlobalDefineError)
|
||||
return 2;
|
||||
|
||||
Req.AllowEmptyInput = AllowEmptyInput;
|
||||
Req.AllowUnusedPrefixes = AllowUnusedPrefixes;
|
||||
Req.EnableVarScope = EnableVarScope;
|
||||
Req.AllowDeprecatedDagOverlap = AllowDeprecatedDagOverlap;
|
||||
Req.Verbose = Verbose;
|
||||
Req.VerboseVerbose = VerboseVerbose;
|
||||
Req.NoCanonicalizeWhiteSpace = NoCanonicalizeWhiteSpace;
|
||||
Req.MatchFullLines = MatchFullLines;
|
||||
Req.IgnoreCase = IgnoreCase;
|
||||
|
||||
if (VerboseVerbose)
|
||||
Req.Verbose = true;
|
||||
|
||||
FileCheck FC(Req);
|
||||
if (!FC.ValidateCheckPrefixes())
|
||||
return 2;
|
||||
|
||||
Regex PrefixRE = FC.buildCheckPrefixRegex();
|
||||
std::string REError;
|
||||
if (!PrefixRE.isValid(REError)) {
|
||||
errs() << "Unable to combine check-prefix strings into a prefix regular "
|
||||
"expression! This is likely a bug in FileCheck's verification of "
|
||||
"the check-prefix strings. Regular expression parsing failed "
|
||||
"with the following error: "
|
||||
<< REError << "\n";
|
||||
return 2;
|
||||
}
|
||||
|
||||
SourceMgr SM;
|
||||
|
||||
// Read the expected strings from the check file.
|
||||
ErrorOr<std::unique_ptr<MemoryBuffer>> CheckFileOrErr =
|
||||
MemoryBuffer::getFileOrSTDIN(CheckFilename, /*IsText=*/true);
|
||||
if (std::error_code EC = CheckFileOrErr.getError()) {
|
||||
errs() << "Could not open check file '" << CheckFilename
|
||||
<< "': " << EC.message() << '\n';
|
||||
return 2;
|
||||
}
|
||||
MemoryBuffer &CheckFile = *CheckFileOrErr.get();
|
||||
|
||||
SmallString<4096> CheckFileBuffer;
|
||||
StringRef CheckFileText = FC.CanonicalizeFile(CheckFile, CheckFileBuffer);
|
||||
|
||||
unsigned CheckFileBufferID =
|
||||
SM.AddNewSourceBuffer(MemoryBuffer::getMemBuffer(
|
||||
CheckFileText, CheckFile.getBufferIdentifier()),
|
||||
SMLoc());
|
||||
|
||||
std::pair<unsigned, unsigned> ImpPatBufferIDRange;
|
||||
if (FC.readCheckFile(SM, CheckFileText, PrefixRE, &ImpPatBufferIDRange))
|
||||
return 2;
|
||||
|
||||
// Open the file to check and add it to SourceMgr.
|
||||
ErrorOr<std::unique_ptr<MemoryBuffer>> InputFileOrErr =
|
||||
MemoryBuffer::getFileOrSTDIN(InputFilename, /*IsText=*/true);
|
||||
if (InputFilename == "-")
|
||||
InputFilename = "<stdin>"; // Overwrite for improved diagnostic messages
|
||||
if (std::error_code EC = InputFileOrErr.getError()) {
|
||||
errs() << "Could not open input file '" << InputFilename
|
||||
<< "': " << EC.message() << '\n';
|
||||
return 2;
|
||||
}
|
||||
MemoryBuffer &InputFile = *InputFileOrErr.get();
|
||||
|
||||
if (InputFile.getBufferSize() == 0 && !AllowEmptyInput) {
|
||||
errs() << "FileCheck error: '" << InputFilename << "' is empty.\n";
|
||||
DumpCommandLine(argc, argv);
|
||||
return 2;
|
||||
}
|
||||
|
||||
SmallString<4096> InputFileBuffer;
|
||||
StringRef InputFileText = FC.CanonicalizeFile(InputFile, InputFileBuffer);
|
||||
|
||||
SM.AddNewSourceBuffer(MemoryBuffer::getMemBuffer(
|
||||
InputFileText, InputFile.getBufferIdentifier()),
|
||||
SMLoc());
|
||||
|
||||
std::vector<FileCheckDiag> Diags;
|
||||
int ExitCode = FC.checkInput(SM, InputFileText,
|
||||
DumpInput == DumpInputNever ? nullptr : &Diags)
|
||||
? EXIT_SUCCESS
|
||||
: 1;
|
||||
if (DumpInput == DumpInputAlways ||
|
||||
(ExitCode == 1 && DumpInput == DumpInputFail)) {
|
||||
errs() << "\n"
|
||||
<< "Input file: " << InputFilename << "\n"
|
||||
<< "Check file: " << CheckFilename << "\n"
|
||||
<< "\n"
|
||||
<< "-dump-input=help explains the following input dump.\n"
|
||||
<< "\n";
|
||||
std::vector<InputAnnotation> Annotations;
|
||||
unsigned LabelWidth;
|
||||
BuildInputAnnotations(SM, CheckFileBufferID, ImpPatBufferIDRange, Diags,
|
||||
Annotations, LabelWidth);
|
||||
DumpAnnotatedInput(errs(), Req, DumpInputFilter, DumpInputContext,
|
||||
InputFileText, Annotations, LabelWidth);
|
||||
}
|
||||
|
||||
return ExitCode;
|
||||
}
|
||||
@@ -2,6 +2,7 @@
|
||||
#define TRITON_CONVERSION_PASSES_H
|
||||
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/ArithToIndexPass.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
|
||||
#include "triton/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.h"
|
||||
|
||||
|
||||
@@ -50,4 +50,18 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
|
||||
];
|
||||
}
|
||||
|
||||
def TritonConvertArithToIndex : Pass<"triton-convert-arith-to-index", "mlir::ModuleOp"> {
|
||||
|
||||
let summary = "Convert arith to index";
|
||||
|
||||
let constructor = "mlir::triton::createTritonConvertArithToIndexPass()";
|
||||
|
||||
let description = [{
|
||||
Convert arith operation on index values to corresponding ops in the index dialect.
|
||||
We need this because SCFToCF conversion currently generates arith ops on indices.
|
||||
}];
|
||||
|
||||
let dependentDialects = ["mlir::index::IndexDialect"];
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
20
include/triton/Conversion/TritonGPUToLLVM/ArithToIndexPass.h
Normal file
20
include/triton/Conversion/TritonGPUToLLVM/ArithToIndexPass.h
Normal file
@@ -0,0 +1,20 @@
|
||||
#ifndef TRITON_CONVERSION_ARITH_TO_INDEX_H
|
||||
#define TRITON_CONVERSION_ARITH_TO_INDEX_H
|
||||
|
||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||
#include "mlir/Transforms/DialectConversion.h"
|
||||
#include <memory>
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class ModuleOp;
|
||||
template <typename T> class OperationPass;
|
||||
|
||||
namespace triton {
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createTritonConvertArithToIndexPass();
|
||||
|
||||
}
|
||||
} // namespace mlir
|
||||
|
||||
#endif
|
||||
@@ -72,7 +72,7 @@ def TT_BitcastOp : TT_Op<"bitcast", [SameOperandsAndResultShape,
|
||||
def TT_FpToFpOp : TT_Op<"fp_to_fp", [SameOperandsAndResultShape,
|
||||
SameOperandsAndResultEncoding,
|
||||
Pure,
|
||||
DeclareOpInterfaceMethods<CastOpInterface>]> {
|
||||
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
|
||||
let summary = "Floating point casting for custom types";
|
||||
|
||||
let description = [{
|
||||
@@ -405,19 +405,31 @@ def TT_MakeRangeOp : TT_Op<"make_range", [Pure]> {
|
||||
}
|
||||
|
||||
//
|
||||
// Make PrintfOp
|
||||
// Make PrintOp
|
||||
//
|
||||
def TT_PrintfOp : TT_Op<"printf", [MemoryEffects<[MemWrite]>]>,
|
||||
Arguments<(ins StrAttr:$prefix,
|
||||
Variadic<AnyTypeOf<[TT_Type]>>:$args)> {
|
||||
let summary = "Device-side printf, as in CUDA for debugging";
|
||||
def TT_PrintOp : TT_Op<"print", [MemoryEffects<[MemWrite]>]>,
|
||||
Arguments<(ins StrAttr:$prefix, Variadic<AnyTypeOf<[TT_Type]>>:$args)> {
|
||||
let summary = "Device-side print, as in CUDA for debugging";
|
||||
let description = [{
|
||||
`tt.printf` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed.
|
||||
`tt.print` takes a literal string prefix and an arbitrary number of scalar or tensor arguments that should be printed.
|
||||
format are generated automatically from the arguments.
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
$prefix attr-dict ($args^ `:` type($args))?
|
||||
$prefix attr-dict `:` ($args^ `:` type($args))?
|
||||
}];
|
||||
}
|
||||
|
||||
//
|
||||
// Make AssertOp
|
||||
//
|
||||
def TT_AssertOp : TT_Op<"assert", [MemoryEffects<[MemWrite]>]> {
|
||||
let summary = "Device-side assert, as in CUDA for correctness checking";
|
||||
let description = [{
|
||||
`tt.assert` takes a condition tensor, a message string, a file string, a function string, and a line number.
|
||||
If the condition is false, the message is printed, and the program is aborted.
|
||||
}];
|
||||
let arguments = (ins TT_Tensor:$condition, StrAttr:$message, StrAttr:$file, StrAttr:$func, I32Attr:$line);
|
||||
let assemblyFormat = "$condition `,` $message `,` $file `,` $func `,` $line attr-dict `:` type($condition)";
|
||||
}
|
||||
|
||||
#endif // Triton_OPS
|
||||
|
||||
@@ -14,9 +14,7 @@ class TritonTypeDef<string name, string _mnemonic>
|
||||
}
|
||||
|
||||
// Floating-point Type
|
||||
def F8 : TritonTypeDef<"Float8", "f8">;
|
||||
|
||||
def TT_Float : AnyTypeOf<[F8, F16, BF16, F32, F64], "floating-point">;
|
||||
def TT_Float : AnyTypeOf<[F8E4M3FN, F8E5M2, F16, BF16, F32, F64], "floating-point">;
|
||||
def TT_FloatTensor : TensorOf<[TT_Float]>;
|
||||
def TT_FloatLike : AnyTypeOf<[TT_Float, TT_FloatTensor]>;
|
||||
|
||||
|
||||
@@ -35,7 +35,7 @@ Right now, Triton implements two classes of layouts: shared, and distributed.
|
||||
}];
|
||||
|
||||
code extraBaseClassDeclaration = [{
|
||||
unsigned getElemsPerThread(ArrayRef<int64_t> shape) const;
|
||||
unsigned getElemsPerThread(ArrayRef<int64_t> shape, Type eltTy) const;
|
||||
::mlir::LogicalResult verifyLayoutForArg(::mlir::Operation* op, unsigned argNo) const;
|
||||
}];
|
||||
}
|
||||
@@ -382,44 +382,63 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
|
||||
);
|
||||
|
||||
let builders = [
|
||||
// Specially for MMAV1(Volta)
|
||||
// Specially for MMAV1(Volta)
|
||||
AttrBuilder<(ins "int":$versionMajor,
|
||||
"int":$numWarps,
|
||||
"ArrayRef<int64_t>":$shapeC,
|
||||
"bool":$isARow,
|
||||
"bool":$isBRow,
|
||||
"bool":$isAVec4,
|
||||
"bool":$isBVec4,
|
||||
"int":$id), [{
|
||||
assert(versionMajor == 1 && "This builder is specially for versionMajor==1");
|
||||
SmallVector<unsigned> wpt({static_cast<unsigned>(numWarps), 1});
|
||||
int versionMinor = 0;
|
||||
// 4-bits to encode 4 booleans: [isARow, isBRow, isAVec4, isBVec4]
|
||||
int versionMinor = (isARow * (1<<0)) |\
|
||||
(isBRow * (1<<1)) |\
|
||||
(isAVec4 * (1<<2)) |\
|
||||
(isBVec4 * (1<<3));
|
||||
|
||||
assert(id < (1<<numBitsToHoldMmaV1ID) && "MMAv1 ID exceeds the maximum");
|
||||
for (int i = 0; i < numBitsToHoldMmaV1ID; ++i)
|
||||
versionMinor |= static_cast<bool>((1<<i) & id) * (1<<(4+i));
|
||||
|
||||
// TODO: Share code with
|
||||
// DotOpMmaV1ConversionHelper::AParam/BParam, since same code to compute the
|
||||
// rep,spw and fpw.
|
||||
SmallVector<unsigned> wpt({1, 1});
|
||||
SmallVector<unsigned> wpt_nm1;
|
||||
|
||||
SmallVector<int, 2> rep(2), spw(2);
|
||||
std::array<int, 3> fpw{{2, 2, 1}};
|
||||
int packSize0 = (isARow || isAVec4) ? 1 : 2;
|
||||
rep[0] = 2 * packSize0;
|
||||
spw[0] = fpw[0] * 4 * rep[0];
|
||||
|
||||
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
|
||||
rep[1] = 2 * packSize1;
|
||||
spw[1] = fpw[1] * 4 * rep[1];
|
||||
|
||||
do {
|
||||
wpt_nm1 = wpt;
|
||||
if (wpt[0] * wpt[1] < numWarps)
|
||||
wpt[0] = std::clamp<int>(wpt[0] * 2, 1, shapeC[0] / spw[0]);
|
||||
if (wpt[0] * wpt[1] < numWarps)
|
||||
wpt[1] = std::clamp<int>(wpt[1] * 2, 1, shapeC[1] / spw[1]);
|
||||
} while (wpt_nm1 != wpt);
|
||||
|
||||
return $_get(context, versionMajor, versionMinor, wpt);
|
||||
}]>,
|
||||
|
||||
// Specially for MMAV1(Volta)
|
||||
|
||||
AttrBuilder<(ins "int":$versionMajor,
|
||||
"ArrayRef<unsigned>":$warpsPerCTA,
|
||||
"int":$numWarps,
|
||||
"ArrayRef<int64_t>":$shapeA,
|
||||
"ArrayRef<int64_t>":$shapeB,
|
||||
"ArrayRef<int64_t>":$shapeC,
|
||||
"bool":$isARow,
|
||||
"bool":$isBRow,
|
||||
"int":$id), [{
|
||||
assert(versionMajor == 1 && "This builder is specially for versionMajor==1");
|
||||
bool isAVec4 = !isARow && (shapeA[isARow] <= 16);
|
||||
bool isBVec4 = isBRow && (shapeB[isBRow] <= 16);
|
||||
// 4-bits to encode 4 booleans: [isARow, isBRow, isAVec4, isBVec4]
|
||||
// 3-bits to encode the MMA ID to make each unique
|
||||
int versionMinor = (isARow * (1<<0)) |\
|
||||
(isBRow * (1<<1)) |\
|
||||
(isAVec4 * (1<<2)) |\
|
||||
(isBVec4 * (1<<3));
|
||||
|
||||
assert(id < (1<<numBitsToHoldMmaV1ID) && "MMAv1 ID exceeds the maximum");
|
||||
for (int i = 0; i < numBitsToHoldMmaV1ID; ++i)
|
||||
versionMinor |= static_cast<bool>((1<<i) & id) * (1<<(4+i));
|
||||
|
||||
return $_get(context, versionMajor, versionMinor, warpsPerCTA);
|
||||
return get(context, versionMajor, numWarps, shapeC, isARow, isBRow, isAVec4, isBVec4, id);
|
||||
}]>
|
||||
];
|
||||
|
||||
@@ -489,25 +508,21 @@ section 9.7.13.4.1 for more details.
|
||||
let parameters = (
|
||||
ins
|
||||
"unsigned":$opIdx,
|
||||
"Attribute":$parent,
|
||||
"Attribute":$isMMAv1Row
|
||||
"Attribute":$parent
|
||||
);
|
||||
|
||||
let builders = [
|
||||
AttrBuilder<(ins "unsigned":$opIdx,
|
||||
"Attribute":$parent), [{
|
||||
Attribute isMMAv1Row;
|
||||
if(parent.isa<MmaEncodingAttr>() &&
|
||||
parent.cast<MmaEncodingAttr>().isVolta()){
|
||||
isMMAv1Row = BoolAttr::get(context, true);
|
||||
}
|
||||
return $_get(context, opIdx, parent, isMMAv1Row);
|
||||
}]>
|
||||
|
||||
];
|
||||
|
||||
let hasCustomAssemblyFormat = 1;
|
||||
let extraClassDeclaration = extraBaseClassDeclaration;
|
||||
let extraClassDeclaration = extraBaseClassDeclaration # [{
|
||||
bool getMMAv1IsRow() const;
|
||||
bool getMMAv1IsVec4() const;
|
||||
SmallVector<int> getMMAv1Rep() const;
|
||||
SmallVector<int> getMMAv1ShapePerWarp() const;
|
||||
int getMMAv1Vec() const;
|
||||
int getMMAv1NumOuter(ArrayRef<int64_t> shape) const;
|
||||
}];
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -9,6 +9,8 @@ include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
|
||||
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
|
||||
include "mlir/Interfaces/DestinationStyleOpInterface.td"
|
||||
include "mlir/Interfaces/ViewLikeInterface.td"
|
||||
|
||||
def ResultsAreSharedEncoding: NativeOpTrait<"ResultsAreSharedEncoding">;
|
||||
|
||||
@@ -105,6 +107,69 @@ def TTG_SelectOp : TTG_Op<"select", [Pure, Elementwise,
|
||||
}
|
||||
|
||||
|
||||
|
||||
def TTG_ExtractSliceOp : TTG_Op<"extract_slice",
|
||||
[AttrSizedOperandSegments,
|
||||
ResultsAreSharedEncoding,
|
||||
Pure,
|
||||
OffsetSizeAndStrideOpInterface
|
||||
]> {
|
||||
let summary = "extract slice operation";
|
||||
let description = [{
|
||||
same as tensor.extract_slice, but with int32 index. The motivations for re-implementing it are:
|
||||
We reimplement ExtractSliceOp with int32 index, because:
|
||||
- we want to enforce int32 indexing on GPUs since Triton tensors fit in SRAM
|
||||
- we still want to use indexWidth = 64 when lowering to LLVM because our loops can have
|
||||
64-bit induction variables and scf.for uses indexType for bounds/ivs
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
AnyRankedTensor:$source,
|
||||
Variadic<I32>:$offsets,
|
||||
Variadic<I32>:$sizes,
|
||||
Variadic<I32>:$strides,
|
||||
DenseI64ArrayAttr:$static_offsets,
|
||||
DenseI64ArrayAttr:$static_sizes,
|
||||
DenseI64ArrayAttr:$static_strides
|
||||
);
|
||||
let results = (outs AnyRankedTensor:$result);
|
||||
|
||||
let builders = [
|
||||
// Build an ExtractSliceOp with mixed static and dynamic entries and custom
|
||||
// result type. If the type passed is nullptr, it is inferred.
|
||||
OpBuilder<(ins "RankedTensorType":$resultType, "Value":$source,
|
||||
"ArrayRef<OpFoldResult>":$offsets, "ArrayRef<OpFoldResult>":$sizes,
|
||||
"ArrayRef<OpFoldResult>":$strides,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>,
|
||||
];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Return the number of leading operands before the `offsets`, `sizes` and
|
||||
/// and `strides` operands.
|
||||
static unsigned getOffsetSizeAndStrideStartOperandIndex() { return 1; }
|
||||
|
||||
/// Returns the type of the base tensor operand.
|
||||
RankedTensorType getSourceType() {
|
||||
return getSource().getType().cast<RankedTensorType>();
|
||||
}
|
||||
|
||||
std::array<unsigned, 3> getArrayAttrMaxRanks() {
|
||||
unsigned rank = getSourceType().getRank();
|
||||
return {rank, rank, rank};
|
||||
}
|
||||
}];
|
||||
|
||||
let assemblyFormat = [{
|
||||
$source ``
|
||||
custom<DynamicIndexList>($offsets, $static_offsets)
|
||||
custom<DynamicIndexList>($sizes, $static_sizes)
|
||||
custom<DynamicIndexList>($strides, $static_strides)
|
||||
attr-dict `:` type($source) `to` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
def TTG_InsertSliceAsyncOp : TTG_Op<"insert_slice_async",
|
||||
[AttrSizedOperandSegments,
|
||||
ResultsAreSharedEncoding,
|
||||
|
||||
@@ -23,7 +23,7 @@ std::unique_ptr<Pass> createTritonGPURemoveLayoutConversionsPass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUVerifier();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUFuseTranspositionsPass();
|
||||
std::unique_ptr<Pass> createTritonGPUOptimizeDotOperandsPass();
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUUpdateMmaForVoltaPass();
|
||||
|
||||
|
||||
@@ -60,7 +60,7 @@ def TritonGPUAccelerateMatmul : Pass<"tritongpu-accelerate-matmul", "mlir::Modul
|
||||
|
||||
}
|
||||
|
||||
def TritonGPUFuseTranspositions : Pass<"tritongpu-fuse-transposition", "mlir::ModuleOp"> {
|
||||
def TritonGPUOptimizeDotOperands : Pass<"tritongpu-optimize-dot-operands", "mlir::ModuleOp"> {
|
||||
let summary = "fuse transpositions";
|
||||
|
||||
let description = [{
|
||||
@@ -68,7 +68,7 @@ def TritonGPUFuseTranspositions : Pass<"tritongpu-fuse-transposition", "mlir::Mo
|
||||
hardware-accelerated transpositions.
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUFuseTranspositionsPass()";
|
||||
let constructor = "mlir::createTritonGPUOptimizeDotOperandsPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect",
|
||||
"mlir::triton::TritonDialect"];
|
||||
@@ -86,6 +86,7 @@ def TritonGPUCoalesce: Pass<"tritongpu-coalesce", "mlir::ModuleOp"> {
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
|
||||
}
|
||||
|
||||
|
||||
def TritonGPURemoveLayoutConversions : Pass<"tritongpu-remove-layout-conversions", "mlir::ModuleOp"> {
|
||||
let summary = "remove superfluous layout conversions";
|
||||
|
||||
@@ -122,16 +123,4 @@ def TritonGPUDecomposeConversions: Pass<"tritongpu-decompose-conversions", "mlir
|
||||
"mlir::triton::TritonDialect"];
|
||||
}
|
||||
|
||||
def UpdateMmaForVolta : Pass<"tritongpu-update-mma-for-volta", "mlir::ModuleOp"> {
|
||||
let summary = "Update mma encodings for Volta";
|
||||
|
||||
let description = [{
|
||||
This helps to update the mma encodings for Volta.
|
||||
}];
|
||||
|
||||
let constructor = "mlir::createTritonGPUUpdateMmaForVoltaPass()";
|
||||
|
||||
let dependentDialects = ["mlir::triton::gpu::TritonGPUDialect"];
|
||||
}
|
||||
|
||||
#endif
|
||||
|
||||
@@ -27,7 +27,7 @@ void SharedMemoryAliasAnalysis::visitOperation(
|
||||
// These ops may allocate a new shared memory buffer.
|
||||
auto result = op->getResult(0);
|
||||
// XXX(Keren): the following ops are always aliasing for now
|
||||
if (isa<tensor::ExtractSliceOp, triton::TransOp>(op)) {
|
||||
if (isa<triton::gpu::ExtractSliceOp, triton::TransOp>(op)) {
|
||||
// extract_slice %src
|
||||
// trans %src
|
||||
aliasInfo = AliasInfo(operands[0]->getValue());
|
||||
|
||||
@@ -846,10 +846,14 @@ AxisInfoAnalysis::AxisInfoAnalysis(DataFlowSolver &solver)
|
||||
void AxisInfoAnalysis::visitOperation(
|
||||
Operation *op, ArrayRef<const dataflow::Lattice<AxisInfo> *> operands,
|
||||
ArrayRef<dataflow::Lattice<AxisInfo> *> results) {
|
||||
// TODO: For sure not the right way to do this
|
||||
// but why is scf.if not initialized otherwise?
|
||||
for (auto op : operands)
|
||||
if (op->getValue().getRank() == 0)
|
||||
setToEntryState((dataflow::Lattice<AxisInfo> *)op);
|
||||
AxisInfo curr = visitors.apply(op, operands);
|
||||
if (curr.getRank() == 0) {
|
||||
if (curr.getRank() == 0)
|
||||
return setAllToEntryStates(results);
|
||||
}
|
||||
// override with hint
|
||||
auto newContiguity = curr.getContiguity();
|
||||
auto newDivisibility = curr.getDivisibility();
|
||||
|
||||
@@ -78,8 +78,8 @@ void MembarAnalysis::visitTerminator(Operation *op,
|
||||
|
||||
void MembarAnalysis::update(Operation *op, BlockInfo *blockInfo,
|
||||
OpBuilder *builder) {
|
||||
if (isa<tensor::ExtractSliceOp>(op) || isa<triton::gpu::AllocTensorOp>(op) ||
|
||||
isa<triton::TransOp>(op)) {
|
||||
if (isa<triton::gpu::ExtractSliceOp>(op) ||
|
||||
isa<triton::gpu::AllocTensorOp>(op) || isa<triton::TransOp>(op)) {
|
||||
// alloc is an allocation op without memory write.
|
||||
// FIXME(Keren): extract_slice is always alias for now
|
||||
return;
|
||||
|
||||
@@ -117,7 +117,7 @@ bool maybeSharedAllocationOp(Operation *op) {
|
||||
}
|
||||
|
||||
bool maybeAliasOp(Operation *op) {
|
||||
return isa<tensor::ExtractSliceOp>(op) || isa<triton::TransOp>(op) ||
|
||||
return isa<triton::gpu::ExtractSliceOp>(op) || isa<triton::TransOp>(op) ||
|
||||
isa<triton::gpu::InsertSliceAsyncOp>(op) ||
|
||||
isa<tensor::InsertSliceOp>(op);
|
||||
}
|
||||
|
||||
92
lib/Conversion/TritonGPUToLLVM/ArithToIndexPass.cpp
Normal file
92
lib/Conversion/TritonGPUToLLVM/ArithToIndexPass.cpp
Normal file
@@ -0,0 +1,92 @@
|
||||
#include "triton/Conversion/TritonGPUToLLVM/ArithToIndexPass.h"
|
||||
#include "mlir/Analysis/DataFlowFramework.h"
|
||||
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
|
||||
#include "mlir/Conversion/ControlFlowToLLVM//ControlFlowToLLVM.h"
|
||||
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
|
||||
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
|
||||
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
|
||||
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
|
||||
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
|
||||
#include "mlir/Dialect/Index/IR/IndexDialect.h"
|
||||
#include "mlir/Dialect/Index/IR/IndexOps.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
#include "triton/Dialect/Triton/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Conversion/Passes.h.inc"
|
||||
|
||||
namespace {
|
||||
class TritonArithToIndexConversionTarget : public mlir::ConversionTarget {
|
||||
public:
|
||||
static bool hasIndexResultOrOperand(Operation *op) {
|
||||
if (!op)
|
||||
return false;
|
||||
bool hasRetIndex = llvm::find_if(op->getResultTypes(), [](Type type) {
|
||||
return type.isIndex();
|
||||
}) != op->getResultTypes().end();
|
||||
bool hasArgIndex = llvm::find_if(op->getOperandTypes(), [](Type type) {
|
||||
return type.isIndex();
|
||||
}) != op->getOperandTypes().end();
|
||||
return !hasRetIndex && !hasArgIndex;
|
||||
}
|
||||
|
||||
explicit TritonArithToIndexConversionTarget(MLIRContext &ctx)
|
||||
: ConversionTarget(ctx) {
|
||||
addLegalDialect<index::IndexDialect>();
|
||||
addDynamicallyLegalDialect<arith::ArithDialect>(hasIndexResultOrOperand);
|
||||
}
|
||||
};
|
||||
|
||||
template <class SrcOp, class DstOp>
|
||||
LogicalResult replaceArithWithIndex(SrcOp op, PatternRewriter &rewriter) {
|
||||
// if (!hasIndexResultOrOperand(&*op))
|
||||
// return failure();
|
||||
rewriter.replaceOpWithNewOp<DstOp>(op, op->getResultTypes(),
|
||||
op->getOperands(), op->getAttrs());
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult replaceArithCmpWithIndexCmp(arith::CmpIOp op,
|
||||
PatternRewriter &rewriter) {
|
||||
// if (!hasIndexResultOrOperand(&*op))
|
||||
// return failure();
|
||||
rewriter.replaceOpWithNewOp<index::CmpOp>(
|
||||
op, op.getResult().getType(), (index::IndexCmpPredicate)op.getPredicate(),
|
||||
op.getOperand(0), op.getOperand(1));
|
||||
return success();
|
||||
}
|
||||
|
||||
class ArithToIndex : public TritonConvertArithToIndexBase<ArithToIndex> {
|
||||
public:
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp mod = getOperation();
|
||||
TritonArithToIndexConversionTarget target(*context);
|
||||
RewritePatternSet patterns(context);
|
||||
patterns.add(replaceArithWithIndex<arith::IndexCastOp, index::CastSOp>);
|
||||
patterns.add(replaceArithWithIndex<arith::ConstantOp, index::ConstantOp>);
|
||||
patterns.add(replaceArithWithIndex<arith::AddIOp, index::AddOp>);
|
||||
patterns.add(replaceArithCmpWithIndexCmp);
|
||||
if (failed(applyPartialConversion(mod, target, std::move(patterns)))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
std::unique_ptr<OperationPass<ModuleOp>> createTritonConvertArithToIndexPass() {
|
||||
return std::make_unique<::ArithToIndex>();
|
||||
}
|
||||
|
||||
} // namespace triton
|
||||
} // namespace mlir
|
||||
@@ -1,17 +1,16 @@
|
||||
add_mlir_conversion_library(TritonGPUToLLVM
|
||||
TritonGPUToLLVM.cpp
|
||||
GCNAsmFormat.cpp
|
||||
PTXAsmFormat.cpp
|
||||
TritonGPUToLLVMPass.cpp
|
||||
ArithToIndexPass.cpp
|
||||
ConvertLayoutOpToLLVM.cpp
|
||||
DotOpToLLVM.cpp
|
||||
ElementwiseOpToLLVM.cpp
|
||||
LoadStoreOpToLLVM.cpp
|
||||
TritonGPUToLLVM.cpp
|
||||
TritonGPUToLLVMPass.cpp
|
||||
GCNAsmFormat.cpp
|
||||
PTXAsmFormat.cpp
|
||||
ReduceOpToLLVM.cpp
|
||||
Utility.cpp
|
||||
TypeConverter.cpp
|
||||
ViewOpToLLVM.cpp
|
||||
DotOpHelpers.cpp
|
||||
|
||||
|
||||
@@ -4,10 +4,8 @@
|
||||
|
||||
using ::mlir::LLVM::DotOpFMAConversionHelper;
|
||||
using ::mlir::LLVM::DotOpMmaV1ConversionHelper;
|
||||
using ::mlir::LLVM::getElementsFromStruct;
|
||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::LLVM::getStridesFromShapeAndOrder;
|
||||
using ::mlir::LLVM::getStructFromElements;
|
||||
using ::mlir::LLVM::MMA16816ConversionHelper;
|
||||
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu::getContigPerThread;
|
||||
@@ -56,28 +54,32 @@ public:
|
||||
private:
|
||||
SmallVector<Value> getMultiDimOffset(Attribute layout, Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
unsigned elemId, ArrayRef<int64_t> shape,
|
||||
unsigned elemId, RankedTensorType type,
|
||||
ArrayRef<unsigned> multiDimCTAInRepId,
|
||||
ArrayRef<unsigned> shapePerCTA) const {
|
||||
auto shape = type.getShape();
|
||||
unsigned rank = shape.size();
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
auto multiDimOffsetFirstElem =
|
||||
emitBaseIndexForLayout(loc, rewriter, blockedLayout, shape);
|
||||
emitBaseIndexForLayout(loc, rewriter, blockedLayout, type);
|
||||
SmallVector<Value> multiDimOffset(rank);
|
||||
SmallVector<unsigned> multiDimElemId = getMultiDimIndex<unsigned>(
|
||||
elemId, getSizePerThread(layout), getOrder(layout));
|
||||
for (unsigned d = 0; d < rank; ++d) {
|
||||
multiDimOffset[d] = add(multiDimOffsetFirstElem[d],
|
||||
idx_val(multiDimCTAInRepId[d] * shapePerCTA[d] +
|
||||
i32_val(multiDimCTAInRepId[d] * shapePerCTA[d] +
|
||||
multiDimElemId[d]));
|
||||
}
|
||||
return multiDimOffset;
|
||||
}
|
||||
if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
unsigned dim = sliceLayout.getDim();
|
||||
auto parentEncoding = sliceLayout.getParent();
|
||||
auto parentShape = sliceLayout.paddedShape(shape);
|
||||
auto parentTy = RankedTensorType::get(parentShape, type.getElementType(),
|
||||
parentEncoding);
|
||||
auto multiDimOffsetParent =
|
||||
getMultiDimOffset(sliceLayout.getParent(), loc, rewriter, elemId,
|
||||
sliceLayout.paddedShape(shape),
|
||||
getMultiDimOffset(parentEncoding, loc, rewriter, elemId, parentTy,
|
||||
sliceLayout.paddedShape(multiDimCTAInRepId),
|
||||
sliceLayout.paddedShape(shapePerCTA));
|
||||
SmallVector<Value> multiDimOffset(rank);
|
||||
@@ -94,24 +96,24 @@ private:
|
||||
SmallVector<Value> mmaRowIdx(2);
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
#ifdef USE_ROCM
|
||||
Value warpSize = idx_val(64);
|
||||
Value warpSize = i32_val(64);
|
||||
#else
|
||||
Value warpSize = idx_val(32);
|
||||
Value warpSize = i32_val(32);
|
||||
#endif
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
Value warpId = udiv(threadId, warpSize);
|
||||
// TODO: fix the bug in MMAEncodingAttr document
|
||||
SmallVector<Value> multiDimWarpId(2);
|
||||
multiDimWarpId[0] = urem(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
|
||||
multiDimWarpId[1] = udiv(warpId, idx_val(mmaLayout.getWarpsPerCTA()[0]));
|
||||
Value _1 = idx_val(1);
|
||||
Value _2 = idx_val(2);
|
||||
Value _4 = idx_val(4);
|
||||
Value _8 = idx_val(8);
|
||||
Value _16 = idx_val(16);
|
||||
multiDimWarpId[0] = urem(warpId, i32_val(mmaLayout.getWarpsPerCTA()[0]));
|
||||
multiDimWarpId[1] = udiv(warpId, i32_val(mmaLayout.getWarpsPerCTA()[0]));
|
||||
Value _1 = i32_val(1);
|
||||
Value _2 = i32_val(2);
|
||||
Value _4 = i32_val(4);
|
||||
Value _8 = i32_val(8);
|
||||
Value _16 = i32_val(16);
|
||||
if (mmaLayout.isAmpere()) {
|
||||
multiDimWarpId[0] = urem(multiDimWarpId[0], idx_val(shape[0] / 16));
|
||||
multiDimWarpId[1] = urem(multiDimWarpId[1], idx_val(shape[1] / 8));
|
||||
multiDimWarpId[0] = urem(multiDimWarpId[0], i32_val(shape[0] / 16));
|
||||
multiDimWarpId[1] = urem(multiDimWarpId[1], i32_val(shape[1] / 8));
|
||||
Value mmaGrpId = udiv(laneId, _4);
|
||||
Value mmaGrpIdP8 = add(mmaGrpId, _8);
|
||||
Value mmaThreadIdInGrp = urem(laneId, _4);
|
||||
@@ -135,15 +137,15 @@ private:
|
||||
multiDimOffset[0] = elemId < 2 ? mmaRowIdx[0] : mmaRowIdx[1];
|
||||
multiDimOffset[1] = elemId % 2 == 0 ? mmaColIdx[0] : mmaColIdx[1];
|
||||
multiDimOffset[0] = add(
|
||||
multiDimOffset[0], idx_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
|
||||
multiDimOffset[0], i32_val(multiDimCTAInRepId[0] * shapePerCTA[0]));
|
||||
multiDimOffset[1] = add(
|
||||
multiDimOffset[1], idx_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
|
||||
multiDimOffset[1], i32_val(multiDimCTAInRepId[1] * shapePerCTA[1]));
|
||||
} else if (mmaLayout.isVolta()) {
|
||||
auto [isARow, isBRow, isAVec4, isBVec4, mmaId] =
|
||||
auto [isARow, isBRow, isAVec4, isBVec4, _] =
|
||||
mmaLayout.decodeVoltaLayoutStates();
|
||||
auto coords = DotOpMmaV1ConversionHelper::getMNCoords(
|
||||
threadId, rewriter, mmaLayout.getWarpsPerCTA(), shape, isARow,
|
||||
isBRow, isAVec4, isBVec4);
|
||||
threadId, rewriter, mmaLayout.getWarpsPerCTA(), mmaLayout, shape,
|
||||
isARow, isBRow, isAVec4, isBVec4);
|
||||
return DotOpMmaV1ConversionHelper::getCoord(elemId, coords);
|
||||
} else {
|
||||
llvm_unreachable("Unexpected MMALayout version");
|
||||
@@ -199,7 +201,7 @@ private:
|
||||
// of performance issue observed.
|
||||
for (unsigned elemId = 0; elemId < accumSizePerThread; elemId += vec) {
|
||||
SmallVector<Value> multiDimOffset =
|
||||
getMultiDimOffset(layout, loc, rewriter, elemId, type.getShape(),
|
||||
getMultiDimOffset(layout, loc, rewriter, elemId, type,
|
||||
multiDimCTAInRepId, shapePerCTA);
|
||||
Value offset =
|
||||
linearize(rewriter, loc, multiDimOffset, paddedRepShape, outOrd);
|
||||
@@ -216,13 +218,13 @@ private:
|
||||
currVal = zext(llvmElemTy, currVal);
|
||||
else if (isPtr)
|
||||
currVal = ptrtoint(llvmElemTy, currVal);
|
||||
valVec = insert_element(vecTy, valVec, currVal, idx_val(v));
|
||||
valVec = insert_element(vecTy, valVec, currVal, i32_val(v));
|
||||
}
|
||||
store(valVec, ptr);
|
||||
} else {
|
||||
Value valVec = load(ptr);
|
||||
for (unsigned v = 0; v < vec; ++v) {
|
||||
Value currVal = extract_element(llvmElemTy, valVec, idx_val(v));
|
||||
Value currVal = extract_element(llvmElemTy, valVec, i32_val(v));
|
||||
if (isInt1)
|
||||
currVal = icmp_ne(currVal,
|
||||
rewriter.create<LLVM::ConstantOp>(
|
||||
@@ -289,18 +291,15 @@ private:
|
||||
// TODO[Superjomn]: Move the coordinate computation out of loop, it is
|
||||
// duplicate in Volta.
|
||||
SmallVector<Value> multiDimOffset =
|
||||
getMultiDimOffset(layout, loc, rewriter, elemId, type.getShape(),
|
||||
getMultiDimOffset(layout, loc, rewriter, elemId, type,
|
||||
multiDimCTAInRepId, shapePerCTA);
|
||||
coord2val[elemId] = std::make_pair(multiDimOffset, vals[elemId]);
|
||||
}
|
||||
|
||||
if (needTrans) {
|
||||
auto [isARow, isBRow, isAVec4, isBVec4, mmaId] =
|
||||
mma.decodeVoltaLayoutStates();
|
||||
DotOpMmaV1ConversionHelper helper(mma);
|
||||
// do transpose
|
||||
int numM = helper.getElemsM(mma.getWarpsPerCTA()[0], shape[0], isARow,
|
||||
isAVec4);
|
||||
auto aEncoding = DotOperandEncodingAttr::get(mma.getContext(), 0, mma);
|
||||
int numM = aEncoding.getMMAv1NumOuter(shape);
|
||||
int numN = accumSizePerThread / numM;
|
||||
|
||||
for (int r = 0; r < numM; r++) {
|
||||
@@ -326,13 +325,13 @@ private:
|
||||
Value valVec = undef(vecTy);
|
||||
for (unsigned v = 0; v < vec; ++v) {
|
||||
auto currVal = coord2valT[elemId + v].second;
|
||||
valVec = insert_element(vecTy, valVec, currVal, idx_val(v));
|
||||
valVec = insert_element(vecTy, valVec, currVal, i32_val(v));
|
||||
}
|
||||
store(valVec, ptr);
|
||||
} else {
|
||||
Value valVec = load(ptr);
|
||||
for (unsigned v = 0; v < vec; ++v) {
|
||||
Value currVal = extract_element(elemTy, valVec, idx_val(v));
|
||||
Value currVal = extract_element(elemTy, valVec, i32_val(v));
|
||||
vals[elemId + v] = currVal;
|
||||
}
|
||||
}
|
||||
@@ -397,7 +396,8 @@ private:
|
||||
// Potentially we need to store for multiple CTAs in this replication
|
||||
auto accumNumReplicates = product<unsigned>(numReplicates);
|
||||
// unsigned elems = getElemsPerThread(srcTy);
|
||||
auto vals = getElementsFromStruct(loc, adaptor.getSrc(), rewriter);
|
||||
auto vals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
|
||||
rewriter, srcTy);
|
||||
unsigned inVec = 0;
|
||||
unsigned outVec = 0;
|
||||
auto paddedRepShape = getScratchConfigForCvtLayout(op, inVec, outVec);
|
||||
@@ -448,7 +448,8 @@ private:
|
||||
SmallVector<Type> types(outElems, llvmElemTy);
|
||||
auto *ctx = llvmElemTy.getContext();
|
||||
Type structTy = struct_ty(types);
|
||||
Value result = getStructFromElements(loc, outVals, rewriter, structTy);
|
||||
Value result =
|
||||
getTypeConverter()->packLLElements(loc, outVals, rewriter, dstTy);
|
||||
rewriter.replaceOp(op, result);
|
||||
|
||||
return success();
|
||||
@@ -480,7 +481,7 @@ private:
|
||||
|
||||
auto dstStrides =
|
||||
getStridesFromShapeAndOrder(dstShape, outOrd, loc, rewriter);
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy);
|
||||
storeDistributedToShared(src, adaptor.getSrc(), dstStrides, srcIndices, dst,
|
||||
smemBase, elemTy, loc, rewriter);
|
||||
auto smemObj =
|
||||
@@ -525,10 +526,10 @@ private:
|
||||
auto thread = getThreadId(rewriter, loc);
|
||||
if (dotOpLayout.getOpIdx() == 0) { // $a
|
||||
res = helper.loadA(src, adaptor.getSrc(), blockedLayout, thread, loc,
|
||||
rewriter);
|
||||
getTypeConverter(), rewriter);
|
||||
} else { // $b
|
||||
res = helper.loadB(src, adaptor.getSrc(), blockedLayout, thread, loc,
|
||||
rewriter);
|
||||
getTypeConverter(), rewriter);
|
||||
}
|
||||
} else {
|
||||
assert(false && "Unsupported dot operand layout found");
|
||||
@@ -551,22 +552,42 @@ private:
|
||||
auto dstDotLayout = dstLayout.cast<DotOperandEncodingAttr>();
|
||||
if (isMmaToDotShortcut(srcMmaLayout, dstDotLayout)) {
|
||||
// get source values
|
||||
auto vals = getElementsFromStruct(loc, adaptor.getSrc(), rewriter);
|
||||
auto vals = getTypeConverter()->unpackLLElements(loc, adaptor.getSrc(),
|
||||
rewriter, srcTy);
|
||||
unsigned elems = getElemsPerThread(srcTy);
|
||||
Type elemTy =
|
||||
this->getTypeConverter()->convertType(srcTy.getElementType());
|
||||
// for the destination type, we need to pack values together
|
||||
// so they can be consumed by tensor core operations
|
||||
unsigned vecSize =
|
||||
std::max<unsigned>(32 / elemTy.getIntOrFloatBitWidth(), 1);
|
||||
Type vecTy = vec_ty(elemTy, vecSize);
|
||||
SmallVector<Type> types(elems / vecSize, vecTy);
|
||||
SmallVector<Value> vecVals;
|
||||
for (unsigned i = 0; i < elems; i += vecSize) {
|
||||
Value packed = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
||||
for (unsigned j = 0; j < vecSize; j++)
|
||||
packed = insert_element(vecTy, packed, vals[i + j], i32_val(j));
|
||||
vecVals.push_back(packed);
|
||||
SmallVector<Type> types;
|
||||
// For some reasons, LLVM's NVPTX backend inserts unnecessary (?) integer
|
||||
// instructions to pack & unpack sub-word integers. A workaround is to
|
||||
// store the results of ldmatrix in i32
|
||||
auto elemSize = elemTy.getIntOrFloatBitWidth();
|
||||
if (auto intTy = elemTy.dyn_cast<IntegerType>() && elemSize <= 16) {
|
||||
auto fold = 32 / elemSize;
|
||||
for (unsigned i = 0; i < elems; i += fold) {
|
||||
Value val = i32_val(0);
|
||||
for (unsigned j = 0; j < fold; j++) {
|
||||
auto ext =
|
||||
shl(i32_ty, zext(i32_ty, vals[i + j]), i32_val(elemSize * j));
|
||||
val = or_(i32_ty, val, ext);
|
||||
}
|
||||
vecVals.push_back(val);
|
||||
}
|
||||
elems = elems / (32 / elemSize);
|
||||
types = SmallVector<Type>(elems, i32_ty);
|
||||
} else {
|
||||
unsigned vecSize = std::max<unsigned>(32 / elemSize, 1);
|
||||
Type vecTy = vec_ty(elemTy, vecSize);
|
||||
types = SmallVector<Type>(elems / vecSize, vecTy);
|
||||
for (unsigned i = 0; i < elems; i += vecSize) {
|
||||
Value packed = rewriter.create<LLVM::UndefOp>(loc, vecTy);
|
||||
for (unsigned j = 0; j < vecSize; j++)
|
||||
packed = insert_element(vecTy, packed, vals[i + j], i32_val(j));
|
||||
vecVals.push_back(packed);
|
||||
}
|
||||
}
|
||||
|
||||
// This needs to be ordered the same way that
|
||||
@@ -582,12 +603,8 @@ private:
|
||||
reorderedVals.push_back(vecVals[i + 3]);
|
||||
}
|
||||
|
||||
// return composeValuesToDotOperandLayoutStruct(ha, numRepM, numRepK);
|
||||
|
||||
Type structTy =
|
||||
LLVM::LLVMStructType::getLiteral(this->getContext(), types);
|
||||
Value view =
|
||||
getStructFromElements(loc, reorderedVals, rewriter, structTy);
|
||||
Value view = getTypeConverter()->packLLElements(loc, reorderedVals,
|
||||
rewriter, dstTy);
|
||||
rewriter.replaceOp(op, view);
|
||||
return success();
|
||||
}
|
||||
@@ -622,8 +639,7 @@ private:
|
||||
}
|
||||
} else if (!isOuter && mmaLayout.isVolta() && isHMMA) { // tensor core v1
|
||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||
bool isMMAv1Row =
|
||||
dotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
bool isMMAv1Row = dotOperandLayout.getMMAv1IsRow();
|
||||
auto srcSharedLayout = src.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
@@ -640,12 +656,12 @@ private:
|
||||
// TODO[Superjomn]: transA is not available here.
|
||||
bool transA = false;
|
||||
res = helper.loadA(src, smemObj, getThreadId(rewriter, loc), loc,
|
||||
rewriter);
|
||||
getTypeConverter(), rewriter, dst.getType());
|
||||
} else if (dotOperandLayout.getOpIdx() == 1) { // operand $b
|
||||
// TODO[Superjomn]: transB is not available here.
|
||||
bool transB = false;
|
||||
res = helper.loadB(src, smemObj, getThreadId(rewriter, loc), loc,
|
||||
rewriter);
|
||||
getTypeConverter(), rewriter, dst.getType());
|
||||
}
|
||||
} else {
|
||||
assert(false && "Unsupported mma layout found");
|
||||
@@ -655,7 +671,7 @@ private:
|
||||
};
|
||||
|
||||
void populateConvertLayoutOpToLLVMPatterns(
|
||||
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
|
||||
@@ -9,7 +9,7 @@ using namespace mlir::triton;
|
||||
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||
|
||||
void populateConvertLayoutOpToLLVMPatterns(
|
||||
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
|
||||
@@ -1,65 +1,13 @@
|
||||
#include "DotOpHelpers.h"
|
||||
#include "TypeConverter.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace LLVM {
|
||||
|
||||
int DotOpMmaV1ConversionHelper::numElemsPerThreadA(ArrayRef<int64_t> shape,
|
||||
bool isARow, bool isAVec4,
|
||||
int vec) const {
|
||||
int numM = getNumM(shape[0], isARow, isAVec4);
|
||||
int NK = shape[1];
|
||||
// Here we mimic the logic in loadA, the result cannot be calculated
|
||||
// directly.
|
||||
llvm::DenseSet<std::pair<int, int>> visited;
|
||||
auto ld = [&](int m, int k) {
|
||||
visited.insert({m, k});
|
||||
if (vec > 4) {
|
||||
if (isARow)
|
||||
visited.insert({m, k + 4});
|
||||
else
|
||||
visited.insert({m + 1, k});
|
||||
}
|
||||
};
|
||||
|
||||
for (unsigned k = 0; k < NK; k += 4)
|
||||
for (unsigned m = 0; m < numM / 2; ++m)
|
||||
if (!visited.count({m, k}))
|
||||
ld(m, k);
|
||||
|
||||
return visited.size() * 2;
|
||||
}
|
||||
|
||||
int DotOpMmaV1ConversionHelper::numElemsPerThreadB(ArrayRef<int64_t> shape,
|
||||
bool isBRow, bool isBVec4,
|
||||
int vec) const {
|
||||
unsigned numN = getNumN(shape[1], isBRow, isBVec4);
|
||||
int NK = shape[0];
|
||||
// Here we mimic the logic in loadA, the result cannot be calculated
|
||||
// directly.
|
||||
llvm::DenseSet<std::pair<int, int>> visited;
|
||||
int elemsPerLd = vec > 4 ? 4 : 2;
|
||||
auto ld = [&](int n, int k) {
|
||||
visited.insert({n, k});
|
||||
if (vec > 4) {
|
||||
if (isBRow)
|
||||
visited.insert({n + 1, k});
|
||||
else
|
||||
visited.insert({n, k + 4});
|
||||
}
|
||||
};
|
||||
|
||||
for (unsigned k = 0; k < NK; k += 4)
|
||||
for (unsigned n = 0; n < numN / 2; ++n) {
|
||||
if (!visited.count({n, k}))
|
||||
ld(n, k);
|
||||
}
|
||||
|
||||
return visited.size() * 2;
|
||||
}
|
||||
|
||||
Value DotOpMmaV1ConversionHelper::loadA(
|
||||
Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter, Type resultTy) const {
|
||||
auto *ctx = rewriter.getContext();
|
||||
auto tensorTy = tensor.getType().cast<RankedTensorType>();
|
||||
auto sharedLayout = tensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
@@ -70,12 +18,12 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
||||
Value smemBase = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
||||
|
||||
bool isARow = order[0] != 0;
|
||||
auto [isARow_, _0, isAVec4, _1, _2] = mmaLayout.decodeVoltaLayoutStates();
|
||||
|
||||
AParam param(isARow_, isAVec4);
|
||||
|
||||
auto resultEncoding = resultTy.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<DotOperandEncodingAttr>();
|
||||
auto [offsetAM, offsetAK, _3, _4] = computeOffsets(
|
||||
thread, isARow, false, fpw, param.spw, param.rep, rewriter, loc);
|
||||
thread, isARow, false, fpw, resultEncoding.getMMAv1ShapePerWarp(),
|
||||
resultEncoding.getMMAv1Rep(), rewriter, loc);
|
||||
|
||||
int vecA = sharedLayout.getVec();
|
||||
|
||||
@@ -152,7 +100,9 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
||||
}
|
||||
};
|
||||
|
||||
unsigned numM = getNumM(shape[0], isARow, isAVec4);
|
||||
bool isARow_ = resultEncoding.getMMAv1IsRow();
|
||||
bool isAVec4 = resultEncoding.getMMAv1IsVec4();
|
||||
unsigned numM = resultEncoding.getMMAv1NumOuter(shape);
|
||||
for (unsigned k = 0; k < NK; k += 4)
|
||||
for (unsigned m = 0; m < numM / 2; ++m)
|
||||
if (!has.count({m, k}))
|
||||
@@ -165,14 +115,14 @@ Value DotOpMmaV1ConversionHelper::loadA(
|
||||
elems.push_back(item.second.second);
|
||||
}
|
||||
|
||||
Type resTy = struct_ty(SmallVector<Type>(elems.size(), elemX2Ty));
|
||||
Value res = getStructFromElements(loc, elems, rewriter, resTy);
|
||||
Value res = typeConverter->packLLElements(loc, elems, rewriter, resultTy);
|
||||
return res;
|
||||
}
|
||||
|
||||
Value DotOpMmaV1ConversionHelper::loadB(
|
||||
Value tensor, const SharedMemoryObject &smemObj, Value thread, Location loc,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter, Type resultTy) const {
|
||||
// smem
|
||||
auto strides = smemObj.strides;
|
||||
|
||||
@@ -186,10 +136,9 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
||||
Value smem = smemObj.getBaseBeforeSwizzle(order[0], loc, rewriter);
|
||||
bool isBRow = order[0] != 0; // is row-major in shared memory layout
|
||||
// isBRow_ indicates whether B is row-major in DotOperand layout
|
||||
auto [_0, isBRow_, _1, isBVec4, _2] = mmaLayout.decodeVoltaLayoutStates();
|
||||
assert(isBRow == isBRow_ && "B need smem isRow");
|
||||
|
||||
BParam param(isBRow_, isBVec4);
|
||||
auto resultEncoding = resultTy.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<DotOperandEncodingAttr>();
|
||||
|
||||
int vecB = sharedLayout.getVec();
|
||||
Value strideBN = isBRow ? i32_val(1) : strides[1];
|
||||
@@ -200,7 +149,8 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
||||
int strideRepK = 1;
|
||||
|
||||
auto [_3, _4, offsetBN, offsetBK] = computeOffsets(
|
||||
thread, false, isBRow, fpw, param.spw, param.rep, rewriter, loc);
|
||||
thread, false, isBRow, fpw, resultEncoding.getMMAv1ShapePerWarp(),
|
||||
resultEncoding.getMMAv1Rep(), rewriter, loc);
|
||||
|
||||
// swizzling
|
||||
int perPhaseB = sharedLayout.getPerPhase();
|
||||
@@ -266,7 +216,10 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
||||
}
|
||||
};
|
||||
|
||||
unsigned numN = getNumN(shape[1], isBRow, isBVec4);
|
||||
bool isBRow_ = resultEncoding.getMMAv1IsRow();
|
||||
assert(isBRow == isBRow_ && "B need smem isRow");
|
||||
bool isBVec4 = resultEncoding.getMMAv1IsVec4();
|
||||
unsigned numN = resultEncoding.getMMAv1NumOuter(shape);
|
||||
for (unsigned k = 0; k < NK; k += 4)
|
||||
for (unsigned n = 0; n < numN / 2; ++n) {
|
||||
if (!hbs.count({n, k}))
|
||||
@@ -279,8 +232,7 @@ Value DotOpMmaV1ConversionHelper::loadB(
|
||||
elems.push_back(item.second.second);
|
||||
}
|
||||
|
||||
Type resTy = struct_ty(SmallVector<Type>(elems.size(), elemX2Ty));
|
||||
Value res = getStructFromElements(loc, elems, rewriter, resTy);
|
||||
Value res = typeConverter->packLLElements(loc, elems, rewriter, resultTy);
|
||||
return res;
|
||||
}
|
||||
|
||||
@@ -350,10 +302,11 @@ DotOpMmaV1ConversionHelper::computeOffsets(Value threadId, bool isARow,
|
||||
|
||||
DotOpMmaV1ConversionHelper::ValueTable
|
||||
DotOpMmaV1ConversionHelper::extractLoadedOperand(
|
||||
Value llStruct, int NK, ConversionPatternRewriter &rewriter) const {
|
||||
Value llStruct, int NK, ConversionPatternRewriter &rewriter,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter, Type type) const {
|
||||
ValueTable rcds;
|
||||
SmallVector<Value> elems =
|
||||
getElementsFromStruct(llStruct.getLoc(), llStruct, rewriter);
|
||||
SmallVector<Value> elems = typeConverter->unpackLLElements(
|
||||
llStruct.getLoc(), llStruct, rewriter, type);
|
||||
|
||||
int offset = 0;
|
||||
for (int i = 0; offset < elems.size(); ++i) {
|
||||
@@ -366,10 +319,12 @@ DotOpMmaV1ConversionHelper::extractLoadedOperand(
|
||||
return rcds;
|
||||
}
|
||||
|
||||
// TODO: Mostly a duplicate of TritonGPUToLLVMBase::emitBaseIndexforMMaLayoutV1
|
||||
SmallVector<DotOpMmaV1ConversionHelper::CoordTy>
|
||||
DotOpMmaV1ConversionHelper::getMNCoords(Value thread,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
ArrayRef<unsigned int> wpt,
|
||||
const MmaEncodingAttr &mmaLayout,
|
||||
ArrayRef<int64_t> shape, bool isARow,
|
||||
bool isBRow, bool isAVec4,
|
||||
bool isBVec4) {
|
||||
@@ -388,11 +343,17 @@ DotOpMmaV1ConversionHelper::getMNCoords(Value thread,
|
||||
Value _fpw0 = i32_val(fpw[0]);
|
||||
Value _fpw1 = i32_val(fpw[1]);
|
||||
|
||||
DotOpMmaV1ConversionHelper::AParam aParam(isARow, isAVec4);
|
||||
DotOpMmaV1ConversionHelper::BParam bParam(isBRow, isBVec4);
|
||||
// A info
|
||||
auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout);
|
||||
auto aRep = aEncoding.getMMAv1Rep();
|
||||
auto aSpw = aEncoding.getMMAv1ShapePerWarp();
|
||||
// B info
|
||||
auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout);
|
||||
auto bSpw = bEncoding.getMMAv1ShapePerWarp();
|
||||
auto bRep = bEncoding.getMMAv1Rep();
|
||||
|
||||
SmallVector<int, 2> rep({aParam.rep[0], bParam.rep[1]});
|
||||
SmallVector<int, 2> spw({aParam.spw[0], bParam.spw[1]});
|
||||
SmallVector<int, 2> rep({aRep[0], bRep[1]});
|
||||
SmallVector<int, 2> spw({aSpw[0], bSpw[1]});
|
||||
SmallVector<unsigned, 2> shapePerCTA({spw[0] * wpt[0], spw[1] * wpt[1]});
|
||||
|
||||
Value lane = urem(thread, warpSize);
|
||||
@@ -468,23 +429,6 @@ DotOpMmaV1ConversionHelper::getMNCoords(Value thread,
|
||||
return coords; // {M,N} in row-major
|
||||
}
|
||||
|
||||
void DotOpMmaV1ConversionHelper::AParam::build(bool isARow) {
|
||||
int packSize0 = (isARow || isAVec4) ? 1 : 2;
|
||||
int repM = 2 * packSize0;
|
||||
int repK = 1;
|
||||
int spwM = fpw[0] * 4 * repM;
|
||||
rep.assign({repM, 0, repK});
|
||||
spw.assign({spwM, 0, 1});
|
||||
vec = 2 * rep[0];
|
||||
}
|
||||
|
||||
void DotOpMmaV1ConversionHelper::BParam::build(bool isBRow) {
|
||||
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
|
||||
rep.assign({0, 2 * packSize1, 1});
|
||||
spw.assign({0, fpw[1] * 4 * rep[1], 1});
|
||||
vec = 2 * rep[1];
|
||||
}
|
||||
|
||||
std::tuple<int, int>
|
||||
DotOpMmaV2ConversionHelper::getRepMN(const RankedTensorType &tensorTy) {
|
||||
auto mmaLayout = tensorTy.getEncoding().cast<MmaEncodingAttr>();
|
||||
@@ -1077,6 +1021,7 @@ LogicalResult MMA16816ConversionHelper::convertDot(Value a, Value b, Value c,
|
||||
helper.deduceMmaType(op);
|
||||
|
||||
auto aTensorTy = a.getType().cast<RankedTensorType>();
|
||||
auto bTensorTy = b.getType().cast<RankedTensorType>();
|
||||
auto dTensorTy = d.getType().cast<RankedTensorType>();
|
||||
|
||||
SmallVector<int64_t> aShape(aTensorTy.getShape().begin(),
|
||||
@@ -1090,10 +1035,10 @@ LogicalResult MMA16816ConversionHelper::convertDot(Value a, Value b, Value c,
|
||||
int numRepK = getNumRepK(aTensorTy, aShape[1]);
|
||||
|
||||
ValueTable ha =
|
||||
getValuesFromDotOperandLayoutStruct(loadedA, numRepM, numRepK);
|
||||
getValuesFromDotOperandLayoutStruct(loadedA, numRepM, numRepK, aTensorTy);
|
||||
ValueTable hb = getValuesFromDotOperandLayoutStruct(
|
||||
loadedB, std::max(numRepN / 2, 1), numRepK);
|
||||
auto fc = getElementsFromStruct(loc, loadedC, rewriter);
|
||||
loadedB, std::max(numRepN / 2, 1), numRepK, bTensorTy);
|
||||
auto fc = typeConverter->unpackLLElements(loc, loadedC, rewriter, dTensorTy);
|
||||
|
||||
auto callMma = [&](unsigned m, unsigned n, unsigned k) {
|
||||
unsigned colsPerThread = numRepN * 2;
|
||||
@@ -1139,7 +1084,7 @@ LogicalResult MMA16816ConversionHelper::convertDot(Value a, Value b, Value c,
|
||||
// replace with new packed result
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(fc.size(), resElemTy));
|
||||
Value res = getStructFromElements(loc, fc, rewriter, structTy);
|
||||
Value res = typeConverter->packLLElements(loc, fc, rewriter, structTy);
|
||||
rewriter.replaceOp(op, res);
|
||||
|
||||
return success();
|
||||
@@ -1218,14 +1163,14 @@ Value MMA16816ConversionHelper::composeValuesToDotOperandLayoutStruct(
|
||||
Type elemTy = elems[0].getType();
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(elems.size(), elemTy));
|
||||
auto result = getStructFromElements(loc, elems, rewriter, structTy);
|
||||
auto result = typeConverter->packLLElements(loc, elems, rewriter, structTy);
|
||||
return result;
|
||||
}
|
||||
MMA16816ConversionHelper::ValueTable
|
||||
MMA16816ConversionHelper::getValuesFromDotOperandLayoutStruct(Value value,
|
||||
int n0,
|
||||
int n1) const {
|
||||
auto elems = getElementsFromStruct(loc, value, rewriter);
|
||||
int n0, int n1,
|
||||
Type type) const {
|
||||
auto elems = typeConverter->unpackLLElements(loc, value, rewriter, type);
|
||||
|
||||
int offset{};
|
||||
ValueTable vals;
|
||||
@@ -1257,6 +1202,7 @@ SmallVector<Value> DotOpFMAConversionHelper::getThreadIds(
|
||||
}
|
||||
Value DotOpFMAConversionHelper::loadA(
|
||||
Value A, Value llA, BlockedEncodingAttr dLayout, Value thread, Location loc,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto aTensorTy = A.getType().cast<RankedTensorType>();
|
||||
auto aLayout = aTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
@@ -1316,10 +1262,11 @@ Value DotOpFMAConversionHelper::loadA(
|
||||
vas.emplace_back(va);
|
||||
}
|
||||
|
||||
return getStructFromValueTable(vas, rewriter, loc, elemTy);
|
||||
return getStructFromValueTable(vas, rewriter, loc, typeConverter, elemTy);
|
||||
}
|
||||
Value DotOpFMAConversionHelper::loadB(
|
||||
Value B, Value llB, BlockedEncodingAttr dLayout, Value thread, Location loc,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
auto bTensorTy = B.getType().cast<RankedTensorType>();
|
||||
auto bLayout = bTensorTy.getEncoding().cast<SharedEncodingAttr>();
|
||||
@@ -1379,14 +1326,15 @@ Value DotOpFMAConversionHelper::loadB(
|
||||
vbs.emplace_back(vb);
|
||||
}
|
||||
|
||||
return getStructFromValueTable(vbs, rewriter, loc, elemTy);
|
||||
return getStructFromValueTable(vbs, rewriter, loc, typeConverter, elemTy);
|
||||
}
|
||||
DotOpFMAConversionHelper::ValueTable
|
||||
DotOpFMAConversionHelper::getValueTableFromStruct(
|
||||
Value val, int K, int n0, int shapePerCTA, int sizePerThread,
|
||||
ConversionPatternRewriter &rewriter, Location loc) const {
|
||||
ConversionPatternRewriter &rewriter, Location loc,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter, Type type) const {
|
||||
ValueTable res;
|
||||
auto elems = getElementsFromStruct(loc, val, rewriter);
|
||||
auto elems = typeConverter->unpackLLElements(loc, val, rewriter, type);
|
||||
int index = 0;
|
||||
for (unsigned k = 0; k < K; ++k) {
|
||||
for (unsigned m = 0; m < n0; m += shapePerCTA)
|
||||
@@ -1398,7 +1346,7 @@ DotOpFMAConversionHelper::getValueTableFromStruct(
|
||||
}
|
||||
Value DotOpFMAConversionHelper::getStructFromValueTable(
|
||||
ArrayRef<Value> vals, ConversionPatternRewriter &rewriter, Location loc,
|
||||
Type elemTy) const {
|
||||
TritonGPUToLLVMTypeConverter *typeConverter, Type elemTy) const {
|
||||
SmallVector<Type> elemTypes(vals.size(), elemTy);
|
||||
SmallVector<Value> elems;
|
||||
elems.reserve(vals.size());
|
||||
@@ -1407,7 +1355,7 @@ Value DotOpFMAConversionHelper::getStructFromValueTable(
|
||||
}
|
||||
|
||||
Type structTy = struct_ty(elemTypes);
|
||||
return getStructFromElements(loc, elems, rewriter, structTy);
|
||||
return typeConverter->packLLElements(loc, elems, rewriter, structTy);
|
||||
}
|
||||
int DotOpFMAConversionHelper::getNumElemsPerThread(
|
||||
ArrayRef<int64_t> shape, DotOperandEncodingAttr dotOpLayout) {
|
||||
|
||||
@@ -26,6 +26,8 @@
|
||||
|
||||
#include "Utility.h"
|
||||
|
||||
class TritonGPUToLLVMTypeConverter;
|
||||
|
||||
namespace mlir {
|
||||
namespace LLVM {
|
||||
using namespace mlir::triton;
|
||||
@@ -45,45 +47,6 @@ struct DotOpMmaV1ConversionHelper {
|
||||
explicit DotOpMmaV1ConversionHelper(MmaEncodingAttr mmaLayout)
|
||||
: mmaLayout(mmaLayout), wpt(mmaLayout.getWarpsPerCTA()) {}
|
||||
|
||||
// Help to share some variables across multiple functions for A.
|
||||
// TODO[Superjomn]: refactor and restrict this to only use in DotOp
|
||||
// conversion.
|
||||
struct AParam {
|
||||
SmallVector<int> rep;
|
||||
SmallVector<int> spw;
|
||||
bool isAVec4{};
|
||||
int vec{}; // This could only used in DotOp, not in
|
||||
// loadA/loadB/TypeConverter
|
||||
|
||||
AParam(bool isARow, bool isAVec4) : isAVec4(isAVec4) { build(isARow); }
|
||||
|
||||
private:
|
||||
void build(bool isARow);
|
||||
};
|
||||
|
||||
// Help to share some variables across multiple functions for A.
|
||||
// TODO[Superjomn]: refactor and restrict this to only use in DotOp
|
||||
// conversion.
|
||||
struct BParam {
|
||||
SmallVector<int> rep;
|
||||
SmallVector<int> spw;
|
||||
bool isBVec4{};
|
||||
int vec{}; // This could only used in DotOp, not in
|
||||
// loadA/loadB/TypeConverter
|
||||
|
||||
BParam(bool isBRow, bool isBVec4) : isBVec4(isBVec4) { build(isBRow); }
|
||||
|
||||
private:
|
||||
void build(bool isBRow);
|
||||
};
|
||||
|
||||
int getRepM(int M) const {
|
||||
return std::max<int>(M / (wpt[0] * instrShape[0]), 1);
|
||||
}
|
||||
int getRepN(int N) const {
|
||||
return std::max<int>(N / (wpt[1] * instrShape[1]), 1);
|
||||
}
|
||||
|
||||
static ArrayRef<unsigned> getMmaInstrShape() { return instrShape; }
|
||||
|
||||
static Type getMmaRetType(TensorType operand) {
|
||||
@@ -100,35 +63,15 @@ struct DotOpMmaV1ConversionHelper {
|
||||
return struct_ty(SmallVector<Type>{vecTy});
|
||||
}
|
||||
|
||||
// Get the number of fp16x2 elements for $a.
|
||||
unsigned getNumM(int M, bool isARow, bool isAVec4) const {
|
||||
AParam param(isARow, isAVec4);
|
||||
|
||||
unsigned numM = param.rep[0] * M / (param.spw[0] * wpt[0]);
|
||||
return numM;
|
||||
}
|
||||
|
||||
// Get the number of fp16x2 elements for $b.
|
||||
unsigned getNumN(int N, bool isBRow, bool isBVec4) const {
|
||||
BParam param(isBRow, isBVec4);
|
||||
|
||||
unsigned numN = param.rep[1] * N / (param.spw[1] * wpt[1]);
|
||||
return numN;
|
||||
}
|
||||
|
||||
int numElemsPerThreadA(ArrayRef<int64_t> shape, bool isARow, bool isAVec4,
|
||||
int vec) const;
|
||||
|
||||
int numElemsPerThreadB(ArrayRef<int64_t> shape, bool isBRow, bool isBVec4,
|
||||
int vec) const;
|
||||
|
||||
// Loading $a from smem to registers, returns a LLVM::Struct.
|
||||
Value loadA(Value tensor, const SharedMemoryObject &smemObj, Value thread,
|
||||
Location loc, ConversionPatternRewriter &rewriter) const;
|
||||
Location loc, TritonGPUToLLVMTypeConverter *converter,
|
||||
ConversionPatternRewriter &rewriter, Type resultTy) const;
|
||||
|
||||
// Loading $b from smem to registers, returns a LLVM::Struct.
|
||||
Value loadB(Value tensor, const SharedMemoryObject &smemObj, Value thread,
|
||||
Location loc, ConversionPatternRewriter &rewriter) const;
|
||||
Location loc, TritonGPUToLLVMTypeConverter *converter,
|
||||
ConversionPatternRewriter &rewriter, Type resultTy) const;
|
||||
|
||||
static ArrayRef<unsigned> getOrder() { return mmaOrder; }
|
||||
|
||||
@@ -145,25 +88,17 @@ struct DotOpMmaV1ConversionHelper {
|
||||
ConversionPatternRewriter &rewriter, Location loc) const;
|
||||
|
||||
// Extract values belong to $a or $b from a LLVMStruct, the shape is n0xn1.
|
||||
DotOpMmaV1ConversionHelper::ValueTable
|
||||
extractLoadedOperand(Value llStruct, int NK,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
// Get the number of elements of this thread in M axis. The N axis could be
|
||||
// further deduced with the accSize / elemsM. \param wpt: the wpt in M axis
|
||||
// \param M: the shape in M axis
|
||||
int getElemsM(int wpt, int M, bool isARow, bool isAVec4) {
|
||||
DotOpMmaV1ConversionHelper::AParam param(isARow, isAVec4);
|
||||
int shapePerCTAM = param.spw[0] * wpt;
|
||||
return M / shapePerCTAM * param.rep[0];
|
||||
}
|
||||
DotOpMmaV1ConversionHelper::ValueTable extractLoadedOperand(
|
||||
Value llStruct, int NK, ConversionPatternRewriter &rewriter,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter, Type type) const;
|
||||
|
||||
using CoordTy = SmallVector<Value>;
|
||||
// Get the coordinates(m,n) of the elements emit by a thread in accumulator.
|
||||
static SmallVector<CoordTy>
|
||||
getMNCoords(Value thread, ConversionPatternRewriter &rewriter,
|
||||
ArrayRef<unsigned> wpt, ArrayRef<int64_t> shape, bool isARow,
|
||||
bool isBRow, bool isAVec4, bool isBVec4);
|
||||
ArrayRef<unsigned> wpt, const MmaEncodingAttr &mmaLayout,
|
||||
ArrayRef<int64_t> shape, bool isARow, bool isBRow, bool isAVec4,
|
||||
bool isBVec4);
|
||||
|
||||
// \param elemId the offset of the element in a thread
|
||||
static CoordTy getCoord(int elemId, ArrayRef<CoordTy> coords) {
|
||||
@@ -411,7 +346,7 @@ struct MMA16816ConversionHelper {
|
||||
|
||||
DotOpMmaV2ConversionHelper helper;
|
||||
ConversionPatternRewriter &rewriter;
|
||||
TypeConverter *typeConverter;
|
||||
TritonGPUToLLVMTypeConverter *typeConverter;
|
||||
Location loc;
|
||||
MLIRContext *ctx{};
|
||||
|
||||
@@ -420,7 +355,8 @@ struct MMA16816ConversionHelper {
|
||||
// dotOperand: type of either one operand of dotOp.
|
||||
MMA16816ConversionHelper(Type dotOperand, MmaEncodingAttr mmaLayout,
|
||||
Value thread, ConversionPatternRewriter &rewriter,
|
||||
TypeConverter *typeConverter, Location loc)
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
Location loc)
|
||||
: mmaLayout(mmaLayout), wpt(mmaLayout.getWarpsPerCTA()), thread(thread),
|
||||
helper(mmaLayout), rewriter(rewriter), typeConverter(typeConverter),
|
||||
loc(loc), ctx(mmaLayout.getContext()) {
|
||||
@@ -564,8 +500,8 @@ private:
|
||||
Value composeValuesToDotOperandLayoutStruct(const ValueTable &vals, int n0,
|
||||
int n1) const;
|
||||
|
||||
ValueTable getValuesFromDotOperandLayoutStruct(Value value, int n0,
|
||||
int n1) const;
|
||||
ValueTable getValuesFromDotOperandLayoutStruct(Value value, int n0, int n1,
|
||||
Type type) const;
|
||||
};
|
||||
|
||||
// Helper for conversion of FMA DotOp.
|
||||
@@ -584,19 +520,23 @@ struct DotOpFMAConversionHelper {
|
||||
ConversionPatternRewriter &rewriter, Location loc) const;
|
||||
|
||||
Value loadA(Value A, Value llA, BlockedEncodingAttr dLayout, Value thread,
|
||||
Location loc, ConversionPatternRewriter &rewriter) const;
|
||||
Location loc, TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
Value loadB(Value B, Value llB, BlockedEncodingAttr dLayout, Value thread,
|
||||
Location loc, ConversionPatternRewriter &rewriter) const;
|
||||
Location loc, TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter) const;
|
||||
|
||||
ValueTable getValueTableFromStruct(Value val, int K, int n0, int shapePerCTA,
|
||||
int sizePerThread,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Location loc) const;
|
||||
ValueTable getValueTableFromStruct(
|
||||
Value val, int K, int n0, int shapePerCTA, int sizePerThread,
|
||||
ConversionPatternRewriter &rewriter, Location loc,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter, Type type) const;
|
||||
|
||||
Value getStructFromValueTable(ArrayRef<Value> vals,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Location loc, Type elemTy) const;
|
||||
Location loc,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
Type elemTy) const;
|
||||
|
||||
// get number of elements per thread for $a or $b.
|
||||
static int getNumElemsPerThread(ArrayRef<int64_t> shape,
|
||||
|
||||
@@ -7,8 +7,6 @@ using namespace mlir::triton;
|
||||
|
||||
using ::mlir::LLVM::DotOpFMAConversionHelper;
|
||||
using ::mlir::LLVM::DotOpMmaV1ConversionHelper;
|
||||
using ::mlir::LLVM::getElementsFromStruct;
|
||||
using ::mlir::LLVM::getStructFromElements;
|
||||
using ::mlir::LLVM::MMA16816ConversionHelper;
|
||||
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||
@@ -116,28 +114,30 @@ private:
|
||||
auto AShape = ATensorTy.getShape();
|
||||
auto BShape = BTensorTy.getShape();
|
||||
|
||||
bool isARow = ALayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
bool isBRow = BLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
auto [isARow_, isBRow_, isAVec4_, isBVec4_, mmaId] =
|
||||
bool isARow = ALayout.getMMAv1IsRow();
|
||||
bool isBRow = BLayout.getMMAv1IsRow();
|
||||
auto [isARow_, isBRow_, isAVec4_, isBVec4_, _] =
|
||||
mmaLayout.decodeVoltaLayoutStates();
|
||||
assert(isARow == isARow_);
|
||||
assert(isBRow == isBRow_);
|
||||
|
||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||
|
||||
unsigned numM = helper.getNumM(AShape[0], isARow, isAVec4_);
|
||||
unsigned numN = helper.getNumN(BShape[1], isBRow, isBVec4_);
|
||||
unsigned numM = ALayout.getMMAv1NumOuter(AShape);
|
||||
unsigned numN = BLayout.getMMAv1NumOuter(BShape);
|
||||
unsigned NK = AShape[1];
|
||||
|
||||
auto has = helper.extractLoadedOperand(adaptor.getA(), NK, rewriter);
|
||||
auto hbs = helper.extractLoadedOperand(adaptor.getB(), NK, rewriter);
|
||||
auto has = helper.extractLoadedOperand(adaptor.getA(), NK, rewriter,
|
||||
getTypeConverter(), ATensorTy);
|
||||
auto hbs = helper.extractLoadedOperand(adaptor.getB(), NK, rewriter,
|
||||
getTypeConverter(), BTensorTy);
|
||||
|
||||
// Initialize accumulators with external values, the acc holds the
|
||||
// accumulator value that is shared between the MMA instructions inside a
|
||||
// DotOp, we can call the order of the values the accumulator-internal
|
||||
// order.
|
||||
SmallVector<Value> acc =
|
||||
getElementsFromStruct(loc, adaptor.getC(), rewriter);
|
||||
SmallVector<Value> acc = getTypeConverter()->unpackLLElements(
|
||||
loc, adaptor.getC(), rewriter, DTensorTy);
|
||||
size_t resSize = acc.size();
|
||||
|
||||
// The resVals holds the final result of the DotOp.
|
||||
@@ -209,9 +209,8 @@ private:
|
||||
resVals[i] = acc[i];
|
||||
}
|
||||
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(resSize, type::f32Ty(ctx)));
|
||||
Value res = getStructFromElements(loc, resVals, rewriter, structTy);
|
||||
Value res =
|
||||
getTypeConverter()->packLLElements(loc, resVals, rewriter, DTensorTy);
|
||||
rewriter.replaceOp(op, res);
|
||||
return success();
|
||||
}
|
||||
@@ -236,7 +235,8 @@ private:
|
||||
BlockedEncodingAttr dLayout =
|
||||
dTensorTy.getEncoding().cast<BlockedEncodingAttr>();
|
||||
auto order = dLayout.getOrder();
|
||||
auto cc = getElementsFromStruct(loc, adaptor.getC(), rewriter);
|
||||
auto cc = getTypeConverter()->unpackLLElements(loc, adaptor.getC(),
|
||||
rewriter, dTensorTy);
|
||||
|
||||
DotOpFMAConversionHelper helper(dLayout);
|
||||
Value llA = adaptor.getA();
|
||||
@@ -259,9 +259,11 @@ private:
|
||||
order[0] == 0 ? sizePerThread[order[1]] : sizePerThread[order[0]];
|
||||
|
||||
auto has = helper.getValueTableFromStruct(llA, K, M, mShapePerCTA,
|
||||
mSizePerThread, rewriter, loc);
|
||||
mSizePerThread, rewriter, loc,
|
||||
getTypeConverter(), aTensorTy);
|
||||
auto hbs = helper.getValueTableFromStruct(llB, K, N, nShapePerCTA,
|
||||
nSizePerThread, rewriter, loc);
|
||||
nSizePerThread, rewriter, loc,
|
||||
getTypeConverter(), bTensorTy);
|
||||
|
||||
SmallVector<Value> ret = cc;
|
||||
bool isCRow = order[0] == 1;
|
||||
@@ -281,16 +283,15 @@ private:
|
||||
}
|
||||
}
|
||||
|
||||
auto res = getStructFromElements(
|
||||
loc, ret, rewriter,
|
||||
struct_ty(SmallVector<Type>(ret.size(), ret[0].getType())));
|
||||
auto res =
|
||||
getTypeConverter()->packLLElements(loc, ret, rewriter, dTensorTy);
|
||||
rewriter.replaceOp(op, res);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateDotOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
void populateDotOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns, int numWarps,
|
||||
AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem,
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
void populateDotOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
void populateDotOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns, int numWarps,
|
||||
AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem,
|
||||
|
||||
@@ -2,9 +2,6 @@
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::LLVM::getElementsFromStruct;
|
||||
using ::mlir::LLVM::getStructFromElements;
|
||||
using ::mlir::triton::gpu::getElemsPerThread;
|
||||
|
||||
struct FpToFpOpConversion
|
||||
@@ -12,10 +9,14 @@ struct FpToFpOpConversion
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::FpToFpOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
/* ------------------ */
|
||||
// FP8 -> FP16
|
||||
/* ------------------ */
|
||||
|
||||
static SmallVector<Value>
|
||||
convertFp8x4ToFp16x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
const char *ptxAsm, const Value &v0, const Value &v1,
|
||||
const Value &v2, const Value &v3) {
|
||||
auto ctx = rewriter.getContext();
|
||||
#ifdef USE_ROCM
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
@@ -63,23 +64,12 @@ struct FpToFpOpConversion
|
||||
fp8x4Vec = bitcast(fp8x4Vec, i32_ty);
|
||||
|
||||
PTXBuilder builder;
|
||||
auto *ptxAsm = "{ \n"
|
||||
".reg .b32 a<2>, b<2>; \n"
|
||||
"prmt.b32 a0, 0, $2, 0x5040; \n"
|
||||
"prmt.b32 a1, 0, $2, 0x7060; \n"
|
||||
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n"
|
||||
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n"
|
||||
"shr.b32 b0, b0, 1; \n"
|
||||
"shr.b32 b1, b1, 1; \n"
|
||||
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
|
||||
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n"
|
||||
"}";
|
||||
auto &call = *builder.create(ptxAsm);
|
||||
auto &ptxOp = *builder.create(ptxAsm);
|
||||
|
||||
auto *o0 = builder.newOperand("=r");
|
||||
auto *o1 = builder.newOperand("=r");
|
||||
auto *i = builder.newOperand(fp8x4Vec, "r");
|
||||
call({o0, o1, i}, /*onlyAttachMLIRArgs=*/true);
|
||||
ptxOp({o0, o1, i}, /*onlyAttachMLIRArgs=*/true);
|
||||
|
||||
auto fp16x2VecTy = vec_ty(f16_ty, 2);
|
||||
auto fp16x2x2StructTy =
|
||||
@@ -96,72 +86,41 @@ struct FpToFpOpConversion
|
||||
}
|
||||
|
||||
static SmallVector<Value>
|
||||
convertFp16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
auto fp16x2VecTy = vec_ty(f16_ty, 2);
|
||||
Value fp16x2Vec0 = undef(fp16x2VecTy);
|
||||
Value fp16x2Vec1 = undef(fp16x2VecTy);
|
||||
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v0, i32_val(0));
|
||||
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v1, i32_val(1));
|
||||
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v2, i32_val(0));
|
||||
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v3, i32_val(1));
|
||||
fp16x2Vec0 = bitcast(fp16x2Vec0, i32_ty);
|
||||
fp16x2Vec1 = bitcast(fp16x2Vec1, i32_ty);
|
||||
|
||||
#ifdef USE_ROCM
|
||||
Value a0 = shl(i32_ty, fp16x2Vec0, i32_val(1));
|
||||
Value a1 = shl(i32_ty, fp16x2Vec1, i32_val(1));
|
||||
a0 = and_(i32_ty, a0, i32_val(0x7fff7fff));
|
||||
a1 = and_(i32_ty, a1, i32_val(0x7fff7fff));
|
||||
a0 = add(i32_ty, a0, i32_val(0x00800080));
|
||||
a1 = add(i32_ty, a1, i32_val(0x00800080));
|
||||
Value b0 = or_( i32_ty, and_(i32_ty, fp16x2Vec0, i32_val(0x80008000)), a0 );
|
||||
Value b1 = or_( i32_ty, and_(i32_ty, fp16x2Vec1, i32_val(0x80008000)), a1 );
|
||||
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
b0 = bitcast(b0, fp8x4VecTy);
|
||||
b1 = bitcast(b1, fp8x4VecTy);
|
||||
|
||||
return {extract_element(i8_ty, b0, i32_val(1)),
|
||||
extract_element(i8_ty, b0, i32_val(3)),
|
||||
extract_element(i8_ty, b1, i32_val(1)),
|
||||
extract_element(i8_ty, b1, i32_val(3))
|
||||
};
|
||||
#else
|
||||
PTXBuilder builder;
|
||||
convertFp8E4M3x4ToFp16x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
auto *ptxAsm = "{ \n"
|
||||
".reg .b32 a<2>, b<2>; \n"
|
||||
"shl.b32 a0, $1, 1; \n"
|
||||
"shl.b32 a1, $2, 1; \n"
|
||||
"lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n"
|
||||
"lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n"
|
||||
"add.u32 a0, a0, 0x00800080; \n"
|
||||
"add.u32 a1, a1, 0x00800080; \n"
|
||||
"lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n"
|
||||
"lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n"
|
||||
"prmt.b32 $0, b0, b1, 0x7531; \n"
|
||||
"prmt.b32 a0, 0, $2, 0x5040; \n"
|
||||
"prmt.b32 a1, 0, $2, 0x7060; \n"
|
||||
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n"
|
||||
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n"
|
||||
"shr.b32 b0, b0, 1; \n"
|
||||
"shr.b32 b1, b1, 1; \n"
|
||||
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
|
||||
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n"
|
||||
"}";
|
||||
auto &call = *builder.create(ptxAsm);
|
||||
|
||||
auto *o = builder.newOperand("=r");
|
||||
auto *i0 = builder.newOperand(fp16x2Vec0, "r");
|
||||
auto *i1 = builder.newOperand(fp16x2Vec1, "r");
|
||||
call({o, i0, i1}, /*onlyAttachMLIRArgs=*/true);
|
||||
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false);
|
||||
return {extract_element(i8_ty, fp8x4Vec, i32_val(0)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(3))};
|
||||
#endif
|
||||
return convertFp8x4ToFp16x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
|
||||
}
|
||||
|
||||
static SmallVector<Value>
|
||||
convertFp8E5M2x4ToFp16x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
auto *ptxAsm = "{ \n"
|
||||
"prmt.b32 $0, 0, $2, 0x5140; \n\t"
|
||||
"prmt.b32 $1, 0, $2, 0x7362; \n\t"
|
||||
"}";
|
||||
return convertFp8x4ToFp16x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
|
||||
}
|
||||
|
||||
/* ------------------ */
|
||||
// FP8 -> BF16
|
||||
/* ------------------ */
|
||||
static SmallVector<Value>
|
||||
convertFp8x4ToBf16x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
const char *ptxAsm, const Value &v0, const Value &v1,
|
||||
const Value &v2, const Value &v3) {
|
||||
auto ctx = rewriter.getContext();
|
||||
#ifdef USE_ROCM
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
@@ -178,7 +137,7 @@ struct FpToFpOpConversion
|
||||
a1 = insert_element(fp8x4VecTy, a1, int_val(8,0), i32_val(2));
|
||||
a1 = insert_element(fp8x4VecTy, a1, v3, i32_val(3));
|
||||
a1 = bitcast(a1, i32_ty);
|
||||
|
||||
|
||||
Value sign0 = and_(i32_ty, a0, i32_val(0x80008000));
|
||||
Value sign1 = and_(i32_ty, a1, i32_val(0x80008000));
|
||||
Value nosign0 = and_(i32_ty, a0, i32_val(0x7fff7fff));
|
||||
@@ -211,27 +170,12 @@ struct FpToFpOpConversion
|
||||
fp8x4Vec = bitcast(fp8x4Vec, i32_ty);
|
||||
|
||||
PTXBuilder builder;
|
||||
auto *ptxAsm = "{ \n"
|
||||
".reg .b32 a<2>, sign<2>, nosign<2>, b<2>; \n"
|
||||
"prmt.b32 a0, 0, $2, 0x5040; \n"
|
||||
"prmt.b32 a1, 0, $2, 0x7060; \n"
|
||||
"and.b32 sign0, a0, 0x80008000; \n"
|
||||
"and.b32 sign1, a1, 0x80008000; \n"
|
||||
"and.b32 nosign0, a0, 0x7fff7fff; \n"
|
||||
"and.b32 nosign1, a1, 0x7fff7fff; \n"
|
||||
"shr.b32 nosign0, nosign0, 4; \n"
|
||||
"shr.b32 nosign1, nosign1, 4; \n"
|
||||
"add.u32 nosign0, nosign0, 0x38003800; \n"
|
||||
"add.u32 nosign1, nosign1, 0x38003800; \n"
|
||||
"or.b32 $0, sign0, nosign0; \n"
|
||||
"or.b32 $1, sign1, nosign1; \n"
|
||||
"}";
|
||||
auto &call = *builder.create(ptxAsm);
|
||||
auto &ptxOp = *builder.create(ptxAsm);
|
||||
|
||||
auto *o0 = builder.newOperand("=r");
|
||||
auto *o1 = builder.newOperand("=r");
|
||||
auto *i = builder.newOperand(fp8x4Vec, "r");
|
||||
call({o0, o1, i}, /* onlyAttachMLIRArgs */ true);
|
||||
ptxOp({o0, o1, i}, /* onlyAttachMLIRArgs */ true);
|
||||
|
||||
auto bf16x2VecTy = vec_ty(i16_ty, 2);
|
||||
auto bf16x2x2StructTy =
|
||||
@@ -247,10 +191,166 @@ struct FpToFpOpConversion
|
||||
#endif
|
||||
}
|
||||
|
||||
static SmallVector<Value>
|
||||
convertFp8E4M3x4ToBf16x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
auto *ptxAsm = "{ \n"
|
||||
".reg .b32 a<2>, sign<2>, nosign<2>, b<2>; \n"
|
||||
"prmt.b32 a0, 0, $2, 0x5040; \n"
|
||||
"prmt.b32 a1, 0, $2, 0x7060; \n"
|
||||
"and.b32 sign0, a0, 0x80008000; \n"
|
||||
"and.b32 sign1, a1, 0x80008000; \n"
|
||||
"and.b32 nosign0, a0, 0x7fff7fff; \n"
|
||||
"and.b32 nosign1, a1, 0x7fff7fff; \n"
|
||||
"shr.b32 nosign0, nosign0, 4; \n"
|
||||
"shr.b32 nosign1, nosign1, 4; \n"
|
||||
"add.u32 nosign0, nosign0, 0x38003800; \n"
|
||||
"add.u32 nosign1, nosign1, 0x38003800; \n"
|
||||
"or.b32 $0, sign0, nosign0; \n"
|
||||
"or.b32 $1, sign1, nosign1; \n"
|
||||
"}";
|
||||
return convertFp8x4ToBf16x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
|
||||
};
|
||||
|
||||
static SmallVector<Value>
|
||||
convertFp8E5M2x4ToBf16x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
auto *ptxAsm = "{ \n"
|
||||
".reg .b32 a<2>, b<2>; \n"
|
||||
"prmt.b32 a0, 0, $2, 0x5140; \n"
|
||||
"prmt.b32 a1, 0, $2, 0x7362; \n"
|
||||
"lop3.b32 b0, a0, 0x7fff7fff, 0, 0xc0; \n"
|
||||
"lop3.b32 b1, a1, 0x7fff7fff, 0, 0xc0; \n"
|
||||
"shr.b32 b0, b0, 3; \n"
|
||||
"shr.b32 b1, b1, 3; \n"
|
||||
"add.u32 b0, b0, 0x30003000; \n"
|
||||
"add.u32 b1, b1, 0x30003000; \n"
|
||||
"lop3.b32 $0, b0, 0x80008000, a0, 0xf8; \n"
|
||||
"lop3.b32 $1, b1, 0x80008000, a1, 0xf8; \n"
|
||||
"}";
|
||||
return convertFp8x4ToBf16x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
|
||||
};
|
||||
|
||||
/* ------------------ */
|
||||
// FP16 -> FP8
|
||||
/* ------------------ */
|
||||
|
||||
static SmallVector<Value>
|
||||
convertFp16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const char *ptxAsm, const Value &v0, const Value &v1,
|
||||
const Value &v2, const Value &v3) {
|
||||
auto fp16x2VecTy = vec_ty(f16_ty, 2);
|
||||
Value fp16x2Vec0 = undef(fp16x2VecTy);
|
||||
Value fp16x2Vec1 = undef(fp16x2VecTy);
|
||||
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v0, i32_val(0));
|
||||
fp16x2Vec0 = insert_element(fp16x2VecTy, fp16x2Vec0, v1, i32_val(1));
|
||||
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v2, i32_val(0));
|
||||
fp16x2Vec1 = insert_element(fp16x2VecTy, fp16x2Vec1, v3, i32_val(1));
|
||||
fp16x2Vec0 = bitcast(fp16x2Vec0, i32_ty);
|
||||
fp16x2Vec1 = bitcast(fp16x2Vec1, i32_ty);
|
||||
|
||||
#ifdef USE_ROCM
|
||||
Value a0 = shl(i32_ty, fp16x2Vec0, i32_val(1));
|
||||
Value a1 = shl(i32_ty, fp16x2Vec1, i32_val(1));
|
||||
a0 = and_(i32_ty, a0, i32_val(0x7fff7fff));
|
||||
a1 = and_(i32_ty, a1, i32_val(0x7fff7fff));
|
||||
a0 = add(i32_ty, a0, i32_val(0x00800080));
|
||||
a1 = add(i32_ty, a1, i32_val(0x00800080));
|
||||
Value b0 = or_( i32_ty, and_(i32_ty, fp16x2Vec0, i32_val(0x80008000)), a0 );
|
||||
Value b1 = or_( i32_ty, and_(i32_ty, fp16x2Vec1, i32_val(0x80008000)), a1 );
|
||||
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
b0 = bitcast(b0, fp8x4VecTy);
|
||||
b1 = bitcast(b1, fp8x4VecTy);
|
||||
|
||||
return {extract_element(i8_ty, b0, i32_val(1)),
|
||||
extract_element(i8_ty, b0, i32_val(3)),
|
||||
extract_element(i8_ty, b1, i32_val(1)),
|
||||
extract_element(i8_ty, b1, i32_val(3))
|
||||
};
|
||||
#else
|
||||
PTXBuilder builder;
|
||||
auto &ptxOp = *builder.create(ptxAsm);
|
||||
|
||||
auto *o = builder.newOperand("=r");
|
||||
auto *i0 = builder.newOperand(fp16x2Vec0, "r");
|
||||
auto *i1 = builder.newOperand(fp16x2Vec1, "r");
|
||||
ptxOp({o, i0, i1}, /*onlyAttachMLIRArgs=*/true);
|
||||
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false);
|
||||
return {extract_element(i8_ty, fp8x4Vec, i32_val(0)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(3))};
|
||||
#endif
|
||||
}
|
||||
|
||||
static SmallVector<Value>
|
||||
convertFp16x4ToFp8E4M3x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
auto *ptxAsm = "{ \n"
|
||||
".reg .b32 a<2>, b<2>; \n"
|
||||
"shl.b32 a0, $1, 1; \n"
|
||||
"shl.b32 a1, $2, 1; \n"
|
||||
"lop3.b32 a0, a0, 0x7fff7fff, 0, 0xc0; \n"
|
||||
"lop3.b32 a1, a1, 0x7fff7fff, 0, 0xc0; \n"
|
||||
"add.u32 a0, a0, 0x00800080; \n"
|
||||
"add.u32 a1, a1, 0x00800080; \n"
|
||||
"lop3.b32 b0, $1, 0x80008000, a0, 0xea; \n"
|
||||
"lop3.b32 b1, $2, 0x80008000, a1, 0xea; \n"
|
||||
"prmt.b32 $0, b0, b1, 0x7531; \n"
|
||||
"}";
|
||||
return convertFp16x4ToFp8x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
|
||||
}
|
||||
|
||||
static SmallVector<Value>
|
||||
convertFp16x4ToFp8E5M2x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
auto *ptxAsm = "{ \n"
|
||||
"prmt.b32 $0, $1, $2, 0x7531; \n\t"
|
||||
"}";
|
||||
return convertFp16x4ToFp8x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
|
||||
}
|
||||
|
||||
/* ------------------ */
|
||||
// FP32 -> FP8
|
||||
/* ------------------ */
|
||||
|
||||
static SmallVector<Value>
|
||||
convertFp32x4ToFp8E4M3x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
auto c0 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v0);
|
||||
auto c1 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v1);
|
||||
auto c2 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v2);
|
||||
auto c3 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v3);
|
||||
return convertFp16x4ToFp8E4M3x4(loc, rewriter, c0, c1, c2, c3);
|
||||
}
|
||||
|
||||
static SmallVector<Value>
|
||||
convertFp32x4ToFp8E5M2x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
auto c0 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v0);
|
||||
auto c1 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v1);
|
||||
auto c2 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v2);
|
||||
auto c3 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v3);
|
||||
return convertFp16x4ToFp8E5M2x4(loc, rewriter, c0, c1, c2, c3);
|
||||
}
|
||||
|
||||
/* ------------------ */
|
||||
// BF16 -> FP8
|
||||
/* ------------------ */
|
||||
|
||||
static SmallVector<Value>
|
||||
convertBf16x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
const char *ptxAsm, const Value &v0, const Value &v1,
|
||||
const Value &v2, const Value &v3) {
|
||||
auto bf16x2VecTy = vec_ty(i16_ty, 2);
|
||||
Value bf16x2Vec0 = undef(bf16x2VecTy);
|
||||
Value bf16x2Vec1 = undef(bf16x2VecTy);
|
||||
@@ -318,6 +418,26 @@ struct FpToFpOpConversion
|
||||
|
||||
#else
|
||||
PTXBuilder builder;
|
||||
auto &ptxOp = *builder.create(ptxAsm);
|
||||
|
||||
auto *o = builder.newOperand("=r");
|
||||
auto *i0 = builder.newOperand(bf16x2Vec0, "r");
|
||||
auto *i1 = builder.newOperand(bf16x2Vec1, "r");
|
||||
ptxOp({o, i0, i1}, /*onlyAttachMLIRArgs=*/true);
|
||||
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false);
|
||||
return {extract_element(i8_ty, fp8x4Vec, i32_val(0)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(3))};
|
||||
#endif
|
||||
}
|
||||
|
||||
static SmallVector<Value>
|
||||
convertBf16x4ToFp8E4M3x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
auto *ptxAsm = "{ \n"
|
||||
".reg .u32 sign, sign<2>, nosign, nosign<2>; \n"
|
||||
".reg .u32 fp8_min, fp8_max, rn_, zero; \n"
|
||||
@@ -354,49 +474,49 @@ struct FpToFpOpConversion
|
||||
"prmt.b32 nosign, nosign0, nosign1, 0x6420; \n"
|
||||
"or.b32 $0, nosign, sign; \n"
|
||||
"}";
|
||||
auto &call = *builder.create(ptxAsm);
|
||||
return convertBf16x4ToFp8x4(loc, rewriter, ptxAsm, v0, v1, v2, v3);
|
||||
};
|
||||
|
||||
auto *o = builder.newOperand("=r");
|
||||
auto *i0 = builder.newOperand(bf16x2Vec0, "r");
|
||||
auto *i1 = builder.newOperand(bf16x2Vec1, "r");
|
||||
call({o, i0, i1}, /*onlyAttachMLIRArgs=*/true);
|
||||
// TODO:
|
||||
// static SmallVector<Value>
|
||||
// convertBf16x4ToFp8E5M2x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
// const Value &v0, const Value &v1, const Value &v2,
|
||||
// const Value &v3) {
|
||||
// }
|
||||
|
||||
auto fp8x4VecTy = vec_ty(i8_ty, 4);
|
||||
auto fp8x4Vec = builder.launch(rewriter, loc, fp8x4VecTy, false);
|
||||
return {extract_element(i8_ty, fp8x4Vec, i32_val(0)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(1)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(2)),
|
||||
extract_element(i8_ty, fp8x4Vec, i32_val(3))};
|
||||
#endif
|
||||
/* ------------------ */
|
||||
// FP8 -> FP32
|
||||
/* ------------------ */
|
||||
|
||||
static SmallVector<Value>
|
||||
convertFp8E4M3x4ToFp32x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
auto fp16Values = convertFp8E4M3x4ToFp16x4(loc, rewriter, v0, v1, v2, v3);
|
||||
return {convertFp16ToFp32(loc, rewriter, fp16Values[0]),
|
||||
convertFp16ToFp32(loc, rewriter, fp16Values[1]),
|
||||
convertFp16ToFp32(loc, rewriter, fp16Values[2]),
|
||||
convertFp16ToFp32(loc, rewriter, fp16Values[3])};
|
||||
}
|
||||
|
||||
static SmallVector<Value>
|
||||
convertFp8x4ToFp32x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
auto fp16Values = convertFp8x4ToFp16x4(loc, rewriter, v0, v1, v2, v3);
|
||||
convertFp8E5M2x4ToFp32x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
auto fp16Values = convertFp8E5M2x4ToFp16x4(loc, rewriter, v0, v1, v2, v3);
|
||||
return {rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[0]),
|
||||
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[1]),
|
||||
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[2]),
|
||||
rewriter.create<LLVM::FPExtOp>(loc, f32_ty, fp16Values[3])};
|
||||
}
|
||||
|
||||
static SmallVector<Value>
|
||||
convertFp32x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
auto c0 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v0);
|
||||
auto c1 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v1);
|
||||
auto c2 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v2);
|
||||
auto c3 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v3);
|
||||
return convertFp16x4ToFp8x4(loc, rewriter, c0, c1, c2, c3);
|
||||
}
|
||||
//
|
||||
|
||||
static SmallVector<Value>
|
||||
convertFp8x4ToFp64x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
auto fp16Values = convertFp8x4ToFp16x4(loc, rewriter, v0, v1, v2, v3);
|
||||
convertFp8E4M3x4ToFp64x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
auto fp16Values = convertFp8E4M3x4ToFp16x4(loc, rewriter, v0, v1, v2, v3);
|
||||
return {rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[0]),
|
||||
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[1]),
|
||||
rewriter.create<LLVM::FPExtOp>(loc, f64_ty, fp16Values[2]),
|
||||
@@ -404,14 +524,14 @@ struct FpToFpOpConversion
|
||||
}
|
||||
|
||||
static SmallVector<Value>
|
||||
convertFp64x4ToFp8x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
convertFp64x4ToFp8E4M3x4(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Value &v0, const Value &v1, const Value &v2,
|
||||
const Value &v3) {
|
||||
auto c0 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v0);
|
||||
auto c1 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v1);
|
||||
auto c2 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v2);
|
||||
auto c3 = rewriter.create<LLVM::FPTruncOp>(loc, f16_ty, v3);
|
||||
return convertFp16x4ToFp8x4(loc, rewriter, c0, c1, c2, c3);
|
||||
return convertFp16x4ToFp8E4M3x4(loc, rewriter, c0, c1, c2, c3);
|
||||
}
|
||||
|
||||
static Value convertBf16ToFp32(Location loc,
|
||||
@@ -424,7 +544,7 @@ struct FpToFpOpConversion
|
||||
return(bitcast(shifted, f32_ty));
|
||||
#else
|
||||
PTXBuilder builder;
|
||||
auto &cvt = *builder.create("cvt.rn.f32.bf16");
|
||||
auto &cvt = *builder.create("cvt.f32.bf16");
|
||||
auto res = builder.newOperand("=r");
|
||||
auto operand = builder.newOperand(v, "h");
|
||||
cvt(res, operand);
|
||||
@@ -432,6 +552,17 @@ struct FpToFpOpConversion
|
||||
#endif
|
||||
}
|
||||
|
||||
static Value convertFp16ToFp32(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Value &v) {
|
||||
PTXBuilder builder;
|
||||
auto &cvt = *builder.create("cvt.f32.f16");
|
||||
auto res = builder.newOperand("=r");
|
||||
auto operand = builder.newOperand(v, "h");
|
||||
cvt(res, operand);
|
||||
return builder.launch(rewriter, loc, f32_ty, false);
|
||||
}
|
||||
|
||||
static Value convertFp32ToBf16(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Value &v) {
|
||||
@@ -452,6 +583,17 @@ struct FpToFpOpConversion
|
||||
#endif
|
||||
}
|
||||
|
||||
static Value convertFp32ToFp16(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Value &v) {
|
||||
PTXBuilder builder;
|
||||
auto &cvt = *builder.create("cvt.rn.f16.f32");
|
||||
auto res = builder.newOperand("=h");
|
||||
auto operand = builder.newOperand(v, "r");
|
||||
cvt(res, operand);
|
||||
return builder.launch(rewriter, loc, f16_ty, false);
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::FpToFpOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
@@ -463,58 +605,68 @@ struct FpToFpOpConversion
|
||||
auto loc = op->getLoc();
|
||||
auto elems = getElemsPerThread(dstTensorType);
|
||||
SmallVector<Value> resultVals;
|
||||
bool isSrcFP8 =
|
||||
srcEltType.isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>();
|
||||
bool isDstFP8 =
|
||||
dstEltType.isa<mlir::Float8E4M3FNType, mlir::Float8E5M2Type>();
|
||||
|
||||
// Select convertor
|
||||
if (srcEltType.isa<triton::Float8Type>() ||
|
||||
dstEltType.isa<triton::Float8Type>()) {
|
||||
std::function<SmallVector<Value>(Location, ConversionPatternRewriter &,
|
||||
const Value &, const Value &,
|
||||
const Value &, const Value &)>
|
||||
convertor;
|
||||
if (srcEltType.isa<triton::Float8Type>() && dstEltType.isF16()) {
|
||||
convertor = convertFp8x4ToFp16x4;
|
||||
} else if (srcEltType.isF16() && dstEltType.isa<triton::Float8Type>()) {
|
||||
convertor = convertFp16x4ToFp8x4;
|
||||
} else if (srcEltType.isa<triton::Float8Type>() && dstEltType.isBF16()) {
|
||||
convertor = convertFp8x4ToBf16x4;
|
||||
} else if (srcEltType.isBF16() && dstEltType.isa<triton::Float8Type>()) {
|
||||
convertor = convertBf16x4ToFp8x4;
|
||||
} else if (srcEltType.isa<triton::Float8Type>() && dstEltType.isF32()) {
|
||||
convertor = convertFp8x4ToFp32x4;
|
||||
} else if (srcEltType.isF32() && dstEltType.isa<triton::Float8Type>()) {
|
||||
convertor = convertFp32x4ToFp8x4;
|
||||
} else if (srcEltType.isa<triton::Float8Type>() && dstEltType.isF64()) {
|
||||
convertor = convertFp8x4ToFp64x4;
|
||||
} else if (srcEltType.isF64() && dstEltType.isa<triton::Float8Type>()) {
|
||||
convertor = convertFp64x4ToFp8x4;
|
||||
} else {
|
||||
assert(false && "unsupported fp8 casting");
|
||||
}
|
||||
typedef std::function<SmallVector<Value>(
|
||||
Location, ConversionPatternRewriter &, const Value &, const Value &,
|
||||
const Value &, const Value &)>
|
||||
ConvertorT;
|
||||
|
||||
// Vectorized casting
|
||||
assert(elems % 4 == 0 &&
|
||||
"FP8 casting only support tensors with 4-aligned sizes");
|
||||
auto elements = getElementsFromStruct(loc, adaptor.getFrom(), rewriter);
|
||||
for (size_t i = 0; i < elems; i += 4) {
|
||||
auto converted = convertor(loc, rewriter, elements[i], elements[i + 1],
|
||||
elements[i + 2], elements[i + 3]);
|
||||
resultVals.append(converted);
|
||||
}
|
||||
} else if (srcEltType.isBF16() && dstEltType.isF32()) {
|
||||
resultVals.emplace_back(
|
||||
convertBf16ToFp32(loc, rewriter, adaptor.getFrom()));
|
||||
} else if (srcEltType.isF32() && dstEltType.isBF16()) {
|
||||
resultVals.emplace_back(
|
||||
convertFp32ToBf16(loc, rewriter, adaptor.getFrom()));
|
||||
} else {
|
||||
assert(false && "unsupported type casting");
|
||||
auto F8E4M3TyID = TypeID::get<mlir::Float8E4M3FNType>();
|
||||
auto F8E5M2TyID = TypeID::get<mlir::Float8E5M2Type>();
|
||||
auto F16TyID = TypeID::get<mlir::Float16Type>();
|
||||
auto BF16TyID = TypeID::get<mlir::BFloat16Type>();
|
||||
auto F32TyID = TypeID::get<mlir::Float32Type>();
|
||||
auto F64TyID = TypeID::get<mlir::Float64Type>();
|
||||
DenseMap<std::pair<TypeID, TypeID>, ConvertorT> convertorMap = {
|
||||
// F8 -> F16
|
||||
{{F8E4M3TyID, F16TyID}, convertFp8E4M3x4ToFp16x4},
|
||||
{{F8E5M2TyID, F16TyID}, convertFp8E5M2x4ToFp16x4},
|
||||
// F16 -> F8
|
||||
{{F16TyID, F8E4M3TyID}, convertFp16x4ToFp8E4M3x4},
|
||||
{{F16TyID, F8E5M2TyID}, convertFp16x4ToFp8E5M2x4},
|
||||
// F8 -> BF16
|
||||
{{F8E4M3TyID, BF16TyID}, convertFp8E4M3x4ToBf16x4},
|
||||
{{F8E5M2TyID, BF16TyID}, convertFp8E5M2x4ToBf16x4},
|
||||
// BF16 -> F8
|
||||
{{BF16TyID, F8E4M3TyID}, convertBf16x4ToFp8E4M3x4},
|
||||
// TODO:
|
||||
// {{BF16TyID, F8E5M2TyID}, convertBf16x4ToFp8E5M2x4},
|
||||
// F8 -> F32
|
||||
{{F8E4M3TyID, F32TyID}, convertFp8E4M3x4ToFp32x4},
|
||||
{{F8E5M2TyID, F32TyID}, convertFp8E5M2x4ToFp32x4},
|
||||
// F32 -> F8
|
||||
{{F32TyID, F8E4M3TyID}, convertFp32x4ToFp8E4M3x4},
|
||||
{{F32TyID, F8E5M2TyID}, convertFp32x4ToFp8E5M2x4},
|
||||
};
|
||||
|
||||
std::pair<TypeID, TypeID> key = {srcEltType.getTypeID(),
|
||||
dstEltType.getTypeID()};
|
||||
if (convertorMap.count(key) == 0) {
|
||||
llvm::errs() << "Unsupported conversion from " << srcEltType << " to "
|
||||
<< dstEltType << "\n";
|
||||
llvm_unreachable("");
|
||||
}
|
||||
auto convertor = convertorMap.lookup(key);
|
||||
|
||||
// Vectorized casting
|
||||
assert(elems % 4 == 0 &&
|
||||
"FP8 casting only support tensors with 4-aligned sizes");
|
||||
auto elements = getTypeConverter()->unpackLLElements(
|
||||
loc, adaptor.getFrom(), rewriter, srcTensorType);
|
||||
for (size_t i = 0; i < elems; i += 4) {
|
||||
auto converted = convertor(loc, rewriter, elements[i], elements[i + 1],
|
||||
elements[i + 2], elements[i + 3]);
|
||||
resultVals.append(converted);
|
||||
}
|
||||
|
||||
assert(resultVals.size() == elems);
|
||||
auto convertedDstTensorType =
|
||||
this->getTypeConverter()->convertType(dstTensorType);
|
||||
auto result = getStructFromElements(loc, resultVals, rewriter,
|
||||
convertedDstTensorType);
|
||||
auto result = getTypeConverter()->packLLElements(loc, resultVals, rewriter,
|
||||
dstTensorType);
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
@@ -536,8 +688,8 @@ class ElementwiseOpConversionBase
|
||||
public:
|
||||
using OpAdaptor = typename SourceOp::Adaptor;
|
||||
|
||||
explicit ElementwiseOpConversionBase(LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
explicit ElementwiseOpConversionBase(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, PatternBenefit benefit = 1)
|
||||
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
|
||||
|
||||
LogicalResult
|
||||
@@ -553,7 +705,7 @@ public:
|
||||
Type structTy = this->getTypeConverter()->convertType(resultTy);
|
||||
|
||||
auto *concreteThis = static_cast<const ConcreteT *>(this);
|
||||
auto operands = getOperands(rewriter, adaptor, elems, loc);
|
||||
auto operands = getOperands(rewriter, adaptor, resultTy, elems, loc);
|
||||
SmallVector<Value> resultVals(elems);
|
||||
for (unsigned i = 0; i < elems; ++i) {
|
||||
resultVals[i] = concreteThis->createDestOp(op, adaptor, rewriter, elemTy,
|
||||
@@ -561,7 +713,8 @@ public:
|
||||
if (!bool(resultVals[i]))
|
||||
return failure();
|
||||
}
|
||||
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||
Value view = this->getTypeConverter()->packLLElements(loc, resultVals,
|
||||
rewriter, resultTy);
|
||||
rewriter.replaceOp(op, view);
|
||||
|
||||
return success();
|
||||
@@ -570,10 +723,11 @@ public:
|
||||
protected:
|
||||
SmallVector<SmallVector<Value>>
|
||||
getOperands(ConversionPatternRewriter &rewriter, OpAdaptor adaptor,
|
||||
const unsigned elems, Location loc) const {
|
||||
Type operandTy, const unsigned elems, Location loc) const {
|
||||
SmallVector<SmallVector<Value>> operands(elems);
|
||||
for (auto operand : adaptor.getOperands()) {
|
||||
auto sub_operands = getElementsFromStruct(loc, operand, rewriter);
|
||||
auto sub_operands = this->getTypeConverter()->unpackLLElements(
|
||||
loc, operand, rewriter, operandTy);
|
||||
for (size_t i = 0; i < elems; ++i) {
|
||||
operands[i].push_back(sub_operands[i]);
|
||||
}
|
||||
@@ -994,15 +1148,14 @@ struct ExpOpConversionApprox
|
||||
}
|
||||
};
|
||||
|
||||
void populateElementwiseOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
int numWarps,
|
||||
AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation,
|
||||
Value smem, PatternBenefit benefit) {
|
||||
void populateElementwiseOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem, PatternBenefit benefit) {
|
||||
#define POPULATE_TERNARY_OP(SRC_OP, DST_OP) \
|
||||
patterns.add<ElementwiseOpConversion<SRC_OP, DST_OP>>(typeConverter, benefit);
|
||||
POPULATE_TERNARY_OP(triton::gpu::SelectOp, LLVM::SelectOp)
|
||||
POPULATE_TERNARY_OP(arith::SelectOp, LLVM::SelectOp)
|
||||
#undef POPULATE_TERNARY_OP
|
||||
|
||||
#define POPULATE_BINARY_OP(SRC_OP, DST_OP) \
|
||||
@@ -1056,9 +1209,160 @@ void populateElementwiseOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
patterns.add<FpToFpOpConversion>(typeConverter, benefit);
|
||||
|
||||
patterns.add<ExtElemwiseOpConversion>(typeConverter, benefit);
|
||||
// ExpOpConversionApprox will try using ex2.approx if the input type is FP32.
|
||||
// For FP64 input type, ExpOpConversionApprox will return failure and
|
||||
// ExpOpConversionApprox will try using ex2.approx if the input type is
|
||||
// FP32. For FP64 input type, ExpOpConversionApprox will return failure and
|
||||
// ElementwiseOpConversion<math::ExpOp, math::ExpOp> defined below will call
|
||||
// __nv_expf for higher-precision calculation
|
||||
patterns.add<ExpOpConversionApprox>(typeConverter, benefit);
|
||||
}
|
||||
|
||||
struct FPExtOpConversion
|
||||
: ElementwiseOpConversionBase<LLVM::FPExtOp, FPExtOpConversion> {
|
||||
using Base = ElementwiseOpConversionBase<LLVM::FPExtOp, FPExtOpConversion>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
|
||||
static bool isLegalOp(LLVM::FPExtOp op) {
|
||||
auto retTy = op.getResult().getType();
|
||||
auto srcTy = op.getOperand().getType();
|
||||
if (retTy.isF32() && srcTy.isF16()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Value createDestOp(LLVM::FPExtOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||
ValueRange operands, Location loc) const {
|
||||
return FpToFpOpConversion::convertFp16ToFp32(loc, rewriter, operands[0]);
|
||||
}
|
||||
};
|
||||
|
||||
struct FPTruncOpConversion
|
||||
: ElementwiseOpConversionBase<LLVM::FPTruncOp, FPTruncOpConversion> {
|
||||
using Base =
|
||||
ElementwiseOpConversionBase<LLVM::FPTruncOp, FPTruncOpConversion>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
|
||||
static bool isLegalOp(LLVM::FPTruncOp op) {
|
||||
auto retTy = op.getResult().getType();
|
||||
auto srcTy = op.getOperand().getType();
|
||||
if (retTy.isF16() && srcTy.isF32()) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Value createDestOp(LLVM::FPTruncOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||
ValueRange operands, Location loc) const {
|
||||
return FpToFpOpConversion::convertFp32ToFp16(loc, rewriter, operands[0]);
|
||||
}
|
||||
};
|
||||
|
||||
struct TruncOpConversion
|
||||
: ElementwiseOpConversionBase<LLVM::TruncOp, TruncOpConversion> {
|
||||
using Base = ElementwiseOpConversionBase<LLVM::TruncOp, TruncOpConversion>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
|
||||
static bool isLegalOp(LLVM::TruncOp op) {
|
||||
auto retTy = op.getResult().getType();
|
||||
auto srcTy = op.getOperand().getType();
|
||||
if (retTy.isInteger(16) && srcTy.isInteger(32)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Value createDestOp(LLVM::TruncOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||
ValueRange operands, Location loc) const {
|
||||
PTXBuilder builder;
|
||||
auto &cvt = *builder.create("cvt.u16.u32");
|
||||
auto res = builder.newOperand("=h");
|
||||
auto operand = builder.newOperand(operands[0], "r");
|
||||
cvt(res, operand);
|
||||
return builder.launch(rewriter, loc, i16_ty, false);
|
||||
}
|
||||
};
|
||||
|
||||
struct SExtOpConversion
|
||||
: ElementwiseOpConversionBase<LLVM::SExtOp, SExtOpConversion> {
|
||||
using Base = ElementwiseOpConversionBase<LLVM::SExtOp, SExtOpConversion>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
|
||||
static bool isLegalOp(LLVM::SExtOp op) {
|
||||
auto retTy = op.getResult().getType();
|
||||
auto srcTy = op.getOperand().getType();
|
||||
if (retTy.isInteger(32) && srcTy.isInteger(16)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Value createDestOp(LLVM::SExtOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||
ValueRange operands, Location loc) const {
|
||||
PTXBuilder builder;
|
||||
auto &cvt = *builder.create("cvt.s32.s16");
|
||||
auto res = builder.newOperand("=r");
|
||||
auto operand = builder.newOperand(operands[0], "h");
|
||||
cvt(res, operand);
|
||||
return builder.launch(rewriter, loc, i32_ty, false);
|
||||
}
|
||||
};
|
||||
|
||||
struct ZExtOpConversion
|
||||
: ElementwiseOpConversionBase<LLVM::ZExtOp, ZExtOpConversion> {
|
||||
using Base = ElementwiseOpConversionBase<LLVM::ZExtOp, ZExtOpConversion>;
|
||||
using Base::Base;
|
||||
using Adaptor = typename Base::OpAdaptor;
|
||||
|
||||
static bool isLegalOp(LLVM::ZExtOp op) {
|
||||
auto retTy = op.getResult().getType();
|
||||
auto srcTy = op.getOperand().getType();
|
||||
if (retTy.isInteger(32) && srcTy.isInteger(16)) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
Value createDestOp(LLVM::ZExtOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter, Type elemTy,
|
||||
ValueRange operands, Location loc) const {
|
||||
PTXBuilder builder;
|
||||
auto &cvt = *builder.create("cvt.u32.u16");
|
||||
auto res = builder.newOperand("=r");
|
||||
auto operand = builder.newOperand(operands[0], "h");
|
||||
cvt(res, operand);
|
||||
return builder.launch(rewriter, loc, i32_ty, false);
|
||||
}
|
||||
};
|
||||
|
||||
bool isLegalElementwiseOp(Operation *op) {
|
||||
if (isa<LLVM::FPExtOp>(op)) {
|
||||
return FPExtOpConversion::isLegalOp(cast<LLVM::FPExtOp>(op));
|
||||
} else if (isa<LLVM::FPTruncOp>(op)) {
|
||||
return FPTruncOpConversion::isLegalOp(cast<LLVM::FPTruncOp>(op));
|
||||
} else if (isa<LLVM::TruncOp>(op)) {
|
||||
return TruncOpConversion::isLegalOp(cast<LLVM::TruncOp>(op));
|
||||
} else if (isa<LLVM::SExtOp>(op)) {
|
||||
return SExtOpConversion::isLegalOp(cast<LLVM::SExtOp>(op));
|
||||
} else if (isa<LLVM::ZExtOp>(op)) {
|
||||
return ZExtOpConversion::isLegalOp(cast<LLVM::ZExtOp>(op));
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void populateElementwiseOpToPTXPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
PatternBenefit benefit) {
|
||||
patterns.add<FPExtOpConversion>(typeConverter, benefit);
|
||||
patterns.add<FPTruncOpConversion>(typeConverter, benefit);
|
||||
patterns.add<TruncOpConversion>(typeConverter, benefit);
|
||||
patterns.add<SExtOpConversion>(typeConverter, benefit);
|
||||
patterns.add<ZExtOpConversion>(typeConverter, benefit);
|
||||
}
|
||||
|
||||
@@ -6,11 +6,15 @@
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
void populateElementwiseOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns,
|
||||
int numWarps,
|
||||
AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation,
|
||||
Value smem, PatternBenefit benefit);
|
||||
void populateElementwiseOpToLLVMPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem, PatternBenefit benefit);
|
||||
|
||||
bool isLegalElementwiseOp(Operation *op);
|
||||
|
||||
void populateElementwiseOpToPTXPatterns(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
PatternBenefit benefit);
|
||||
|
||||
#endif
|
||||
|
||||
@@ -7,9 +7,7 @@
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::LLVM::getElementsFromStruct;
|
||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::LLVM::getStructFromElements;
|
||||
using ::mlir::triton::gpu::getElemsPerThread;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
|
||||
@@ -18,19 +16,6 @@ struct LoadStoreConversionBase {
|
||||
explicit LoadStoreConversionBase(AxisInfoAnalysis &axisAnalysisPass)
|
||||
: axisAnalysisPass(axisAnalysisPass) {}
|
||||
|
||||
// Get corresponding LLVM element values of \param value.
|
||||
static SmallVector<Value> getLLVMElems(Value value, Value llValue,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Location loc) {
|
||||
if (!value)
|
||||
return {};
|
||||
if (!llValue.getType().isa<LLVM::LLVMStructType>())
|
||||
return {llValue};
|
||||
// Here, we assume that all inputs should have a blockedLayout
|
||||
auto valueVals = getElementsFromStruct(loc, llValue, rewriter);
|
||||
return valueVals;
|
||||
}
|
||||
|
||||
unsigned getContiguity(Value ptr) const {
|
||||
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorTy)
|
||||
@@ -62,7 +47,7 @@ struct LoadOpConversion
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::LoadOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LoadOpConversion(LLVMTypeConverter &converter,
|
||||
LoadOpConversion(TritonGPUToLLVMTypeConverter &converter,
|
||||
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>(converter, benefit),
|
||||
LoadStoreConversionBase(axisAnalysisPass) {}
|
||||
@@ -92,13 +77,15 @@ struct LoadOpConversion
|
||||
vec = std::min<size_t>(vec, getMaskAlignment(mask));
|
||||
|
||||
// Get the LLVM values for pointers
|
||||
auto ptrElems = getLLVMElems(ptr, llPtr, rewriter, loc);
|
||||
auto ptrElems = getTypeConverter()->unpackLLElements(loc, llPtr, rewriter,
|
||||
ptr.getType());
|
||||
assert(ptrElems.size() == numElems);
|
||||
|
||||
// Get the LLVM values for mask
|
||||
SmallVector<Value> maskElems;
|
||||
if (llMask) {
|
||||
maskElems = getLLVMElems(mask, llMask, rewriter, loc);
|
||||
maskElems = getTypeConverter()->unpackLLElements(loc, llMask, rewriter,
|
||||
mask.getType());
|
||||
assert(maskElems.size() == numElems);
|
||||
}
|
||||
|
||||
@@ -114,7 +101,11 @@ struct LoadOpConversion
|
||||
otherIsSplatConstInt = true;
|
||||
splatVal = constAttr.getSplatValue<APInt>().getSExtValue();
|
||||
}
|
||||
auto otherElems = getLLVMElems(other, llOther, rewriter, loc);
|
||||
SmallVector<Value> otherElems;
|
||||
if (other) {
|
||||
otherElems = getTypeConverter()->unpackLLElements(loc, llOther, rewriter,
|
||||
other.getType());
|
||||
}
|
||||
|
||||
// vectorized iteration through all the pointer/mask/other elements
|
||||
const int valueElemNbits =
|
||||
@@ -283,8 +274,8 @@ struct LoadOpConversion
|
||||
} // end vec
|
||||
|
||||
Type llvmResultStructTy = getTypeConverter()->convertType(valueTy);
|
||||
Value resultStruct =
|
||||
getStructFromElements(loc, loadedVals, rewriter, llvmResultStructTy);
|
||||
Value resultStruct = getTypeConverter()->packLLElements(
|
||||
loc, loadedVals, rewriter, llvmResultStructTy);
|
||||
rewriter.replaceOp(op, {resultStruct});
|
||||
return success();
|
||||
}
|
||||
@@ -296,7 +287,7 @@ struct StoreOpConversion
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::StoreOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
StoreOpConversion(LLVMTypeConverter &converter,
|
||||
StoreOpConversion(TritonGPUToLLVMTypeConverter &converter,
|
||||
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::StoreOp>(converter, benefit),
|
||||
LoadStoreConversionBase(axisAnalysisPass) {}
|
||||
@@ -305,7 +296,6 @@ struct StoreOpConversion
|
||||
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value ptr = op.getPtr();
|
||||
Value mask = op.getMask();
|
||||
Value value = op.getValue();
|
||||
|
||||
Value llPtr = adaptor.getPtr();
|
||||
@@ -320,28 +310,40 @@ struct StoreOpConversion
|
||||
typeConverter->convertType(getElementTypeOrSelf(valueTy));
|
||||
|
||||
unsigned vec = getVectorSize(ptr);
|
||||
unsigned numElems = getElemsPerThread(ptr.getType());
|
||||
unsigned elemsPerThread = getElemsPerThread(ptr.getType());
|
||||
|
||||
auto ptrElems = getLLVMElems(ptr, llPtr, rewriter, loc);
|
||||
auto valueElems = getLLVMElems(value, llValue, rewriter, loc);
|
||||
auto ptrElems = getTypeConverter()->unpackLLElements(loc, llPtr, rewriter,
|
||||
ptr.getType());
|
||||
auto valueElems = getTypeConverter()->unpackLLElements(
|
||||
loc, llValue, rewriter, value.getType());
|
||||
assert(ptrElems.size() == valueElems.size());
|
||||
|
||||
// Determine the vectorization size
|
||||
SmallVector<Value> maskElems;
|
||||
if (llMask) {
|
||||
maskElems = getLLVMElems(mask, llMask, rewriter, loc);
|
||||
Value mask = op.getMask();
|
||||
maskElems = getTypeConverter()->unpackLLElements(loc, llMask, rewriter,
|
||||
mask.getType());
|
||||
assert(valueElems.size() == maskElems.size());
|
||||
|
||||
unsigned maskAlign = getMaskAlignment(mask);
|
||||
vec = std::min(vec, maskAlign);
|
||||
}
|
||||
|
||||
// numElements = 1 for scalar
|
||||
auto tensorTy = valueTy.dyn_cast<RankedTensorType>();
|
||||
auto numElems = tensorTy ? tensorTy.getNumElements() : 1;
|
||||
Value mask = int_val(1, 1);
|
||||
auto tid = tid_val();
|
||||
mask = and_(mask,
|
||||
icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems)));
|
||||
|
||||
const size_t dtsize =
|
||||
std::max<int>(1, valueElemTy.getIntOrFloatBitWidth() / 8);
|
||||
const size_t valueElemNbits = dtsize * 8;
|
||||
|
||||
const int numVecs = numElems / vec;
|
||||
for (size_t vecStart = 0; vecStart < numElems; vecStart += vec) {
|
||||
const int numVecs = elemsPerThread / vec;
|
||||
for (size_t vecStart = 0; vecStart < elemsPerThread; vecStart += vec) {
|
||||
// TODO: optimization when ptr is AddPtr with constant offset
|
||||
size_t in_off = 0;
|
||||
|
||||
@@ -350,7 +352,7 @@ struct StoreOpConversion
|
||||
const size_t width = std::min(totalWidth, maxWordWidth);
|
||||
const size_t nWords = std::max<size_t>(1, totalWidth / width);
|
||||
const size_t wordNElems = width / valueElemNbits;
|
||||
assert(wordNElems * nWords * numVecs == numElems);
|
||||
assert(wordNElems * nWords * numVecs == elemsPerThread);
|
||||
|
||||
// TODO(Superjomn) Add cache policy fields to StoreOp.
|
||||
// TODO(Superjomn) Deal with cache policy here.
|
||||
@@ -368,13 +370,13 @@ struct StoreOpConversion
|
||||
assert(elemOffset < valueElems.size());
|
||||
Value elem = valueElems[elemOffset];
|
||||
if (elem.getType().isInteger(1))
|
||||
elem = rewriter.create<LLVM::SExtOp>(loc, type::i8Ty(ctx), elem);
|
||||
elem = sext(i8_ty, elem);
|
||||
elem = bitcast(elem, valueElemTy);
|
||||
llWord = insert_element(wordTy, llWord, elem, i32_val(elemIdx));
|
||||
}
|
||||
llWord = bitcast(llWord, valArgTy);
|
||||
#ifdef USE_ROCM
|
||||
Value maskVal = llMask ? maskElems[vecStart] : int_val(1, 1);
|
||||
Value maskVal = llMask ? and_(mask, maskElems[vecStart]) : mask;
|
||||
rewriter.create<scf::IfOp>(loc, maskVal,
|
||||
[&](OpBuilder &builder, Location loc){
|
||||
auto storeOp = builder.create<LLVM::StoreOp>(loc, llWord, ptrElems[vecStart + wordIdx * wordNElems]);
|
||||
@@ -393,7 +395,7 @@ struct StoreOpConversion
|
||||
PTXBuilder ptxBuilder;
|
||||
auto *asmArgList = ptxBuilder.newListOperand(asmArgs);
|
||||
|
||||
Value maskVal = llMask ? maskElems[vecStart] : int_val(1, 1);
|
||||
Value maskVal = llMask ? and_(mask, maskElems[vecStart]) : mask;
|
||||
|
||||
auto *asmAddr =
|
||||
ptxBuilder.newAddrOperand(ptrElems[vecStart], "l", in_off);
|
||||
@@ -422,7 +424,7 @@ struct AtomicCASOpConversion
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::AtomicCASOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
AtomicCASOpConversion(LLVMTypeConverter &converter,
|
||||
AtomicCASOpConversion(TritonGPUToLLVMTypeConverter &converter,
|
||||
const Allocation *allocation, Value smem,
|
||||
AxisInfoAnalysis &axisAnalysisPass,
|
||||
PatternBenefit benefit)
|
||||
@@ -442,13 +444,16 @@ struct AtomicCASOpConversion
|
||||
Value llCmp = adaptor.getCmp();
|
||||
Value llVal = adaptor.getVal();
|
||||
|
||||
auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter);
|
||||
auto cmpElements = getElementsFromStruct(loc, llCmp, rewriter);
|
||||
auto valElements = getElementsFromStruct(loc, llVal, rewriter);
|
||||
auto ptrElements = getTypeConverter()->unpackLLElements(
|
||||
loc, llPtr, rewriter, op.getPtr().getType());
|
||||
auto cmpElements = getTypeConverter()->unpackLLElements(
|
||||
loc, llCmp, rewriter, op.getCmp().getType());
|
||||
auto valElements = getTypeConverter()->unpackLLElements(
|
||||
loc, llVal, rewriter, op.getVal().getType());
|
||||
|
||||
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
auto TensorTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
Type valueElemTy =
|
||||
valueTy ? getTypeConverter()->convertType(valueTy.getElementType())
|
||||
TensorTy ? getTypeConverter()->convertType(TensorTy.getElementType())
|
||||
: op.getResult().getType();
|
||||
auto tid = tid_val();
|
||||
Value pred = icmp_eq(tid, i32_val(0));
|
||||
@@ -508,14 +513,17 @@ struct AtomicCASOpConversion
|
||||
Value llCmp = adaptor.getCmp();
|
||||
Value llVal = adaptor.getVal();
|
||||
|
||||
auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter);
|
||||
auto cmpElements = getElementsFromStruct(loc, llCmp, rewriter);
|
||||
auto valElements = getElementsFromStruct(loc, llVal, rewriter);
|
||||
auto ptrElements = getTypeConverter()->unpackLLElements(
|
||||
loc, llPtr, rewriter, op.getPtr().getType());
|
||||
auto cmpElements = getTypeConverter()->unpackLLElements(
|
||||
loc, llCmp, rewriter, op.getCmp().getType());
|
||||
auto valElements = getTypeConverter()->unpackLLElements(
|
||||
loc, llVal, rewriter, op.getVal().getType());
|
||||
|
||||
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
auto TensorTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
Type valueElemTy =
|
||||
valueTy ? getTypeConverter()->convertType(valueTy.getElementType())
|
||||
: op.getResult().getType();
|
||||
TensorTy ? getTypeConverter()->convertType(TensorTy.getElementType())
|
||||
: op.getResult().getType();
|
||||
auto tid = tid_val();
|
||||
Value pred = icmp_eq(tid, i32_val(0));
|
||||
PTXBuilder ptxBuilderMemfence;
|
||||
@@ -565,7 +573,7 @@ struct AtomicRMWOpConversion
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::AtomicRMWOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
AtomicRMWOpConversion(LLVMTypeConverter &converter,
|
||||
AtomicRMWOpConversion(TritonGPUToLLVMTypeConverter &converter,
|
||||
const Allocation *allocation, Value smem,
|
||||
AxisInfoAnalysis &axisAnalysisPass,
|
||||
PatternBenefit benefit)
|
||||
@@ -617,31 +625,36 @@ struct AtomicRMWOpConversion
|
||||
Value llVal = adaptor.getVal();
|
||||
Value llMask = adaptor.getMask();
|
||||
|
||||
auto valElements = getElementsFromStruct(loc, llVal, rewriter);
|
||||
auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter);
|
||||
auto maskElements = getElementsFromStruct(loc, llMask, rewriter);
|
||||
auto valElements = getTypeConverter()->unpackLLElements(
|
||||
loc, llVal, rewriter, val.getType());
|
||||
auto ptrElements = getTypeConverter()->unpackLLElements(
|
||||
loc, llPtr, rewriter, ptr.getType());
|
||||
SmallVector<Value> maskElements;
|
||||
if (llMask)
|
||||
maskElements = getTypeConverter()->unpackLLElements(
|
||||
loc, llMask, rewriter, op.getMask().getType());
|
||||
|
||||
Value opResult = op.getResult();
|
||||
auto valueTy = opResult.getType().dyn_cast<RankedTensorType>();
|
||||
auto tensorTy = opResult.getType().dyn_cast<RankedTensorType>();
|
||||
Type valueElemTy =
|
||||
valueTy ? getTypeConverter()->convertType(valueTy.getElementType())
|
||||
tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType())
|
||||
: opResult.getType();
|
||||
const size_t valueElemNbits = valueElemTy.getIntOrFloatBitWidth();
|
||||
auto elemsPerThread = getElemsPerThread(val.getType());
|
||||
// vec = 1 for scalar
|
||||
// vec = 1, numElements = 1 for scalar
|
||||
auto vec = getVectorSize(ptr);
|
||||
Value mask = int_val(1, 1);
|
||||
auto tid = tid_val();
|
||||
int numElems = 1;
|
||||
// tensor
|
||||
if (valueTy) {
|
||||
if (tensorTy) {
|
||||
auto valTy = val.getType().cast<RankedTensorType>();
|
||||
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
|
||||
// mask
|
||||
auto shape = valueTy.getShape();
|
||||
auto numElements = product(shape);
|
||||
mask = and_(mask, icmp_slt(mul(tid, i32_val(elemsPerThread)),
|
||||
i32_val(numElements)));
|
||||
numElems = tensorTy.getNumElements();
|
||||
}
|
||||
Value mask = int_val(1, 1);
|
||||
auto tid = tid_val();
|
||||
mask = and_(mask,
|
||||
icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems)));
|
||||
|
||||
auto vecTy = vec_ty(valueElemTy, vec);
|
||||
SmallVector<Value> resultVals(elemsPerThread);
|
||||
@@ -654,11 +667,7 @@ struct AtomicRMWOpConversion
|
||||
}
|
||||
|
||||
Value rmwPtr = ptrElements[i];
|
||||
Value rmwMask = maskElements[i];
|
||||
rmwMask = and_(rmwMask, mask);
|
||||
if (!valueTy) {
|
||||
rmwMask = and_(rmwMask, icmp_eq(tid, i32_val(0)));
|
||||
}
|
||||
Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask;
|
||||
|
||||
Value undefVal = undef(valueElemTy);
|
||||
// Build blocks to bypass the atomic instruction for ~rmwMask.
|
||||
@@ -683,11 +692,11 @@ struct AtomicRMWOpConversion
|
||||
|
||||
rewriter.setInsertionPointToStart(endBlock);
|
||||
Value retVal = endBlock->getArgument(0);
|
||||
if (valueTy) {
|
||||
if (tensorTy) {
|
||||
for (int ii = 0; ii < vec; ++ii) {
|
||||
resultVals[i + ii] =
|
||||
vec == 1 ? retVal
|
||||
: extract_element(valueElemTy, retVal, idx_val(ii));
|
||||
: extract_element(valueElemTy, retVal, i32_val(ii));
|
||||
}
|
||||
} else {
|
||||
Value atomPtr = getSharedMemoryBase(loc, rewriter, op.getOperation());
|
||||
@@ -697,10 +706,10 @@ struct AtomicRMWOpConversion
|
||||
rewriter.replaceOp(op, {ret});
|
||||
}
|
||||
}
|
||||
if (valueTy) {
|
||||
Type structTy = getTypeConverter()->convertType(valueTy);
|
||||
Value resultStruct =
|
||||
getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||
if (tensorTy) {
|
||||
Type structTy = getTypeConverter()->convertType(tensorTy);
|
||||
Value resultStruct = getTypeConverter()->packLLElements(
|
||||
loc, resultVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, {resultStruct});
|
||||
}
|
||||
return success();
|
||||
@@ -715,37 +724,43 @@ struct AtomicRMWOpConversion
|
||||
MLIRContext *ctx = rewriter.getContext();
|
||||
|
||||
auto atomicRmwAttr = op.getAtomicRmwOp();
|
||||
Value ptr = op.getPtr();
|
||||
|
||||
Value val = op.getVal();
|
||||
Value ptr = op.getPtr();
|
||||
|
||||
Value llPtr = adaptor.getPtr();
|
||||
Value llVal = adaptor.getVal();
|
||||
Value llMask = adaptor.getMask();
|
||||
|
||||
auto valElements = getElementsFromStruct(loc, llVal, rewriter);
|
||||
auto ptrElements = getElementsFromStruct(loc, llPtr, rewriter);
|
||||
auto maskElements = getElementsFromStruct(loc, llMask, rewriter);
|
||||
auto valElements = getTypeConverter()->unpackLLElements(
|
||||
loc, llVal, rewriter, val.getType());
|
||||
auto ptrElements = getTypeConverter()->unpackLLElements(
|
||||
loc, llPtr, rewriter, ptr.getType());
|
||||
SmallVector<Value> maskElements;
|
||||
if (llMask)
|
||||
maskElements = getTypeConverter()->unpackLLElements(
|
||||
loc, llMask, rewriter, op.getMask().getType());
|
||||
|
||||
auto valueTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
auto tensorTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
Type valueElemTy =
|
||||
valueTy ? getTypeConverter()->convertType(valueTy.getElementType())
|
||||
: op.getResult().getType();
|
||||
tensorTy ? getTypeConverter()->convertType(tensorTy.getElementType())
|
||||
: op.getResult().getType();
|
||||
const size_t valueElemNbits = valueElemTy.getIntOrFloatBitWidth();
|
||||
auto elemsPerThread = getElemsPerThread(val.getType());
|
||||
// vec = 1 for scalar
|
||||
// vec = 1, numElements = 1 for scalar
|
||||
auto vec = getVectorSize(ptr);
|
||||
Value mask = int_val(1, 1);
|
||||
auto tid = tid_val();
|
||||
int numElems = 1;
|
||||
// tensor
|
||||
if (valueTy) {
|
||||
if (tensorTy) {
|
||||
auto valTy = val.getType().cast<RankedTensorType>();
|
||||
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
|
||||
// mask
|
||||
auto shape = valueTy.getShape();
|
||||
auto numElements = product(shape);
|
||||
mask = and_(mask, icmp_slt(mul(tid, i32_val(elemsPerThread)),
|
||||
i32_val(numElements)));
|
||||
numElems = tensorTy.getNumElements();
|
||||
}
|
||||
Value mask = int_val(1, 1);
|
||||
auto tid = tid_val();
|
||||
mask = and_(mask,
|
||||
icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems)));
|
||||
|
||||
auto vecTy = vec_ty(valueElemTy, vec);
|
||||
SmallVector<Value> resultVals(elemsPerThread);
|
||||
@@ -758,8 +773,7 @@ struct AtomicRMWOpConversion
|
||||
}
|
||||
|
||||
Value rmwPtr = ptrElements[i];
|
||||
Value rmwMask = maskElements[i];
|
||||
rmwMask = and_(rmwMask, mask);
|
||||
Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask;
|
||||
std::string sTy;
|
||||
PTXBuilder ptxBuilderAtomicRMW;
|
||||
std::string tyId = valueElemNbits * vec == 64
|
||||
@@ -812,13 +826,13 @@ struct AtomicRMWOpConversion
|
||||
return failure();
|
||||
}
|
||||
atom.o(rmwOp).o(sTy);
|
||||
if (valueTy) {
|
||||
if (tensorTy) {
|
||||
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
|
||||
auto retType = vec == 1 ? valueElemTy : vecTy;
|
||||
auto ret = ptxBuilderAtomicRMW.launch(rewriter, loc, retType);
|
||||
for (int ii = 0; ii < vec; ++ii) {
|
||||
resultVals[i + ii] =
|
||||
vec == 1 ? ret : extract_element(valueElemTy, ret, idx_val(ii));
|
||||
vec == 1 ? ret : extract_element(valueElemTy, ret, i32_val(ii));
|
||||
}
|
||||
} else {
|
||||
PTXBuilder ptxBuilderMemfence;
|
||||
@@ -838,10 +852,10 @@ struct AtomicRMWOpConversion
|
||||
rewriter.replaceOp(op, {ret});
|
||||
}
|
||||
}
|
||||
if (valueTy) {
|
||||
Type structTy = getTypeConverter()->convertType(valueTy);
|
||||
Value resultStruct =
|
||||
getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||
if (tensorTy) {
|
||||
Type structTy = getTypeConverter()->convertType(tensorTy);
|
||||
Value resultStruct = getTypeConverter()->packLLElements(
|
||||
loc, resultVals, rewriter, structTy);
|
||||
rewriter.replaceOp(op, {resultStruct});
|
||||
}
|
||||
return success();
|
||||
@@ -904,7 +918,7 @@ struct InsertSliceOpConversion
|
||||
auto smemBase = gep(elemPtrTy, smemObj.base, offset);
|
||||
|
||||
auto llSrc = adaptor.getSource();
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy);
|
||||
storeDistributedToShared(src, llSrc, srcStrides, srcIndices, dst, smemBase,
|
||||
elemTy, loc, rewriter);
|
||||
// Barrier is not necessary.
|
||||
@@ -922,7 +936,8 @@ struct InsertSliceAsyncOpConversion
|
||||
triton::gpu::InsertSliceAsyncOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
InsertSliceAsyncOpConversion(
|
||||
LLVMTypeConverter &converter, const Allocation *allocation, Value smem,
|
||||
TritonGPUToLLVMTypeConverter &converter, const Allocation *allocation,
|
||||
Value smem,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
AxisInfoAnalysis &axisAnalysisPass, PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::InsertSliceAsyncOp>(
|
||||
@@ -958,7 +973,8 @@ struct InsertSliceAsyncOpConversion
|
||||
Value llIndex = adaptor.getIndex();
|
||||
|
||||
// %src
|
||||
auto srcElems = getLLVMElems(src, llSrc, rewriter, loc);
|
||||
auto srcElems = getTypeConverter()->unpackLLElements(loc, llSrc, rewriter,
|
||||
src.getType());
|
||||
|
||||
// %dst
|
||||
auto dstTy = dst.getType().cast<RankedTensorType>();
|
||||
@@ -984,7 +1000,8 @@ struct InsertSliceAsyncOpConversion
|
||||
// %mask
|
||||
SmallVector<Value> maskElems;
|
||||
if (llMask) {
|
||||
maskElems = getLLVMElems(mask, llMask, rewriter, loc);
|
||||
maskElems = getTypeConverter()->unpackLLElements(loc, llMask, rewriter,
|
||||
mask.getType());
|
||||
assert(srcElems.size() == maskElems.size());
|
||||
}
|
||||
|
||||
@@ -995,7 +1012,8 @@ struct InsertSliceAsyncOpConversion
|
||||
// It's not necessary for now because the pipeline pass will skip
|
||||
// generating insert_slice_async if the load op has any "other" tensor.
|
||||
// assert(false && "insert_slice_async: Other value not supported yet");
|
||||
otherElems = getLLVMElems(other, llOther, rewriter, loc);
|
||||
otherElems = getTypeConverter()->unpackLLElements(loc, llOther, rewriter,
|
||||
other.getType());
|
||||
assert(srcElems.size() == otherElems.size());
|
||||
}
|
||||
|
||||
@@ -1025,7 +1043,7 @@ struct InsertSliceAsyncOpConversion
|
||||
// single vector read into multiple ones
|
||||
auto numVecCols = std::max<unsigned>(inVec / outVec, 1);
|
||||
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcBlockedLayout, srcShape);
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcBlockedLayout, srcTy);
|
||||
|
||||
for (unsigned elemIdx = 0; elemIdx < numElems; elemIdx += minVec) {
|
||||
// 16 * 8 = 128bits
|
||||
@@ -1075,7 +1093,7 @@ struct InsertSliceAsyncOpConversion
|
||||
};
|
||||
|
||||
void populateLoadStoreOpToLLVMPatterns(
|
||||
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
|
||||
@@ -7,7 +7,7 @@ using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
void populateLoadStoreOpToLLVMPatterns(
|
||||
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
|
||||
@@ -3,8 +3,6 @@
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::LLVM::getElementsFromStruct;
|
||||
using ::mlir::LLVM::getStructFromElements;
|
||||
using ::mlir::LLVM::shflSync;
|
||||
using ::mlir::LLVM::storeShared;
|
||||
using ::mlir::triton::gpu::getElemsPerThread;
|
||||
@@ -160,11 +158,12 @@ private:
|
||||
indexSmemBase = bitcast(indexSmemBase, indexPtrTy);
|
||||
|
||||
unsigned srcElems = getElemsPerThread(srcTy);
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
|
||||
auto srcValues = getElementsFromStruct(loc, adaptor.getOperand(), rewriter);
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy);
|
||||
auto srcValues = getTypeConverter()->unpackLLElements(
|
||||
loc, adaptor.getOperand(), rewriter, srcTy);
|
||||
|
||||
SmallVector<SmallVector<unsigned>> offset =
|
||||
emitOffsetForLayout(srcLayout, srcShape);
|
||||
emitOffsetForLayout(srcLayout, srcTy);
|
||||
|
||||
std::map<SmallVector<unsigned>, Value> accs;
|
||||
std::map<SmallVector<unsigned>, Value> accIndices;
|
||||
@@ -247,8 +246,7 @@ private:
|
||||
auto resultShape = resultTy.getShape();
|
||||
|
||||
unsigned resultElems = getElemsPerThread(resultTy);
|
||||
auto resultIndices =
|
||||
emitIndices(loc, rewriter, resultLayout, resultShape);
|
||||
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy);
|
||||
assert(resultIndices.size() == resultElems);
|
||||
|
||||
SmallVector<Value> resultVals(resultElems);
|
||||
@@ -260,12 +258,8 @@ private:
|
||||
Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset);
|
||||
resultVals[i] = withIndex ? load(indexReadPtr) : load(readPtr);
|
||||
}
|
||||
|
||||
SmallVector<Type> resultTypes(resultElems,
|
||||
withIndex ? llvmIndexTy : llvmElemTy);
|
||||
Type structTy =
|
||||
LLVM::LLVMStructType::getLiteral(this->getContext(), resultTypes);
|
||||
Value ret = getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||
Value ret = getTypeConverter()->packLLElements(loc, resultVals, rewriter,
|
||||
resultTy);
|
||||
rewriter.replaceOp(op, ret);
|
||||
} else {
|
||||
// 0d-tensor -> scalar
|
||||
@@ -310,11 +304,12 @@ private:
|
||||
unsigned sizeInterWarps = helper.getInterWarpSize();
|
||||
|
||||
unsigned srcElems = getElemsPerThread(srcTy);
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcShape);
|
||||
auto srcValues = getElementsFromStruct(loc, adaptor.getOperand(), rewriter);
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcLayout, srcTy);
|
||||
auto srcValues = getTypeConverter()->unpackLLElements(
|
||||
loc, adaptor.getOperand(), rewriter, srcTy);
|
||||
|
||||
SmallVector<SmallVector<unsigned>> offset =
|
||||
emitOffsetForLayout(srcLayout, srcShape);
|
||||
emitOffsetForLayout(srcLayout, srcTy);
|
||||
|
||||
std::map<SmallVector<unsigned>, Value> accs;
|
||||
std::map<SmallVector<unsigned>, Value> accIndices;
|
||||
@@ -457,8 +452,7 @@ private:
|
||||
auto resultLayout = resultTy.getEncoding().cast<SliceEncodingAttr>();
|
||||
auto resultShape = resultTy.getShape();
|
||||
unsigned resultElems = getElemsPerThread(resultTy);
|
||||
auto resultIndices =
|
||||
emitIndices(loc, rewriter, resultLayout, resultShape);
|
||||
auto resultIndices = emitIndices(loc, rewriter, resultLayout, resultTy);
|
||||
assert(resultIndices.size() == resultElems);
|
||||
|
||||
SmallVector<Value> resultVals(resultElems);
|
||||
@@ -471,12 +465,8 @@ private:
|
||||
Value indexReadPtr = gep(indexPtrTy, indexSmemBase, readOffset);
|
||||
resultVals[i] = withIndex ? load(indexReadPtr) : load(readPtr);
|
||||
}
|
||||
|
||||
SmallVector<Type> resultTypes(resultElems,
|
||||
withIndex ? llvmIndexTy : llvmElemTy);
|
||||
Type structTy =
|
||||
LLVM::LLVMStructType::getLiteral(this->getContext(), resultTypes);
|
||||
Value ret = getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||
Value ret = getTypeConverter()->packLLElements(loc, resultVals, rewriter,
|
||||
resultTy);
|
||||
rewriter.replaceOp(op, ret);
|
||||
} else {
|
||||
// 0d-tensor -> scalar
|
||||
@@ -489,7 +479,7 @@ private:
|
||||
};
|
||||
|
||||
void populateReduceOpToLLVMPatterns(
|
||||
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
|
||||
@@ -7,7 +7,7 @@ using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
void populateReduceOpToLLVMPatterns(
|
||||
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
|
||||
@@ -1,13 +1,12 @@
|
||||
#include "TritonGPUToLLVM.h"
|
||||
#include "DotOpHelpers.h"
|
||||
#include "Utility.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::LLVM::getElementsFromStruct;
|
||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::LLVM::getStructFromElements;
|
||||
using ::mlir::triton::gpu::getElemsPerThread;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
|
||||
@@ -66,9 +65,10 @@ struct BroadcastOpConversion
|
||||
|
||||
assert(rank == resultTy.getRank());
|
||||
auto order = triton::gpu::getOrder(srcLayout);
|
||||
auto srcOffsets = emitOffsetForLayout(srcLayout, srcShape);
|
||||
auto resultOffsets = emitOffsetForLayout(resultLayout, resultShape);
|
||||
SmallVector<Value> srcVals = getElementsFromStruct(loc, src, rewriter);
|
||||
auto srcOffsets = emitOffsetForLayout(srcLayout, srcTy);
|
||||
auto resultOffsets = emitOffsetForLayout(resultLayout, resultTy);
|
||||
SmallVector<Value> srcVals =
|
||||
getTypeConverter()->unpackLLElements(loc, src, rewriter, srcTy);
|
||||
|
||||
DenseMap<SmallVector<unsigned>, Value, SmallVectorKeyInfo> srcValues;
|
||||
for (size_t i = 0; i < srcOffsets.size(); i++) {
|
||||
@@ -84,27 +84,26 @@ struct BroadcastOpConversion
|
||||
resultVals.push_back(srcValues.lookup(offset));
|
||||
}
|
||||
|
||||
auto llvmStructTy = getTypeConverter()->convertType(resultTy);
|
||||
|
||||
Value resultStruct =
|
||||
getStructFromElements(loc, resultVals, rewriter, llvmStructTy);
|
||||
getTypeConverter()->packLLElements(loc, resultVals, rewriter, resultTy);
|
||||
rewriter.replaceOp(op, {resultStruct});
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct PrintfOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::PrintfOp> {
|
||||
struct PrintOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::PrintOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::PrintfOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
triton::PrintOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::PrintfOp op, OpAdaptor adaptor,
|
||||
matchAndRewrite(triton::PrintOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op->getLoc();
|
||||
SmallVector<Value, 16> operands;
|
||||
for (auto operand : adaptor.getOperands()) {
|
||||
auto sub_operands = getElementsFromStruct(loc, operand, rewriter);
|
||||
for (size_t i = 0; i < op.getNumOperands(); i++) {
|
||||
auto sub_operands = getTypeConverter()->unpackLLElements(
|
||||
loc, adaptor.getOperands()[i], rewriter, op.getOperand(i).getType());
|
||||
for (auto elem : sub_operands) {
|
||||
operands.push_back(elem);
|
||||
}
|
||||
@@ -170,22 +169,20 @@ struct PrintfOpConversion
|
||||
auto type = value.getType();
|
||||
Value newOp = value;
|
||||
Type newType = type;
|
||||
auto loc = UnknownLoc::get(context);
|
||||
|
||||
bool bUnsigned = type.isUnsignedInteger();
|
||||
if (type.isIntOrIndex() && type.getIntOrFloatBitWidth() < 32) {
|
||||
if (bUnsigned) {
|
||||
newType = ui32_ty;
|
||||
newOp = rewriter.create<LLVM::ZExtOp>(UnknownLoc::get(context), newType,
|
||||
value);
|
||||
newOp = zext(newType, value);
|
||||
} else {
|
||||
newType = i32_ty;
|
||||
newOp = rewriter.create<LLVM::SExtOp>(UnknownLoc::get(context), newType,
|
||||
value);
|
||||
newOp = sext(newType, value);
|
||||
}
|
||||
} else if (type.isBF16() || type.isF16() || type.isF32()) {
|
||||
newType = f64_ty;
|
||||
newOp = rewriter.create<LLVM::FPExtOp>(UnknownLoc::get(context), newType,
|
||||
value);
|
||||
newOp = fpext(newType, value);
|
||||
}
|
||||
|
||||
return {newType, newOp};
|
||||
@@ -193,51 +190,24 @@ struct PrintfOpConversion
|
||||
|
||||
static void llPrintf(StringRef msg, ValueRange args,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
static const char formatStringPrefix[] = "printfFormat_";
|
||||
assert(!msg.empty() && "printf with empty string not support");
|
||||
Type int8Ptr = ptr_ty(i8_ty);
|
||||
|
||||
auto *context = rewriter.getContext();
|
||||
auto *ctx = rewriter.getContext();
|
||||
auto moduleOp =
|
||||
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
|
||||
auto funcOp = getVprintfDeclaration(rewriter);
|
||||
auto loc = UnknownLoc::get(ctx);
|
||||
|
||||
Value one = rewriter.create<LLVM::ConstantOp>(
|
||||
UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(1));
|
||||
Value zero = rewriter.create<LLVM::ConstantOp>(
|
||||
UnknownLoc::get(context), i32_ty, rewriter.getI32IntegerAttr(0));
|
||||
Value one = i32_val(1);
|
||||
Value zero = i32_val(0);
|
||||
|
||||
unsigned stringNumber = 0;
|
||||
SmallString<16> stringConstName;
|
||||
do {
|
||||
stringConstName.clear();
|
||||
(formatStringPrefix + Twine(stringNumber++)).toStringRef(stringConstName);
|
||||
} while (moduleOp.lookupSymbol(stringConstName));
|
||||
|
||||
llvm::SmallString<64> formatString(msg);
|
||||
formatString.push_back('\n');
|
||||
formatString.push_back('\0');
|
||||
size_t formatStringSize = formatString.size_in_bytes();
|
||||
auto globalType = LLVM::LLVMArrayType::get(i8_ty, formatStringSize);
|
||||
|
||||
LLVM::GlobalOp global;
|
||||
{
|
||||
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
||||
global = rewriter.create<LLVM::GlobalOp>(
|
||||
UnknownLoc::get(context), globalType,
|
||||
/*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
|
||||
rewriter.getStringAttr(formatString));
|
||||
}
|
||||
|
||||
Value globalPtr =
|
||||
rewriter.create<LLVM::AddressOfOp>(UnknownLoc::get(context), global);
|
||||
Value stringStart = rewriter.create<LLVM::GEPOp>(
|
||||
UnknownLoc::get(context), int8Ptr, globalPtr,
|
||||
SmallVector<Value>({zero, zero}));
|
||||
|
||||
Value bufferPtr =
|
||||
rewriter.create<LLVM::NullOp>(UnknownLoc::get(context), int8Ptr);
|
||||
llvm::SmallString<64> msgNewline(msg);
|
||||
msgNewline.push_back('\n');
|
||||
msgNewline.push_back('\0');
|
||||
Value prefixString =
|
||||
LLVM::addStringToModule(loc, rewriter, "printfFormat_", msgNewline);
|
||||
Value bufferPtr = null(int8Ptr);
|
||||
|
||||
SmallVector<Value, 16> newArgs;
|
||||
if (args.size() >= 1) {
|
||||
@@ -250,27 +220,121 @@ struct PrintfOpConversion
|
||||
newArgs.push_back(newArg);
|
||||
}
|
||||
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(context, argTypes);
|
||||
auto allocated = rewriter.create<LLVM::AllocaOp>(UnknownLoc::get(context),
|
||||
ptr_ty(structTy), one,
|
||||
/*alignment=*/0);
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(ctx, argTypes);
|
||||
auto allocated =
|
||||
rewriter.create<LLVM::AllocaOp>(loc, ptr_ty(structTy), one,
|
||||
/*alignment=*/0);
|
||||
|
||||
for (const auto &entry : llvm::enumerate(newArgs)) {
|
||||
auto index = rewriter.create<LLVM::ConstantOp>(
|
||||
UnknownLoc::get(context), i32_ty,
|
||||
rewriter.getI32IntegerAttr(entry.index()));
|
||||
auto fieldPtr = rewriter.create<LLVM::GEPOp>(
|
||||
UnknownLoc::get(context), ptr_ty(argTypes[entry.index()]),
|
||||
allocated, ArrayRef<Value>{zero, index});
|
||||
rewriter.create<LLVM::StoreOp>(UnknownLoc::get(context), entry.value(),
|
||||
fieldPtr);
|
||||
auto index = i32_val(entry.index());
|
||||
auto fieldPtr = gep(ptr_ty(argTypes[entry.index()]), allocated,
|
||||
ArrayRef<Value>{zero, index});
|
||||
store(entry.value(), fieldPtr);
|
||||
}
|
||||
bufferPtr = rewriter.create<LLVM::BitcastOp>(UnknownLoc::get(context),
|
||||
int8Ptr, allocated);
|
||||
bufferPtr = bitcast(allocated, int8Ptr);
|
||||
}
|
||||
|
||||
SmallVector<Value> operands{stringStart, bufferPtr};
|
||||
rewriter.create<LLVM::CallOp>(UnknownLoc::get(context), funcOp, operands);
|
||||
SmallVector<Value> operands{prefixString, bufferPtr};
|
||||
call(funcOp, operands);
|
||||
}
|
||||
};
|
||||
|
||||
struct AssertOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::AssertOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
triton::AssertOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::AssertOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto loc = op.getLoc();
|
||||
auto ctx = rewriter.getContext();
|
||||
auto elems = getTypeConverter()->unpackLLElements(
|
||||
loc, adaptor.getCondition(), rewriter, op.getCondition().getType());
|
||||
auto elemTy = elems[0].getType();
|
||||
Value condition = int_val(elemTy.getIntOrFloatBitWidth(), 0);
|
||||
for (auto elem : elems) {
|
||||
if (elemTy.isSignedInteger() || elemTy.isSignlessInteger()) {
|
||||
condition =
|
||||
or_(condition,
|
||||
icmp_eq(elem, rewriter.create<LLVM::ConstantOp>(
|
||||
loc, elemTy, rewriter.getZeroAttr(elemTy))));
|
||||
} else {
|
||||
assert(false && "Unsupported type for assert");
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
llAssert(op, condition, adaptor.getMessage(), adaptor.getFile(),
|
||||
adaptor.getFunc(), adaptor.getLine(), rewriter);
|
||||
rewriter.eraseOp(op);
|
||||
return success();
|
||||
}
|
||||
|
||||
// op: the op at which the assert is inserted. Unlike printf, we need to
|
||||
// know about the op to split the block.
|
||||
static void llAssert(Operation *op, Value condition, StringRef message,
|
||||
StringRef file, StringRef func, int line,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
||||
auto ctx = rewriter.getContext();
|
||||
auto loc = op->getLoc();
|
||||
|
||||
// #block1
|
||||
// if (condition) {
|
||||
// #block2
|
||||
// __assertfail(message);
|
||||
// }
|
||||
// #block3
|
||||
Block *prevBlock = op->getBlock();
|
||||
Block *ifBlock = rewriter.splitBlock(prevBlock, op->getIterator());
|
||||
rewriter.setInsertionPointToStart(ifBlock);
|
||||
|
||||
auto funcOp = getAssertfailDeclaration(rewriter);
|
||||
auto moduleOp =
|
||||
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
|
||||
Value messageString =
|
||||
LLVM::addStringToModule(loc, rewriter, "assertMessage_", message);
|
||||
Value fileString =
|
||||
LLVM::addStringToModule(loc, rewriter, "assertFile_", file);
|
||||
Value funcString =
|
||||
LLVM::addStringToModule(loc, rewriter, "assertFunc_", func);
|
||||
Value lineNumber = i32_val(line);
|
||||
Value charSize = int_val(sizeof(size_t) * 8, sizeof(char));
|
||||
|
||||
SmallVector<Value> operands = {messageString, fileString, lineNumber,
|
||||
funcString, charSize};
|
||||
auto ret = call(funcOp, operands);
|
||||
|
||||
// Split a block after the call.
|
||||
Block *thenBlock = rewriter.splitBlock(ifBlock, op->getIterator());
|
||||
rewriter.setInsertionPointToEnd(ifBlock);
|
||||
rewriter.create<cf::BranchOp>(loc, thenBlock);
|
||||
rewriter.setInsertionPointToEnd(prevBlock);
|
||||
rewriter.create<cf::CondBranchOp>(loc, condition, ifBlock, thenBlock);
|
||||
}
|
||||
|
||||
static LLVM::LLVMFuncOp
|
||||
getAssertfailDeclaration(ConversionPatternRewriter &rewriter) {
|
||||
auto moduleOp =
|
||||
rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
|
||||
StringRef funcName("__assertfail");
|
||||
Operation *funcOp = moduleOp.lookupSymbol(funcName);
|
||||
if (funcOp)
|
||||
return cast<LLVM::LLVMFuncOp>(*funcOp);
|
||||
|
||||
// void __assert_fail(const char * assertion, const char * file, unsigned
|
||||
// int line, const char * function);
|
||||
auto *ctx = rewriter.getContext();
|
||||
SmallVector<Type> argsType{ptr_ty(i8_ty), ptr_ty(i8_ty), i32_ty,
|
||||
ptr_ty(i8_ty),
|
||||
rewriter.getIntegerType(sizeof(size_t) * 8)};
|
||||
auto funcType = LLVM::LLVMFunctionType::get(void_ty(ctx), argsType);
|
||||
|
||||
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
||||
|
||||
return rewriter.create<LLVM::LLVMFuncOp>(UnknownLoc::get(ctx), funcName,
|
||||
funcType);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -278,7 +342,7 @@ struct MakeRangeOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp> {
|
||||
|
||||
MakeRangeOpConversion(
|
||||
LLVMTypeConverter &converter,
|
||||
TritonGPUToLLVMTypeConverter &converter,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
PatternBenefit benefit)
|
||||
: ConvertTritonGPUOpToLLVMPattern<triton::MakeRangeOp>(
|
||||
@@ -289,14 +353,14 @@ struct MakeRangeOpConversion
|
||||
matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
auto rankedTy = op.getResult().getType().dyn_cast<RankedTensorType>();
|
||||
auto rankedTy = op.getResult().getType().cast<RankedTensorType>();
|
||||
auto shape = rankedTy.getShape();
|
||||
auto layout = rankedTy.getEncoding();
|
||||
|
||||
auto elemTy = rankedTy.getElementType();
|
||||
assert(elemTy.isInteger(32));
|
||||
Value start = createIndexAttrConstant(rewriter, loc, elemTy, op.getStart());
|
||||
auto idxs = emitIndices(loc, rewriter, layout, shape);
|
||||
auto idxs = emitIndices(loc, rewriter, layout, rankedTy);
|
||||
unsigned elems = idxs.size();
|
||||
SmallVector<Value> retVals(elems);
|
||||
// TODO: slice layout has more elements than expected.
|
||||
@@ -306,9 +370,8 @@ struct MakeRangeOpConversion
|
||||
assert(multiDim.value().size() == 1);
|
||||
retVals[multiDim.index()] = add(multiDim.value()[0], start);
|
||||
}
|
||||
SmallVector<Type> types(elems, elemTy);
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types);
|
||||
Value result = getStructFromElements(loc, retVals, rewriter, structTy);
|
||||
Value result =
|
||||
getTypeConverter()->packLLElements(loc, retVals, rewriter, rankedTy);
|
||||
rewriter.replaceOp(op, result);
|
||||
return success();
|
||||
}
|
||||
@@ -325,11 +388,9 @@ struct GetProgramIdOpConversion
|
||||
Location loc = op->getLoc();
|
||||
assert(op.getAxis() < 3);
|
||||
|
||||
Value blockId = rewriter.create<::mlir::gpu::BlockIdOp>(
|
||||
loc, rewriter.getIndexType(), dims[op.getAxis()]);
|
||||
auto llvmIndexTy = getTypeConverter()->getIndexType();
|
||||
rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
|
||||
op, TypeRange{llvmIndexTy}, ValueRange{blockId});
|
||||
Value blockId =
|
||||
rewriter.create<::mlir::gpu::BlockIdOp>(loc, dims[op.getAxis()]);
|
||||
rewriter.replaceOpWithNewOp<arith::TruncIOp>(op, i32_ty, blockId);
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -349,11 +410,10 @@ struct GetNumProgramsOpConversion
|
||||
Location loc = op->getLoc();
|
||||
assert(op.getAxis() < 3);
|
||||
|
||||
Value blockId = rewriter.create<::mlir::gpu::GridDimOp>(
|
||||
loc, rewriter.getIndexType(), dims[op.getAxis()]);
|
||||
auto llvmIndexTy = getTypeConverter()->getIndexType();
|
||||
rewriter.replaceOpWithNewOp<UnrealizedConversionCastOp>(
|
||||
op, TypeRange{llvmIndexTy}, ValueRange{blockId});
|
||||
Value blockId =
|
||||
rewriter.create<::mlir::gpu::GridDimOp>(loc, dims[op.getAxis()]);
|
||||
rewriter.replaceOpWithNewOp<arith::TruncIOp>(op, i32_ty, blockId);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
@@ -372,20 +432,23 @@ struct AddPtrOpConversion
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Location loc = op->getLoc();
|
||||
auto resultTy = op.getType();
|
||||
auto offsetTy = op.getOffset().getType();
|
||||
auto ptrTy = op.getPtr().getType();
|
||||
auto resultTensorTy = resultTy.dyn_cast<RankedTensorType>();
|
||||
if (resultTensorTy) {
|
||||
unsigned elems = getElemsPerThread(resultTy);
|
||||
Type elemTy =
|
||||
getTypeConverter()->convertType(resultTensorTy.getElementType());
|
||||
SmallVector<Type> types(elems, elemTy);
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(getContext(), types);
|
||||
auto ptrs = getElementsFromStruct(loc, adaptor.getPtr(), rewriter);
|
||||
auto offsets = getElementsFromStruct(loc, adaptor.getOffset(), rewriter);
|
||||
auto ptrs = getTypeConverter()->unpackLLElements(loc, adaptor.getPtr(),
|
||||
rewriter, ptrTy);
|
||||
auto offsets = getTypeConverter()->unpackLLElements(
|
||||
loc, adaptor.getOffset(), rewriter, offsetTy);
|
||||
SmallVector<Value> resultVals(elems);
|
||||
for (unsigned i = 0; i < elems; ++i) {
|
||||
resultVals[i] = gep(elemTy, ptrs[i], offsets[i]);
|
||||
}
|
||||
Value view = getStructFromElements(loc, resultVals, rewriter, structTy);
|
||||
Value view = getTypeConverter()->packLLElements(loc, resultVals, rewriter,
|
||||
resultTy);
|
||||
rewriter.replaceOp(op, view);
|
||||
} else {
|
||||
assert(resultTy.isa<triton::PointerType>());
|
||||
@@ -431,12 +494,12 @@ struct AllocTensorOpConversion
|
||||
};
|
||||
|
||||
struct ExtractSliceOpConversion
|
||||
: public ConvertTritonGPUOpToLLVMPattern<tensor::ExtractSliceOp> {
|
||||
: public ConvertTritonGPUOpToLLVMPattern<triton::gpu::ExtractSliceOp> {
|
||||
using ConvertTritonGPUOpToLLVMPattern<
|
||||
tensor::ExtractSliceOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
triton::gpu::ExtractSliceOp>::ConvertTritonGPUOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(tensor::ExtractSliceOp op, OpAdaptor adaptor,
|
||||
matchAndRewrite(triton::gpu::ExtractSliceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// %dst = extract_slice %src[%offsets]
|
||||
Location loc = op->getLoc();
|
||||
@@ -531,7 +594,7 @@ namespace LLVM {
|
||||
|
||||
void vprintf(StringRef msg, ValueRange args,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
PrintfOpConversion::llPrintf(msg, args, rewriter);
|
||||
PrintOpConversion::llPrintf(msg, args, rewriter);
|
||||
}
|
||||
|
||||
void vprintf_array(Value thread, ArrayRef<Value> arr, std::string info,
|
||||
@@ -550,7 +613,7 @@ void vprintf_array(Value thread, ArrayRef<Value> arr, std::string info,
|
||||
} // namespace mlir
|
||||
|
||||
void populateTritonGPUToLLVMPatterns(
|
||||
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
@@ -568,5 +631,6 @@ void populateTritonGPUToLLVMPatterns(
|
||||
patterns.add<GetNumProgramsOpConversion>(typeConverter, benefit);
|
||||
patterns.add<MakeRangeOpConversion>(typeConverter, indexCacheInfo, benefit);
|
||||
patterns.add<ReturnOpConversion>(typeConverter, benefit);
|
||||
patterns.add<PrintfOpConversion>(typeConverter, benefit);
|
||||
patterns.add<PrintOpConversion>(typeConverter, benefit);
|
||||
patterns.add<AssertOpConversion>(typeConverter, benefit);
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
void populateTritonGPUToLLVMPatterns(
|
||||
mlir::LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
|
||||
int numWarps, AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem,
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo &indexCacheInfo,
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
#include "mlir/Dialect/Func/IR/FuncOps.h"
|
||||
#include "triton/Analysis/Allocation.h"
|
||||
|
||||
#include "TypeConverter.h"
|
||||
//
|
||||
#include "DotOpHelpers.h"
|
||||
#include "Utility.h"
|
||||
@@ -18,6 +19,7 @@ using namespace mlir::triton;
|
||||
|
||||
using ::mlir::LLVM::SharedMemoryObject;
|
||||
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
||||
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||
using ::mlir::triton::gpu::SliceEncodingAttr;
|
||||
|
||||
@@ -141,25 +143,26 @@ protected:
|
||||
}
|
||||
};
|
||||
|
||||
using IndexCacheKeyT = std::pair<Attribute, SmallVector<int64_t>>;
|
||||
using IndexCacheKeyT = std::pair<Attribute, RankedTensorType>;
|
||||
|
||||
struct CacheKeyDenseMapInfo {
|
||||
static IndexCacheKeyT getEmptyKey() {
|
||||
auto *pointer = llvm::DenseMapInfo<void *>::getEmptyKey();
|
||||
return std::make_pair(
|
||||
mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer)),
|
||||
SmallVector<int64_t>{});
|
||||
RankedTensorType{});
|
||||
}
|
||||
static IndexCacheKeyT getTombstoneKey() {
|
||||
auto *pointer = llvm::DenseMapInfo<void *>::getTombstoneKey();
|
||||
auto tombstone = llvm::DenseMapInfo<RankedTensorType>::getTombstoneKey();
|
||||
return std::make_pair(
|
||||
mlir::Attribute(static_cast<mlir::Attribute::ImplType *>(pointer)),
|
||||
SmallVector<int64_t>{std::numeric_limits<int64_t>::max()});
|
||||
tombstone);
|
||||
}
|
||||
static unsigned getHashValue(IndexCacheKeyT key) {
|
||||
return llvm::hash_combine(
|
||||
mlir::hash_value(key.first),
|
||||
llvm::hash_combine_range(key.second.begin(), key.second.end()));
|
||||
auto shape = key.second.getShape();
|
||||
return llvm::hash_combine(mlir::hash_value(key.first),
|
||||
mlir::hash_value(key.second));
|
||||
}
|
||||
static bool isEqual(IndexCacheKeyT LHS, IndexCacheKeyT RHS) {
|
||||
return LHS == RHS;
|
||||
@@ -178,22 +181,22 @@ public:
|
||||
OpBuilder::InsertPoint *indexInsertPoint;
|
||||
};
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPatternBase(LLVMTypeConverter &typeConverter)
|
||||
explicit ConvertTritonGPUOpToLLVMPatternBase(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter)
|
||||
: converter(&typeConverter) {}
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPatternBase(LLVMTypeConverter &typeConverter,
|
||||
const Allocation *allocation,
|
||||
Value smem)
|
||||
explicit ConvertTritonGPUOpToLLVMPatternBase(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, const Allocation *allocation,
|
||||
Value smem)
|
||||
: converter(&typeConverter), allocation(allocation), smem(smem) {}
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPatternBase(LLVMTypeConverter &typeConverter,
|
||||
const Allocation *allocation,
|
||||
Value smem,
|
||||
IndexCacheInfo indexCacheInfo)
|
||||
explicit ConvertTritonGPUOpToLLVMPatternBase(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, const Allocation *allocation,
|
||||
Value smem, IndexCacheInfo indexCacheInfo)
|
||||
: converter(&typeConverter), allocation(allocation), smem(smem),
|
||||
indexCacheInfo(indexCacheInfo) {}
|
||||
|
||||
LLVMTypeConverter *getTypeConverter() const { return converter; }
|
||||
TritonGPUToLLVMTypeConverter *getTypeConverter() const { return converter; }
|
||||
|
||||
static Value
|
||||
getStructFromSharedMemoryObject(Location loc,
|
||||
@@ -203,18 +206,20 @@ public:
|
||||
auto types = smemObj.getTypes();
|
||||
auto structTy =
|
||||
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types);
|
||||
return getStructFromElements(loc, elems, rewriter, structTy);
|
||||
// pack into struct
|
||||
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structTy);
|
||||
for (const auto &v : llvm::enumerate(elems)) {
|
||||
assert(v.value() && "can not insert null values");
|
||||
llvmStruct = insert_val(structTy, llvmStruct, v.value(), v.index());
|
||||
}
|
||||
return llvmStruct;
|
||||
}
|
||||
|
||||
Value getThreadId(ConversionPatternRewriter &rewriter, Location loc) const {
|
||||
auto llvmIndexTy = this->getTypeConverter()->getIndexType();
|
||||
auto cast = rewriter.create<UnrealizedConversionCastOp>(
|
||||
loc, TypeRange{llvmIndexTy},
|
||||
ValueRange{rewriter.create<::mlir::gpu::ThreadIdOp>(
|
||||
loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x)});
|
||||
Value threadId = cast.getResult(0);
|
||||
|
||||
return threadId;
|
||||
auto tid = rewriter.create<::mlir::gpu::ThreadIdOp>(
|
||||
loc, rewriter.getIndexType(), ::mlir::gpu::Dimension::x);
|
||||
return rewriter.create<arith::TruncIOp>(loc, i32_ty, tid);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------
|
||||
@@ -223,13 +228,12 @@ public:
|
||||
template <typename T>
|
||||
Value getSharedMemoryBase(Location loc, ConversionPatternRewriter &rewriter,
|
||||
T value) const {
|
||||
|
||||
auto ptrTy = LLVM::LLVMPointerType::get(
|
||||
this->getTypeConverter()->convertType(rewriter.getI8Type()), 3);
|
||||
auto bufferId = allocation->getBufferId(value);
|
||||
assert(bufferId != Allocation::InvalidBufferId && "BufferId not found");
|
||||
size_t offset = allocation->getOffset(bufferId);
|
||||
Value offVal = idx_val(offset);
|
||||
Value offVal = i32_val(offset);
|
||||
Value base = gep(ptrTy, smem, offVal);
|
||||
return base;
|
||||
}
|
||||
@@ -244,8 +248,8 @@ public:
|
||||
// This utililty computes the pointers for accessing the provided swizzled
|
||||
// shared memory layout `resSharedLayout`. More specifically, it computes,
|
||||
// for all indices (row, col) of `srcEncoding` such that idx % inVec = 0,
|
||||
// the pointer: ptr[(row, col)] = base + (rowOff * strides[ord[1]] + colOff)
|
||||
// where :
|
||||
// the pointer: ptr[(row, col)] = base + (rowOff * strides[ord[1]] +
|
||||
// colOff) where :
|
||||
// compute phase = (row // perPhase) % maxPhase
|
||||
// rowOff = row
|
||||
// colOff = colOffSwizzled + colOffOrdered
|
||||
@@ -255,8 +259,8 @@ public:
|
||||
// Note 1:
|
||||
// -------
|
||||
// Because swizzling happens at a granularity of outVec, we need to
|
||||
// decompose the offset into a swizzled factor and a non-swizzled (ordered)
|
||||
// factor
|
||||
// decompose the offset into a swizzled factor and a non-swizzled
|
||||
// (ordered) factor
|
||||
//
|
||||
// Note 2:
|
||||
// -------
|
||||
@@ -282,7 +286,7 @@ public:
|
||||
auto inOrder = triton::gpu::getOrder(srcEncoding);
|
||||
auto outOrder = triton::gpu::getOrder(resSharedLayout);
|
||||
// tensor indices held by the current thread, as LLVM values
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcEncoding, srcShape);
|
||||
auto srcIndices = emitIndices(loc, rewriter, srcEncoding, srcTy);
|
||||
// return values
|
||||
DenseMap<unsigned, Value> ret;
|
||||
// cache for non-immediate offsets
|
||||
@@ -376,7 +380,8 @@ public:
|
||||
unsigned maxPhase = dstSharedLayout.getMaxPhase();
|
||||
unsigned numElems = triton::gpu::getElemsPerThread(srcTy);
|
||||
assert(numElems == srcIndices.size());
|
||||
auto inVals = LLVM::getElementsFromStruct(loc, llSrc, rewriter);
|
||||
auto inVals =
|
||||
getTypeConverter()->unpackLLElements(loc, llSrc, rewriter, srcTy);
|
||||
auto wordTy = vec_ty(elemTy, minVec);
|
||||
auto elemPtrTy = ptr_ty(elemTy);
|
||||
Value outVecVal = i32_val(outVec);
|
||||
@@ -435,7 +440,7 @@ public:
|
||||
} else {
|
||||
Value remained = linear;
|
||||
for (auto &&en : llvm::enumerate(shape.drop_back())) {
|
||||
Value dimSize = idx_val(en.value());
|
||||
Value dimSize = i32_val(en.value());
|
||||
multiDim[en.index()] = urem(remained, dimSize);
|
||||
remained = udiv(remained, dimSize);
|
||||
}
|
||||
@@ -454,12 +459,12 @@ public:
|
||||
Value linearize(ConversionPatternRewriter &rewriter, Location loc,
|
||||
ArrayRef<Value> multiDim, ArrayRef<unsigned> shape) const {
|
||||
auto rank = multiDim.size();
|
||||
Value linear = idx_val(0);
|
||||
Value linear = i32_val(0);
|
||||
if (rank > 0) {
|
||||
linear = multiDim.back();
|
||||
for (auto [dim, dimShape] :
|
||||
llvm::reverse(llvm::zip(multiDim.drop_back(), shape.drop_back()))) {
|
||||
Value dimSize = idx_val(dimShape);
|
||||
Value dimSize = i32_val(dimShape);
|
||||
linear = add(mul(linear, dimSize), dim);
|
||||
}
|
||||
}
|
||||
@@ -469,7 +474,7 @@ public:
|
||||
Value dot(ConversionPatternRewriter &rewriter, Location loc,
|
||||
ArrayRef<Value> offsets, ArrayRef<Value> strides) const {
|
||||
assert(offsets.size() == strides.size());
|
||||
Value ret = idx_val(0);
|
||||
Value ret = i32_val(0);
|
||||
for (auto [offset, stride] : llvm::zip(offsets, strides)) {
|
||||
ret = add(ret, mul(offset, stride));
|
||||
}
|
||||
@@ -499,8 +504,8 @@ public:
|
||||
SmallVector<Value> emitBaseIndexForLayout(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const Attribute &layout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
IndexCacheKeyT key = std::make_pair(layout, llvm::to_vector(shape));
|
||||
RankedTensorType type) const {
|
||||
IndexCacheKeyT key = std::make_pair(layout, type);
|
||||
auto cache = indexCacheInfo.baseIndexCache;
|
||||
assert(cache && "baseIndexCache is nullptr");
|
||||
auto insertPt = indexCacheInfo.indexInsertPoint;
|
||||
@@ -512,12 +517,12 @@ public:
|
||||
SmallVector<Value> result;
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
result =
|
||||
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, shape);
|
||||
emitBaseIndexForBlockedLayout(loc, rewriter, blockedLayout, type);
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (mmaLayout.isVolta())
|
||||
result = emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, shape);
|
||||
result = emitBaseIndexForMmaLayoutV1(loc, rewriter, mmaLayout, type);
|
||||
if (mmaLayout.isAmpere())
|
||||
result = emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, shape);
|
||||
result = emitBaseIndexForMmaLayoutV2(loc, rewriter, mmaLayout, type);
|
||||
} else {
|
||||
llvm_unreachable("unsupported emitBaseIndexForLayout");
|
||||
}
|
||||
@@ -528,14 +533,14 @@ public:
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<unsigned>>
|
||||
emitOffsetForLayout(const Attribute &layout, ArrayRef<int64_t> shape) const {
|
||||
emitOffsetForLayout(const Attribute &layout, RankedTensorType type) const {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>())
|
||||
return emitOffsetForBlockedLayout(blockedLayout, shape);
|
||||
return emitOffsetForBlockedLayout(blockedLayout, type);
|
||||
if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
if (mmaLayout.isVolta())
|
||||
return emitOffsetForMmaLayoutV1(mmaLayout, shape);
|
||||
return emitOffsetForMmaLayoutV1(mmaLayout, type);
|
||||
if (mmaLayout.isAmpere())
|
||||
return emitOffsetForMmaLayoutV2(mmaLayout, shape);
|
||||
return emitOffsetForMmaLayoutV2(mmaLayout, type);
|
||||
}
|
||||
llvm_unreachable("unsupported emitOffsetForLayout");
|
||||
}
|
||||
@@ -546,8 +551,8 @@ public:
|
||||
SmallVector<SmallVector<Value>> emitIndices(Location loc,
|
||||
ConversionPatternRewriter &b,
|
||||
const Attribute &layout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
IndexCacheKeyT key(layout, llvm::to_vector(shape));
|
||||
RankedTensorType type) const {
|
||||
IndexCacheKeyT key(layout, type);
|
||||
auto cache = indexCacheInfo.indexCache;
|
||||
assert(cache && "indexCache is nullptr");
|
||||
auto insertPt = indexCacheInfo.indexInsertPoint;
|
||||
@@ -558,11 +563,11 @@ public:
|
||||
restoreInsertionPointIfSet(insertPt, b);
|
||||
SmallVector<SmallVector<Value>> result;
|
||||
if (auto blocked = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
result = emitIndicesForDistributedLayout(loc, b, blocked, shape);
|
||||
result = emitIndicesForDistributedLayout(loc, b, blocked, type);
|
||||
} else if (auto mma = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
result = emitIndicesForDistributedLayout(loc, b, mma, shape);
|
||||
result = emitIndicesForDistributedLayout(loc, b, mma, type);
|
||||
} else if (auto slice = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
result = emitIndicesForSliceLayout(loc, b, slice, shape);
|
||||
result = emitIndicesForSliceLayout(loc, b, slice, type);
|
||||
} else {
|
||||
llvm_unreachable(
|
||||
"emitIndices for layouts other than blocked & slice not "
|
||||
@@ -591,16 +596,15 @@ private:
|
||||
// -----------------------------------------------------------------------
|
||||
|
||||
// Get an index-base for each dimension for a \param blocked_layout.
|
||||
SmallVector<Value>
|
||||
emitBaseIndexForBlockedLayout(Location loc,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
const BlockedEncodingAttr &blocked_layout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
SmallVector<Value> emitBaseIndexForBlockedLayout(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
const BlockedEncodingAttr &blocked_layout, RankedTensorType type) const {
|
||||
auto shape = type.getShape();
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
#ifdef USE_ROCM
|
||||
Value warpSize = idx_val(64);
|
||||
Value warpSize = i32_val(64);
|
||||
#else
|
||||
Value warpSize = idx_val(32);
|
||||
Value warpSize = i32_val(32);
|
||||
#endif
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
Value warpId = udiv(threadId, warpSize);
|
||||
@@ -623,13 +627,13 @@ private:
|
||||
auto maxWarps =
|
||||
ceil<unsigned>(shape[k], sizePerThread[k] * threadsPerWarp[k]);
|
||||
auto maxThreads = ceil<unsigned>(shape[k], sizePerThread[k]);
|
||||
multiDimWarpId[k] = urem(multiDimWarpId[k], idx_val(maxWarps));
|
||||
multiDimThreadId[k] = urem(multiDimThreadId[k], idx_val(maxThreads));
|
||||
multiDimWarpId[k] = urem(multiDimWarpId[k], i32_val(maxWarps));
|
||||
multiDimThreadId[k] = urem(multiDimThreadId[k], i32_val(maxThreads));
|
||||
// multiDimBase[k] = (multiDimThreadId[k] +
|
||||
// multiDimWarpId[k] * threadsPerWarp[k]) *
|
||||
// sizePerThread[k];
|
||||
Value threadsPerWarpK = idx_val(threadsPerWarp[k]);
|
||||
Value sizePerThreadK = idx_val(sizePerThread[k]);
|
||||
Value threadsPerWarpK = i32_val(threadsPerWarp[k]);
|
||||
Value sizePerThreadK = i32_val(sizePerThread[k]);
|
||||
multiDimBase[k] =
|
||||
mul(sizePerThreadK, add(multiDimThreadId[k],
|
||||
mul(multiDimWarpId[k], threadsPerWarpK)));
|
||||
@@ -639,7 +643,8 @@ private:
|
||||
|
||||
SmallVector<SmallVector<unsigned>>
|
||||
emitOffsetForBlockedLayout(const BlockedEncodingAttr &blockedLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
RankedTensorType type) const {
|
||||
auto shape = type.getShape();
|
||||
auto sizePerThread = blockedLayout.getSizePerThread();
|
||||
auto threadsPerWarp = blockedLayout.getThreadsPerWarp();
|
||||
auto warpsPerCTA = blockedLayout.getWarpsPerCTA();
|
||||
@@ -668,7 +673,7 @@ private:
|
||||
threadOffset * sizePerThread[k] + elemOffset);
|
||||
}
|
||||
|
||||
unsigned elemsPerThread = blockedLayout.getElemsPerThread(shape);
|
||||
unsigned elemsPerThread = triton::gpu::getElemsPerThread(type);
|
||||
unsigned totalSizePerThread = product<unsigned>(sizePerThread);
|
||||
SmallVector<SmallVector<unsigned>> reorderedOffset(elemsPerThread);
|
||||
for (unsigned n = 0; n < elemsPerThread; ++n) {
|
||||
@@ -696,11 +701,12 @@ private:
|
||||
SmallVector<Value>
|
||||
emitBaseIndexForMmaLayoutV1(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const MmaEncodingAttr &mmaLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
RankedTensorType type) const {
|
||||
auto shape = type.getShape();
|
||||
|
||||
auto wpt = mmaLayout.getWarpsPerCTA();
|
||||
auto fpw = LLVM::DotOpMmaV1ConversionHelper::fpw;
|
||||
auto [isARow, isBRow, isAVec4, isBVec4, id] =
|
||||
auto [isARow, isBRow, isAVec4, isBVec4, _] =
|
||||
mmaLayout.decodeVoltaLayoutStates();
|
||||
|
||||
Value thread = getThreadId(rewriter, loc);
|
||||
@@ -717,11 +723,17 @@ private:
|
||||
Value _fpw0 = i32_val(fpw[0]);
|
||||
Value _fpw1 = i32_val(fpw[1]);
|
||||
|
||||
LLVM::DotOpMmaV1ConversionHelper::AParam aParam(isARow, isAVec4);
|
||||
LLVM::DotOpMmaV1ConversionHelper::BParam bParam(isBRow, isBVec4);
|
||||
// A info
|
||||
auto aEncoding = DotOperandEncodingAttr::get(ctx, 0, mmaLayout);
|
||||
auto aRep = aEncoding.getMMAv1Rep();
|
||||
auto aSpw = aEncoding.getMMAv1ShapePerWarp();
|
||||
// B info
|
||||
auto bEncoding = DotOperandEncodingAttr::get(ctx, 1, mmaLayout);
|
||||
auto bSpw = bEncoding.getMMAv1ShapePerWarp();
|
||||
auto bRep = bEncoding.getMMAv1Rep();
|
||||
|
||||
SmallVector<int, 2> rep({aParam.rep[0], bParam.rep[1]});
|
||||
SmallVector<int, 2> spw({aParam.spw[0], bParam.spw[1]});
|
||||
SmallVector<int, 2> rep({aRep[0], bRep[1]});
|
||||
SmallVector<int, 2> spw({aSpw[0], bSpw[1]});
|
||||
SmallVector<unsigned, 2> shapePerCTA({spw[0] * wpt[0], spw[1] * wpt[1]});
|
||||
|
||||
Value lane = urem(thread, warpSize);
|
||||
@@ -764,16 +776,28 @@ private:
|
||||
|
||||
SmallVector<SmallVector<unsigned>>
|
||||
emitOffsetForMmaLayoutV1(const MmaEncodingAttr &mmaLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
RankedTensorType type) const {
|
||||
auto shape = type.getShape();
|
||||
|
||||
auto [isARow, isBRow, isAVec4, isBVec4, id] =
|
||||
auto [isARow, isBRow, isAVec4, isBVec4, _] =
|
||||
mmaLayout.decodeVoltaLayoutStates();
|
||||
LLVM::DotOpMmaV1ConversionHelper::AParam aParam(isARow, isAVec4);
|
||||
LLVM::DotOpMmaV1ConversionHelper::BParam bParam(isBRow, isBVec4);
|
||||
|
||||
// TODO: seems like the apttern below to get `rep`/`spw` appears quite often
|
||||
// A info
|
||||
auto aEncoding =
|
||||
DotOperandEncodingAttr::get(type.getContext(), 0, mmaLayout);
|
||||
auto aRep = aEncoding.getMMAv1Rep();
|
||||
auto aSpw = aEncoding.getMMAv1ShapePerWarp();
|
||||
// B info
|
||||
auto bEncoding =
|
||||
DotOperandEncodingAttr::get(type.getContext(), 1, mmaLayout);
|
||||
auto bSpw = bEncoding.getMMAv1ShapePerWarp();
|
||||
auto bRep = bEncoding.getMMAv1Rep();
|
||||
|
||||
auto wpt = mmaLayout.getWarpsPerCTA();
|
||||
auto fpw = LLVM::DotOpMmaV1ConversionHelper::fpw;
|
||||
SmallVector<int, 2> rep({aParam.rep[0], bParam.rep[1]});
|
||||
SmallVector<int, 2> spw({aParam.spw[0], bParam.spw[1]});
|
||||
SmallVector<int, 2> rep({aRep[0], bRep[1]});
|
||||
SmallVector<int, 2> spw({aSpw[0], bSpw[1]});
|
||||
SmallVector<unsigned, 2> shapePerCTA({spw[0] * wpt[0], spw[1] * wpt[1]});
|
||||
|
||||
SmallVector<unsigned> idxM;
|
||||
@@ -804,34 +828,36 @@ private:
|
||||
SmallVector<Value>
|
||||
emitBaseIndexForMmaLayoutV2(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const MmaEncodingAttr &mmaLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
RankedTensorType type) const {
|
||||
auto shape = type.getShape();
|
||||
auto _warpsPerCTA = mmaLayout.getWarpsPerCTA();
|
||||
assert(_warpsPerCTA.size() == 2);
|
||||
SmallVector<Value> warpsPerCTA = {idx_val(_warpsPerCTA[0]),
|
||||
idx_val(_warpsPerCTA[1])};
|
||||
SmallVector<Value> warpsPerCTA = {i32_val(_warpsPerCTA[0]),
|
||||
i32_val(_warpsPerCTA[1])};
|
||||
Value threadId = getThreadId(rewriter, loc);
|
||||
#ifdef USE_ROCM
|
||||
Value warpSize = idx_val(64);
|
||||
Value warpSize = i32_val(64);
|
||||
#else
|
||||
Value warpSize = idx_val(32);
|
||||
Value warpSize = i32_val(32);
|
||||
#endif
|
||||
Value laneId = urem(threadId, warpSize);
|
||||
Value warpId = udiv(threadId, warpSize);
|
||||
Value warpId0 = urem(urem(warpId, warpsPerCTA[0]), idx_val(shape[0] / 16));
|
||||
Value warpId0 = urem(urem(warpId, warpsPerCTA[0]), i32_val(shape[0] / 16));
|
||||
Value warpId1 = urem(urem(udiv(warpId, warpsPerCTA[0]), warpsPerCTA[1]),
|
||||
idx_val(shape[1] / 8));
|
||||
Value offWarp0 = mul(warpId0, idx_val(16));
|
||||
Value offWarp1 = mul(warpId1, idx_val(8));
|
||||
i32_val(shape[1] / 8));
|
||||
Value offWarp0 = mul(warpId0, i32_val(16));
|
||||
Value offWarp1 = mul(warpId1, i32_val(8));
|
||||
|
||||
SmallVector<Value> multiDimBase(2);
|
||||
multiDimBase[0] = add(udiv(laneId, idx_val(4)), offWarp0);
|
||||
multiDimBase[1] = add(mul(idx_val(2), urem(laneId, idx_val(4))), offWarp1);
|
||||
multiDimBase[0] = add(udiv(laneId, i32_val(4)), offWarp0);
|
||||
multiDimBase[1] = add(mul(i32_val(2), urem(laneId, i32_val(4))), offWarp1);
|
||||
return multiDimBase;
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<unsigned>>
|
||||
emitOffsetForMmaLayoutV2(const MmaEncodingAttr &mmaLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
RankedTensorType type) const {
|
||||
auto shape = type.getShape();
|
||||
SmallVector<SmallVector<unsigned>> ret;
|
||||
|
||||
for (unsigned i = 0; i < shape[0]; i += getShapePerCTA(mmaLayout)[0]) {
|
||||
@@ -849,31 +875,34 @@ private:
|
||||
// [elemsPerThread X rank] index matrix.
|
||||
SmallVector<SmallVector<Value>> emitIndicesForDistributedLayout(
|
||||
Location loc, ConversionPatternRewriter &rewriter,
|
||||
const Attribute &layout, ArrayRef<int64_t> shape) const {
|
||||
const Attribute &layout, RankedTensorType type) const {
|
||||
// step 1, delinearize threadId to get the base index
|
||||
auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, layout, shape);
|
||||
auto multiDimBase = emitBaseIndexForLayout(loc, rewriter, layout, type);
|
||||
// step 2, get offset of each element
|
||||
auto offset = emitOffsetForLayout(layout, shape);
|
||||
auto offset = emitOffsetForLayout(layout, type);
|
||||
// step 3, add offset to base, and reorder the sequence of indices to
|
||||
// guarantee that elems in the same sizePerThread are adjacent in order
|
||||
auto shape = type.getShape();
|
||||
unsigned rank = shape.size();
|
||||
unsigned elemsPerThread = offset.size();
|
||||
SmallVector<SmallVector<Value>> multiDimIdx(elemsPerThread,
|
||||
SmallVector<Value>(rank));
|
||||
for (unsigned n = 0; n < elemsPerThread; ++n)
|
||||
for (unsigned k = 0; k < rank; ++k)
|
||||
multiDimIdx[n][k] = add(multiDimBase[k], idx_val(offset[n][k]));
|
||||
multiDimIdx[n][k] = add(multiDimBase[k], i32_val(offset[n][k]));
|
||||
return multiDimIdx;
|
||||
}
|
||||
|
||||
SmallVector<SmallVector<Value>>
|
||||
emitIndicesForSliceLayout(Location loc, ConversionPatternRewriter &rewriter,
|
||||
const SliceEncodingAttr &sliceLayout,
|
||||
ArrayRef<int64_t> shape) const {
|
||||
auto parent = sliceLayout.getParent();
|
||||
RankedTensorType type) const {
|
||||
auto parentEncoding = sliceLayout.getParent();
|
||||
unsigned dim = sliceLayout.getDim();
|
||||
auto parentIndices =
|
||||
emitIndices(loc, rewriter, parent, sliceLayout.paddedShape(shape));
|
||||
auto parentShape = sliceLayout.paddedShape(type.getShape());
|
||||
RankedTensorType parentTy = RankedTensorType::get(
|
||||
parentShape, type.getElementType(), parentEncoding);
|
||||
auto parentIndices = emitIndices(loc, rewriter, parentEncoding, parentTy);
|
||||
unsigned numIndices = parentIndices.size();
|
||||
SmallVector<SmallVector<Value>> resultIndices;
|
||||
for (unsigned i = 0; i < numIndices; ++i) {
|
||||
@@ -885,7 +914,7 @@ private:
|
||||
}
|
||||
|
||||
protected:
|
||||
LLVMTypeConverter *converter;
|
||||
TritonGPUToLLVMTypeConverter *converter;
|
||||
const Allocation *allocation;
|
||||
Value smem;
|
||||
IndexCacheInfo indexCacheInfo;
|
||||
@@ -898,30 +927,29 @@ class ConvertTritonGPUOpToLLVMPattern
|
||||
public:
|
||||
using OpAdaptor = typename SourceOp::Adaptor;
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
explicit ConvertTritonGPUOpToLLVMPattern(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, PatternBenefit benefit = 1)
|
||||
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
|
||||
ConvertTritonGPUOpToLLVMPatternBase(typeConverter) {}
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
|
||||
const Allocation *allocation,
|
||||
Value smem,
|
||||
PatternBenefit benefit = 1)
|
||||
explicit ConvertTritonGPUOpToLLVMPattern(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, const Allocation *allocation,
|
||||
Value smem, PatternBenefit benefit = 1)
|
||||
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
|
||||
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation, smem) {}
|
||||
|
||||
explicit ConvertTritonGPUOpToLLVMPattern(LLVMTypeConverter &typeConverter,
|
||||
const Allocation *allocation,
|
||||
Value smem,
|
||||
IndexCacheInfo indexCacheInfo,
|
||||
PatternBenefit benefit = 1)
|
||||
explicit ConvertTritonGPUOpToLLVMPattern(
|
||||
TritonGPUToLLVMTypeConverter &typeConverter, const Allocation *allocation,
|
||||
Value smem, IndexCacheInfo indexCacheInfo, PatternBenefit benefit = 1)
|
||||
: ConvertOpToLLVMPattern<SourceOp>(typeConverter, benefit),
|
||||
ConvertTritonGPUOpToLLVMPatternBase(typeConverter, allocation, smem,
|
||||
indexCacheInfo) {}
|
||||
|
||||
protected:
|
||||
LLVMTypeConverter *getTypeConverter() const {
|
||||
return ((ConvertTritonGPUOpToLLVMPatternBase *)this)->getTypeConverter();
|
||||
TritonGPUToLLVMTypeConverter *getTypeConverter() const {
|
||||
LLVMTypeConverter *ret =
|
||||
((ConvertTritonGPUOpToLLVMPatternBase *)this)->getTypeConverter();
|
||||
return (TritonGPUToLLVMTypeConverter *)ret;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
@@ -5,8 +5,11 @@
|
||||
#include "mlir/Conversion/ControlFlowToLLVM//ControlFlowToLLVM.h"
|
||||
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
|
||||
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
|
||||
#include "mlir/Conversion/LLVMCommon/VectorPattern.h"
|
||||
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
|
||||
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
|
||||
#include "mlir/Dialect/Index/IR/IndexDialect.h"
|
||||
#include "mlir/Dialect/Index/IR/IndexOps.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
|
||||
@@ -26,39 +29,25 @@
|
||||
#include "TypeConverter.h"
|
||||
#include "ViewOpToLLVM.h"
|
||||
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Conversion/Passes.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
class TritonLLVMConversionTarget : public ConversionTarget {
|
||||
public:
|
||||
explicit TritonLLVMConversionTarget(MLIRContext &ctx)
|
||||
: ConversionTarget(ctx) {
|
||||
addLegalDialect<LLVM::LLVMDialect>();
|
||||
#ifdef USE_ROCM
|
||||
addLegalDialect<ROCDL::ROCDLDialect>();
|
||||
addLegalDialect<mlir::scf::SCFDialect>();
|
||||
#else
|
||||
addLegalDialect<NVVM::NVVMDialect>();
|
||||
#endif
|
||||
addIllegalDialect<triton::TritonDialect>();
|
||||
addIllegalDialect<triton::gpu::TritonGPUDialect>();
|
||||
addIllegalDialect<mlir::gpu::GPUDialect>();
|
||||
addLegalOp<mlir::UnrealizedConversionCastOp>();
|
||||
}
|
||||
};
|
||||
namespace {
|
||||
|
||||
class TritonLLVMFunctionConversionTarget : public ConversionTarget {
|
||||
public:
|
||||
explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx)
|
||||
: ConversionTarget(ctx) {
|
||||
addLegalDialect<index::IndexDialect>();
|
||||
addLegalDialect<LLVM::LLVMDialect>();
|
||||
#ifdef USE_ROCM
|
||||
addLegalDialect<ROCDL::ROCDLDialect>();
|
||||
addLegalDialect<mlir::scf::SCFDialect>();
|
||||
#else
|
||||
addLegalDialect<NVVM::NVVMDialect>();
|
||||
#endif
|
||||
@@ -67,9 +56,40 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace mlir
|
||||
class TritonPTXConversionTarget : public ConversionTarget {
|
||||
public:
|
||||
explicit TritonPTXConversionTarget(MLIRContext &ctx) : ConversionTarget(ctx) {
|
||||
addDynamicallyLegalDialect<LLVM::LLVMDialect>(
|
||||
[&](Operation *op) { return isLegalElementwiseOp(op); });
|
||||
#ifdef USE_ROCM
|
||||
addLegalDialect<ROCDL::ROCDLDialect>();
|
||||
#else
|
||||
addLegalDialect<NVVM::NVVMDialect>();
|
||||
#endif
|
||||
addLegalOp<mlir::UnrealizedConversionCastOp>();
|
||||
}
|
||||
};
|
||||
|
||||
namespace {
|
||||
struct ReturnOpConversion : public ConvertOpToLLVMPattern<func::ReturnOp> {
|
||||
using ConvertOpToLLVMPattern<func::ReturnOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(func::ReturnOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
unsigned numArguments = op.getNumOperands();
|
||||
|
||||
// Currently, Triton kernel function always return nothing.
|
||||
// TODO(Superjomn) add support for non-inline device function
|
||||
if (numArguments > 0) {
|
||||
return rewriter.notifyMatchFailure(
|
||||
op, "Only kernel function with nothing returned is supported.");
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(op, TypeRange(), ValueRange(),
|
||||
op->getAttrs());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// FuncOp legalization pattern that converts MemRef arguments to pointers to
|
||||
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
|
||||
@@ -83,8 +103,9 @@ struct FuncOpConversion : public FuncOpConversionBase {
|
||||
matchAndRewrite(func::FuncOp funcOp, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto newFuncOp = convertFuncOpToLLVMFuncOp(funcOp, rewriter);
|
||||
if (!newFuncOp)
|
||||
if (!newFuncOp) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
auto ctx = funcOp->getContext();
|
||||
|
||||
@@ -105,6 +126,24 @@ private:
|
||||
int numWarps{0};
|
||||
};
|
||||
|
||||
class TritonLLVMConversionTarget : public ConversionTarget {
|
||||
public:
|
||||
explicit TritonLLVMConversionTarget(MLIRContext &ctx)
|
||||
: ConversionTarget(ctx) {
|
||||
addLegalDialect<LLVM::LLVMDialect>();
|
||||
#ifdef USE_ROCM
|
||||
addLegalDialect<ROCDL::ROCDLDialect>();
|
||||
addLegalDialect<mlir::scf::SCFDialect>();
|
||||
#else
|
||||
addLegalDialect<NVVM::NVVMDialect>();
|
||||
#endif
|
||||
addIllegalDialect<triton::TritonDialect>();
|
||||
addIllegalDialect<triton::gpu::TritonGPUDialect>();
|
||||
addIllegalDialect<mlir::gpu::GPUDialect>();
|
||||
addLegalOp<mlir::UnrealizedConversionCastOp>();
|
||||
}
|
||||
};
|
||||
|
||||
class ConvertTritonGPUToLLVM
|
||||
: public ConvertTritonGPUToLLVMBase<ConvertTritonGPUToLLVM> {
|
||||
|
||||
@@ -115,48 +154,39 @@ public:
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp mod = getOperation();
|
||||
|
||||
mlir::LowerToLLVMOptions option(context);
|
||||
option.overrideIndexBitwidth(32);
|
||||
TritonGPUToLLVMTypeConverter typeConverter(context, option);
|
||||
TritonLLVMFunctionConversionTarget funcTarget(*context);
|
||||
TritonLLVMConversionTarget target(*context);
|
||||
|
||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
|
||||
// Step 1: Decompose unoptimized layout conversions to use shared memory
|
||||
// Step 2: Decompose insert_slice_async to use load + insert_slice for
|
||||
// pre-Ampere architectures or unsupported vectorized load sizes
|
||||
// Step 3: Allocate shared memories and insert barriers
|
||||
// Step 4: Convert FuncOp to LLVMFuncOp via partial conversion
|
||||
// Step 5: Get axis and shared memory info
|
||||
// Step 6: Convert the rest of ops via partial conversion
|
||||
//
|
||||
// The reason for a separation between 4/6 is that, step 5 is out of the
|
||||
// scope of Dialect Conversion, thus we need to make sure the smem is not
|
||||
// revised during the conversion of step 6.
|
||||
|
||||
// Step 1
|
||||
/* preprocess */
|
||||
decomposeMmaToDotOperand(mod, numWarps);
|
||||
decomposeBlockedToDotOperand(mod);
|
||||
|
||||
// Step 2
|
||||
if (failed(decomposeInsertSliceAsyncOp(mod)))
|
||||
return signalPassFailure();
|
||||
|
||||
// Step 3
|
||||
/* allocate shared memory and set barrier */
|
||||
Allocation allocation(mod);
|
||||
MembarAnalysis membarPass(&allocation);
|
||||
membarPass.run();
|
||||
|
||||
// Step 4
|
||||
RewritePatternSet funcPatterns(context);
|
||||
funcPatterns.add<FuncOpConversion>(typeConverter, numWarps, /*benefit=*/1);
|
||||
if (failed(
|
||||
applyPartialConversion(mod, funcTarget, std::move(funcPatterns))))
|
||||
return signalPassFailure();
|
||||
/* lower functions */
|
||||
{
|
||||
mlir::LowerToLLVMOptions option(context);
|
||||
TritonGPUToLLVMTypeConverter typeConverter(context, option);
|
||||
TritonLLVMFunctionConversionTarget funcTarget(*context);
|
||||
RewritePatternSet funcPatterns(context);
|
||||
funcPatterns.add<FuncOpConversion>(typeConverter, numWarps,
|
||||
/*benefit=*/1);
|
||||
funcPatterns.add<ReturnOpConversion>(typeConverter);
|
||||
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
|
||||
funcPatterns);
|
||||
if (failed(
|
||||
applyPartialConversion(mod, funcTarget, std::move(funcPatterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
// Step 5 - get axis and shared memory info
|
||||
std::unique_ptr<DataFlowSolver> solver = createDataFlowSolver();
|
||||
AxisInfoAnalysis *axisInfoAnalysis = solver->load<AxisInfoAnalysis>();
|
||||
if (failed(solver->initializeAndRun(mod)))
|
||||
@@ -166,64 +196,58 @@ public:
|
||||
mlir::IntegerAttr::get(mlir::IntegerType::get(context, 32),
|
||||
allocation.getSharedMemorySize()));
|
||||
|
||||
// Step 6 - rewrite rest of ops
|
||||
// We set a higher benefit here to ensure triton's patterns runs before
|
||||
// arith patterns for some encoding not supported by the community
|
||||
// patterns.
|
||||
/* rewrite ops */
|
||||
RewritePatternSet patterns(context);
|
||||
// TritonGPU lowering patterns
|
||||
OpBuilder::InsertPoint indexInsertPoint;
|
||||
ConvertTritonGPUOpToLLVMPatternBase::IndexCacheInfo indexCacheInfo{
|
||||
&baseIndexCache, &indexCache, &indexInsertPoint};
|
||||
|
||||
RewritePatternSet patterns(context);
|
||||
|
||||
// Normal conversions
|
||||
populateTritonGPUToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||
*axisInfoAnalysis, &allocation, smem,
|
||||
indexCacheInfo, /*benefit=*/10);
|
||||
// ConvertLayoutOp
|
||||
populateConvertLayoutOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||
*axisInfoAnalysis, &allocation, smem,
|
||||
indexCacheInfo, /*benefit=*/10);
|
||||
// DotOp
|
||||
populateDotOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||
*axisInfoAnalysis, &allocation, smem,
|
||||
/*benefit=*/10);
|
||||
// ElementwiseOp
|
||||
populateElementwiseOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||
*axisInfoAnalysis, &allocation, smem,
|
||||
/*benefit=*/10);
|
||||
// LoadStoreOp
|
||||
populateLoadStoreOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||
*axisInfoAnalysis, &allocation, smem,
|
||||
indexCacheInfo, /*benefit=*/10);
|
||||
// ReduceOp
|
||||
populateReduceOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||
*axisInfoAnalysis, &allocation, smem,
|
||||
indexCacheInfo, /*benefit=*/10);
|
||||
// ViewOp
|
||||
populateViewOpToLLVMPatterns(typeConverter, patterns, numWarps,
|
||||
*axisInfoAnalysis, &allocation, smem,
|
||||
/*benefit=*/10);
|
||||
|
||||
// Add arith/math's patterns to help convert scalar expression to LLVM.
|
||||
mlir::arith::populateArithToLLVMConversionPatterns(typeConverter, patterns);
|
||||
mlir::populateMathToLLVMConversionPatterns(typeConverter, patterns);
|
||||
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
|
||||
patterns);
|
||||
auto populatePatterns1 = [&](auto populateFunc) {
|
||||
populateFunc(typeConverter, patterns, numWarps, *axisInfoAnalysis,
|
||||
&allocation, smem, indexCacheInfo, /*benefit*/ 1);
|
||||
};
|
||||
auto populatePatterns2 = [&](auto populateFunc) {
|
||||
populateFunc(typeConverter, patterns, numWarps, *axisInfoAnalysis,
|
||||
&allocation, smem, /*benefit*/ 1);
|
||||
};
|
||||
populatePatterns1(populateTritonGPUToLLVMPatterns);
|
||||
populatePatterns1(populateConvertLayoutOpToLLVMPatterns);
|
||||
populatePatterns2(populateDotOpToLLVMPatterns);
|
||||
populatePatterns2(populateElementwiseOpToLLVMPatterns);
|
||||
populatePatterns1(populateLoadStoreOpToLLVMPatterns);
|
||||
populatePatterns1(populateReduceOpToLLVMPatterns);
|
||||
populatePatterns2(populateViewOpToLLVMPatterns);
|
||||
// Native lowering patterns
|
||||
#ifdef USE_ROCM
|
||||
mlir::populateGpuToROCDLConversionPatterns(typeConverter, patterns, mlir::gpu::amd::HIP);
|
||||
#else
|
||||
mlir::populateGpuToNVVMConversionPatterns(typeConverter, patterns);
|
||||
#endif
|
||||
|
||||
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
|
||||
patterns);
|
||||
if (failed(applyPartialConversion(mod, target, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
|
||||
#ifndef USE_ROCM
|
||||
// Use our custom converters to convert some operations to PTX to avoid
|
||||
// using NVPTX for two reasons:
|
||||
// 1. NVPTX backend is flaky on data types like float16 and bfloat16
|
||||
// 2. In some cases, we may generate faster PTX code than NVPTX backend
|
||||
TritonPTXConversionTarget ptxTarget(*context);
|
||||
RewritePatternSet ptxPatterns(context);
|
||||
// Add patterns to convert LLVM to PTX
|
||||
populateElementwiseOpToPTXPatterns(typeConverter, ptxPatterns,
|
||||
/*benefits=*/10);
|
||||
|
||||
if (failed(applyPartialConversion(mod, ptxTarget, std::move(ptxPatterns))))
|
||||
return signalPassFailure();
|
||||
#endif
|
||||
}
|
||||
|
||||
private:
|
||||
Value smem;
|
||||
|
||||
using IndexCacheKeyT = std::pair<Attribute, SmallVector<int64_t>>;
|
||||
using IndexCacheKeyT = std::pair<Attribute, RankedTensorType>;
|
||||
DenseMap<IndexCacheKeyT, SmallVector<Value>, CacheKeyDenseMapInfo>
|
||||
baseIndexCache;
|
||||
DenseMap<IndexCacheKeyT, SmallVector<SmallVector<Value>>,
|
||||
|
||||
163
lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp
Normal file
163
lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp
Normal file
@@ -0,0 +1,163 @@
|
||||
#include "TypeConverter.h"
|
||||
#include "DotOpHelpers.h"
|
||||
#include "Utility.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "triton/Conversion/MLIRTypes.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::LLVM::DotOpFMAConversionHelper;
|
||||
using ::mlir::LLVM::DotOpMmaV1ConversionHelper;
|
||||
using ::mlir::LLVM::MMA16816ConversionHelper;
|
||||
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
||||
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu::getElemsPerThread;
|
||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
using ::mlir::triton::gpu::SliceEncodingAttr;
|
||||
|
||||
TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
|
||||
MLIRContext *ctx, LowerToLLVMOptions &option,
|
||||
const DataLayoutAnalysis *analysis)
|
||||
: LLVMTypeConverter(ctx, option, analysis) {
|
||||
addConversion([&](triton::PointerType type) -> llvm::Optional<Type> {
|
||||
return convertTritonPointerType(type);
|
||||
});
|
||||
addConversion([&](RankedTensorType type) -> llvm::Optional<Type> {
|
||||
return convertTritonTensorType(type);
|
||||
});
|
||||
// Internally store float8 as int8
|
||||
addConversion([&](mlir::Float8E4M3FNType type) -> llvm::Optional<Type> {
|
||||
return IntegerType::get(type.getContext(), 8);
|
||||
});
|
||||
addConversion([&](mlir::Float8E5M2Type type) -> llvm::Optional<Type> {
|
||||
return IntegerType::get(type.getContext(), 8);
|
||||
});
|
||||
// Internally store bfloat16 as int16
|
||||
addConversion([&](BFloat16Type type) -> llvm::Optional<Type> {
|
||||
return IntegerType::get(type.getContext(), 16);
|
||||
});
|
||||
}
|
||||
|
||||
Type TritonGPUToLLVMTypeConverter::convertTritonPointerType(
|
||||
triton::PointerType type) {
|
||||
// Recursively translate pointee type
|
||||
return LLVM::LLVMPointerType::get(convertType(type.getPointeeType()),
|
||||
type.getAddressSpace());
|
||||
}
|
||||
|
||||
Value TritonGPUToLLVMTypeConverter::packLLElements(
|
||||
Location loc, ValueRange resultVals, ConversionPatternRewriter &rewriter,
|
||||
Type type) {
|
||||
auto structType = this->convertType(type);
|
||||
if (!structType.isa<LLVM::LLVMStructType>()) {
|
||||
return *resultVals.begin();
|
||||
}
|
||||
|
||||
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
|
||||
// llvm::outs() << structType << "\n";
|
||||
for (const auto &v : llvm::enumerate(resultVals)) {
|
||||
assert(v.value() && "can not insert null values");
|
||||
llvmStruct = insert_val(structType, llvmStruct, v.value(), v.index());
|
||||
}
|
||||
return llvmStruct;
|
||||
}
|
||||
|
||||
SmallVector<Value> TritonGPUToLLVMTypeConverter::unpackLLElements(
|
||||
Location loc, Value llvmStruct, ConversionPatternRewriter &rewriter,
|
||||
Type type) {
|
||||
assert(bool(llvmStruct) && "can not unpack null values");
|
||||
if (llvmStruct.getType().isIntOrIndexOrFloat() ||
|
||||
llvmStruct.getType().isa<triton::PointerType>() ||
|
||||
llvmStruct.getType().isa<LLVM::LLVMPointerType>())
|
||||
return {llvmStruct};
|
||||
ArrayRef<Type> types =
|
||||
llvmStruct.getType().cast<LLVM::LLVMStructType>().getBody();
|
||||
SmallVector<Value> results(types.size());
|
||||
for (unsigned i = 0; i < types.size(); ++i) {
|
||||
Type type = types[i];
|
||||
results[i] = extract_val(type, llvmStruct, i);
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
llvm::Optional<Type>
|
||||
TritonGPUToLLVMTypeConverter::convertTritonTensorType(RankedTensorType type) {
|
||||
auto ctx = type.getContext();
|
||||
Attribute layout = type.getEncoding();
|
||||
SmallVector<int64_t> shape(type.getShape().begin(), type.getShape().end());
|
||||
|
||||
if (layout &&
|
||||
(layout.isa<BlockedEncodingAttr>() || layout.isa<SliceEncodingAttr>() ||
|
||||
layout.isa<MmaEncodingAttr>())) {
|
||||
unsigned numElementsPerThread = getElemsPerThread(type);
|
||||
SmallVector<Type, 4> types(numElementsPerThread,
|
||||
convertType(type.getElementType()));
|
||||
return LLVM::LLVMStructType::getLiteral(ctx, types);
|
||||
} else if (auto shared_layout =
|
||||
layout.dyn_cast_or_null<SharedEncodingAttr>()) {
|
||||
SmallVector<Type, 4> types;
|
||||
// base ptr
|
||||
auto ptrType =
|
||||
LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3);
|
||||
types.push_back(ptrType);
|
||||
// shape dims
|
||||
auto rank = type.getRank();
|
||||
// offsets + strides
|
||||
for (auto i = 0; i < rank * 2; i++) {
|
||||
types.push_back(IntegerType::get(ctx, 32));
|
||||
}
|
||||
return LLVM::LLVMStructType::getLiteral(ctx, types);
|
||||
} else if (auto dotOpLayout =
|
||||
layout.dyn_cast_or_null<DotOperandEncodingAttr>()) {
|
||||
if (dotOpLayout.getParent()
|
||||
.isa<BlockedEncodingAttr>()) { // for parent is blocked layout
|
||||
int numElemsPerThread =
|
||||
DotOpFMAConversionHelper::getNumElemsPerThread(shape, dotOpLayout);
|
||||
|
||||
return LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(numElemsPerThread, type::f32Ty(ctx)));
|
||||
} else { // for parent is MMA layout
|
||||
auto mmaLayout = dotOpLayout.getParent().cast<MmaEncodingAttr>();
|
||||
auto wpt = mmaLayout.getWarpsPerCTA();
|
||||
Type elemTy = convertType(type.getElementType());
|
||||
if (mmaLayout.isAmpere()) {
|
||||
const llvm::DenseMap<int, Type> targetTyMap = {
|
||||
{32, vec_ty(elemTy, 1)},
|
||||
{16, vec_ty(elemTy, 2)},
|
||||
{8, vec_ty(elemTy, 4)},
|
||||
};
|
||||
Type targetTy;
|
||||
if (targetTyMap.count(elemTy.getIntOrFloatBitWidth())) {
|
||||
targetTy = targetTyMap.lookup(elemTy.getIntOrFloatBitWidth());
|
||||
// <2xi16>/<4xi8> => i32
|
||||
// We are doing this because NVPTX inserts extra integer instrs to
|
||||
// pack & unpack vectors of sub-word integers
|
||||
// Note: this needs to be synced with
|
||||
// DotOpMmaV2ConversionHelper::loadX4
|
||||
if (elemTy.isa<IntegerType>() &&
|
||||
(elemTy.getIntOrFloatBitWidth() == 8 ||
|
||||
elemTy.getIntOrFloatBitWidth() == 16))
|
||||
targetTy = IntegerType::get(ctx, 32);
|
||||
} else {
|
||||
assert(false && "Unsupported element type");
|
||||
}
|
||||
auto elems = getElemsPerThread(type);
|
||||
return struct_ty(SmallVector<Type>(elems, targetTy));
|
||||
}
|
||||
|
||||
if (mmaLayout.isVolta()) {
|
||||
int elems = getElemsPerThread(type);
|
||||
Type x2Ty = vec_ty(elemTy, 2);
|
||||
return struct_ty(SmallVector<Type>(elems, x2Ty));
|
||||
}
|
||||
}
|
||||
|
||||
llvm::errs() << "Unexpected dot operand layout detected in "
|
||||
"TritonToLLVMTypeConverter";
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
@@ -1,154 +1,30 @@
|
||||
#ifndef TRITON_CONVERSION_TRITONGPU_TO_LLVM_TYPECONVERTER_H
|
||||
#define TRITON_CONVERSION_TRITONGPU_TO_LLVM_TYPECONVERTER_H
|
||||
|
||||
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "triton/Conversion/MLIRTypes.h"
|
||||
|
||||
#include "DotOpHelpers.h"
|
||||
#include "Utility.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
using ::mlir::LLVM::DotOpFMAConversionHelper;
|
||||
using ::mlir::LLVM::DotOpMmaV1ConversionHelper;
|
||||
using ::mlir::LLVM::MMA16816ConversionHelper;
|
||||
using ::mlir::triton::gpu::BlockedEncodingAttr;
|
||||
using ::mlir::triton::gpu::DotOperandEncodingAttr;
|
||||
using ::mlir::triton::gpu::getElemsPerThread;
|
||||
using ::mlir::triton::gpu::MmaEncodingAttr;
|
||||
using ::mlir::triton::gpu::SharedEncodingAttr;
|
||||
using ::mlir::triton::gpu::SliceEncodingAttr;
|
||||
|
||||
class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter {
|
||||
public:
|
||||
using TypeConverter::convertType;
|
||||
|
||||
TritonGPUToLLVMTypeConverter(MLIRContext *ctx, LowerToLLVMOptions &option,
|
||||
const DataLayoutAnalysis *analysis = nullptr)
|
||||
: LLVMTypeConverter(ctx, option, analysis) {
|
||||
addConversion([&](triton::PointerType type) -> llvm::Optional<Type> {
|
||||
return convertTritonPointerType(type);
|
||||
});
|
||||
addConversion([&](RankedTensorType type) -> llvm::Optional<Type> {
|
||||
return convertTritonTensorType(type);
|
||||
});
|
||||
// Internally store float8 as int8
|
||||
addConversion([&](triton::Float8Type type) -> llvm::Optional<Type> {
|
||||
return IntegerType::get(type.getContext(), 8);
|
||||
});
|
||||
// Internally store bfloat16 as int16
|
||||
addConversion([&](BFloat16Type type) -> llvm::Optional<Type> {
|
||||
return IntegerType::get(type.getContext(), 16);
|
||||
});
|
||||
}
|
||||
const DataLayoutAnalysis *analysis = nullptr);
|
||||
|
||||
Type convertTritonPointerType(triton::PointerType type) {
|
||||
// Recursively translate pointee type
|
||||
return LLVM::LLVMPointerType::get(convertType(type.getPointeeType()),
|
||||
type.getAddressSpace());
|
||||
}
|
||||
Type convertTritonPointerType(triton::PointerType type);
|
||||
|
||||
llvm::Optional<Type> convertTritonTensorType(RankedTensorType type) {
|
||||
auto ctx = type.getContext();
|
||||
Attribute layout = type.getEncoding();
|
||||
SmallVector<int64_t> shape(type.getShape().begin(), type.getShape().end());
|
||||
Value packLLElements(Location loc, ValueRange resultVals,
|
||||
ConversionPatternRewriter &rewriter, Type type);
|
||||
|
||||
if (layout &&
|
||||
(layout.isa<BlockedEncodingAttr>() || layout.isa<SliceEncodingAttr>() ||
|
||||
layout.isa<MmaEncodingAttr>())) {
|
||||
unsigned numElementsPerThread = getElemsPerThread(type);
|
||||
SmallVector<Type, 4> types(numElementsPerThread,
|
||||
convertType(type.getElementType()));
|
||||
return LLVM::LLVMStructType::getLiteral(ctx, types);
|
||||
} else if (auto shared_layout =
|
||||
layout.dyn_cast_or_null<SharedEncodingAttr>()) {
|
||||
SmallVector<Type, 4> types;
|
||||
// base ptr
|
||||
auto ptrType =
|
||||
LLVM::LLVMPointerType::get(convertType(type.getElementType()), 3);
|
||||
types.push_back(ptrType);
|
||||
// shape dims
|
||||
auto rank = type.getRank();
|
||||
// offsets + strides
|
||||
for (auto i = 0; i < rank * 2; i++) {
|
||||
types.push_back(IntegerType::get(ctx, 32));
|
||||
}
|
||||
return LLVM::LLVMStructType::getLiteral(ctx, types);
|
||||
} else if (auto dotOpLayout =
|
||||
layout.dyn_cast_or_null<DotOperandEncodingAttr>()) {
|
||||
if (dotOpLayout.getParent()
|
||||
.isa<BlockedEncodingAttr>()) { // for parent is blocked layout
|
||||
int numElemsPerThread =
|
||||
DotOpFMAConversionHelper::getNumElemsPerThread(shape, dotOpLayout);
|
||||
SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Type type);
|
||||
|
||||
return LLVM::LLVMStructType::getLiteral(
|
||||
ctx, SmallVector<Type>(numElemsPerThread, type::f32Ty(ctx)));
|
||||
} else { // for parent is MMA layout
|
||||
auto mmaLayout = dotOpLayout.getParent().cast<MmaEncodingAttr>();
|
||||
auto wpt = mmaLayout.getWarpsPerCTA();
|
||||
Type elemTy = convertType(type.getElementType());
|
||||
if (mmaLayout.isAmpere()) {
|
||||
const llvm::DenseMap<int, Type> targetTyMap = {
|
||||
{32, vec_ty(elemTy, 1)},
|
||||
{16, vec_ty(elemTy, 2)},
|
||||
{8, vec_ty(elemTy, 4)},
|
||||
};
|
||||
Type targetTy;
|
||||
if (targetTyMap.count(elemTy.getIntOrFloatBitWidth())) {
|
||||
targetTy = targetTyMap.lookup(elemTy.getIntOrFloatBitWidth());
|
||||
// <2xi16>/<4xi8> => i32
|
||||
// We are doing this because NVPTX inserts extra integer instrs to
|
||||
// pack & unpack vectors of sub-word integers
|
||||
// Note: this needs to be synced with
|
||||
// DotOpMmaV2ConversionHelper::loadX4
|
||||
if (elemTy.isa<IntegerType>() &&
|
||||
(elemTy.getIntOrFloatBitWidth() == 8 ||
|
||||
elemTy.getIntOrFloatBitWidth() == 16))
|
||||
targetTy = IntegerType::get(ctx, 32);
|
||||
} else {
|
||||
assert(false && "Unsupported element type");
|
||||
}
|
||||
if (dotOpLayout.getOpIdx() == 0) { // $a
|
||||
auto elems =
|
||||
MMA16816ConversionHelper::getANumElemsPerThread(type, wpt[0]);
|
||||
return struct_ty(SmallVector<Type>(elems, targetTy));
|
||||
}
|
||||
if (dotOpLayout.getOpIdx() == 1) { // $b
|
||||
auto elems =
|
||||
MMA16816ConversionHelper::getBNumElemsPerThread(type, wpt[1]);
|
||||
return struct_ty(SmallVector<Type>(elems, targetTy));
|
||||
}
|
||||
}
|
||||
|
||||
if (mmaLayout.isVolta()) {
|
||||
auto [isARow, isBRow, isAVec4, isBVec4, mmaId] =
|
||||
mmaLayout.decodeVoltaLayoutStates();
|
||||
DotOpMmaV1ConversionHelper helper(mmaLayout);
|
||||
if (dotOpLayout.getOpIdx() == 0) { // $a
|
||||
DotOpMmaV1ConversionHelper::AParam param(isARow, isAVec4);
|
||||
int elems =
|
||||
helper.numElemsPerThreadA(shape, isARow, isAVec4, param.vec);
|
||||
Type x2Ty = vec_ty(elemTy, 2);
|
||||
return struct_ty(SmallVector<Type>(elems, x2Ty));
|
||||
}
|
||||
if (dotOpLayout.getOpIdx() == 1) { // $b
|
||||
DotOpMmaV1ConversionHelper::BParam param(isBRow, isBVec4);
|
||||
int elems =
|
||||
helper.numElemsPerThreadB(shape, isBRow, isBVec4, param.vec);
|
||||
Type x2Ty = vec_ty(elemTy, 2);
|
||||
return struct_ty(SmallVector<Type>(elems, x2Ty));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
llvm::errs() << "Unexpected dot operand layout detected in "
|
||||
"TritonToLLVMTypeConverter";
|
||||
return std::nullopt;
|
||||
}
|
||||
|
||||
return std::nullopt;
|
||||
}
|
||||
llvm::Optional<Type> convertTritonTensorType(RankedTensorType type);
|
||||
};
|
||||
|
||||
#endif
|
||||
|
||||
@@ -1,41 +1,11 @@
|
||||
#include "Utility.h"
|
||||
#include "TypeConverter.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
namespace LLVM {
|
||||
using namespace mlir::triton;
|
||||
|
||||
Value getStructFromElements(Location loc, ValueRange resultVals,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Type structType) {
|
||||
if (!structType.isa<LLVM::LLVMStructType>()) {
|
||||
return *resultVals.begin();
|
||||
}
|
||||
|
||||
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structType);
|
||||
for (const auto &v : llvm::enumerate(resultVals)) {
|
||||
assert(v.value() && "can not insert null values");
|
||||
llvmStruct = insert_val(structType, llvmStruct, v.value(), v.index());
|
||||
}
|
||||
return llvmStruct;
|
||||
}
|
||||
|
||||
SmallVector<Value> getElementsFromStruct(Location loc, Value llvmStruct,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
if (llvmStruct.getType().isIntOrIndexOrFloat() ||
|
||||
llvmStruct.getType().isa<triton::PointerType>() ||
|
||||
llvmStruct.getType().isa<LLVM::LLVMPointerType>())
|
||||
return {llvmStruct};
|
||||
ArrayRef<Type> types =
|
||||
llvmStruct.getType().cast<LLVM::LLVMStructType>().getBody();
|
||||
SmallVector<Value> results(types.size());
|
||||
for (unsigned i = 0; i < types.size(); ++i) {
|
||||
Type type = types[i];
|
||||
results[i] = extract_val(type, llvmStruct, i);
|
||||
}
|
||||
return results;
|
||||
}
|
||||
|
||||
Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v) {
|
||||
auto i32ty = rewriter.getIntegerType(32);
|
||||
return rewriter.create<LLVM::ConstantOp>(loc, i32ty,
|
||||
@@ -73,7 +43,14 @@ Value createLLVMIntegerConstant(OpBuilder &builder, Location loc, short width,
|
||||
SharedMemoryObject
|
||||
getSharedMemoryObjectFromStruct(Location loc, Value llvmStruct,
|
||||
ConversionPatternRewriter &rewriter) {
|
||||
auto elems = getElementsFromStruct(loc, llvmStruct, rewriter);
|
||||
ArrayRef<Type> types =
|
||||
llvmStruct.getType().cast<LLVM::LLVMStructType>().getBody();
|
||||
SmallVector<Value> elems(types.size());
|
||||
for (unsigned i = 0; i < types.size(); ++i) {
|
||||
Type type = types[i];
|
||||
elems[i] = extract_val(type, llvmStruct, i);
|
||||
}
|
||||
|
||||
auto rank = (elems.size() - 1) / 2;
|
||||
return {/*base=*/elems[0],
|
||||
/*strides=*/{elems.begin() + 1, elems.begin() + 1 + rank},
|
||||
@@ -175,5 +152,39 @@ Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val,
|
||||
return builder.launch(rewriter, loc, val.getType(), false);
|
||||
}
|
||||
|
||||
Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter,
|
||||
StringRef key, StringRef content) {
|
||||
auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
|
||||
auto ctx = moduleOp.getContext();
|
||||
unsigned stringNumber = 0;
|
||||
SmallString<16> stringConstName;
|
||||
do {
|
||||
stringConstName.clear();
|
||||
(key + Twine(stringNumber++)).toStringRef(stringConstName);
|
||||
} while (moduleOp.lookupSymbol(stringConstName));
|
||||
|
||||
llvm::SmallString<64> contentStr(content);
|
||||
size_t contentSize = contentStr.size_in_bytes();
|
||||
auto globalType = LLVM::LLVMArrayType::get(i8_ty, contentSize);
|
||||
|
||||
LLVM::GlobalOp global;
|
||||
{
|
||||
ConversionPatternRewriter::InsertionGuard guard(rewriter);
|
||||
rewriter.setInsertionPointToStart(moduleOp.getBody());
|
||||
global = rewriter.create<LLVM::GlobalOp>(
|
||||
UnknownLoc::get(ctx), globalType,
|
||||
/*isConstant=*/true, LLVM::Linkage::Internal, stringConstName,
|
||||
rewriter.getStringAttr(contentStr));
|
||||
}
|
||||
|
||||
Value zero = i32_val(0);
|
||||
Value globalPtr =
|
||||
rewriter.create<LLVM::AddressOfOp>(UnknownLoc::get(ctx), global);
|
||||
Value stringStart =
|
||||
rewriter.create<LLVM::GEPOp>(UnknownLoc::get(ctx), ptr_ty(i8_ty),
|
||||
globalPtr, SmallVector<Value>({zero, zero}));
|
||||
return stringStart;
|
||||
}
|
||||
|
||||
} // namespace LLVM
|
||||
} // namespace mlir
|
||||
|
||||
@@ -14,6 +14,8 @@
|
||||
#define ptrtoint(...) rewriter.create<LLVM::PtrToIntOp>(loc, __VA_ARGS__)
|
||||
#define zext(...) rewriter.create<LLVM::ZExtOp>(loc, __VA_ARGS__)
|
||||
#define trunc(...) rewriter.create<LLVM::TruncOp>(loc, __VA_ARGS__)
|
||||
#define sext(...) rewriter.create<LLVM::SExtOp>(loc, __VA_ARGS__)
|
||||
#define fpext(...) rewriter.create<LLVM::FPExtOp>(loc, __VA_ARGS__)
|
||||
#define udiv(...) rewriter.create<LLVM::UDivOp>(loc, __VA_ARGS__)
|
||||
#define urem(...) rewriter.create<LLVM::URemOp>(loc, __VA_ARGS__)
|
||||
#define add(...) rewriter.create<LLVM::AddOp>(loc, __VA_ARGS__)
|
||||
@@ -30,9 +32,11 @@
|
||||
#define smin(...) rewriter.create<LLVM::SMinOp>(loc, __VA_ARGS__)
|
||||
#define umin(...) rewriter.create<LLVM::UMinOp>(loc, __VA_ARGS__)
|
||||
#define fmin(...) rewriter.create<LLVM::MinNumOp>(loc, __VA_ARGS__)
|
||||
#define shl(...) rewriter.create<LLVM::ShlOp>(loc, __VA_ARGS__)
|
||||
#define and_(...) rewriter.create<LLVM::AndOp>(loc, __VA_ARGS__)
|
||||
#define or_(...) rewriter.create<LLVM::OrOp>(loc, __VA_ARGS__)
|
||||
#define xor_(...) rewriter.create<LLVM::XOrOp>(loc, __VA_ARGS__)
|
||||
#define or_(...) rewriter.create<LLVM::OrOp>(loc, __VA_ARGS__)
|
||||
#define bitcast(val__, type__) \
|
||||
rewriter.create<LLVM::BitcastOp>(loc, type__, val__)
|
||||
#define addrspacecast(val__, type__) \
|
||||
@@ -53,6 +57,9 @@
|
||||
#define fcmp_olt(lhs, rhs) \
|
||||
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
|
||||
LLVM::FCmpPredicate::olt, lhs, rhs)
|
||||
#define fcmp_eq(lhs, rhs) \
|
||||
rewriter.create<LLVM::FCmpOp>(loc, rewriter.getI1Type(), \
|
||||
LLVM::FCmpPredicate::oeq, lhs, rhs)
|
||||
#define icmp_eq(...) \
|
||||
rewriter.create<LLVM::ICmpOp>(loc, LLVM::ICmpPredicate::eq, __VA_ARGS__)
|
||||
#define icmp_ne(...) \
|
||||
@@ -77,30 +84,31 @@
|
||||
#define address_of(...) rewriter.create<LLVM::AddressOfOp>(loc, __VA_ARGS__)
|
||||
#define barrier() rewriter.create<mlir::gpu::BarrierOp>(loc)
|
||||
#define undef(...) rewriter.create<LLVM::UndefOp>(loc, __VA_ARGS__)
|
||||
#define null(...) rewriter.create<LLVM::NullOp>(loc, __VA_ARGS__)
|
||||
#define call(...) rewriter.create<LLVM::CallOp>(loc, __VA_ARGS__)
|
||||
|
||||
// Types
|
||||
#define i64_ty rewriter.getIntegerType(64)
|
||||
#define i32_ty rewriter.getIntegerType(32)
|
||||
#define i16_ty rewriter.getIntegerType(16)
|
||||
#define ui32_ty rewriter.getIntegerType(32, false)
|
||||
#define f16_ty rewriter.getF16Type()
|
||||
#define bf16_ty rewriter.getBF16Type()
|
||||
#define i8_ty rewriter.getIntegerType(8)
|
||||
#define i1_ty rewriter.getI1Type()
|
||||
#define f32_ty rewriter.getF32Type()
|
||||
#define f64_ty rewriter.getF64Type()
|
||||
#define vec_ty(type, num) VectorType::get(num, type)
|
||||
#define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__)
|
||||
#define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__)
|
||||
#define void_ty(ctx) LLVM::LLVMVoidType::get(ctx)
|
||||
#define struct_ty(...) LLVM::LLVMStructType::getLiteral(ctx, __VA_ARGS__)
|
||||
#define array_ty(elemTy, count) LLVM::LLVMArrayType::get(elemTy, count)
|
||||
|
||||
// Constants
|
||||
#define f32_val(...) LLVM::createConstantF32(loc, rewriter, __VA_ARGS__)
|
||||
#define f64_val(...) LLVM::createConstantF64(loc, rewriter, __VA_ARGS__)
|
||||
#define i32_val(...) LLVM::createConstantI32(loc, rewriter, __VA_ARGS__)
|
||||
#define int_val(width, val) \
|
||||
LLVM::createLLVMIntegerConstant(rewriter, loc, width, val)
|
||||
#define idx_val(...) \
|
||||
LLVM::createIndexConstant(rewriter, loc, this->getTypeConverter(), \
|
||||
__VA_ARGS__)
|
||||
#define tid_val() getThreadId(rewriter, loc)
|
||||
|
||||
// Attributes
|
||||
@@ -173,14 +181,6 @@ T getLinearIndex(llvm::ArrayRef<T> multiDimIndex, llvm::ArrayRef<T> shape,
|
||||
namespace LLVM {
|
||||
using namespace mlir::triton;
|
||||
|
||||
Value getStructFromElements(Location loc, ValueRange resultVals,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Type structType);
|
||||
|
||||
SmallVector<Value> getElementsFromStruct(Location loc, Value llvmStruct,
|
||||
ConversionPatternRewriter &rewriter);
|
||||
|
||||
/// Create a 32-bit integer constant.
|
||||
Value createConstantI32(Location loc, PatternRewriter &rewriter, int32_t v);
|
||||
|
||||
/// Create a 32-bit float constant.
|
||||
@@ -271,6 +271,9 @@ Value storeShared(ConversionPatternRewriter &rewriter, Location loc, Value ptr,
|
||||
Value shflSync(Location loc, ConversionPatternRewriter &rewriter, Value val,
|
||||
int i);
|
||||
|
||||
Value addStringToModule(Location loc, ConversionPatternRewriter &rewriter,
|
||||
StringRef key, StringRef content);
|
||||
|
||||
} // namespace LLVM
|
||||
} // namespace mlir
|
||||
|
||||
|
||||
@@ -7,9 +7,7 @@ using namespace mlir::triton;
|
||||
using ::mlir::LLVM::DotOpFMAConversionHelper;
|
||||
using ::mlir::LLVM::DotOpMmaV1ConversionHelper;
|
||||
using ::mlir::LLVM::DotOpMmaV2ConversionHelper;
|
||||
using ::mlir::LLVM::getElementsFromStruct;
|
||||
using ::mlir::LLVM::getSharedMemoryObjectFromStruct;
|
||||
using ::mlir::LLVM::getStructFromElements;
|
||||
using ::mlir::LLVM::MMA16816ConversionHelper;
|
||||
using ::mlir::triton::gpu::getElemsPerThread;
|
||||
|
||||
@@ -25,7 +23,7 @@ struct SplatOpConversion
|
||||
// @resType: the return type of the Splat-like op.
|
||||
// @constVal: a LLVM::ConstantOp or other scalar value.
|
||||
static Value convertSplatLikeOp(Type elemType, Type resType, Value constVal,
|
||||
TypeConverter *typeConverter,
|
||||
TritonGPUToLLVMTypeConverter *typeConverter,
|
||||
ConversionPatternRewriter &rewriter,
|
||||
Location loc) {
|
||||
auto tensorTy = resType.cast<RankedTensorType>();
|
||||
@@ -33,11 +31,7 @@ struct SplatOpConversion
|
||||
auto llSrc = bitcast(constVal, srcType);
|
||||
size_t elemsPerThread = getElemsPerThread(tensorTy);
|
||||
llvm::SmallVector<Value> elems(elemsPerThread, llSrc);
|
||||
llvm::SmallVector<Type> elemTypes(elems.size(), srcType);
|
||||
auto structTy =
|
||||
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), elemTypes);
|
||||
|
||||
return getStructFromElements(loc, elems, rewriter, structTy);
|
||||
return typeConverter->packLLElements(loc, elems, rewriter, resType);
|
||||
}
|
||||
|
||||
LogicalResult matchAndRewrite(triton::SplatOp op, OpAdaptor adaptor,
|
||||
@@ -95,7 +89,7 @@ struct ArithConstantSplatOpConversion
|
||||
struct CatOpConversion : public ConvertTritonGPUOpToLLVMPattern<CatOp> {
|
||||
using OpAdaptor = typename CatOp::Adaptor;
|
||||
|
||||
explicit CatOpConversion(LLVMTypeConverter &typeConverter,
|
||||
explicit CatOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertTritonGPUOpToLLVMPattern<CatOp>(typeConverter, benefit) {}
|
||||
|
||||
@@ -109,8 +103,10 @@ struct CatOpConversion : public ConvertTritonGPUOpToLLVMPattern<CatOp> {
|
||||
this->getTypeConverter()->convertType(resultTy.getElementType());
|
||||
SmallVector<Type> types(elems, elemTy);
|
||||
// unpack input values
|
||||
auto lhsVals = getElementsFromStruct(loc, adaptor.getLhs(), rewriter);
|
||||
auto rhsVals = getElementsFromStruct(loc, adaptor.getRhs(), rewriter);
|
||||
auto lhsVals = getTypeConverter()->unpackLLElements(
|
||||
loc, adaptor.getLhs(), rewriter, op.getOperand(0).getType());
|
||||
auto rhsVals = getTypeConverter()->unpackLLElements(
|
||||
loc, adaptor.getRhs(), rewriter, op.getOperand(1).getType());
|
||||
// concatenate (and potentially reorder) values
|
||||
SmallVector<Value> retVals;
|
||||
for (Value v : lhsVals)
|
||||
@@ -118,8 +114,8 @@ struct CatOpConversion : public ConvertTritonGPUOpToLLVMPattern<CatOp> {
|
||||
for (Value v : rhsVals)
|
||||
retVals.push_back(v);
|
||||
// pack and replace
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
|
||||
Value ret = getStructFromElements(loc, retVals, rewriter, structTy);
|
||||
Value ret =
|
||||
getTypeConverter()->packLLElements(loc, retVals, rewriter, resultTy);
|
||||
rewriter.replaceOp(op, ret);
|
||||
return success();
|
||||
}
|
||||
@@ -128,24 +124,19 @@ struct CatOpConversion : public ConvertTritonGPUOpToLLVMPattern<CatOp> {
|
||||
template <typename SourceOp>
|
||||
struct ViewLikeOpConversion : public ConvertTritonGPUOpToLLVMPattern<SourceOp> {
|
||||
using OpAdaptor = typename SourceOp::Adaptor;
|
||||
explicit ViewLikeOpConversion(LLVMTypeConverter &typeConverter,
|
||||
explicit ViewLikeOpConversion(TritonGPUToLLVMTypeConverter &typeConverter,
|
||||
PatternBenefit benefit = 1)
|
||||
: ConvertTritonGPUOpToLLVMPattern<SourceOp>(typeConverter, benefit) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(SourceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// We cannot directly run `rewriter.replaceOp(op, adaptor.getSrc())`
|
||||
// due to MLIR's restrictions
|
||||
Location loc = op->getLoc();
|
||||
auto resultTy = op.getType().template cast<RankedTensorType>();
|
||||
unsigned elems = getElemsPerThread(resultTy);
|
||||
Type elemTy =
|
||||
this->getTypeConverter()->convertType(resultTy.getElementType());
|
||||
SmallVector<Type> types(elems, elemTy);
|
||||
Type structTy = LLVM::LLVMStructType::getLiteral(this->getContext(), types);
|
||||
auto vals = getElementsFromStruct(loc, adaptor.getSrc(), rewriter);
|
||||
Value view = getStructFromElements(loc, vals, rewriter, structTy);
|
||||
auto vals = this->getTypeConverter()->unpackLLElements(
|
||||
loc, adaptor.getSrc(), rewriter, op.getOperand().getType());
|
||||
Value view =
|
||||
this->getTypeConverter()->packLLElements(loc, vals, rewriter, resultTy);
|
||||
rewriter.replaceOp(op, view);
|
||||
return success();
|
||||
}
|
||||
@@ -174,7 +165,7 @@ struct TransOpConversion
|
||||
}
|
||||
};
|
||||
|
||||
void populateViewOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
void populateViewOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns, int numWarps,
|
||||
AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem,
|
||||
|
||||
@@ -6,7 +6,7 @@
|
||||
using namespace mlir;
|
||||
using namespace mlir::triton;
|
||||
|
||||
void populateViewOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
|
||||
void populateViewOpToLLVMPatterns(TritonGPUToLLVMTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns, int numWarps,
|
||||
AxisInfoAnalysis &axisInfoAnalysis,
|
||||
const Allocation *allocation, Value smem,
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
#include "mlir/Dialect/Arith/IR/Arith.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
|
||||
#include "mlir/Dialect/Index/IR/IndexDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
|
||||
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
|
||||
@@ -22,6 +23,13 @@ using namespace mlir::triton;
|
||||
|
||||
namespace {
|
||||
|
||||
// pass named attrs (e.g., tt.contiguity) from Triton to Triton
|
||||
static void addNamedAttrs(Operation *op, DictionaryAttr dictAttrs) {
|
||||
for (const NamedAttribute attr : dictAttrs.getValue())
|
||||
if (!op->hasAttr(attr.getName()))
|
||||
op->setAttr(attr.getName(), attr.getValue());
|
||||
}
|
||||
|
||||
template <class Op> class GenericOpPattern : public OpConversionPattern<Op> {
|
||||
public:
|
||||
using OpConversionPattern<Op>::OpConversionPattern;
|
||||
@@ -30,7 +38,9 @@ public:
|
||||
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands());
|
||||
addNamedAttrs(
|
||||
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -44,8 +54,10 @@ public:
|
||||
matchAndRewrite(SrcOp op, typename SrcOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<DstOp>(op, retType, adaptor.getPredicate(),
|
||||
adaptor.getLhs(), adaptor.getRhs());
|
||||
addNamedAttrs(
|
||||
rewriter.replaceOpWithNewOp<DstOp>(op, retType, adaptor.getPredicate(),
|
||||
adaptor.getLhs(), adaptor.getRhs()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -66,7 +78,9 @@ public:
|
||||
else
|
||||
// This is a hack. We just want to add encoding
|
||||
value = value.reshape(retType);
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, retType, value);
|
||||
addNamedAttrs(
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, retType, value),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -138,9 +152,10 @@ public:
|
||||
matchAndRewrite(arith::SelectOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::SelectOp>(
|
||||
op, retType, adaptor.getCondition(), adaptor.getTrueValue(),
|
||||
adaptor.getFalseValue());
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::gpu::SelectOp>(
|
||||
op, retType, adaptor.getCondition(),
|
||||
adaptor.getTrueValue(), adaptor.getFalseValue()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -177,8 +192,9 @@ struct TritonMakeRangePattern
|
||||
matchAndRewrite(triton::MakeRangeOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<triton::MakeRangeOp>(
|
||||
op, retType, adaptor.getStart(), adaptor.getEnd());
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::MakeRangeOp>(
|
||||
op, retType, adaptor.getStart(), adaptor.getEnd()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -221,8 +237,9 @@ struct TritonExpandDimsPattern
|
||||
// construct new op
|
||||
auto newSrc = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op.getLoc(), newArgType, adaptor.getSrc());
|
||||
rewriter.replaceOpWithNewOp<triton::ExpandDimsOp>(op, newSrc,
|
||||
adaptor.getAxis());
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::ExpandDimsOp>(
|
||||
op, newSrc, adaptor.getAxis()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -280,8 +297,9 @@ struct TritonDotPattern : public OpConversionPattern<triton::DotOp> {
|
||||
}
|
||||
c = rewriter.create<triton::gpu::ConvertLayoutOp>(c.getLoc(), retType, c);
|
||||
|
||||
rewriter.replaceOpWithNewOp<triton::DotOp>(op, retType, a, b, c,
|
||||
adaptor.getAllowTF32());
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::DotOp>(
|
||||
op, retType, a, b, c, adaptor.getAllowTF32()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -296,8 +314,9 @@ struct TritonCatPattern : public OpConversionPattern<triton::CatOp> {
|
||||
// For now, this behaves like generic, but this will evolve when
|
||||
// we add support for `can_reorder=False`
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<triton::CatOp>(op, retType,
|
||||
adaptor.getOperands());
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::CatOp>(
|
||||
op, retType, adaptor.getOperands()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -329,7 +348,8 @@ struct TritonTransPattern : public OpConversionPattern<triton::TransOp> {
|
||||
src = rewriter.create<triton::gpu::ConvertLayoutOp>(src.getLoc(), srcType,
|
||||
src);
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<triton::TransOp>(op, src);
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::TransOp>(op, src),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -340,10 +360,12 @@ struct TritonLoadPattern : public OpConversionPattern<triton::LoadOp> {
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::LoadOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
||||
op, typeConverter->convertType(op.getType()), adaptor.getPtr(),
|
||||
adaptor.getMask(), adaptor.getOther(), adaptor.getCache(),
|
||||
adaptor.getEvict(), adaptor.getIsVolatile());
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::LoadOp>(
|
||||
op, typeConverter->convertType(op.getType()),
|
||||
adaptor.getPtr(), adaptor.getMask(), adaptor.getOther(),
|
||||
adaptor.getCache(), adaptor.getEvict(),
|
||||
adaptor.getIsVolatile()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -354,9 +376,11 @@ struct TritonStorePattern : public OpConversionPattern<triton::StoreOp> {
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::StoreOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<triton::StoreOp>(
|
||||
op, adaptor.getPtr(), adaptor.getValue(), adaptor.getMask(),
|
||||
adaptor.getCache(), adaptor.getEvict());
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::StoreOp>(
|
||||
op, adaptor.getPtr(), adaptor.getValue(),
|
||||
adaptor.getMask(), adaptor.getCache(),
|
||||
adaptor.getEvict()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -368,9 +392,10 @@ struct TritonAtomicCASPattern
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::AtomicCASOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<triton::AtomicCASOp>(
|
||||
op, typeConverter->convertType(op.getType()), adaptor.getPtr(),
|
||||
adaptor.getCmp(), adaptor.getVal());
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::AtomicCASOp>(
|
||||
op, typeConverter->convertType(op.getType()),
|
||||
adaptor.getPtr(), adaptor.getCmp(), adaptor.getVal()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -382,9 +407,11 @@ struct TritonAtomicRMWPattern
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<triton::AtomicRMWOp>(
|
||||
op, typeConverter->convertType(op.getType()), adaptor.getAtomicRmwOp(),
|
||||
adaptor.getPtr(), adaptor.getVal(), adaptor.getMask());
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::AtomicRMWOp>(
|
||||
op, typeConverter->convertType(op.getType()),
|
||||
adaptor.getAtomicRmwOp(), adaptor.getPtr(),
|
||||
adaptor.getVal(), adaptor.getMask()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -396,9 +423,11 @@ struct TritonExtElemwisePattern
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::ExtElemwiseOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<triton::ExtElemwiseOp>(
|
||||
op, typeConverter->convertType(op.getType()), adaptor.getArgs(),
|
||||
adaptor.getLibname(), adaptor.getLibpath(), adaptor.getSymbol());
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::ExtElemwiseOp>(
|
||||
op, typeConverter->convertType(op.getType()),
|
||||
adaptor.getArgs(), adaptor.getLibname(),
|
||||
adaptor.getLibpath(), adaptor.getSymbol()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -411,7 +440,9 @@ struct TritonGenericPattern : public OpConversionPattern<Op> {
|
||||
matchAndRewrite(Op op, typename Op::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands());
|
||||
addNamedAttrs(
|
||||
rewriter.replaceOpWithNewOp<Op>(op, retType, adaptor.getOperands()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -432,8 +463,9 @@ struct TritonBroadcastPattern
|
||||
Type retType = RankedTensorType::get(opType.getShape(),
|
||||
opType.getElementType(), srcEncoding);
|
||||
// Type retType = this->getTypeConverter()->convertType(op.getType());
|
||||
rewriter.replaceOpWithNewOp<triton::BroadcastOp>(op, retType,
|
||||
adaptor.getOperands());
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::BroadcastOp>(
|
||||
op, retType, adaptor.getOperands()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -444,20 +476,38 @@ struct TritonReducePattern : public OpConversionPattern<triton::ReduceOp> {
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::ReduceOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<triton::ReduceOp>(
|
||||
op, adaptor.getRedOp(), adaptor.getOperand(), adaptor.getAxis());
|
||||
addNamedAttrs(
|
||||
rewriter.replaceOpWithNewOp<triton::ReduceOp>(
|
||||
op, adaptor.getRedOp(), adaptor.getOperand(), adaptor.getAxis()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonPrintfPattern : public OpConversionPattern<triton::PrintfOp> {
|
||||
using OpConversionPattern<PrintfOp>::OpConversionPattern;
|
||||
struct TritonPrintPattern : public OpConversionPattern<triton::PrintOp> {
|
||||
using OpConversionPattern<triton::PrintOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(PrintfOp op, typename PrintfOp::Adaptor adaptor,
|
||||
matchAndRewrite(triton::PrintOp op, typename triton::PrintOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<triton::PrintfOp>(op, op.getPrefixAttr(),
|
||||
adaptor.getOperands());
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::PrintOp>(
|
||||
op, op.getPrefixAttr(), adaptor.getOperands()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct TritonAssertPattern : public OpConversionPattern<triton::AssertOp> {
|
||||
using OpConversionPattern<triton::AssertOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(triton::AssertOp op,
|
||||
typename triton::AssertOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<triton::AssertOp>(
|
||||
op, adaptor.getCondition(), op.getMessageAttr(),
|
||||
op.getFileAttr(), op.getFuncAttr(), op.getLineAttr()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -476,8 +526,8 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
|
||||
TritonGenericPattern<triton::AddPtrOp>, TritonCatPattern,
|
||||
TritonReducePattern, TritonTransPattern, TritonExpandDimsPattern,
|
||||
TritonMakeRangePattern, TritonDotPattern, TritonLoadPattern,
|
||||
TritonStorePattern, TritonExtElemwisePattern, TritonPrintfPattern,
|
||||
TritonAtomicRMWPattern>(typeConverter, context);
|
||||
TritonStorePattern, TritonExtElemwisePattern, TritonPrintPattern,
|
||||
TritonAssertPattern, TritonAtomicRMWPattern>(typeConverter, context);
|
||||
}
|
||||
|
||||
//
|
||||
@@ -536,7 +586,9 @@ struct SCFYieldPattern : public OpConversionPattern<scf::YieldOp> {
|
||||
// rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
|
||||
// rewriter.create<scf::YieldOp>(op.getLoc(), adaptor.getOperands());
|
||||
// op.erase();
|
||||
rewriter.replaceOpWithNewOp<scf::YieldOp>(op, adaptor.getOperands());
|
||||
addNamedAttrs(
|
||||
rewriter.replaceOpWithNewOp<scf::YieldOp>(op, adaptor.getOperands()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -641,8 +693,9 @@ public:
|
||||
LogicalResult
|
||||
matchAndRewrite(cf::BranchOp op, cf::BranchOp::Adaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<cf::BranchOp>(op, op.getSuccessor(),
|
||||
adaptor.getOperands());
|
||||
addNamedAttrs(rewriter.replaceOpWithNewOp<cf::BranchOp>(
|
||||
op, op.getSuccessor(), adaptor.getOperands()),
|
||||
adaptor.getAttributes());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
@@ -659,6 +712,7 @@ public:
|
||||
op, adaptor.getCondition(), op.getTrueDest(),
|
||||
adaptor.getTrueDestOperands(), op.getFalseDest(),
|
||||
adaptor.getFalseDestOperands());
|
||||
addNamedAttrs(newOp, adaptor.getAttributes());
|
||||
|
||||
if (failed(rewriter.convertRegionTypes(newOp.getTrueDest()->getParent(),
|
||||
*converter)))
|
||||
|
||||
@@ -123,29 +123,6 @@ void StoreOp::print(OpAsmPrinter &printer) {
|
||||
namespace mlir {
|
||||
namespace triton {
|
||||
|
||||
//-- FpToFpOp --
|
||||
bool FpToFpOp::areCastCompatible(::mlir::TypeRange inputs,
|
||||
::mlir::TypeRange outputs) {
|
||||
if (inputs.size() != 1 || outputs.size() != 1)
|
||||
return false;
|
||||
auto srcEltType = inputs.front();
|
||||
auto dstEltType = outputs.front();
|
||||
auto srcTensorType = srcEltType.dyn_cast<mlir::RankedTensorType>();
|
||||
auto dstTensorType = dstEltType.dyn_cast<mlir::RankedTensorType>();
|
||||
if (srcTensorType && dstTensorType) {
|
||||
srcEltType = srcTensorType.getElementType();
|
||||
dstEltType = dstTensorType.getElementType();
|
||||
}
|
||||
// Check whether fp8 <=> fp16, bf16, f32, f64
|
||||
// Make `srcEltType` always the fp8 side
|
||||
if (dstEltType.dyn_cast<mlir::triton::Float8Type>())
|
||||
std::swap(srcEltType, dstEltType);
|
||||
if (!srcEltType.dyn_cast<mlir::triton::Float8Type>())
|
||||
return false;
|
||||
return dstEltType.isF16() || dstEltType.isBF16() || dstEltType.isF32() ||
|
||||
dstEltType.isF64();
|
||||
}
|
||||
|
||||
//-- StoreOp --
|
||||
void StoreOp::build(::mlir::OpBuilder &builder, ::mlir::OperationState &state,
|
||||
::mlir::Value ptr, ::mlir::Value value,
|
||||
|
||||
@@ -43,9 +43,7 @@ namespace mlir {
|
||||
unsigned getPointeeBitWidth(RankedTensorType tensorTy) {
|
||||
auto ptrTy = tensorTy.getElementType().cast<triton::PointerType>();
|
||||
auto pointeeType = ptrTy.getPointeeType();
|
||||
return pointeeType.isa<triton::Float8Type>()
|
||||
? 8
|
||||
: pointeeType.getIntOrFloatBitWidth();
|
||||
return pointeeType.getIntOrFloatBitWidth();
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
@@ -46,17 +46,18 @@ namespace gpu {
|
||||
// so that all distributed layouts implement
|
||||
// these utilities
|
||||
|
||||
unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape) {
|
||||
unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape,
|
||||
Type eltTy) {
|
||||
if (auto blockedLayout = layout.dyn_cast<BlockedEncodingAttr>()) {
|
||||
return blockedLayout.getElemsPerThread(shape);
|
||||
return blockedLayout.getElemsPerThread(shape, eltTy);
|
||||
} else if (auto sliceLayout = layout.dyn_cast<SliceEncodingAttr>()) {
|
||||
return sliceLayout.getElemsPerThread(shape);
|
||||
return sliceLayout.getElemsPerThread(shape, eltTy);
|
||||
} else if (auto mmaLayout = layout.dyn_cast<MmaEncodingAttr>()) {
|
||||
return mmaLayout.getElemsPerThread(shape);
|
||||
return mmaLayout.getElemsPerThread(shape, eltTy);
|
||||
} else if (auto sharedLayout = layout.dyn_cast<SharedEncodingAttr>()) {
|
||||
return sharedLayout.getElemsPerThread(shape);
|
||||
return sharedLayout.getElemsPerThread(shape, eltTy);
|
||||
} else if (auto dotLayout = layout.dyn_cast<DotOperandEncodingAttr>()) {
|
||||
return dotLayout.getElemsPerThread(shape);
|
||||
return dotLayout.getElemsPerThread(shape, eltTy);
|
||||
} else {
|
||||
assert(0 && "getElemsPerThread not implemented");
|
||||
return 0;
|
||||
@@ -64,11 +65,11 @@ unsigned getElemsPerThread(Attribute layout, ArrayRef<int64_t> shape) {
|
||||
}
|
||||
|
||||
unsigned getElemsPerThread(Type type) {
|
||||
if (type.isIntOrIndexOrFloat() || type.isa<triton::Float8Type>() ||
|
||||
type.isa<triton::PointerType>())
|
||||
if (type.isIntOrIndexOrFloat() || type.isa<triton::PointerType>())
|
||||
return 1;
|
||||
auto tensorType = type.cast<RankedTensorType>();
|
||||
return getElemsPerThread(tensorType.getEncoding(), tensorType.getShape());
|
||||
return getElemsPerThread(tensorType.getEncoding(), tensorType.getShape(),
|
||||
tensorType.getElementType());
|
||||
}
|
||||
|
||||
SmallVector<unsigned> getThreadsPerWarp(const Attribute &layout) {
|
||||
@@ -330,7 +331,8 @@ SliceEncodingAttr BlockedEncodingAttr::squeeze(int axis) {
|
||||
return SliceEncodingAttr::get(getContext(), axis, *this);
|
||||
}
|
||||
|
||||
unsigned BlockedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
unsigned BlockedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
|
||||
Type eltTy) const {
|
||||
size_t rank = shape.size();
|
||||
auto sizePerThread = getSizePerThread();
|
||||
auto warpsPerCTA = getWarpsPerCTA();
|
||||
@@ -365,12 +367,14 @@ SliceEncodingAttr::paddedShape<unsigned>(ArrayRef<unsigned> shape) const;
|
||||
template SmallVector<int64_t>
|
||||
SliceEncodingAttr::paddedShape<int64_t>(ArrayRef<int64_t> shape) const;
|
||||
|
||||
unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
unsigned SliceEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
|
||||
Type eltTy) const {
|
||||
auto parent = getParent();
|
||||
return ::getElemsPerThread(parent, paddedShape(shape));
|
||||
return ::getElemsPerThread(parent, paddedShape(shape), eltTy);
|
||||
}
|
||||
|
||||
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
|
||||
Type eltTy) const {
|
||||
size_t rank = shape.size();
|
||||
assert(rank == 2 && "Unexpected rank of mma layout");
|
||||
assert((isVolta() || isAmpere()) && "Only version 1 and 2 is supported");
|
||||
@@ -401,18 +405,99 @@ unsigned MmaEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
return res;
|
||||
}
|
||||
|
||||
unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
// TODO:
|
||||
assert(0 && "SharedEncodingAttr::getElemsPerThread not implemented");
|
||||
unsigned SharedEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
|
||||
Type eltTy) const {
|
||||
llvm_unreachable("Unexpected shared layout");
|
||||
return 0;
|
||||
}
|
||||
|
||||
unsigned
|
||||
DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape) const {
|
||||
if (auto blockedLayout = getParent().dyn_cast<BlockedEncodingAttr>()) {
|
||||
return blockedLayout.getElemsPerThread(shape);
|
||||
unsigned DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
|
||||
Type eltTy) const {
|
||||
if (auto mmaParent = getParent().dyn_cast<MmaEncodingAttr>()) {
|
||||
int warpsPerCTAM = mmaParent.getWarpsPerCTA()[0];
|
||||
int warpsPerCTAN = mmaParent.getWarpsPerCTA()[1];
|
||||
// A100
|
||||
if (mmaParent.isAmpere()) {
|
||||
int bitwidth = eltTy.getIntOrFloatBitWidth();
|
||||
int shapePerWarpM = 16;
|
||||
int shapePerWarpN = 8;
|
||||
int shapePerWarpK = 4 * 64 / bitwidth;
|
||||
int shapePerCTAM = shapePerWarpM * warpsPerCTAM;
|
||||
int shapePerCTAN = shapePerWarpN * warpsPerCTAN;
|
||||
|
||||
if (getOpIdx() == 0) {
|
||||
int repM = std::max<int>(1, shape[0] / shapePerCTAM);
|
||||
int repK = std::max<int>(1, shape[1] / shapePerWarpK);
|
||||
return 4 * repM * repK;
|
||||
}
|
||||
if (getOpIdx() == 1) {
|
||||
int repN = std::max<int>(1, shape[1] / shapePerCTAN);
|
||||
int repK = std::max<int>(1, shape[0] / shapePerWarpK);
|
||||
return 4 * std::max(repN / 2, 1) * repK;
|
||||
}
|
||||
}
|
||||
// V100
|
||||
if (mmaParent.isVolta()) {
|
||||
bool isRow = getMMAv1IsRow();
|
||||
bool isVec4 = getMMAv1IsVec4();
|
||||
if (getOpIdx() == 0) {
|
||||
int packSizeM = (isRow || isVec4) ? 1 : 2;
|
||||
int repM = 2 * packSizeM;
|
||||
int spwM = 2 * 4 * repM;
|
||||
int numM = getMMAv1NumOuter(shape);
|
||||
int NK = shape[1];
|
||||
int vec = 2 * repM;
|
||||
// Here we mimic the logic in loadA, the result cannot be calculated
|
||||
// directly.
|
||||
llvm::DenseSet<std::pair<int, int>> visited;
|
||||
auto ld = [&](int m, int k) {
|
||||
visited.insert({m, k});
|
||||
if (vec > 4) {
|
||||
if (isRow)
|
||||
visited.insert({m, k + 4});
|
||||
else
|
||||
visited.insert({m + 1, k});
|
||||
}
|
||||
};
|
||||
for (unsigned k = 0; k < NK; k += 4)
|
||||
for (unsigned m = 0; m < numM / 2; ++m)
|
||||
if (!visited.count({m, k}))
|
||||
ld(m, k);
|
||||
return visited.size() * 2;
|
||||
}
|
||||
if (getOpIdx() == 1) {
|
||||
int packSizeN = (isRow && !isVec4) ? 2 : 1;
|
||||
int repN = 2 * packSizeN;
|
||||
int spwN = 2 * 4 * repN;
|
||||
int numN = getMMAv1NumOuter(shape);
|
||||
int vec = 2 * repN;
|
||||
|
||||
int NK = shape[0];
|
||||
// Here we mimic the logic in loadA, the result cannot be calculated
|
||||
// directly.
|
||||
llvm::DenseSet<std::pair<int, int>> visited;
|
||||
int elemsPerLd = vec > 4 ? 4 : 2;
|
||||
auto ld = [&](int n, int k) {
|
||||
visited.insert({n, k});
|
||||
if (vec > 4) {
|
||||
if (isRow)
|
||||
visited.insert({n + 1, k});
|
||||
else
|
||||
visited.insert({n, k + 4});
|
||||
}
|
||||
};
|
||||
|
||||
for (unsigned k = 0; k < NK; k += 4)
|
||||
for (unsigned n = 0; n < numN / 2; ++n) {
|
||||
if (!visited.count({n, k}))
|
||||
ld(n, k);
|
||||
}
|
||||
|
||||
return visited.size() * 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
assert(0 && "DotOperandEncodingAttr::getElemsPerThread not implemented");
|
||||
llvm_unreachable("unknown mma version");
|
||||
return 0;
|
||||
}
|
||||
|
||||
@@ -630,26 +715,69 @@ Attribute DotOperandEncodingAttr::parse(AsmParser &parser, Type type) {
|
||||
return {};
|
||||
unsigned opIdx = attrs.get("opIdx").cast<IntegerAttr>().getInt();
|
||||
Attribute parent = attrs.get("parent");
|
||||
Attribute isMMAv1Row;
|
||||
if (parent.isa<MmaEncodingAttr>() &&
|
||||
parent.cast<MmaEncodingAttr>().isVolta()) {
|
||||
isMMAv1Row = attrs.get("isMMAv1Row");
|
||||
if (!isMMAv1Row)
|
||||
llvm::report_fatal_error("isMMAv1Row attribute is missing");
|
||||
}
|
||||
return parser.getChecked<DotOperandEncodingAttr>(parser.getContext(), opIdx,
|
||||
parent, isMMAv1Row);
|
||||
parent);
|
||||
}
|
||||
|
||||
void DotOperandEncodingAttr::print(mlir::AsmPrinter &printer) const {
|
||||
printer << "<{"
|
||||
<< "opIdx = " << getOpIdx() << ", "
|
||||
<< "parent = " << getParent();
|
||||
if (getIsMMAv1Row())
|
||||
printer << ", isMMAv1Row = " << getIsMMAv1Row();
|
||||
printer << "}>";
|
||||
}
|
||||
|
||||
bool DotOperandEncodingAttr::getMMAv1IsRow() const {
|
||||
auto [isARow, isBRow, _0, _1, _2] =
|
||||
getParent().cast<MmaEncodingAttr>().decodeVoltaLayoutStates();
|
||||
return getOpIdx() == 0 ? isARow : isBRow;
|
||||
}
|
||||
|
||||
bool DotOperandEncodingAttr::getMMAv1IsVec4() const {
|
||||
auto [_0, _1, isAVec4, isBVec4, _2] =
|
||||
getParent().cast<MmaEncodingAttr>().decodeVoltaLayoutStates();
|
||||
return getOpIdx() == 0 ? isAVec4 : isBVec4;
|
||||
}
|
||||
|
||||
SmallVector<int> DotOperandEncodingAttr::getMMAv1Rep() const {
|
||||
auto [isARow, isBRow, isAVec4, isBVec4, _] =
|
||||
getParent().cast<MmaEncodingAttr>().decodeVoltaLayoutStates();
|
||||
// A
|
||||
if (getOpIdx() == 0) {
|
||||
int packSize = (isARow || isAVec4) ? 1 : 2;
|
||||
return {2 * packSize, 0, 1};
|
||||
}
|
||||
// B
|
||||
else {
|
||||
int packSize = (isBRow && !isBVec4) ? 2 : 1;
|
||||
return {0, 2 * packSize, 1};
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<int> DotOperandEncodingAttr::getMMAv1ShapePerWarp() const {
|
||||
auto rep = getMMAv1Rep();
|
||||
if (getOpIdx() == 0) {
|
||||
return {8 * rep[0], 0, 1};
|
||||
} else {
|
||||
return {0, 8 * rep[1], 1};
|
||||
}
|
||||
}
|
||||
|
||||
int DotOperandEncodingAttr::getMMAv1Vec() const {
|
||||
size_t opIdx = getOpIdx();
|
||||
return 2 * getMMAv1Rep()[opIdx];
|
||||
}
|
||||
|
||||
int DotOperandEncodingAttr::getMMAv1NumOuter(ArrayRef<int64_t> shape) const {
|
||||
auto spw = getMMAv1ShapePerWarp();
|
||||
auto rep = getMMAv1Rep();
|
||||
auto warpsPerCTA = getParent().cast<MmaEncodingAttr>().getWarpsPerCTA();
|
||||
if (getOpIdx() == 0) {
|
||||
return rep[0] * shape[0] / (spw[0] * warpsPerCTA[0]);
|
||||
} else {
|
||||
return rep[1] * shape[1] / (spw[1] * warpsPerCTA[1]);
|
||||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// InsertSliceAsyncOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
@@ -851,7 +979,7 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
|
||||
return mlir::success();
|
||||
}
|
||||
// cvt(extract_slice(x), type2) -> extract_slice(cvt(x, type2))
|
||||
auto extract_slice = dyn_cast<tensor::ExtractSliceOp>(arg);
|
||||
auto extract_slice = dyn_cast<triton::gpu::ExtractSliceOp>(arg);
|
||||
if (extract_slice) {
|
||||
if (!isSharedEncoding(op->getResult(0))) {
|
||||
return mlir::failure();
|
||||
@@ -872,7 +1000,7 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
|
||||
rewriter.setInsertionPoint(extract_slice);
|
||||
auto newArg = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), newType, extract_slice.getSource());
|
||||
rewriter.replaceOpWithNewOp<tensor::ExtractSliceOp>(
|
||||
rewriter.replaceOpWithNewOp<triton::gpu::ExtractSliceOp>(
|
||||
op, resType, newArg.getResult(), extract_slice.offsets(),
|
||||
extract_slice.sizes(), extract_slice.strides(),
|
||||
extract_slice.static_offsets(), extract_slice.static_sizes(),
|
||||
@@ -925,6 +1053,29 @@ LogicalResult ConvertLayoutOp::canonicalize(ConvertLayoutOp op,
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Build an ExtractSliceOp with mixed static and dynamic entries and custom
|
||||
/// result type. If the type passed is nullptr, it is inferred.
|
||||
void ExtractSliceOp::build(OpBuilder &b, OperationState &result,
|
||||
RankedTensorType resultType, Value source,
|
||||
ArrayRef<OpFoldResult> offsets,
|
||||
ArrayRef<OpFoldResult> sizes,
|
||||
ArrayRef<OpFoldResult> strides,
|
||||
ArrayRef<NamedAttribute> attrs) {
|
||||
SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
|
||||
SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
|
||||
dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets);
|
||||
dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes);
|
||||
dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides);
|
||||
auto sourceRankedTensorType = source.getType().cast<RankedTensorType>();
|
||||
build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
|
||||
dynamicStrides, b.getDenseI64ArrayAttr(staticOffsets),
|
||||
b.getDenseI64ArrayAttr(staticSizes),
|
||||
b.getDenseI64ArrayAttr(staticStrides));
|
||||
result.addAttributes(attrs);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void TritonGPUDialect::initialize() {
|
||||
addAttributes<
|
||||
#define GET_ATTRDEF_LIST
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
using namespace mlir;
|
||||
namespace {
|
||||
using triton::DotOp;
|
||||
using triton::gpu::BlockedEncodingAttr;
|
||||
using triton::gpu::ConvertLayoutOp;
|
||||
using triton::gpu::DotOperandEncodingAttr;
|
||||
using triton::gpu::MmaEncodingAttr;
|
||||
@@ -43,12 +44,6 @@ SmallVector<int64_t, 2> mmaVersionToShapePerWarp(int version) {
|
||||
}
|
||||
}
|
||||
|
||||
SmallVector<unsigned, 2> warpsPerTileV1(const ArrayRef<int64_t> shape,
|
||||
int numWarps) {
|
||||
// Set a default value that ensures product of wpt equals numWarps
|
||||
return {static_cast<unsigned>(numWarps), 1};
|
||||
}
|
||||
|
||||
SmallVector<unsigned, 2> warpsPerTileV2(triton::DotOp dotOp,
|
||||
const ArrayRef<int64_t> shape,
|
||||
int numWarps) {
|
||||
@@ -92,19 +87,6 @@ public:
|
||||
: mlir::RewritePattern(triton::DotOp::getOperationName(), 2, context),
|
||||
computeCapability(computeCapability) {}
|
||||
|
||||
static SmallVector<unsigned, 2> getWarpsPerTile(triton::DotOp dotOp,
|
||||
const ArrayRef<int64_t> shape,
|
||||
int version, int numWarps) {
|
||||
switch (version) {
|
||||
case 1:
|
||||
return warpsPerTileV1(shape, numWarps);
|
||||
case 2:
|
||||
return warpsPerTileV2(dotOp, shape, numWarps);
|
||||
default:
|
||||
llvm_unreachable("unsupported MMA version");
|
||||
}
|
||||
}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
@@ -125,13 +107,46 @@ public:
|
||||
auto mod = op->getParentOfType<mlir::ModuleOp>();
|
||||
int numWarps = triton::gpu::TritonGPUDialect::getNumWarps(mod);
|
||||
|
||||
auto warpsPerTile =
|
||||
getWarpsPerTile(dotOp, retShape, versionMajor, numWarps);
|
||||
// operands
|
||||
Value a = dotOp.getA();
|
||||
Value b = dotOp.getB();
|
||||
auto oldAType = a.getType().cast<RankedTensorType>();
|
||||
auto oldBType = b.getType().cast<RankedTensorType>();
|
||||
|
||||
triton::gpu::MmaEncodingAttr mmaEnc;
|
||||
if (versionMajor == 1) {
|
||||
SetVector<Operation *> aBwdSlices, bBwdSlices;
|
||||
auto isCvt = [](Operation *op) { return isa<ConvertLayoutOp>(op); };
|
||||
getBackwardSlice(a, &aBwdSlices, isCvt);
|
||||
getBackwardSlice(b, &bBwdSlices, isCvt);
|
||||
// get the source of the first conversion found in slices
|
||||
auto getCvtArgOrder = [](Operation *op) {
|
||||
return cast<ConvertLayoutOp>(op)
|
||||
.getOperand()
|
||||
.getType()
|
||||
.cast<RankedTensorType>()
|
||||
.getEncoding()
|
||||
.cast<BlockedEncodingAttr>()
|
||||
.getOrder();
|
||||
};
|
||||
bool isARow = true;
|
||||
bool isBRow = true;
|
||||
Operation *aOp = a.getDefiningOp();
|
||||
Operation *bOp = b.getDefiningOp();
|
||||
if (!aBwdSlices.empty())
|
||||
aOp = aBwdSlices[0];
|
||||
if (!bBwdSlices.empty())
|
||||
bOp = bBwdSlices[0];
|
||||
if (aOp)
|
||||
isARow = getCvtArgOrder(aOp)[0] == 1;
|
||||
if (bOp)
|
||||
isBRow = getCvtArgOrder(bOp)[0] == 1;
|
||||
|
||||
mmaEnc = triton::gpu::MmaEncodingAttr::get(
|
||||
oldRetType.getContext(), versionMajor, numWarps, mmaV1Counter++);
|
||||
oldRetType.getContext(), versionMajor, numWarps, oldAType.getShape(),
|
||||
oldBType.getShape(), retShape, isARow, isBRow, mmaV1Counter++);
|
||||
} else if (versionMajor == 2) {
|
||||
auto warpsPerTile = warpsPerTileV2(dotOp, retShape, numWarps);
|
||||
mmaEnc = triton::gpu::MmaEncodingAttr::get(
|
||||
oldRetType.getContext(), versionMajor, 0 /*versionMinor*/,
|
||||
warpsPerTile);
|
||||
@@ -145,10 +160,6 @@ public:
|
||||
auto oldAcc = dotOp.getOperand(2);
|
||||
auto newAcc = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
oldAcc.getLoc(), newRetType, oldAcc);
|
||||
Value a = dotOp.getA();
|
||||
Value b = dotOp.getB();
|
||||
auto oldAType = a.getType().cast<RankedTensorType>();
|
||||
auto oldBType = b.getType().cast<RankedTensorType>();
|
||||
auto oldAOrder = oldAType.getEncoding()
|
||||
.cast<triton::gpu::DotOperandEncodingAttr>()
|
||||
.getParent()
|
||||
@@ -159,21 +170,15 @@ public:
|
||||
.getParent()
|
||||
.cast<triton::gpu::BlockedEncodingAttr>()
|
||||
.getOrder();
|
||||
Attribute isMMAv1RowA;
|
||||
Attribute isMMAv1RowB;
|
||||
if (versionMajor == 1) {
|
||||
isMMAv1RowA = BoolAttr::get(getContext(), oldAOrder[0] == 1);
|
||||
isMMAv1RowB = BoolAttr::get(getContext(), oldBOrder[0] == 1);
|
||||
}
|
||||
|
||||
auto newAType = RankedTensorType::get(
|
||||
oldAType.getShape(), oldAType.getElementType(),
|
||||
triton::gpu::DotOperandEncodingAttr::get(
|
||||
oldAType.getContext(), 0, newRetType.getEncoding(), isMMAv1RowA));
|
||||
triton::gpu::DotOperandEncodingAttr::get(oldAType.getContext(), 0,
|
||||
newRetType.getEncoding()));
|
||||
auto newBType = RankedTensorType::get(
|
||||
oldBType.getShape(), oldBType.getElementType(),
|
||||
triton::gpu::DotOperandEncodingAttr::get(
|
||||
oldBType.getContext(), 1, newRetType.getEncoding(), isMMAv1RowB));
|
||||
triton::gpu::DotOperandEncodingAttr::get(oldBType.getContext(), 1,
|
||||
newRetType.getEncoding()));
|
||||
|
||||
a = rewriter.create<triton::gpu::ConvertLayoutOp>(a.getLoc(), newAType, a);
|
||||
b = rewriter.create<triton::gpu::ConvertLayoutOp>(b.getLoc(), newBType, b);
|
||||
|
||||
@@ -2,13 +2,12 @@ add_mlir_dialect_library(TritonGPUTransforms
|
||||
AccelerateMatmul.cpp
|
||||
Coalesce.cpp
|
||||
DecomposeConversions.cpp
|
||||
FuseTranspositions.cpp
|
||||
OptimizeDotOperands.cpp
|
||||
Pipeline.cpp
|
||||
Prefetch.cpp
|
||||
RemoveLayoutConversions.cpp
|
||||
ReorderInstructions.cpp
|
||||
TritonGPUConversion.cpp
|
||||
UpdateMmaForVolta.cpp
|
||||
Utility.cpp
|
||||
|
||||
DEPENDS
|
||||
|
||||
@@ -16,54 +16,6 @@ using triton::gpu::DotOperandEncodingAttr;
|
||||
using triton::gpu::MmaEncodingAttr;
|
||||
using triton::gpu::SliceEncodingAttr;
|
||||
|
||||
class OptimizeConvertToDotOperand : public mlir::RewritePattern {
|
||||
public:
|
||||
explicit OptimizeConvertToDotOperand(mlir::MLIRContext *context)
|
||||
: RewritePattern(triton::gpu::ConvertLayoutOp::getOperationName(), 1,
|
||||
context) {}
|
||||
|
||||
mlir::LogicalResult
|
||||
matchAndRewrite(mlir::Operation *op,
|
||||
mlir::PatternRewriter &rewriter) const override {
|
||||
auto cvt = cast<triton::gpu::ConvertLayoutOp>(op);
|
||||
auto srcType = cvt.getOperand().getType().cast<RankedTensorType>();
|
||||
auto dstType = cvt.getResult().getType().cast<RankedTensorType>();
|
||||
// order
|
||||
ArrayRef<unsigned> order;
|
||||
if (auto srcBlockedLayout =
|
||||
srcType.getEncoding().dyn_cast<triton::gpu::BlockedEncodingAttr>())
|
||||
order = srcBlockedLayout.getOrder();
|
||||
else if (auto srcSharedLayout =
|
||||
srcType.getEncoding()
|
||||
.dyn_cast<triton::gpu::SharedEncodingAttr>())
|
||||
order = srcSharedLayout.getOrder();
|
||||
else
|
||||
return failure();
|
||||
// dot operand output
|
||||
auto dstDotOperandLayout =
|
||||
dstType.getEncoding().dyn_cast<triton::gpu::DotOperandEncodingAttr>();
|
||||
if (!dstDotOperandLayout)
|
||||
return failure();
|
||||
if (!dstDotOperandLayout.getIsMMAv1Row())
|
||||
return failure();
|
||||
bool isMMAv1Row =
|
||||
dstDotOperandLayout.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
if ((order[0] == 1 && isMMAv1Row) || (order[0] == 0 && !isMMAv1Row))
|
||||
return failure();
|
||||
|
||||
auto newIsRow = BoolAttr::get(op->getContext(), !isMMAv1Row);
|
||||
auto newDstEncoding = triton::gpu::DotOperandEncodingAttr::get(
|
||||
op->getContext(), dstDotOperandLayout.getOpIdx(),
|
||||
dstDotOperandLayout.getParent(), newIsRow);
|
||||
auto newDstType = RankedTensorType::get(
|
||||
dstType.getShape(), dstType.getElementType(), newDstEncoding);
|
||||
auto newCvt = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
op->getLoc(), newDstType, cvt.getOperand());
|
||||
rewriter.replaceOp(op, newCvt.getResult());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// convert(trans(convert(arg)))
|
||||
// x = convert_layout arg: #distributed -> #shared_x
|
||||
// y = trans x: #shared_x -> #shared_y
|
||||
@@ -125,10 +77,11 @@ public:
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
class TritonGPUFuseTranspositionsPass
|
||||
: public TritonGPUFuseTranspositionsBase<TritonGPUFuseTranspositionsPass> {
|
||||
class TritonGPUOptimizeDotOperandsPass
|
||||
: public TritonGPUOptimizeDotOperandsBase<
|
||||
TritonGPUOptimizeDotOperandsPass> {
|
||||
public:
|
||||
TritonGPUFuseTranspositionsPass() = default;
|
||||
TritonGPUOptimizeDotOperandsPass() = default;
|
||||
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
@@ -139,7 +92,6 @@ public:
|
||||
auto ret = pm.run(m);
|
||||
|
||||
mlir::RewritePatternSet patterns(context);
|
||||
patterns.add<OptimizeConvertToDotOperand>(context);
|
||||
patterns.add<ConvertTransConvert>(context);
|
||||
if (applyPatternsAndFoldGreedily(m, std::move(patterns)).failed())
|
||||
signalPassFailure();
|
||||
@@ -148,6 +100,6 @@ public:
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<Pass> mlir::createTritonGPUFuseTranspositionsPass() {
|
||||
return std::make_unique<TritonGPUFuseTranspositionsPass>();
|
||||
std::unique_ptr<Pass> mlir::createTritonGPUOptimizeDotOperandsPass() {
|
||||
return std::make_unique<TritonGPUOptimizeDotOperandsPass>();
|
||||
}
|
||||
@@ -1,6 +1,8 @@
|
||||
#include "mlir/Dialect/Tensor/IR/Tensor.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/IR/TypeUtilities.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "triton/Analysis/AxisInfo.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
@@ -173,6 +175,8 @@ LogicalResult LoopPipeliner::initialize() {
|
||||
if (auto loadOp = dyn_cast<triton::LoadOp>(&op)) {
|
||||
auto ptr = loadOp.getPtr();
|
||||
unsigned vec = axisInfoAnalysis->getPtrContiguity(ptr);
|
||||
if (auto mask = loadOp.getMask())
|
||||
vec = std::min<unsigned>(vec, axisInfoAnalysis->getMaskAlignment(mask));
|
||||
auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>();
|
||||
if (!tensorTy)
|
||||
continue;
|
||||
@@ -394,7 +398,7 @@ void LoopPipeliner::emitPrologue() {
|
||||
sliceType = RankedTensorType::get({bufferShape[1], bufferShape[2]},
|
||||
sliceType.getElementType(),
|
||||
loadsBufferType[loadOp].getEncoding());
|
||||
Value extractSlice = builder.create<tensor::ExtractSliceOp>(
|
||||
Value extractSlice = builder.create<triton::gpu::ExtractSliceOp>(
|
||||
loadOp.getLoc(), sliceType, loadStageBuffer[loadOp][numStages - 1],
|
||||
SmallVector<OpFoldResult>{int_attr(0), int_attr(0), int_attr(0)},
|
||||
SmallVector<OpFoldResult>{int_attr(1),
|
||||
@@ -480,15 +484,25 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
|
||||
// 3. replace loads with block args (from prologue)
|
||||
for (size_t idx = 0; idx < loads.size(); ++idx) {
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
Value load = loads[idx];
|
||||
assert(load.hasOneUse() &&
|
||||
"we assume that this load has one use (ConvertLayout)");
|
||||
Value loadUse = load.getUsers().begin()->getResult(0);
|
||||
mapping.lookup(loadUse).replaceAllUsesWith(
|
||||
// set insertion point
|
||||
Value newLoad = mapping.lookup(load);
|
||||
Value newLoadUse = mapping.lookup(loadUse);
|
||||
builder.setInsertionPoint(newLoadUse.getDefiningOp());
|
||||
// create conversion
|
||||
auto cvt = builder.create<ttg::ConvertLayoutOp>(
|
||||
loadUse.getLoc(), loadUse.getType(),
|
||||
newForOp.getRegionIterArgs()[loadIdx + idx]);
|
||||
|
||||
// replace uses
|
||||
newLoadUse.replaceAllUsesWith(cvt.getResult());
|
||||
// delete old load and layout conversion
|
||||
mapping.lookup(loadUse).getDefiningOp()->erase();
|
||||
mapping.lookup(load).getDefiningOp()->erase();
|
||||
newLoadUse.getDefiningOp()->erase();
|
||||
newLoad.getDefiningOp()->erase();
|
||||
}
|
||||
|
||||
// 4. prefetch the next iteration
|
||||
@@ -532,8 +546,6 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
Value extractSliceIndex = builder.create<arith::RemSIOp>(
|
||||
nextIV.getLoc(), loopIterIdx,
|
||||
builder.create<arith::ConstantIntOp>(nextIV.getLoc(), numStages, 32));
|
||||
extractSliceIndex = builder.create<arith::IndexCastOp>(
|
||||
extractSliceIndex.getLoc(), builder.getIndexType(), extractSliceIndex);
|
||||
|
||||
for (Operation *op : orderedDeps)
|
||||
if (!loads.contains(op->getResult(0))) {
|
||||
@@ -591,7 +603,7 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
sliceType.getElementType(),
|
||||
loadsBufferType[loadOp].getEncoding());
|
||||
|
||||
nextOp = builder.create<tensor::ExtractSliceOp>(
|
||||
nextOp = builder.create<triton::gpu::ExtractSliceOp>(
|
||||
op->getLoc(), sliceType, insertAsyncOp,
|
||||
SmallVector<OpFoldResult>{extractSliceIndex, int_attr(0),
|
||||
int_attr(0)},
|
||||
@@ -618,35 +630,6 @@ scf::ForOp LoopPipeliner::createNewForOp() {
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
for (Operation &op : *newForOp.getBody()) {
|
||||
if (auto dotOp = llvm::dyn_cast<triton::DotOp>(&op)) {
|
||||
builder.setInsertionPoint(&op);
|
||||
auto dotType = dotOp.getType().cast<RankedTensorType>();
|
||||
Value a = dotOp.getA();
|
||||
Value b = dotOp.getB();
|
||||
auto layoutCast = [&](Value dotOperand, int opIdx) -> Value {
|
||||
auto tensorType = dotOperand.getType().cast<RankedTensorType>();
|
||||
if (!tensorType.getEncoding().isa<ttg::DotOperandEncodingAttr>()) {
|
||||
auto newEncoding = ttg::DotOperandEncodingAttr::get(
|
||||
tensorType.getContext(), opIdx, dotType.getEncoding());
|
||||
auto newType =
|
||||
RankedTensorType::get(tensorType.getShape(),
|
||||
tensorType.getElementType(), newEncoding);
|
||||
return builder.create<ttg::ConvertLayoutOp>(dotOperand.getLoc(),
|
||||
newType, dotOperand);
|
||||
}
|
||||
return dotOperand;
|
||||
};
|
||||
a = layoutCast(a, 0);
|
||||
b = layoutCast(b, 1);
|
||||
dotOp->setOperand(0, a);
|
||||
dotOp->setOperand(1, b);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// async.wait & extract_slice
|
||||
Operation *asyncWait = builder.create<ttg::AsyncWaitOp>(
|
||||
loads[0].getLoc(), loads.size() * (numStages - 2));
|
||||
@@ -698,6 +681,17 @@ struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
|
||||
if (numStages <= 1)
|
||||
return;
|
||||
|
||||
// Pre-processing
|
||||
// we make sure element-wise ops are done *after* the conversion
|
||||
// to dot operands
|
||||
// we can achieve this with simple recursive pattern matching
|
||||
// MLIRContext *context = &getContext();
|
||||
// mlir::RewritePatternSet patterns(context);
|
||||
// patterns.add<MoveOpAfterLayoutConversion>(context);
|
||||
// auto didPreprocess =
|
||||
// applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
|
||||
|
||||
// Do the pipelining
|
||||
getOperation()->walk([&](scf::ForOp forOp) -> void {
|
||||
LoopPipeliner pipeliner(forOp, numStages);
|
||||
|
||||
@@ -707,7 +701,6 @@ struct PipelinePass : public TritonGPUPipelineBase<PipelinePass> {
|
||||
pipeliner.emitPrologue();
|
||||
|
||||
scf::ForOp newForOp = pipeliner.createNewForOp();
|
||||
|
||||
pipeliner.emitEpilogue();
|
||||
|
||||
// replace the original loop
|
||||
|
||||
@@ -103,11 +103,9 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
|
||||
if (offsetK)
|
||||
offset[kIdx] = *offsetK;
|
||||
|
||||
Value newSmem = builder.create<tensor::ExtractSliceOp>(
|
||||
v.getLoc(),
|
||||
// TODO: encoding?
|
||||
RankedTensorType::get(shape, elementType, type.getEncoding()), v,
|
||||
SmallVector<OpFoldResult>{intAttr(offset[0]), intAttr(offset[1])},
|
||||
Value newSmem = builder.create<triton::gpu::ExtractSliceOp>(
|
||||
v.getLoc(), RankedTensorType::get(shape, elementType, type.getEncoding()),
|
||||
v, SmallVector<OpFoldResult>{intAttr(offset[0]), intAttr(offset[1])},
|
||||
SmallVector<OpFoldResult>{intAttr(shape[0]), intAttr(shape[1])},
|
||||
SmallVector<OpFoldResult>{intAttr(1), intAttr(1)});
|
||||
|
||||
|
||||
@@ -108,7 +108,7 @@ public:
|
||||
auto newReduce = rewriter.create<triton::ReduceOp>(
|
||||
op->getLoc(), reduce.getRedOp(), reduceArg.getOperand(),
|
||||
reduce.getAxis());
|
||||
if (isa<triton::gpu::ConvertLayoutOp>(
|
||||
if (isa_and_nonnull<triton::gpu::ConvertLayoutOp>(
|
||||
*reduceArg.getOperand().getDefiningOp()))
|
||||
return mlir::failure();
|
||||
Value newRet = newReduce.getResult();
|
||||
@@ -146,157 +146,6 @@ public:
|
||||
//
|
||||
// -----------------------------------------------------------------------------
|
||||
|
||||
// TODO: Interface
|
||||
LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
|
||||
Attribute &ret) {
|
||||
ret = targetEncoding;
|
||||
if (auto expand_dims = dyn_cast<triton::ExpandDimsOp>(op)) {
|
||||
ret = triton::gpu::SliceEncodingAttr::get(
|
||||
op->getContext(), expand_dims.getAxis(), targetEncoding);
|
||||
}
|
||||
if (auto reduce = dyn_cast<triton::ReduceOp>(op)) {
|
||||
auto sliceEncoding =
|
||||
targetEncoding.dyn_cast<triton::gpu::SliceEncodingAttr>();
|
||||
if (!sliceEncoding)
|
||||
return failure();
|
||||
if (sliceEncoding.getDim() != reduce.getAxis())
|
||||
return failure();
|
||||
ret = sliceEncoding.getParent();
|
||||
}
|
||||
if (auto view = dyn_cast<triton::ViewOp>(op)) {
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
inline bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
|
||||
// Case 1: A size 1 tensor is not expensive since all threads will load the
|
||||
// same
|
||||
if (isSingleValue(op->getOperand(0)))
|
||||
return false;
|
||||
auto ptr = op->getOperand(0);
|
||||
// Case 2: We assume that `evict_last` loads/stores have high hit rate
|
||||
if (auto load = dyn_cast<triton::LoadOp>(op))
|
||||
if (load.getEvict() == triton::EvictionPolicy::EVICT_LAST)
|
||||
return false;
|
||||
if (auto store = dyn_cast<triton::StoreOp>(op))
|
||||
if (store.getEvict() == triton::EvictionPolicy::EVICT_LAST)
|
||||
return false;
|
||||
if (auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>()) {
|
||||
auto encoding = tensorTy.getEncoding();
|
||||
// Case 3: Different type conversion is expensive (e.g., mma <-> block)
|
||||
if (encoding.getTypeID() != targetEncoding.getTypeID())
|
||||
return true;
|
||||
auto sizePerThread = triton::gpu::getSizePerThread(encoding);
|
||||
auto targetSizePerThread = triton::gpu::getSizePerThread(targetEncoding);
|
||||
auto order = triton::gpu::getOrder(encoding);
|
||||
auto targetOrder = triton::gpu::getOrder(targetEncoding);
|
||||
// Case 4: The targeEncoding may expose more vectorization opportunities
|
||||
return sizePerThread[order[0]] >= targetSizePerThread[targetOrder[0]];
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
inline bool expensiveToRemat(Operation *op, Attribute &targetEncoding) {
|
||||
if (!op)
|
||||
return true;
|
||||
if (isa<triton::LoadOp, triton::StoreOp>(op))
|
||||
return expensiveLoadOrStore(op, targetEncoding);
|
||||
if (isa<tensor::ExtractSliceOp, triton::gpu::AllocTensorOp,
|
||||
triton::gpu::InsertSliceAsyncOp, triton::AtomicRMWOp,
|
||||
triton::AtomicCASOp, triton::DotOp>(op))
|
||||
return true;
|
||||
if (isa<scf::YieldOp, scf::ForOp, scf::IfOp, scf::WhileOp, scf::ConditionOp>(
|
||||
op))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
LogicalResult simulateBackwardRematerialization(
|
||||
Operation *initOp, SetVector<Operation *> &processed,
|
||||
SetVector<Attribute> &layout, llvm::MapVector<Value, Attribute> &toConvert,
|
||||
const Attribute &targetEncoding) {
|
||||
// DFS
|
||||
std::vector<std::pair<Operation *, Attribute>> queue;
|
||||
queue.emplace_back(initOp, targetEncoding);
|
||||
// We want to see the effect of converting `initOp` to a new layout
|
||||
// so we initialize `numCvts = 1`.
|
||||
int numCvts = 1;
|
||||
while (!queue.empty()) {
|
||||
Operation *currOp;
|
||||
Attribute currLayout;
|
||||
std::tie(currOp, currLayout) = queue.back();
|
||||
queue.pop_back();
|
||||
// If the current operation is expensive to rematerialize,
|
||||
// we stop everything
|
||||
if (expensiveToRemat(currOp, currLayout))
|
||||
break;
|
||||
// A conversion will be removed here (i.e. transferred to operands)
|
||||
numCvts -= 1;
|
||||
// Done processing
|
||||
processed.insert(currOp);
|
||||
layout.insert(currLayout);
|
||||
// Add all operands to the queue
|
||||
for (Value argI : currOp->getOperands()) {
|
||||
Attribute newEncoding;
|
||||
// Cannot invert the current encoding for this operand
|
||||
// we stop everything
|
||||
if (failed(invertEncoding(currLayout, currOp, newEncoding)))
|
||||
return mlir::failure();
|
||||
if (toConvert.count(argI) && toConvert[argI] != newEncoding)
|
||||
return mlir::failure();
|
||||
Operation *opArgI = argI.getDefiningOp();
|
||||
toConvert.insert({argI, newEncoding});
|
||||
// 1. Only convert RankedTensorType
|
||||
// 2. Skip if there's no defining op
|
||||
// 3. Skip if the defining op has already been processed
|
||||
// 4. Skip or the defining op is in a different block
|
||||
if (!argI.getType().isa<RankedTensorType>() || !opArgI ||
|
||||
processed.contains(opArgI) ||
|
||||
opArgI->getBlock() != currOp->getBlock())
|
||||
continue;
|
||||
// If the conversion can be folded into opArgI then
|
||||
// we don't count this conversion as expensive
|
||||
if (isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
|
||||
triton::MakeRangeOp, triton::SplatOp>(*opArgI))
|
||||
continue;
|
||||
if (auto view = dyn_cast<triton::ViewOp>(opArgI))
|
||||
continue;
|
||||
|
||||
// We add one expensive conversion for the current operand
|
||||
numCvts += 1;
|
||||
queue.emplace_back(opArgI, newEncoding);
|
||||
}
|
||||
}
|
||||
// if rematerialization would add more conversions than it removes
|
||||
// then we don't do it
|
||||
if (numCvts > 0)
|
||||
return mlir::failure();
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
|
||||
IRMapping &mapping) {
|
||||
Operation *newOp = rewriter.clone(*op, mapping);
|
||||
auto origType = op->getResult(0).getType().cast<RankedTensorType>();
|
||||
auto argType = newOp->getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto newType = RankedTensorType::get(
|
||||
origType.getShape(), origType.getElementType(), argType.getEncoding());
|
||||
newOp->getResult(0).setType(newType);
|
||||
auto typeInfer = dyn_cast<InferTypeOpInterface>(newOp);
|
||||
if (typeInfer) {
|
||||
SmallVector<Type, 1> newTypes;
|
||||
auto success = typeInfer.inferReturnTypes(
|
||||
newOp->getContext(), newOp->getLoc(), newOp->getOperands(),
|
||||
newOp->getAttrDictionary(), newOp->getRegions(), newTypes);
|
||||
if (succeeded(success))
|
||||
newOp->getResult(0).setType(newTypes.front());
|
||||
}
|
||||
return newOp;
|
||||
}
|
||||
|
||||
// op(cvt(arg_0), arg_1, ..., arg_n)
|
||||
// -> cvt(op(arg_0, cvt(arg_1), ..., cvt(arg_n)))
|
||||
void pushConversionForward(triton::gpu::ConvertLayoutOp cvt,
|
||||
@@ -368,11 +217,11 @@ public:
|
||||
|
||||
IRMapping mapping;
|
||||
for (size_t i = 0; i < numOps; i++) {
|
||||
auto thenCvt = dyn_cast<triton::gpu::ConvertLayoutOp>(
|
||||
auto thenCvt = dyn_cast_or_null<triton::gpu::ConvertLayoutOp>(
|
||||
thenYield.getOperand(i).getDefiningOp());
|
||||
if (hasElse) {
|
||||
auto elseYield = ifOp.elseYield();
|
||||
auto elseCvt = dyn_cast<triton::gpu::ConvertLayoutOp>(
|
||||
auto elseCvt = dyn_cast_or_null<triton::gpu::ConvertLayoutOp>(
|
||||
elseYield.getOperand(i).getDefiningOp());
|
||||
if (thenCvt && elseCvt &&
|
||||
std::distance(elseCvt->user_begin(), elseCvt->user_end()) == 1 &&
|
||||
@@ -492,8 +341,8 @@ public:
|
||||
SetVector<Attribute> layout;
|
||||
llvm::MapVector<Value, Attribute> toConvert;
|
||||
if (argOp && (argOp != cvt) && cvtSlices.count(argOp) == 0 &&
|
||||
failed(simulateBackwardRematerialization(argOp, processed, layout,
|
||||
toConvert, srcEncoding))) {
|
||||
simulateBackwardRematerialization(argOp, processed, layout,
|
||||
toConvert, srcEncoding) > 0) {
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
@@ -539,50 +388,13 @@ public:
|
||||
SetVector<Attribute> layout;
|
||||
llvm::MapVector<Value, Attribute> toConvert;
|
||||
std::vector<std::pair<Operation *, Attribute>> queue;
|
||||
if (failed(simulateBackwardRematerialization(
|
||||
cvt, processed, layout, toConvert, targetType.getEncoding())))
|
||||
if (simulateBackwardRematerialization(cvt, processed, layout, toConvert,
|
||||
targetType.getEncoding()) > 0)
|
||||
return mlir::failure();
|
||||
|
||||
SmallVector<Value, 4> sortedValues;
|
||||
SetVector<Operation *> tmp;
|
||||
for (auto &item : toConvert) {
|
||||
Value v = item.first;
|
||||
if (v.getDefiningOp())
|
||||
tmp.insert(v.getDefiningOp());
|
||||
else
|
||||
sortedValues.push_back(v);
|
||||
}
|
||||
tmp = mlir::multiRootTopologicalSort(tmp);
|
||||
for (Operation *op : tmp)
|
||||
sortedValues.push_back(op->getResult(0));
|
||||
|
||||
IRMapping mapping;
|
||||
for (Value currOperand : sortedValues) {
|
||||
// unpack information
|
||||
Attribute targetLayout = toConvert.lookup(currOperand);
|
||||
// rematerialize the operand if necessary
|
||||
Operation *currOperation = currOperand.getDefiningOp();
|
||||
if (processed.contains(currOperation)) {
|
||||
Operation *newOperation =
|
||||
cloneWithInferType(rewriter, currOperation, mapping);
|
||||
newOperation->moveAfter(currOperation);
|
||||
currOperation = newOperation;
|
||||
currOperand = currOperation->getResult(0);
|
||||
}
|
||||
// compute target type for the layout cast
|
||||
auto currType = currOperand.getType().cast<RankedTensorType>();
|
||||
auto newType = RankedTensorType::get(
|
||||
currType.getShape(), currType.getElementType(), targetLayout);
|
||||
auto newOperand = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
currOperand.getLoc(), newType, currOperand);
|
||||
if (currOperation)
|
||||
newOperand->moveAfter(currOperation);
|
||||
else {
|
||||
Block *block = currOperand.cast<BlockArgument>().getOwner();
|
||||
newOperand->moveAfter(block, block->begin());
|
||||
}
|
||||
mapping.map(currOperand, newOperand);
|
||||
}
|
||||
rematerializeConversionChain(toConvert, rewriter, processed, mapping);
|
||||
|
||||
rewriter.replaceOp(cvt, mapping.lookup(cvt->getOperand(0)));
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
@@ -1,212 +0,0 @@
|
||||
#include "Utility.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h"
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "triton/Dialect/TritonGPU/Transforms/Passes.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace {
|
||||
using triton::DotOp;
|
||||
using triton::gpu::ConvertLayoutOp;
|
||||
using triton::gpu::DotOperandEncodingAttr;
|
||||
using triton::gpu::MmaEncodingAttr;
|
||||
using triton::gpu::SharedEncodingAttr;
|
||||
using triton::gpu::SliceEncodingAttr;
|
||||
|
||||
// Get the wpt for MMAv1 using more information.
|
||||
// Reference the original logic here
|
||||
// https://github.com/openai/triton/blob/0e4691e6dd91e001a8d33b71badf8b3314325459/lib/codegen/analysis/layout.cc#L223
|
||||
SmallVector<unsigned> getWarpsPerCTA(ArrayRef<int64_t> shape, bool isARow,
|
||||
bool isBRow, bool isAVec4, bool isBVec4,
|
||||
int numWarps) {
|
||||
// TODO[Superjomn]: Share code with
|
||||
// DotOpMmaV1ConversionHelper::AParam/BParam, since same code to compute the
|
||||
// rep,spw and fpw.
|
||||
SmallVector<unsigned> wpt({1, 1});
|
||||
SmallVector<unsigned> wpt_nm1;
|
||||
|
||||
SmallVector<int, 2> rep(2), spw(2);
|
||||
std::array<int, 3> fpw{{2, 2, 1}};
|
||||
int packSize0 = (isARow || isAVec4) ? 1 : 2;
|
||||
rep[0] = 2 * packSize0;
|
||||
spw[0] = fpw[0] * 4 * rep[0];
|
||||
|
||||
int packSize1 = (isBRow && !isBVec4) ? 2 : 1;
|
||||
rep[1] = 2 * packSize1;
|
||||
spw[1] = fpw[1] * 4 * rep[1];
|
||||
|
||||
do {
|
||||
wpt_nm1 = wpt;
|
||||
if (wpt[0] * wpt[1] < numWarps)
|
||||
wpt[0] = std::clamp<int>(wpt[0] * 2, 1, shape[0] / spw[0]);
|
||||
if (wpt[0] * wpt[1] < numWarps)
|
||||
wpt[1] = std::clamp<int>(wpt[1] * 2, 1, shape[1] / spw[1]);
|
||||
} while (wpt_nm1 != wpt);
|
||||
|
||||
return wpt;
|
||||
}
|
||||
|
||||
// Given a (potentially malformed) DotOp, determines the optimal
|
||||
// MMAEncoding to use on V100
|
||||
LogicalResult getOptimizedV100MMaLayout(triton::DotOp dotOp,
|
||||
MmaEncodingAttr &old,
|
||||
MmaEncodingAttr &ret) {
|
||||
auto *ctx = dotOp->getContext();
|
||||
auto AT = dotOp.getA().getType().cast<RankedTensorType>();
|
||||
auto BT = dotOp.getB().getType().cast<RankedTensorType>();
|
||||
auto DT = dotOp.getD().getType().cast<RankedTensorType>();
|
||||
auto shapeA = AT.getShape();
|
||||
auto shapeB = BT.getShape();
|
||||
if (!DT.getEncoding())
|
||||
return mlir::failure();
|
||||
auto mmaLayout = DT.getEncoding().dyn_cast<MmaEncodingAttr>();
|
||||
if (!(mmaLayout && mmaLayout.isVolta()))
|
||||
return mlir::failure();
|
||||
// We have an MmaEncodingAttr here. Find the correct layout for it.
|
||||
auto dotOperandA = AT.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
auto dotOperandB = BT.getEncoding().cast<DotOperandEncodingAttr>();
|
||||
bool isARow = dotOperandA.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
bool isBRow = dotOperandB.getIsMMAv1Row().cast<BoolAttr>().getValue();
|
||||
auto [isARow_, isBRow_, isAVec4_, isBVec4_, mmaId] =
|
||||
mmaLayout.decodeVoltaLayoutStates();
|
||||
bool isAVec4 = !isARow && (shapeA[isARow] <= 16);
|
||||
bool isBVec4 = isBRow && (shapeB[isBRow] <= 16);
|
||||
// The wpt of MMAv1 is also determined by isARow, isBRow and shape, and it
|
||||
// could only be set here for those states might be updated by previous
|
||||
// patterns in the Combine Pass.
|
||||
auto tgtWpt = getWarpsPerCTA(DT.getShape(), isARow, isBRow, isAVec4, isBVec4,
|
||||
product(mmaLayout.getWarpsPerCTA()));
|
||||
if (isARow == isARow_ && isBRow == isBRow_ && isAVec4 == isAVec4_ &&
|
||||
isBVec4 == isBVec4_) {
|
||||
if (tgtWpt == mmaLayout.getWarpsPerCTA())
|
||||
return mlir::failure();
|
||||
}
|
||||
// Recalculate the wpt, for here we could get the latest information, the
|
||||
// wpt should be updated.
|
||||
auto updatedWpt =
|
||||
getWarpsPerCTA(DT.getShape(), isARow, isBRow, isAVec4, isBVec4,
|
||||
product(mmaLayout.getWarpsPerCTA()));
|
||||
// return results
|
||||
old = mmaLayout;
|
||||
ret =
|
||||
MmaEncodingAttr::get(ctx, mmaLayout.getVersionMajor(), updatedWpt,
|
||||
AT.getShape(), BT.getShape(), isARow, isBRow, mmaId);
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
// Replace result op type
|
||||
void setOpResultType(Operation *op, ArrayRef<Type> newTypes) {
|
||||
if (op->getNumResults() != newTypes.size())
|
||||
llvm_unreachable("number of types different from number of results");
|
||||
// nothing to do
|
||||
if (op->getNumResults() == 0)
|
||||
return;
|
||||
// replace types
|
||||
for (unsigned i = 0; i < op->getNumResults(); i++) {
|
||||
Type newType = newTypes[i];
|
||||
op->getResult(i).setType(newType);
|
||||
}
|
||||
// special case: arith.constant: we need to change the value attr
|
||||
if (isa<arith::ConstantOp>(op)) {
|
||||
Type newType = newTypes[0];
|
||||
auto attr = op->getAttrDictionary()
|
||||
.get("value")
|
||||
.dyn_cast<mlir::DenseElementsAttr>();
|
||||
if (attr) {
|
||||
auto newAttr =
|
||||
mlir::DenseElementsAttr::getFromRawBuffer(newType, attr.getRawData());
|
||||
op->setAttr("value", newAttr);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// update style type given the provided layoutMap
|
||||
Type updateStaleType(
|
||||
const DenseMap<MmaEncodingAttr, MmaEncodingAttr> &layoutMap,
|
||||
RankedTensorType type) {
|
||||
auto encoding = type.getEncoding();
|
||||
// mma encoding
|
||||
if (auto mma = encoding.dyn_cast<MmaEncodingAttr>()) {
|
||||
auto newMma = layoutMap.lookup(mma);
|
||||
if (!newMma)
|
||||
return Type();
|
||||
return RankedTensorType::get(type.getShape(), type.getElementType(),
|
||||
newMma);
|
||||
}
|
||||
// slice encoding
|
||||
else if (auto slice = encoding.dyn_cast<SliceEncodingAttr>()) {
|
||||
if (auto mma = slice.getParent().dyn_cast<MmaEncodingAttr>()) {
|
||||
auto newMma = layoutMap.lookup(mma);
|
||||
if (!newMma)
|
||||
return Type();
|
||||
auto newSlice =
|
||||
SliceEncodingAttr::get(slice.getContext(), slice.getDim(), newMma);
|
||||
return RankedTensorType::get(type.getShape(), type.getElementType(),
|
||||
newSlice);
|
||||
}
|
||||
}
|
||||
// dot operand encoding
|
||||
else if (auto dotOp = encoding.dyn_cast<DotOperandEncodingAttr>()) {
|
||||
if (auto mma = dotOp.getParent().dyn_cast<MmaEncodingAttr>()) {
|
||||
auto newMma = layoutMap.lookup(mma);
|
||||
if (!newMma)
|
||||
return Type();
|
||||
auto newDotOp = DotOperandEncodingAttr::get(
|
||||
dotOp.getContext(), dotOp.getOpIdx(), newMma, dotOp.getIsMMAv1Row());
|
||||
return RankedTensorType::get(type.getShape(), type.getElementType(),
|
||||
newDotOp);
|
||||
}
|
||||
}
|
||||
return Type();
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
class UpdateMmaForVoltaPass
|
||||
: public UpdateMmaForVoltaBase<UpdateMmaForVoltaPass> {
|
||||
public:
|
||||
UpdateMmaForVoltaPass() = default;
|
||||
void runOnOperation() override {
|
||||
MLIRContext *context = &getContext();
|
||||
ModuleOp m = getOperation();
|
||||
// Step 1:
|
||||
// Build a map from old MMA encoding to new MMA encoding.
|
||||
DenseMap<MmaEncodingAttr, MmaEncodingAttr> layoutMap;
|
||||
m.walk([&layoutMap](triton::DotOp dotOp) {
|
||||
MmaEncodingAttr newLayout;
|
||||
MmaEncodingAttr oldLayout;
|
||||
if (failed(getOptimizedV100MMaLayout(dotOp, oldLayout, newLayout)))
|
||||
return;
|
||||
layoutMap[oldLayout] = newLayout;
|
||||
});
|
||||
// Step 2:
|
||||
// Replace all wrong layouts with the right one
|
||||
m.walk([&layoutMap](Operation *op) {
|
||||
if (op->getNumResults() != 1)
|
||||
return;
|
||||
auto type = op->getResult(0).getType().dyn_cast<RankedTensorType>();
|
||||
if (!type)
|
||||
return;
|
||||
Type newType = updateStaleType(layoutMap, type);
|
||||
if (!newType)
|
||||
return;
|
||||
setOpResultType(op, {newType});
|
||||
});
|
||||
// Step 3:
|
||||
// We may have messed up some loops in the process.
|
||||
// Fix them up
|
||||
if (fixupLoops(m).failed())
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
|
||||
std::unique_ptr<Pass> createTritonGPUUpdateMmaForVoltaPass() {
|
||||
return std::make_unique<UpdateMmaForVoltaPass>();
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
@@ -1,7 +1,10 @@
|
||||
#include "Utility.h"
|
||||
#include "mlir/Analysis/SliceAnalysis.h"
|
||||
#include "mlir/Dialect/SCF/IR/SCF.h"
|
||||
#include "mlir/IR/IRMapping.h"
|
||||
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
|
||||
#include "triton/Analysis/Utility.h"
|
||||
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
@@ -60,4 +63,199 @@ LogicalResult fixupLoops(ModuleOp mod) {
|
||||
return success();
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------------------- //
|
||||
|
||||
// TODO: Interface
|
||||
LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
|
||||
Attribute &ret) {
|
||||
ret = targetEncoding;
|
||||
if (auto expand_dims = dyn_cast<triton::ExpandDimsOp>(op)) {
|
||||
ret = triton::gpu::SliceEncodingAttr::get(
|
||||
op->getContext(), expand_dims.getAxis(), targetEncoding);
|
||||
}
|
||||
if (auto reduce = dyn_cast<triton::ReduceOp>(op)) {
|
||||
auto sliceEncoding =
|
||||
targetEncoding.dyn_cast<triton::gpu::SliceEncodingAttr>();
|
||||
if (!sliceEncoding)
|
||||
return failure();
|
||||
if (sliceEncoding.getDim() != reduce.getAxis())
|
||||
return failure();
|
||||
ret = sliceEncoding.getParent();
|
||||
}
|
||||
if (auto view = dyn_cast<triton::ViewOp>(op)) {
|
||||
return failure();
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding) {
|
||||
// Case 1: A size 1 tensor is not expensive since all threads will load the
|
||||
// same
|
||||
if (isSingleValue(op->getOperand(0)))
|
||||
return false;
|
||||
// auto ptr = op->getOperand(0);
|
||||
//// Case 2: We assume that `evict_last` loads/stores have high hit rate
|
||||
// if (auto load = dyn_cast<triton::LoadOp>(op))
|
||||
// if (load.getEvict() == triton::EvictionPolicy::EVICT_LAST)
|
||||
// return false;
|
||||
// if (auto store = dyn_cast<triton::StoreOp>(op))
|
||||
// if (store.getEvict() == triton::EvictionPolicy::EVICT_LAST)
|
||||
// return false;
|
||||
// if (auto tensorTy = ptr.getType().dyn_cast<RankedTensorType>()) {
|
||||
// auto encoding = tensorTy.getEncoding();
|
||||
// // Case 3: Different type conversion is expensive (e.g., mma <-> block)
|
||||
// if (encoding.getTypeID() != targetEncoding.getTypeID())
|
||||
// return true;
|
||||
// auto sizePerThread = triton::gpu::getSizePerThread(encoding);
|
||||
// auto targetSizePerThread = triton::gpu::getSizePerThread(targetEncoding);
|
||||
// auto order = triton::gpu::getOrder(encoding);
|
||||
// auto targetOrder = triton::gpu::getOrder(targetEncoding);
|
||||
// // Case 4: The targeEncoding may expose more vectorization opportunities
|
||||
// return sizePerThread[order[0]] >= targetSizePerThread[targetOrder[0]];
|
||||
// }
|
||||
return true;
|
||||
}
|
||||
|
||||
bool expensiveToRemat(Operation *op, Attribute &targetEncoding) {
|
||||
if (!op)
|
||||
return true;
|
||||
if (isa<triton::LoadOp, triton::StoreOp>(op))
|
||||
return expensiveLoadOrStore(op, targetEncoding);
|
||||
if (isa<tensor::ExtractSliceOp, triton::gpu::AllocTensorOp,
|
||||
triton::gpu::InsertSliceAsyncOp, triton::AtomicRMWOp,
|
||||
triton::AtomicCASOp, triton::DotOp>(op))
|
||||
return true;
|
||||
if (isa<scf::YieldOp, scf::ForOp, scf::IfOp, scf::WhileOp, scf::ConditionOp>(
|
||||
op))
|
||||
return true;
|
||||
return false;
|
||||
}
|
||||
|
||||
int simulateBackwardRematerialization(
|
||||
Operation *initOp, SetVector<Operation *> &processed,
|
||||
SetVector<Attribute> &layout, llvm::MapVector<Value, Attribute> &toConvert,
|
||||
const Attribute &targetEncoding) {
|
||||
// DFS
|
||||
std::vector<std::pair<Operation *, Attribute>> queue;
|
||||
queue.emplace_back(initOp, targetEncoding);
|
||||
// We want to see the effect of converting `initOp` to a new layout
|
||||
// so we initialize `numCvts = 1`.
|
||||
int numCvts = 1;
|
||||
while (!queue.empty()) {
|
||||
Operation *currOp;
|
||||
Attribute currLayout;
|
||||
std::tie(currOp, currLayout) = queue.back();
|
||||
queue.pop_back();
|
||||
// If the current operation is expensive to rematerialize,
|
||||
// we stop everything
|
||||
if (expensiveToRemat(currOp, currLayout))
|
||||
break;
|
||||
// A conversion will be removed here (i.e. transferred to operands)
|
||||
numCvts -= 1;
|
||||
// Done processing
|
||||
processed.insert(currOp);
|
||||
layout.insert(currLayout);
|
||||
// Add all operands to the queue
|
||||
for (Value argI : currOp->getOperands()) {
|
||||
Attribute newEncoding;
|
||||
// Cannot invert the current encoding for this operand
|
||||
// we stop everything
|
||||
if (failed(invertEncoding(currLayout, currOp, newEncoding)))
|
||||
return INT_MAX;
|
||||
if (toConvert.count(argI) && toConvert[argI] != newEncoding)
|
||||
return INT_MAX;
|
||||
Operation *opArgI = argI.getDefiningOp();
|
||||
toConvert.insert({argI, newEncoding});
|
||||
// 1. Only convert RankedTensorType
|
||||
// 2. Skip if there's no defining op
|
||||
// 3. Skip if the defining op has already been processed
|
||||
// 4. Skip or the defining op is in a different block
|
||||
if (!argI.getType().isa<RankedTensorType>() || !opArgI ||
|
||||
processed.contains(opArgI) ||
|
||||
opArgI->getBlock() != currOp->getBlock())
|
||||
continue;
|
||||
// If the conversion can be folded into opArgI then
|
||||
// we don't count this conversion as expensive
|
||||
if (isa<triton::gpu::ConvertLayoutOp, arith::ConstantOp,
|
||||
triton::MakeRangeOp, triton::SplatOp>(*opArgI))
|
||||
continue;
|
||||
if (auto view = dyn_cast<triton::ViewOp>(opArgI))
|
||||
continue;
|
||||
|
||||
// We add one expensive conversion for the current operand
|
||||
numCvts += 1;
|
||||
queue.emplace_back(opArgI, newEncoding);
|
||||
}
|
||||
}
|
||||
// return net number of conversions
|
||||
return numCvts;
|
||||
}
|
||||
|
||||
//
|
||||
|
||||
Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
|
||||
IRMapping &mapping) {
|
||||
Operation *newOp = rewriter.clone(*op, mapping);
|
||||
auto origType = op->getResult(0).getType().cast<RankedTensorType>();
|
||||
auto argType = newOp->getOperand(0).getType().cast<RankedTensorType>();
|
||||
auto newType = RankedTensorType::get(
|
||||
origType.getShape(), origType.getElementType(), argType.getEncoding());
|
||||
newOp->getResult(0).setType(newType);
|
||||
auto typeInfer = dyn_cast<InferTypeOpInterface>(newOp);
|
||||
if (typeInfer) {
|
||||
SmallVector<Type, 1> newTypes;
|
||||
auto success = typeInfer.inferReturnTypes(
|
||||
newOp->getContext(), newOp->getLoc(), newOp->getOperands(),
|
||||
newOp->getAttrDictionary(), newOp->getRegions(), newTypes);
|
||||
if (succeeded(success))
|
||||
newOp->getResult(0).setType(newTypes.front());
|
||||
}
|
||||
return newOp;
|
||||
}
|
||||
|
||||
void rematerializeConversionChain(
|
||||
const llvm::MapVector<Value, Attribute> &toConvert,
|
||||
mlir::PatternRewriter &rewriter, SetVector<Operation *> &processed,
|
||||
IRMapping &mapping) {
|
||||
SmallVector<Value, 4> sortedValues;
|
||||
SetVector<Operation *> tmp;
|
||||
for (auto &item : toConvert) {
|
||||
Value v = item.first;
|
||||
if (v.getDefiningOp())
|
||||
tmp.insert(v.getDefiningOp());
|
||||
else
|
||||
sortedValues.push_back(v);
|
||||
}
|
||||
tmp = mlir::multiRootTopologicalSort(tmp);
|
||||
for (Operation *op : tmp)
|
||||
sortedValues.push_back(op->getResult(0));
|
||||
|
||||
for (Value currOperand : sortedValues) {
|
||||
// unpack information
|
||||
Attribute targetLayout = toConvert.lookup(currOperand);
|
||||
// rematerialize the operand if necessary
|
||||
Operation *currOperation = currOperand.getDefiningOp();
|
||||
if (processed.contains(currOperation)) {
|
||||
Operation *newOperation =
|
||||
cloneWithInferType(rewriter, currOperation, mapping);
|
||||
newOperation->moveAfter(currOperation);
|
||||
currOperation = newOperation;
|
||||
currOperand = currOperation->getResult(0);
|
||||
}
|
||||
// compute target type for the layout cast
|
||||
auto currType = currOperand.getType().cast<RankedTensorType>();
|
||||
auto newType = RankedTensorType::get(
|
||||
currType.getShape(), currType.getElementType(), targetLayout);
|
||||
auto newOperand = rewriter.create<triton::gpu::ConvertLayoutOp>(
|
||||
currOperand.getLoc(), newType, currOperand);
|
||||
if (currOperation)
|
||||
newOperand->moveAfter(currOperation);
|
||||
else {
|
||||
Block *block = currOperand.cast<BlockArgument>().getOwner();
|
||||
newOperand->moveAfter(block, block->begin());
|
||||
}
|
||||
mapping.map(currOperand, newOperand);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace mlir
|
||||
|
||||
@@ -2,11 +2,32 @@
|
||||
#define TRITON_LIB_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
|
||||
#include "mlir/IR/Matchers.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "llvm/ADT/MapVector.h"
|
||||
|
||||
namespace mlir {
|
||||
|
||||
LogicalResult fixupLoops(ModuleOp mod);
|
||||
|
||||
// TODO: Interface
|
||||
LogicalResult invertEncoding(Attribute targetEncoding, Operation *op,
|
||||
Attribute &ret);
|
||||
|
||||
bool expensiveLoadOrStore(Operation *op, Attribute &targetEncoding);
|
||||
|
||||
bool expensiveToRemat(Operation *op, Attribute &targetEncoding);
|
||||
|
||||
int simulateBackwardRematerialization(
|
||||
Operation *initOp, SetVector<Operation *> &processed,
|
||||
SetVector<Attribute> &layout, llvm::MapVector<Value, Attribute> &toConvert,
|
||||
const Attribute &targetEncoding);
|
||||
|
||||
Operation *cloneWithInferType(mlir::PatternRewriter &rewriter, Operation *op,
|
||||
IRMapping &mapping);
|
||||
|
||||
void rematerializeConversionChain(
|
||||
const llvm::MapVector<Value, Attribute> &toConvert,
|
||||
mlir::PatternRewriter &rewriter, SetVector<Operation *> &processed,
|
||||
IRMapping &mapping);
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TRITON_LIB_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
#include "mlir/Target/LLVMIR/Export.h"
|
||||
#include "mlir/Target/LLVMIR/LLVMTranslationInterface.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/ArithToIndexPass.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
|
||||
#include "triton/Tools/Sys/GetEnv.hpp"
|
||||
#include "llvm/IR/CallingConv.h"
|
||||
@@ -309,12 +310,14 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
/*printAfterOnlyOnFailure*/ false, llvm::dbgs(), printingFlags);
|
||||
|
||||
pm.addPass(mlir::createConvertSCFToCFPass());
|
||||
pm.addPass(createTritonConvertArithToIndexPass());
|
||||
pm.addPass(mlir::createConvertIndexToLLVMPass());
|
||||
pm.addPass(createConvertTritonGPUToLLVMPass(computeCapability));
|
||||
// Canonicalize to eliminate the remaining UnrealizedConversionCastOp
|
||||
pm.addPass(mlir::createArithToLLVMConversionPass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
pm.addPass(mlir::createCSEPass()); // Simplify the IR to improve readability.
|
||||
// Simplify the IR
|
||||
pm.addPass(mlir::createCSEPass());
|
||||
pm.addPass(mlir::createSymbolDCEPass());
|
||||
pm.addPass(mlir::createCanonicalizerPass());
|
||||
#ifdef USE_ROCM
|
||||
pm.addPass(mlir::createConvertSCFToCFPass());
|
||||
pm.addPass(createConvertControlFlowToLLVMPass());
|
||||
@@ -325,6 +328,7 @@ translateTritonGPUToLLVMIR(llvm::LLVMContext *llvmContext,
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// llvm::outs() << module << "\n";
|
||||
auto llvmIR = translateLLVMToLLVMIR(llvmContext, module);
|
||||
if (!llvmIR) {
|
||||
llvm::errs() << "Translate to LLVM IR failed";
|
||||
|
||||
@@ -50,7 +50,6 @@ std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
|
||||
int ptxMajor = maxPTX / 10;
|
||||
int ptxMinor = maxPTX % 10;
|
||||
// create
|
||||
llvm::SmallVector<char, 0> buffer;
|
||||
std::string triple = "nvptx64-nvidia-cuda";
|
||||
std::string proc = "sm_" + std::to_string(maxCC);
|
||||
std::string layout = "";
|
||||
@@ -82,17 +81,19 @@ std::string translateLLVMIRToPTX(llvm::Module &module, int cc, int version) {
|
||||
else
|
||||
module.setDataLayout(layout);
|
||||
// emit machine code
|
||||
for (llvm::Function &f : module.functions())
|
||||
f.addFnAttr(llvm::Attribute::AlwaysInline);
|
||||
llvm::legacy::PassManager pass;
|
||||
llvm::raw_svector_ostream stream(buffer);
|
||||
// emit
|
||||
machine->addPassesToEmitFile(pass, stream, nullptr,
|
||||
llvm::CodeGenFileType::CGFT_AssemblyFile);
|
||||
pass.run(module);
|
||||
|
||||
std::string result;
|
||||
{
|
||||
llvm::raw_string_ostream stream(result);
|
||||
llvm::buffer_ostream pstream(stream);
|
||||
for (llvm::Function &f : module.functions())
|
||||
f.addFnAttr(llvm::Attribute::AlwaysInline);
|
||||
llvm::legacy::PassManager pass;
|
||||
// emit
|
||||
machine->addPassesToEmitFile(pass, pstream, nullptr,
|
||||
llvm::CodeGenFileType::CGFT_AssemblyFile);
|
||||
pass.run(module);
|
||||
}
|
||||
// post-process
|
||||
std::string result(buffer.begin(), buffer.end());
|
||||
findAndReplace(result, ".version", "\n",
|
||||
".version " + std::to_string(ptxMajor) + "." +
|
||||
std::to_string(ptxMinor) + "\n");
|
||||
|
||||
@@ -9,6 +9,7 @@ import tarfile
|
||||
import tempfile
|
||||
import urllib.request
|
||||
from distutils.version import LooseVersion
|
||||
from pathlib import Path
|
||||
from typing import NamedTuple
|
||||
|
||||
from setuptools import Extension, setup
|
||||
@@ -38,7 +39,6 @@ class Package(NamedTuple):
|
||||
package: str
|
||||
name: str
|
||||
url: str
|
||||
test_file: str
|
||||
include_flag: str
|
||||
lib_flag: str
|
||||
syspath_var_name: str
|
||||
@@ -49,7 +49,7 @@ class Package(NamedTuple):
|
||||
def get_pybind11_package_info():
|
||||
name = "pybind11-2.10.0"
|
||||
url = "https://github.com/pybind/pybind11/archive/refs/tags/v2.10.0.tar.gz"
|
||||
return Package("pybind11", name, url, "include/pybind11/pybind11.h", "PYBIND11_INCLUDE_DIR", "", "PYBIND11_SYSPATH")
|
||||
return Package("pybind11", name, url, "PYBIND11_INCLUDE_DIR", "", "PYBIND11_SYSPATH")
|
||||
|
||||
# llvm
|
||||
|
||||
@@ -65,12 +65,13 @@ def get_llvm_package_info():
|
||||
linux_suffix = 'ubuntu-18.04' if vglibc > 217 else 'centos-7'
|
||||
system_suffix = f"linux-gnu-{linux_suffix}"
|
||||
else:
|
||||
return Package("llvm", "LLVM-C.lib", "", "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
|
||||
return Package("llvm", "LLVM-C.lib", "", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
|
||||
use_assert_enabled_llvm = check_env_flag("TRITON_USE_ASSERT_ENABLED_LLVM", "False")
|
||||
release_suffix = "assert" if use_assert_enabled_llvm else "release"
|
||||
name = f'llvm+mlir-17.0.0-x86_64-{system_suffix}-{release_suffix}'
|
||||
url = f"https://github.com/ptillet/triton-llvm-releases/releases/download/llvm-17.0.0-37b7a60cd74b/{name}.tar.xz"
|
||||
return Package("llvm", name, url, "lib", "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
|
||||
version = "llvm-17.0.0-8e5a41e8271f"
|
||||
url = f"https://github.com/ptillet/triton-llvm-releases/releases/download/{version}/{name}.tar.xz"
|
||||
return Package("llvm", name, url, "LLVM_INCLUDE_DIRS", "LLVM_LIBRARY_DIR", "LLVM_SYSPATH")
|
||||
|
||||
|
||||
def get_thirdparty_packages(triton_cache_path):
|
||||
@@ -81,8 +82,9 @@ def get_thirdparty_packages(triton_cache_path):
|
||||
package_dir = os.path.join(package_root_dir, p.name)
|
||||
if p.syspath_var_name in os.environ:
|
||||
package_dir = os.environ[p.syspath_var_name]
|
||||
test_file_path = os.path.join(package_dir, p.test_file)
|
||||
if not os.path.exists(test_file_path):
|
||||
version_file_path = os.path.join(package_dir, "version.txt")
|
||||
if not os.path.exists(version_file_path) or\
|
||||
Path(version_file_path).read_text() != p.url:
|
||||
try:
|
||||
shutil.rmtree(package_root_dir)
|
||||
except Exception:
|
||||
@@ -92,6 +94,9 @@ def get_thirdparty_packages(triton_cache_path):
|
||||
ftpstream = urllib.request.urlopen(p.url)
|
||||
file = tarfile.open(fileobj=ftpstream, mode="r|*")
|
||||
file.extractall(path=package_root_dir)
|
||||
# write version url to package_dir
|
||||
with open(os.path.join(package_dir, "version.txt"), "w") as f:
|
||||
f.write(p.url)
|
||||
if p.include_flag:
|
||||
thirdparty_cmake_args.append(f"-D{p.include_flag}={package_dir}/include")
|
||||
if p.lib_flag:
|
||||
@@ -208,14 +213,14 @@ download_and_copy_ptxas()
|
||||
|
||||
setup(
|
||||
name="triton",
|
||||
version="2.0.0",
|
||||
version="2.1.0",
|
||||
author="Philippe Tillet",
|
||||
author_email="phil@openai.com",
|
||||
description="A language and compiler for custom Deep Learning operations",
|
||||
long_description="",
|
||||
packages=["triton", "triton/_C", "triton/language", "triton/tools", "triton/impl", "triton/ops", "triton/runtime", "triton/ops/blocksparse"],
|
||||
install_requires=[
|
||||
"cmake",
|
||||
"cmake>=3.20",
|
||||
"filelock",
|
||||
"torch",
|
||||
"lit",
|
||||
|
||||
@@ -13,6 +13,8 @@
|
||||
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
|
||||
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
|
||||
#include "mlir/Dialect/Index/IR/IndexDialect.h"
|
||||
#include "mlir/Dialect/Index/IR/IndexOps.h"
|
||||
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
|
||||
#include "triton/Analysis/Allocation.h"
|
||||
#include "triton/Conversion/TritonGPUToLLVM/TritonGPUToLLVMPass.h"
|
||||
@@ -45,6 +47,7 @@
|
||||
#include <pybind11/stl.h>
|
||||
#include <pybind11/stl_bind.h>
|
||||
#include <regex>
|
||||
#include <signal.h>
|
||||
#include <sstream>
|
||||
#include <stdexcept>
|
||||
#include <string>
|
||||
@@ -119,6 +122,7 @@ void init_triton_ir(py::module &&m) {
|
||||
.def(py::init<>())
|
||||
.def("load_triton", [](mlir::MLIRContext &self) {
|
||||
self.getOrLoadDialect<mlir::triton::TritonDialect>();
|
||||
self.getOrLoadDialect<mlir::index::IndexDialect>();
|
||||
// we load LLVM because the frontend uses LLVM.undef for
|
||||
// some placeholders
|
||||
self.getOrLoadDialect<mlir::triton::TritonDialect>();
|
||||
@@ -394,8 +398,8 @@ void init_triton_ir(py::module &&m) {
|
||||
registry.insert<mlir::triton::TritonDialect,
|
||||
mlir::triton::gpu::TritonGPUDialect,
|
||||
mlir::math::MathDialect, mlir::arith::ArithDialect,
|
||||
mlir::func::FuncDialect, mlir::scf::SCFDialect,
|
||||
mlir::cf::ControlFlowDialect>();
|
||||
mlir::index::IndexDialect, mlir::func::FuncDialect,
|
||||
mlir::scf::SCFDialect, mlir::cf::ControlFlowDialect>();
|
||||
context.appendDialectRegistry(registry);
|
||||
context.loadAllAvailableDialects();
|
||||
|
||||
@@ -513,6 +517,12 @@ void init_triton_ir(py::module &&m) {
|
||||
return mlir::Value(self.create<mlir::arith::ConstantIntOp>(
|
||||
loc, v, self.getI8Type()));
|
||||
})
|
||||
.def("get_int16",
|
||||
[](mlir::OpBuilder &self, int64_t v) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return mlir::Value(self.create<mlir::arith::ConstantIntOp>(
|
||||
loc, v, self.getI16Type()));
|
||||
})
|
||||
.def("get_int32",
|
||||
[](mlir::OpBuilder &self, int64_t v) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
@@ -525,16 +535,15 @@ void init_triton_ir(py::module &&m) {
|
||||
return mlir::Value(self.create<mlir::arith::ConstantIntOp>(
|
||||
loc, v, self.getI64Type()));
|
||||
})
|
||||
// bfloat16 cannot be initialized as it is treated as int16 for now
|
||||
//.def("get_bf16",
|
||||
// [](mlir::OpBuilder &self, float v) -> mlir::Value {
|
||||
// auto loc = self.getUnknownLoc();
|
||||
// auto type = self.getBF16Type();
|
||||
// return self.create<mlir::arith::ConstantFloatOp>(
|
||||
// loc,
|
||||
// mlir::APFloat(type.getFloatSemantics(), std::to_string(v)),
|
||||
// type);
|
||||
// })
|
||||
.def("get_bf16",
|
||||
[](mlir::OpBuilder &self, float v) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
auto type = self.getBF16Type();
|
||||
return self.create<mlir::arith::ConstantFloatOp>(
|
||||
loc,
|
||||
mlir::APFloat(type.getFloatSemantics(), std::to_string(v)),
|
||||
type);
|
||||
})
|
||||
.def("get_fp16",
|
||||
[](mlir::OpBuilder &self, float v) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
@@ -547,6 +556,12 @@ void init_triton_ir(py::module &&m) {
|
||||
return self.create<mlir::arith::ConstantOp>(
|
||||
loc, self.getF32FloatAttr(v));
|
||||
})
|
||||
.def("get_fp64",
|
||||
[](mlir::OpBuilder &self, double v) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::arith::ConstantOp>(
|
||||
loc, self.getF64FloatAttr(v));
|
||||
})
|
||||
.def("get_null_value",
|
||||
[](mlir::OpBuilder &self, mlir::Type type) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
@@ -598,9 +613,13 @@ void init_triton_ir(py::module &&m) {
|
||||
.def(
|
||||
"get_int64_ty",
|
||||
[](mlir::OpBuilder &self) -> mlir::Type { return self.getI64Type(); })
|
||||
.def("get_fp8_ty",
|
||||
.def("get_fp8e4_ty",
|
||||
[](mlir::OpBuilder &self) -> mlir::Type {
|
||||
return self.getType<mlir::triton::Float8Type>();
|
||||
return self.getType<mlir::Float8E4M3FNType>();
|
||||
})
|
||||
.def("get_fp8e5_ty",
|
||||
[](mlir::OpBuilder &self) -> mlir::Type {
|
||||
return self.getType<mlir::Float8E5M2Type>();
|
||||
})
|
||||
.def(
|
||||
"get_half_ty",
|
||||
@@ -817,7 +836,7 @@ void init_triton_ir(py::module &&m) {
|
||||
[](mlir::OpBuilder &self, mlir::Value &input) -> mlir::Value {
|
||||
auto loc = self.getUnknownLoc();
|
||||
return self.create<mlir::arith::IndexCastOp>(
|
||||
loc, self.getI32Type(), input);
|
||||
loc, self.getI64Type(), input);
|
||||
})
|
||||
.def("create_fmul",
|
||||
[](mlir::OpBuilder &self, mlir::Value &lhs,
|
||||
@@ -1428,16 +1447,32 @@ void init_triton_ir(py::module &&m) {
|
||||
return self.create<mlir::arith::SelectOp>(loc, condition,
|
||||
trueValue, falseValue);
|
||||
})
|
||||
.def("create_printf",
|
||||
.def("create_print",
|
||||
[](mlir::OpBuilder &self, const std::string &prefix,
|
||||
const std::vector<mlir::Value> &values) -> void {
|
||||
auto loc = self.getUnknownLoc();
|
||||
self.create<mlir::triton::PrintfOp>(
|
||||
self.create<mlir::triton::PrintOp>(
|
||||
loc,
|
||||
mlir::StringAttr::get(self.getContext(),
|
||||
llvm::StringRef(prefix)),
|
||||
values);
|
||||
})
|
||||
.def("create_assert",
|
||||
[](mlir::OpBuilder &self, mlir::Value &condition,
|
||||
const std::string &message, const std::string &fileName,
|
||||
const std::string &funcName, unsigned lineNo) -> void {
|
||||
auto loc = self.getUnknownLoc();
|
||||
auto messageAttr = mlir::StringAttr::get(self.getContext(),
|
||||
llvm::StringRef(message));
|
||||
auto fileNameAttr = mlir::StringAttr::get(
|
||||
self.getContext(), llvm::StringRef(fileName));
|
||||
auto funcNameAttr = mlir::StringAttr::get(
|
||||
self.getContext(), llvm::StringRef(funcName));
|
||||
auto lineNoAttr = self.getI32IntegerAttr(lineNo);
|
||||
self.create<mlir::triton::AssertOp>(loc, condition, messageAttr,
|
||||
fileNameAttr, funcNameAttr,
|
||||
lineNoAttr);
|
||||
})
|
||||
// Undef
|
||||
.def("create_undef",
|
||||
[](mlir::OpBuilder &self, mlir::Type &type) -> mlir::Value {
|
||||
@@ -1521,18 +1556,14 @@ void init_triton_ir(py::module &&m) {
|
||||
self.addPass(
|
||||
mlir::createTritonGPUAccelerateMatmulPass(computeCapability));
|
||||
})
|
||||
.def("add_tritongpu_fuse_transpositions_pass",
|
||||
.def("add_tritongpu_optimize_dot_operands_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPUFuseTranspositionsPass());
|
||||
self.addPass(mlir::createTritonGPUOptimizeDotOperandsPass());
|
||||
})
|
||||
.def("add_tritongpu_remove_layout_conversions_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPURemoveLayoutConversionsPass());
|
||||
})
|
||||
.def("add_tritongpu_update_mma_for_volta_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPUUpdateMmaForVoltaPass());
|
||||
})
|
||||
.def("add_tritongpu_reorder_instructions_pass",
|
||||
[](mlir::PassManager &self) {
|
||||
self.addPass(mlir::createTritonGPUReorderInstructionsPass());
|
||||
@@ -1611,7 +1642,6 @@ void init_triton_translation(py::module &m) {
|
||||
llvm::sys::fs::createTemporaryFile("compile-ptx-src", "", fsrc);
|
||||
llvm::sys::fs::createTemporaryFile("compile-ptx-log", "", flog);
|
||||
std::string fbin = std::string(fsrc) + ".o";
|
||||
llvm::FileRemover srcRemover(fsrc);
|
||||
llvm::FileRemover logRemover(flog);
|
||||
llvm::FileRemover binRemover(fbin);
|
||||
const char *_fsrc = fsrc.c_str();
|
||||
@@ -1628,16 +1658,30 @@ void init_triton_translation(py::module &m) {
|
||||
|
||||
err = system(cmd.c_str());
|
||||
if (err != 0) {
|
||||
err >>= 8;
|
||||
std::ifstream _log(_flog);
|
||||
std::string log(std::istreambuf_iterator<char>(_log), {});
|
||||
throw std::runtime_error("Internal Triton PTX codegen error: \n" +
|
||||
log);
|
||||
if (err == 255) {
|
||||
throw std::runtime_error("Internal Triton PTX codegen error: \n" +
|
||||
log);
|
||||
} else if (err == 128 + SIGSEGV) {
|
||||
throw std::runtime_error("Please run `ptxas " + fsrc.str().str() +
|
||||
"` to confirm that this is a "
|
||||
"bug in `ptxas`\n" +
|
||||
log);
|
||||
} else {
|
||||
throw std::runtime_error("`ptxas` failed with error code " +
|
||||
std::to_string(err) + ": \n" + log);
|
||||
}
|
||||
return {};
|
||||
} else {
|
||||
llvm::FileRemover srcRemover(fsrc);
|
||||
std::ifstream _cubin(_fbin, std::ios::binary);
|
||||
std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
|
||||
_cubin.close();
|
||||
py::bytes bytes(cubin);
|
||||
return std::move(bytes);
|
||||
}
|
||||
std::ifstream _cubin(_fbin, std::ios::binary);
|
||||
std::string cubin(std::istreambuf_iterator<char>(_cubin), {});
|
||||
_cubin.close();
|
||||
py::bytes bytes(cubin);
|
||||
return std::move(bytes);
|
||||
});
|
||||
|
||||
m.def("add_external_libs",
|
||||
|
||||
45
python/test/unit/language/assert_helper.py
Normal file
45
python/test/unit/language/assert_helper.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_device_assert(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.device_assert(x == 0, "x != 0")
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_assert(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
assert x == 0, "x != 0"
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_static_assert(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.static_assert(BLOCK == 128, "BLOCK != 128")
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
def test_assert(func: str):
|
||||
shape = (128, )
|
||||
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda')
|
||||
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||
if func == "device_assert":
|
||||
kernel_device_assert[(1,)](x, y, BLOCK=shape[0])
|
||||
elif func == "assert":
|
||||
kernel_assert[(1,)](x, y, BLOCK=shape[0])
|
||||
elif func == "static_assert":
|
||||
kernel_static_assert[(1,)](x, y, BLOCK=shape[0])
|
||||
assert_close(y, x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_assert(sys.argv[1])
|
||||
46
python/test/unit/language/print_helper.py
Normal file
46
python/test/unit/language/print_helper.py
Normal file
@@ -0,0 +1,46 @@
|
||||
import sys
|
||||
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_device_print(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.device_print("", x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_print(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
print("", x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
@triton.jit
|
||||
def kernel_static_print(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.static_print(x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
|
||||
def test_print(func: str, data_type: str):
|
||||
shape = (128, )
|
||||
# limit the range of integers so that the sum does not overflow
|
||||
x = torch.arange(0, shape[0], dtype=torch.int32, device='cuda').to(getattr(torch, data_type))
|
||||
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||
if func == "device_print":
|
||||
kernel_device_print[(1,)](x, y, BLOCK=shape[0])
|
||||
elif func == "print":
|
||||
kernel_print[(1,)](x, y, BLOCK=shape[0])
|
||||
elif func == "static_print":
|
||||
kernel_static_print[(1,)](x, y, BLOCK=shape[0])
|
||||
assert_close(y, x)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_print(sys.argv[1], sys.argv[2])
|
||||
@@ -1,56 +0,0 @@
|
||||
import torch
|
||||
from torch.testing import assert_close
|
||||
|
||||
import triton
|
||||
import triton.language as tl
|
||||
|
||||
torch_type = {
|
||||
"bool": torch.bool,
|
||||
'int8': torch.int8,
|
||||
'uint8': torch.uint8,
|
||||
'int16': torch.int16,
|
||||
"int32": torch.int32,
|
||||
'int64': torch.long,
|
||||
'float16': torch.float16,
|
||||
'bfloat16': torch.bfloat16,
|
||||
"float32": torch.float32,
|
||||
"float64": torch.float64
|
||||
}
|
||||
|
||||
|
||||
def get_tensor(shape, data_type, b_positive=False):
|
||||
x = None
|
||||
if data_type.startswith('int'):
|
||||
x = torch.arange(0, shape[0], dtype=torch_type[data_type], device='cuda')
|
||||
else:
|
||||
x = torch.arange(0, shape[0], dtype=torch_type[data_type], device='cuda')
|
||||
|
||||
return x
|
||||
|
||||
# @pytest.mark.parametrize('data_type',
|
||||
# [("int8"),
|
||||
# ('int16'),
|
||||
# ('int32'),
|
||||
# ("int64"),
|
||||
# ('float16'),
|
||||
# ("float32"),
|
||||
# ("float64")])
|
||||
|
||||
|
||||
def printf(data_type):
|
||||
@triton.jit
|
||||
def kernel(X, Y, BLOCK: tl.constexpr):
|
||||
x = tl.load(X + tl.arange(0, BLOCK))
|
||||
tl.printf("", x)
|
||||
tl.store(Y + tl.arange(0, BLOCK), x)
|
||||
|
||||
shape = (128, )
|
||||
# limit the range of integers so that the sum does not overflow
|
||||
x = get_tensor(shape, data_type)
|
||||
y = torch.zeros(shape, dtype=x.dtype, device="cuda")
|
||||
kernel[(1,)](x, y, BLOCK=shape[0])
|
||||
assert_close(y, x)
|
||||
|
||||
|
||||
printf("float16")
|
||||
printf("int8")
|
||||
@@ -385,17 +385,22 @@ def test_where(dtype):
|
||||
@triton.jit
|
||||
def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
TEST_POINTERS: tl.constexpr):
|
||||
TEST_POINTERS: tl.constexpr,
|
||||
TEST_SCALAR_POINTERS: tl.constexpr):
|
||||
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
decide = tl.load(cond_ptr + offsets, mask=mask)
|
||||
if TEST_POINTERS:
|
||||
a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t)
|
||||
b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t)
|
||||
if TEST_SCALAR_POINTERS:
|
||||
ptr = tl.where(tl.load(cond_ptr), a_ptr, b_ptr)
|
||||
output = tl.load(ptr + offsets, mask=mask)
|
||||
else:
|
||||
a = tl.load(a_ptr + offsets, mask=mask)
|
||||
b = tl.load(b_ptr + offsets, mask=mask)
|
||||
output = tl.where(decide, a, b)
|
||||
if TEST_POINTERS:
|
||||
a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t)
|
||||
b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t)
|
||||
else:
|
||||
a = tl.load(a_ptr + offsets, mask=mask)
|
||||
b = tl.load(b_ptr + offsets, mask=mask)
|
||||
output = tl.where(decide, a, b)
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
SIZE = 1_000
|
||||
@@ -411,8 +416,12 @@ def test_where(dtype):
|
||||
z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device='cuda', dst_type=dtype)
|
||||
|
||||
grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']),)
|
||||
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs)
|
||||
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, TEST_SCALAR_POINTERS=False)
|
||||
assert (z == to_numpy(z_tri)).all()
|
||||
if select_ptrs:
|
||||
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, TEST_SCALAR_POINTERS=True)
|
||||
z = np.where(cond[0], x, y)
|
||||
assert (z == to_numpy(z_tri)).all()
|
||||
|
||||
|
||||
def test_where_broadcast():
|
||||
@@ -683,6 +692,22 @@ def test_tensor_atomic_rmw(shape, axis, device="cuda"):
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
|
||||
|
||||
|
||||
def test_tensor_atomic_rmw_block(device="cuda"):
|
||||
shape = (8, 8)
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
|
||||
off0 = tl.arange(0, SHAPE0)
|
||||
off1 = tl.arange(0, SHAPE1)
|
||||
offs = off0[:, None] * SHAPE1 + off1[None, :]
|
||||
val = offs.to(tl.float32)
|
||||
x = X + offs
|
||||
tl.atomic_min(x, val)
|
||||
x = torch.ones((8, 8), device=device, dtype=torch.float32)
|
||||
kernel[(2,)](x, shape[0], shape[1])
|
||||
assert torch.min(x).item() == 0.0
|
||||
|
||||
|
||||
def test_atomic_cas():
|
||||
# 1. make sure that atomic_cas changes the original value (Lock)
|
||||
@triton.jit
|
||||
@@ -798,10 +823,25 @@ def test_store_constant(dtype_str):
|
||||
assert torch.all(output == ref)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_f8_xf16_roundtrip(dtype):
|
||||
def test_load_store_same_ptr():
|
||||
@triton.jit()
|
||||
def kernel(in_out_ptr):
|
||||
pid = tl.program_id(axis=0)
|
||||
x = tl.load(in_out_ptr + pid)
|
||||
out = x * 2
|
||||
tl.store(in_out_ptr + pid, out)
|
||||
|
||||
for _ in range(1000):
|
||||
x = torch.ones((65536,), device="cuda", dtype=torch.float32)
|
||||
kernel[(65536,)](x, num_warps=32)
|
||||
assert torch.all(x == 2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("in_dtype", [tl.float8e4, tl.float8e5])
|
||||
@pytest.mark.parametrize("out_dtype", [torch.float16, torch.float32])
|
||||
def test_f8_xf16_roundtrip(in_dtype, out_dtype):
|
||||
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
|
||||
check_type_supported(dtype)
|
||||
check_type_supported(out_dtype)
|
||||
|
||||
@triton.jit
|
||||
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
@@ -812,20 +852,24 @@ def test_f8_xf16_roundtrip(dtype):
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda')
|
||||
f8 = triton.reinterpret(f8_tensor, tl.float8)
|
||||
# f32_to_f8 doesn't handle nan, so we make sure f8_tensor doesn't contain any nan
|
||||
all_exp_ones = (f8_tensor & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width
|
||||
f8_tensor[all_exp_ones] = 0
|
||||
f8 = triton.reinterpret(f8_tensor, in_dtype)
|
||||
n_elements = f8_tensor.numel()
|
||||
xf16 = torch.empty_like(f8_tensor, dtype=dtype)
|
||||
xf16 = torch.empty_like(f8_tensor, dtype=out_dtype)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
copy_kernel[grid](f8, xf16, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
f8_output_tensor = torch.empty_like(xf16, dtype=torch.int8)
|
||||
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
|
||||
f8_output = triton.reinterpret(f8_output_tensor, in_dtype)
|
||||
copy_kernel[grid](xf16, f8_output, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
assert torch.all(f8_tensor == f8_output_tensor)
|
||||
|
||||
|
||||
def test_f16_to_f8_rounding():
|
||||
@pytest.mark.parametrize("in_dtype", [tl.float8e4])
|
||||
def test_f16_to_f8_rounding(in_dtype):
|
||||
"""Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
|
||||
error is the minimum over all float8.
|
||||
Or the same explanation a bit mathier:
|
||||
@@ -848,7 +892,7 @@ def test_f16_to_f8_rounding():
|
||||
f16_input = torch.tensor(f16_input_np, dtype=torch.float16, device='cuda')
|
||||
n_elements = f16_input.numel()
|
||||
f8_output_tensor = torch.empty_like(f16_input, dtype=torch.int8)
|
||||
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
|
||||
f8_output = triton.reinterpret(f8_output_tensor, in_dtype)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
copy_kernel[grid](f16_input, f8_output, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
@@ -858,7 +902,7 @@ def test_f16_to_f8_rounding():
|
||||
abs_error = torch.abs(f16_input - f16_output)
|
||||
|
||||
all_f8_vals_tensor = torch.tensor(range(2 ** 8), dtype=torch.uint8, device='cuda')
|
||||
all_f8_vals = triton.reinterpret(all_f8_vals_tensor, tl.float8)
|
||||
all_f8_vals = triton.reinterpret(all_f8_vals_tensor, in_dtype)
|
||||
all_f8_vals_in_f16 = torch.empty_like(all_f8_vals_tensor, dtype=torch.float16)
|
||||
copy_kernel[grid](all_f8_vals, all_f8_vals_in_f16, n_elements=256, BLOCK_SIZE=1024)
|
||||
|
||||
@@ -1240,6 +1284,32 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
||||
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str", int_dtypes + float_dtypes + ['bfloat16'])
|
||||
def test_full(dtype_str):
|
||||
dtype = getattr(torch, dtype_str)
|
||||
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
|
||||
|
||||
@triton.jit
|
||||
def kernel_static(out):
|
||||
a = GENERATE_TEST_HERE
|
||||
out_ptr = out + tl.arange(0, 128)[:]
|
||||
tl.store(out_ptr, a)
|
||||
|
||||
@triton.jit
|
||||
def kernel_dynamic(out, val, dtype: tl.constexpr):
|
||||
a = tl.full((128,), val, dtype)
|
||||
out_ptr = out + tl.arange(0, 128)[:]
|
||||
tl.store(out_ptr, a)
|
||||
|
||||
kernel_static_patched = patch_kernel(kernel_static, {'GENERATE_TEST_HERE': f"tl.full((128,), 2, tl.{dtype_str})"})
|
||||
out_static = torch.zeros((128), dtype=dtype, device="cuda")
|
||||
kernel_static_patched[(1,)](out_static)
|
||||
out_dynamic = torch.zeros((128), dtype=dtype, device="cuda")
|
||||
kernel_dynamic[(1,)](out_dynamic, 2, getattr(triton.language, dtype_str))
|
||||
assert torch.all(out_static == 2)
|
||||
assert torch.all(out_dynamic == 2)
|
||||
|
||||
|
||||
# TODO: uncomment once DotOperandEncoding::getElemsPerThread is implemented
|
||||
# @pytest.mark.parametrize("dtype_str", ['float32', 'float16'])
|
||||
# def test_dot_without_load(dtype_str):
|
||||
@@ -1409,6 +1479,28 @@ def test_vectorization(N):
|
||||
assert "ld.global.b32" in ptx
|
||||
# triton.testing.assert_almost_equal(dst, src[:N])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("has_hints", [False, True])
|
||||
def test_vectorization_hints(has_hints):
|
||||
src = torch.empty(1024, device='cuda')
|
||||
dst = torch.empty(1024, device='cuda')
|
||||
off = torch.zeros(1, device='cuda', dtype=torch.int32)
|
||||
|
||||
@triton.jit
|
||||
def _kernel(dst, src, off, N, BLOCK_SIZE: tl.constexpr, HINT: tl.constexpr):
|
||||
offsets = tl.program_id(0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
offsets = offsets + tl.load(off)
|
||||
if HINT:
|
||||
tl.max_contiguous(tl.multiple_of(offsets, 1024), 1024)
|
||||
x = tl.load(src + offsets, mask=offsets < N)
|
||||
tl.store(dst + offsets, x, mask=offsets < N)
|
||||
pgm = _kernel[(1,)](dst, src, off, N=1024, BLOCK_SIZE=src.shape[0], HINT=has_hints)
|
||||
ptx = pgm.asm["ptx"]
|
||||
if has_hints:
|
||||
assert "ld.global.v4.b32" in ptx
|
||||
else:
|
||||
assert "ld.global.v4.b32" not in ptx
|
||||
|
||||
# ---------------
|
||||
# test store
|
||||
# ---------------
|
||||
@@ -1479,7 +1571,7 @@ def test_pointer_arguments(device):
|
||||
|
||||
@pytest.mark.parametrize("value, value_type", [
|
||||
(-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'),
|
||||
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'),
|
||||
(2**31, 'i64'), (2**32 - 1, 'i64'), (2**32, 'i64'), (2**63 - 1, 'i64'),
|
||||
(-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')
|
||||
])
|
||||
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
|
||||
@@ -1739,6 +1831,23 @@ def test_libdevice_scalar(dtype_str, expr, lib_path):
|
||||
# -----------------------
|
||||
|
||||
|
||||
def test_for_iv_int64():
|
||||
|
||||
@triton.jit
|
||||
def kernel(Out, lo, hi):
|
||||
acc = 0
|
||||
acc = acc.to(tl.int64)
|
||||
for i in range(lo, hi):
|
||||
acc += i
|
||||
tl.store(Out, acc)
|
||||
|
||||
lo = 2**35
|
||||
hi = 2**35 + 20
|
||||
out = to_triton(np.zeros((1,), dtype=np.int64), device='cuda')
|
||||
kernel[(1,)](out, lo, hi)
|
||||
assert out[0] == sum(range(lo, hi))
|
||||
|
||||
|
||||
def test_if_else():
|
||||
|
||||
@triton.jit
|
||||
@@ -1972,3 +2081,42 @@ def test_load_scalar_with_mask():
|
||||
Out = torch.empty_like(Index, device='cuda')
|
||||
kernel[(1,)](Input, Index, Out, Index.numel())
|
||||
assert Out.data[0] == 0
|
||||
|
||||
|
||||
# This test is used to test our own PTX codegen for float16 and int16 conversions
|
||||
# maybe delete it later after ptxas has been fixed
|
||||
@pytest.mark.parametrize("dtype_str", ['float16', 'int16'])
|
||||
def test_ptx_cast(dtype_str):
|
||||
@triton.jit
|
||||
def kernel(in_ptr0, out_ptr2, xnumel, rnumel, dtype: tl.constexpr, XBLOCK: tl.constexpr, RBLOCK: tl.constexpr):
|
||||
xoffset = tl.program_id(0) * XBLOCK
|
||||
xindex = xoffset + tl.arange(0, XBLOCK)[:, None]
|
||||
xmask = xindex < xnumel
|
||||
rbase = tl.arange(0, RBLOCK)[None, :]
|
||||
x0 = xindex
|
||||
_tmp4 = (tl.zeros([XBLOCK, RBLOCK], dtype) - 10000).to(dtype)
|
||||
for roffset in range(0, rnumel, RBLOCK):
|
||||
rindex = roffset + rbase
|
||||
rmask = rindex < rnumel
|
||||
r1 = rindex
|
||||
tmp0 = tl.load(in_ptr0 + (r1 + (197 * x0)), rmask & xmask).to(dtype)
|
||||
tmp1 = 2
|
||||
tmp2 = tmp0 * tmp1
|
||||
tmp3 = tmp2.to(dtype)
|
||||
tmp5 = _tmp4 < tmp3
|
||||
_tmp4 = tl.where(rmask & xmask & tmp5, tmp3, _tmp4)
|
||||
tl.store(out_ptr2 + (r1 + (197 * x0) + tl.zeros([XBLOCK, RBLOCK], tl.int32)), _tmp4, rmask & xmask)
|
||||
|
||||
torch.manual_seed(123)
|
||||
if dtype_str == 'int16':
|
||||
torch_dtype = torch.int16
|
||||
triton_dtype = tl.int32
|
||||
else:
|
||||
torch_dtype = torch.float16
|
||||
triton_dtype = tl.float32
|
||||
|
||||
s0 = 4
|
||||
buf11 = -torch.ones((6 * s0, 197, 197), device='cuda', dtype=torch_dtype)
|
||||
buf14 = -torch.ones((s0, 6, 197, 197), device='cuda', dtype=torch_dtype)
|
||||
kernel[(4728,)](buf11, buf14, 1182 * s0, 197, triton_dtype, 1, 256, num_warps=2)
|
||||
assert buf14.to(torch.float32).mean() == -2.0
|
||||
|
||||
@@ -385,17 +385,22 @@ def test_where(dtype):
|
||||
@triton.jit
|
||||
def where_kernel(cond_ptr, a_ptr, b_ptr, output_ptr, n_elements,
|
||||
BLOCK_SIZE: tl.constexpr,
|
||||
TEST_POINTERS: tl.constexpr):
|
||||
TEST_POINTERS: tl.constexpr,
|
||||
TEST_SCALAR_POINTERS: tl.constexpr):
|
||||
offsets = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
||||
mask = offsets < n_elements
|
||||
decide = tl.load(cond_ptr + offsets, mask=mask)
|
||||
if TEST_POINTERS:
|
||||
a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t)
|
||||
b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t)
|
||||
if TEST_SCALAR_POINTERS:
|
||||
ptr = tl.where(tl.load(cond_ptr), a_ptr, b_ptr)
|
||||
output = tl.load(ptr + offsets, mask=mask)
|
||||
else:
|
||||
a = tl.load(a_ptr + offsets, mask=mask)
|
||||
b = tl.load(b_ptr + offsets, mask=mask)
|
||||
output = tl.where(decide, a, b)
|
||||
if TEST_POINTERS:
|
||||
a = tl.load(a_ptr + offsets, mask=mask).to(tl.pi32_t)
|
||||
b = tl.load(b_ptr + offsets, mask=mask).to(tl.pi32_t)
|
||||
else:
|
||||
a = tl.load(a_ptr + offsets, mask=mask)
|
||||
b = tl.load(b_ptr + offsets, mask=mask)
|
||||
output = tl.where(decide, a, b)
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
SIZE = 1_000
|
||||
@@ -411,8 +416,12 @@ def test_where(dtype):
|
||||
z_tri = to_triton(np.empty(SIZE, dtype=z.dtype), device='cuda', dst_type=dtype)
|
||||
|
||||
grid = lambda meta: (triton.cdiv(SIZE, meta['BLOCK_SIZE']),)
|
||||
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs)
|
||||
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, TEST_SCALAR_POINTERS=False)
|
||||
assert (z == to_numpy(z_tri)).all()
|
||||
if select_ptrs:
|
||||
where_kernel[grid](cond_tri, x_tri, y_tri, z_tri, SIZE, BLOCK_SIZE=1024, TEST_POINTERS=select_ptrs, TEST_SCALAR_POINTERS=True)
|
||||
z = np.where(cond[0], x, y)
|
||||
assert (z == to_numpy(z_tri)).all()
|
||||
|
||||
|
||||
def test_where_broadcast():
|
||||
@@ -683,6 +692,22 @@ def test_tensor_atomic_rmw(shape, axis, device="cuda"):
|
||||
np.testing.assert_allclose(z_ref, to_numpy(z_tri), rtol=1e-4)
|
||||
|
||||
|
||||
def test_tensor_atomic_rmw_block(device="cuda"):
|
||||
shape = (8, 8)
|
||||
|
||||
@triton.jit
|
||||
def kernel(X, SHAPE0: tl.constexpr, SHAPE1: tl.constexpr):
|
||||
off0 = tl.arange(0, SHAPE0)
|
||||
off1 = tl.arange(0, SHAPE1)
|
||||
offs = off0[:, None] * SHAPE1 + off1[None, :]
|
||||
val = offs.to(tl.float32)
|
||||
x = X + offs
|
||||
tl.atomic_min(x, val)
|
||||
x = torch.ones((8, 8), device=device, dtype=torch.float32)
|
||||
kernel[(2,)](x, shape[0], shape[1])
|
||||
assert torch.min(x).item() == 0.0
|
||||
|
||||
|
||||
def test_atomic_cas():
|
||||
# 1. make sure that atomic_cas changes the original value (Lock)
|
||||
@triton.jit
|
||||
@@ -798,10 +823,25 @@ def test_store_constant(dtype_str):
|
||||
assert torch.all(output == ref)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16])
|
||||
def test_f8_xf16_roundtrip(dtype):
|
||||
def test_load_store_same_ptr():
|
||||
@triton.jit()
|
||||
def kernel(in_out_ptr):
|
||||
pid = tl.program_id(axis=0)
|
||||
x = tl.load(in_out_ptr + pid)
|
||||
out = x * 2
|
||||
tl.store(in_out_ptr + pid, out)
|
||||
|
||||
for _ in range(1000):
|
||||
x = torch.ones((65536,), device="cuda", dtype=torch.float32)
|
||||
kernel[(65536,)](x, num_warps=16)
|
||||
assert torch.all(x == 2)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("in_dtype", [tl.float8e4]) # TODO: support tl.float8e5
|
||||
@pytest.mark.parametrize("out_dtype", [torch.float16]) # TODO: support torch.float32
|
||||
def test_f8_xf16_roundtrip(in_dtype, out_dtype):
|
||||
"""Tests that converting an f8 to f16 and back to f8 doesn't change its value"""
|
||||
check_type_supported(dtype)
|
||||
check_type_supported(out_dtype)
|
||||
|
||||
@triton.jit
|
||||
def copy_kernel(input_ptr, output_ptr, n_elements, BLOCK_SIZE: tl.constexpr):
|
||||
@@ -812,20 +852,24 @@ def test_f8_xf16_roundtrip(dtype):
|
||||
tl.store(output_ptr + offsets, output, mask=mask)
|
||||
|
||||
f8_tensor = torch.tensor(range(-128, 128), dtype=torch.int8, device='cuda')
|
||||
f8 = triton.reinterpret(f8_tensor, tl.float8)
|
||||
# f32_to_f8 doesn't handle nan, so we make sure f8_tensor doesn't contain any nan
|
||||
all_exp_ones = (f8_tensor & 0b01111100) == 128 - 2**in_dtype.fp_mantissa_width
|
||||
f8_tensor[all_exp_ones] = 0
|
||||
f8 = triton.reinterpret(f8_tensor, in_dtype)
|
||||
n_elements = f8_tensor.numel()
|
||||
xf16 = torch.empty_like(f8_tensor, dtype=dtype)
|
||||
xf16 = torch.empty_like(f8_tensor, dtype=out_dtype)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
copy_kernel[grid](f8, xf16, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
f8_output_tensor = torch.empty_like(xf16, dtype=torch.int8)
|
||||
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
|
||||
f8_output = triton.reinterpret(f8_output_tensor, in_dtype)
|
||||
copy_kernel[grid](xf16, f8_output, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
assert torch.all(f8_tensor == f8_output_tensor)
|
||||
|
||||
|
||||
def test_f16_to_f8_rounding():
|
||||
@pytest.mark.parametrize("in_dtype", [tl.float8e4])
|
||||
def test_f16_to_f8_rounding(in_dtype):
|
||||
"""Takes all float16s, converts them to float8 and back to float16. Checks that the absolute
|
||||
error is the minimum over all float8.
|
||||
Or the same explanation a bit mathier:
|
||||
@@ -848,7 +892,7 @@ def test_f16_to_f8_rounding():
|
||||
f16_input = torch.tensor(f16_input_np, dtype=torch.float16, device='cuda')
|
||||
n_elements = f16_input.numel()
|
||||
f8_output_tensor = torch.empty_like(f16_input, dtype=torch.int8)
|
||||
f8_output = triton.reinterpret(f8_output_tensor, tl.float8)
|
||||
f8_output = triton.reinterpret(f8_output_tensor, in_dtype)
|
||||
grid = lambda meta: (triton.cdiv(n_elements, meta['BLOCK_SIZE']),)
|
||||
copy_kernel[grid](f16_input, f8_output, n_elements, BLOCK_SIZE=1024)
|
||||
|
||||
@@ -858,7 +902,7 @@ def test_f16_to_f8_rounding():
|
||||
abs_error = torch.abs(f16_input - f16_output)
|
||||
|
||||
all_f8_vals_tensor = torch.tensor(range(2 ** 8), dtype=torch.uint8, device='cuda')
|
||||
all_f8_vals = triton.reinterpret(all_f8_vals_tensor, tl.float8)
|
||||
all_f8_vals = triton.reinterpret(all_f8_vals_tensor, in_dtype)
|
||||
all_f8_vals_in_f16 = torch.empty_like(all_f8_vals_tensor, dtype=torch.float16)
|
||||
copy_kernel[grid](all_f8_vals, all_f8_vals_in_f16, n_elements=256, BLOCK_SIZE=1024)
|
||||
|
||||
@@ -1036,8 +1080,8 @@ def test_reduce2d(op, dtype_str, shape, axis, device='cuda'):
|
||||
[(dtype, shape, perm)
|
||||
# TODO: bfloat16
|
||||
for dtype in ['float16', 'float32']
|
||||
for shape in [(64, 64), (128, 128)]
|
||||
for perm in [(1, 0)]])
|
||||
for shape in [(64, 64), (128, 128)]
|
||||
for perm in [(1, 0)]])
|
||||
def test_permute(dtype_str, shape, perm, device='cuda'):
|
||||
check_type_supported(dtype_str) # bfloat16 on cc < 80 will not be tested
|
||||
if torch.version.hip is not None:
|
||||
@@ -1248,6 +1292,51 @@ def test_dot(M, N, K, num_warps, col_a, col_b, epilogue, allow_tf32, dtype, devi
|
||||
assert 'mma.sync.aligned.m16n8k32.row.col.satfinite.s32.s8.s8.s32' in ptx
|
||||
|
||||
|
||||
@pytest.mark.parametrize("dtype_str", int_dtypes + float_dtypes + ['bfloat16'])
|
||||
def test_full(dtype_str):
|
||||
dtype = getattr(torch, dtype_str)
|
||||
check_type_supported(dtype) # bfloat16 on cc < 80 will not be tested
|
||||
|
||||
@triton.jit
|
||||
def kernel_static(out):
|
||||
a = GENERATE_TEST_HERE
|
||||
out_ptr = out + tl.arange(0, 128)[:]
|
||||
tl.store(out_ptr, a)
|
||||
|
||||
@triton.jit
|
||||
def kernel_dynamic(out, val, dtype: tl.constexpr):
|
||||
a = tl.full((128,), val, dtype)
|
||||
out_ptr = out + tl.arange(0, 128)[:]
|
||||
tl.store(out_ptr, a)
|
||||
|
||||
kernel_static_patched = patch_kernel(kernel_static, {'GENERATE_TEST_HERE': f"tl.full((128,), 2, tl.{dtype_str})"})
|
||||
out_static = torch.zeros((128), dtype=dtype, device="cuda")
|
||||
kernel_static_patched[(1,)](out_static)
|
||||
out_dynamic = torch.zeros((128), dtype=dtype, device="cuda")
|
||||
kernel_dynamic[(1,)](out_dynamic, 2, getattr(triton.language, dtype_str))
|
||||
assert torch.all(out_static == 2)
|
||||
assert torch.all(out_dynamic == 2)
|
||||
|
||||
|
||||
# TODO: uncomment once DotOperandEncoding::getElemsPerThread is implemented
|
||||
# @pytest.mark.parametrize("dtype_str", ['float32', 'float16'])
|
||||
# def test_dot_without_load(dtype_str):
|
||||
# @triton.jit
|
||||
# def _kernel(out):
|
||||
# a = GENERATE_TEST_HERE
|
||||
# b = GENERATE_TEST_HERE
|
||||
# c = tl.dot(a, b)
|
||||
# out_ptr = out + tl.arange(0, 32)[:, None] * 32 + tl.arange(0, 32)[None, :]
|
||||
# tl.store(out_ptr, c)
|
||||
|
||||
# kernel = patch_kernel(_kernel, {'GENERATE_TEST_HERE': f"tl.full((32, 32), 1.0, tl.{dtype_str})"})
|
||||
# a = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
|
||||
# b = torch.ones((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
|
||||
# out_ref = torch.matmul(a, b)
|
||||
# out = torch.zeros((32, 32), dtype=getattr(torch, dtype_str), device="cuda")
|
||||
# kernel[(1,)](out)
|
||||
# assert torch.all(out == out_ref)
|
||||
|
||||
# ---------------
|
||||
# test arange
|
||||
# ---------------
|
||||
@@ -1426,7 +1515,7 @@ def test_pointer_arguments(device):
|
||||
|
||||
@pytest.mark.parametrize("value, value_type", [
|
||||
(-1, 'i32'), (0, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'),
|
||||
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**32, 'i64'), (2**63 - 1, 'i64'),
|
||||
(2**31, 'i64'), (2**32 - 1, 'i64'), (2**32, 'i64'), (2**63 - 1, 'i64'),
|
||||
(-2**63, 'i64'), (2**63, 'u64'), (2**64 - 1, 'u64')
|
||||
])
|
||||
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
|
||||
@@ -1686,6 +1775,23 @@ def test_libdevice_scalar(dtype_str, expr, lib_path):
|
||||
# -----------------------
|
||||
|
||||
|
||||
def test_for_iv_int64():
|
||||
|
||||
@triton.jit
|
||||
def kernel(Out, lo, hi):
|
||||
acc = 0
|
||||
acc = acc.to(tl.int64)
|
||||
for i in range(lo, hi):
|
||||
acc += i
|
||||
tl.store(Out, acc)
|
||||
|
||||
lo = 2**35
|
||||
hi = 2**35 + 20
|
||||
out = to_triton(np.zeros((1,), dtype=np.int64), device='cuda')
|
||||
kernel[(1,)](out, lo, hi)
|
||||
assert out[0] == sum(range(lo, hi))
|
||||
|
||||
|
||||
def test_if_else():
|
||||
|
||||
@triton.jit
|
||||
@@ -1863,7 +1969,7 @@ else:
|
||||
BlockedLayout([1, 8], [2, 16], [4, 1], [1, 0]),
|
||||
BlockedLayout([1, 4], [4, 8], [2, 2], [1, 0]),
|
||||
BlockedLayout([1, 1], [1, 32], [2, 2], [1, 0]),
|
||||
# BlockedLayout([8, 1], [16, 2], [1, 4], [0, 1]),
|
||||
BlockedLayout([8, 1], [16, 2], [1, 4], [0, 1]),
|
||||
BlockedLayout([4, 1], [8, 4], [2, 2], [0, 1]),
|
||||
BlockedLayout([1, 1], [32, 1], [2, 2], [0, 1]),
|
||||
BlockedLayout([4, 4], [1, 32], [4, 1], [1, 0])
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
printf_path = os.path.join(dir_path, "printf_helper.py")
|
||||
|
||||
|
||||
def test_printf():
|
||||
proc = subprocess.Popen([sys.executable, printf_path], stdout=subprocess.PIPE, shell=False)
|
||||
(outs, err) = proc.communicate()
|
||||
outs = outs.split()
|
||||
new_lines = set()
|
||||
for line in outs:
|
||||
try:
|
||||
value = int(float(line))
|
||||
new_lines.add(value)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
for i in range(128):
|
||||
assert i in new_lines
|
||||
assert len(new_lines) == 128
|
||||
53
python/test/unit/language/test_subprocess.py
Normal file
53
python/test/unit/language/test_subprocess.py
Normal file
@@ -0,0 +1,53 @@
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
dir_path = os.path.dirname(os.path.realpath(__file__))
|
||||
print_path = os.path.join(dir_path, "print_helper.py")
|
||||
assert_path = os.path.join(dir_path, "assert_helper.py")
|
||||
|
||||
# TODO: bfloat16 after LLVM-15
|
||||
func_types = ["device_assert", "assert", "static_assert"]
|
||||
torch_types = ["int8", "uint8", "int16", "int32", "long", "float16", "float32", "float64"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("func_type, data_type",
|
||||
[("device_print", data_type) for data_type in torch_types] + [("print", "int32"), ("static_print", "int32")])
|
||||
def test_print(func_type: str, data_type: str):
|
||||
proc = subprocess.Popen([sys.executable, print_path, func_type, data_type], stdout=subprocess.PIPE, shell=False)
|
||||
outs, _ = proc.communicate()
|
||||
outs = outs.split()
|
||||
new_lines = set()
|
||||
for line in outs:
|
||||
try:
|
||||
value = line
|
||||
if func_type != "static_print":
|
||||
value = int(float(line))
|
||||
new_lines.add(value)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
if func_type != "static_print":
|
||||
for i in range(128):
|
||||
assert i in new_lines
|
||||
assert len(new_lines) == 128
|
||||
else:
|
||||
assert len(new_lines) == 1
|
||||
|
||||
|
||||
@pytest.mark.parametrize("func_type", func_types)
|
||||
def test_assert(func_type: str):
|
||||
os.environ["TRITON_DEBUG"] = "1"
|
||||
proc = subprocess.Popen([sys.executable, assert_path, func_type], stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=False)
|
||||
_, errs = proc.communicate()
|
||||
errs = errs.splitlines()
|
||||
num_errs = 0
|
||||
for err in errs:
|
||||
if "x != 0" in err.decode("utf-8"):
|
||||
num_errs += 1
|
||||
os.environ["TRITON_DEBUG"] = "0"
|
||||
if func_type != "static_assert":
|
||||
assert num_errs == 127
|
||||
else:
|
||||
assert num_errs == 0
|
||||
@@ -5,7 +5,8 @@ import triton
|
||||
|
||||
|
||||
@pytest.mark.parametrize('Z, H, N_CTX, D_HEAD', [(4, 48, 1024, 64)])
|
||||
def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
@pytest.mark.parametrize('dtype', [torch.float16, torch.bfloat16])
|
||||
def test_op(Z, H, N_CTX, D_HEAD, dtype):
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 8:
|
||||
pytest.skip("Flash attention only supported for compute capability < 80")
|
||||
@@ -21,7 +22,7 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
for z in range(Z):
|
||||
for h in range(H):
|
||||
p[:, :, M == 0] = float("-inf")
|
||||
p = torch.softmax(p.float(), dim=-1).half()
|
||||
p = torch.softmax(p.float(), dim=-1).to(dtype)
|
||||
# p = torch.exp(p)
|
||||
ref_out = torch.matmul(p, v)
|
||||
ref_out.backward(dout)
|
||||
@@ -38,6 +39,7 @@ def test_op(Z, H, N_CTX, D_HEAD, dtype=torch.float16):
|
||||
tri_dq, q.grad = q.grad.clone(), None
|
||||
# compare
|
||||
triton.testing.assert_almost_equal(ref_out, tri_out)
|
||||
triton.testing.assert_almost_equal(ref_dv, tri_dv)
|
||||
decimal = 1 if dtype == torch.bfloat16 else 2
|
||||
triton.testing.assert_almost_equal(ref_dv, tri_dv, decimal=decimal)
|
||||
triton.testing.assert_almost_equal(ref_dk, tri_dk)
|
||||
triton.testing.assert_almost_equal(ref_dq, tri_dq)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
from collections import namedtuple
|
||||
|
||||
@@ -107,33 +106,6 @@ def test_specialize(mode):
|
||||
assert counter == target
|
||||
|
||||
|
||||
@pytest.mark.parametrize("value, value_type", [
|
||||
(-1, 'i32'), (0, 'i32'), (1, 'i32'), (-2**31, 'i32'), (2**31 - 1, 'i32'),
|
||||
(2**32, 'i64'), (2**63 - 1, 'i64'), (-2**63, 'i64'),
|
||||
(2**31, 'u32'), (2**32 - 1, 'u32'), (2**63, 'u64'), (2**64 - 1, 'u64')
|
||||
])
|
||||
def test_value_specialization(value: int, value_type: str, device='cuda') -> None:
|
||||
|
||||
@triton.jit
|
||||
def kernel(VALUE, X):
|
||||
pass
|
||||
|
||||
cache_str = None
|
||||
|
||||
def get_cache_str(*args, **kwargs):
|
||||
nonlocal cache_str
|
||||
cache_str = kwargs["repr"]
|
||||
triton.JITFunction.cache_hook = get_cache_str
|
||||
reset_tmp_dir()
|
||||
x = torch.tensor([3.14159], device='cuda')
|
||||
kernel[(1, )](value, x)
|
||||
triton.JITFunction.cache_hook = None
|
||||
|
||||
cache_str_match = re.match(r".*VALUE: (\w+).*", cache_str)
|
||||
spec_type = None if cache_str_match is None else cache_str_match.group(1)
|
||||
assert spec_type == value_type
|
||||
|
||||
|
||||
def test_constexpr_not_callable() -> None:
|
||||
@triton.jit
|
||||
def kernel(X, c: tl.constexpr):
|
||||
@@ -176,6 +148,26 @@ def test_jit_warmup_cache() -> None:
|
||||
assert len(kernel_add.cache) == 1
|
||||
|
||||
|
||||
def test_jit_debug() -> None:
|
||||
@triton.jit
|
||||
def kernel_add(a, b, o, N: tl.constexpr):
|
||||
idx = tl.arange(0, N)
|
||||
tl.device_assert(idx < 32, "idx < 32")
|
||||
tl.store(o + idx,
|
||||
tl.load(a + idx) + tl.load(b + idx))
|
||||
|
||||
device = torch.cuda.current_device()
|
||||
assert len(kernel_add.cache[device]) == 0
|
||||
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
|
||||
assert len(kernel_add.cache[device]) == 1
|
||||
kernel_add.debug = False
|
||||
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
|
||||
assert len(kernel_add.cache[device]) == 1
|
||||
kernel_add.debug = True
|
||||
kernel_add.warmup(torch.float32, torch.float32, torch.float32, 32, grid=(1,))
|
||||
assert len(kernel_add.cache[device]) == 2
|
||||
|
||||
|
||||
def test_compile_in_subproc() -> None:
|
||||
@triton.jit
|
||||
def kernel_sub(a, b, o, N: tl.constexpr):
|
||||
|
||||
@@ -39,7 +39,8 @@ def str_to_ty(name):
|
||||
ty = str_to_ty(name[1:])
|
||||
return triton.language.pointer_type(ty)
|
||||
tys = {
|
||||
"fp8": triton.language.float8,
|
||||
"fp8e5": triton.language.float8e5,
|
||||
"fp8e4": triton.language.float8e4,
|
||||
"fp16": triton.language.float16,
|
||||
"bf16": triton.language.bfloat16,
|
||||
"fp32": triton.language.float32,
|
||||
@@ -111,7 +112,7 @@ class enter_sub_region:
|
||||
self.generator.local_defs = self.prev_defs
|
||||
|
||||
class CodeGenerator(ast.NodeVisitor):
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, function_name, module=None, is_kernel=False, function_types=dict()):
|
||||
def __init__(self, context, prototype, gscope, attributes, constants, function_name, module=None, is_kernel=False, function_types=dict(), debug=False):
|
||||
self.builder = _triton.ir.builder(context)
|
||||
self.module = self.builder.create_module() if module is None else module
|
||||
self.function_ret_types = function_types
|
||||
@@ -123,15 +124,19 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
self.function_name = function_name
|
||||
self.is_kernel = is_kernel
|
||||
self.last_node = None
|
||||
self.debug = debug
|
||||
self.builtins = {
|
||||
'range': range,
|
||||
'min': triton.language.minimum,
|
||||
'float': float,
|
||||
'int': int,
|
||||
'print': print,
|
||||
'print': triton.language.core.device_print,
|
||||
'isinstance': isinstance,
|
||||
'getattr': getattr,
|
||||
}
|
||||
self.static_functions = [
|
||||
'static_print', 'static_assert'
|
||||
]
|
||||
self.scf_stack = []
|
||||
# SSA-construction
|
||||
# name => triton.language.tensor
|
||||
@@ -183,9 +188,23 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
break
|
||||
return stmts and isinstance(stmt, ast.Return)
|
||||
|
||||
# TODO: should be its own AST visitor
|
||||
def contains_return_op(self, node):
|
||||
if isinstance(node, ast.Return):
|
||||
return True
|
||||
elif isinstance(node, ast.Assign):
|
||||
return self.contains_return_op(node.value)
|
||||
elif isinstance(node, ast.Module):
|
||||
pred = lambda s: self.contains_return_op(s)
|
||||
return any(pred(s) for s in node.body)
|
||||
elif isinstance(node, ast.FunctionDef):
|
||||
pred = lambda s: self.contains_return_op(s)
|
||||
return any(pred(s) for s in node.body)
|
||||
elif isinstance(node, ast.Call):
|
||||
fn = self.visit(node.func)
|
||||
if isinstance(fn, triton.JITFunction):
|
||||
return self.contains_return_op(fn.parse())
|
||||
return False
|
||||
elif isinstance(node, ast.If):
|
||||
pred = lambda s: self.contains_return_op(s)
|
||||
ret = any(pred(s) for s in node.body)
|
||||
@@ -670,17 +689,24 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
step = triton.language.constexpr(-step.value)
|
||||
negative_step = True
|
||||
lb, ub = ub, lb
|
||||
lb = triton.language.core._to_tensor(lb, self.builder)
|
||||
ub = triton.language.core._to_tensor(ub, self.builder)
|
||||
step = triton.language.core._to_tensor(step, self.builder)
|
||||
# induction variable type
|
||||
iv_type = triton.language.semantic.integer_promote_impl(lb.dtype, ub.dtype)
|
||||
iv_type = triton.language.semantic.integer_promote_impl(iv_type, step.dtype)
|
||||
iv_ir_type = iv_type.to_ir(self.builder)
|
||||
# lb/ub/step might be constexpr, we need to cast them to tensor
|
||||
lb = triton.language.core._to_tensor(lb, self.builder).handle
|
||||
ub = triton.language.core._to_tensor(ub, self.builder).handle
|
||||
step = triton.language.core._to_tensor(step, self.builder).handle
|
||||
lb = lb.handle
|
||||
ub = ub.handle
|
||||
step = step.handle
|
||||
# ForOp can only accept IndexType as lb/ub/step. Cast integer to Index
|
||||
lb = self.builder.create_to_index(lb)
|
||||
ub = self.builder.create_to_index(ub)
|
||||
step = self.builder.create_to_index(step)
|
||||
# Create placeholder for the loop induction variable
|
||||
iv = self.builder.create_undef(self.builder.get_int32_ty())
|
||||
self.set_value(node.target.id, triton.language.core.tensor(iv, triton.language.core.int32))
|
||||
iv = self.builder.create_undef(iv_ir_type)
|
||||
self.set_value(node.target.id, triton.language.core.tensor(iv, iv_type))
|
||||
|
||||
with enter_sub_region(self) as sr:
|
||||
liveins, insert_block = sr
|
||||
@@ -737,11 +763,13 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
# update induction variable with actual value, and replace all uses
|
||||
self.builder.set_insertion_point_to_start(for_op.get_body(0))
|
||||
iv = self.builder.create_index_to_si(for_op.get_induction_var())
|
||||
iv = self.builder.create_int_cast(iv, iv_ir_type, True)
|
||||
if negative_step:
|
||||
ub_si = self.builder.create_index_to_si(ub)
|
||||
ub_si = self.builder.create_int_cast(ub_si, iv_ir_type, True)
|
||||
iv = self.builder.create_sub(ub_si, iv)
|
||||
self.lscope[node.target.id].handle.replace_all_uses_with(iv)
|
||||
self.set_value(node.target.id, triton.language.core.tensor(iv, triton.language.core.int32))
|
||||
self.set_value(node.target.id, triton.language.core.tensor(iv, iv_type))
|
||||
|
||||
# update lscope & local_defs (ForOp defines new values)
|
||||
for i, name in enumerate(names):
|
||||
@@ -763,6 +791,14 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
def visit_keyword(self, node):
|
||||
return {node.arg: self.visit(node.value)}
|
||||
|
||||
def visit_Assert(self, node) -> Any:
|
||||
if not self.debug:
|
||||
return
|
||||
test = self.visit(node.test)
|
||||
msg = self.visit(node.msg)
|
||||
# Convert assert to triton's device_assert which happens on the device
|
||||
return triton.language.core.device_assert(test, msg, _builder=self.builder)
|
||||
|
||||
def visit_Call(self, node):
|
||||
fn = self.visit(node.func)
|
||||
if isinstance(fn, triton.language.constexpr):
|
||||
@@ -771,6 +807,18 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
for keyword in node.keywords:
|
||||
kws.update(self.visit(keyword))
|
||||
args = [self.visit(arg) for arg in node.args]
|
||||
if fn.__name__ == "print":
|
||||
fn = self.builtins["print"]
|
||||
elif fn.__name__ == "device_assert":
|
||||
if not self.debug:
|
||||
return
|
||||
elif fn.__name__ in self.static_functions:
|
||||
if fn.__name__ == "static_print":
|
||||
print(*args, **kws)
|
||||
return
|
||||
elif fn.__name__ == "static_assert":
|
||||
assert args[0], args[1]
|
||||
return
|
||||
if isinstance(fn, triton.runtime.JITFunction):
|
||||
from inspect import getcallargs
|
||||
args = getcallargs(fn.fn, *args, **kws)
|
||||
@@ -790,7 +838,7 @@ class CodeGenerator(ast.NodeVisitor):
|
||||
if not self.module.has_function(fn_name):
|
||||
prototype = triton.language.function_type([], arg_types)
|
||||
gscope = sys.modules[fn.fn.__module__].__dict__
|
||||
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types)
|
||||
generator = CodeGenerator(self.builder.context, prototype, gscope, attributes, constants, module=self.module, function_name=fn_name, function_types=self.function_ret_types, debug=self.debug)
|
||||
generator.visit(fn.parse())
|
||||
callee_ret_type = generator.last_ret_type
|
||||
self.function_ret_types[fn_name] = callee_ret_type
|
||||
@@ -928,7 +976,7 @@ def parse_mlir_module(path, context):
|
||||
return module
|
||||
|
||||
|
||||
def build_triton_ir(fn, signature, specialization, constants):
|
||||
def build_triton_ir(fn, signature, specialization, constants, debug=False):
|
||||
# canonicalize signature
|
||||
if isinstance(signature, str):
|
||||
signature = {k: v.strip() for k, v in enumerate(signature.split(","))}
|
||||
@@ -948,7 +996,7 @@ def build_triton_ir(fn, signature, specialization, constants):
|
||||
arg_types = [str_to_ty(v) for k, v in signature.items() if k not in constants]
|
||||
|
||||
prototype = triton.language.function_type([], arg_types)
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, attributes=new_attrs, is_kernel=True)
|
||||
generator = CodeGenerator(context, prototype, gscope=gscope, constants=all_constants, function_name=function_name, attributes=new_attrs, is_kernel=True, debug=debug)
|
||||
try:
|
||||
generator.visit(fn.parse())
|
||||
except Exception as e:
|
||||
@@ -975,8 +1023,8 @@ def optimize_triton_ir(mod):
|
||||
return mod
|
||||
|
||||
|
||||
def ast_to_ttir(fn, signature, specialization, constants):
|
||||
mod, _ = build_triton_ir(fn, signature, specialization, constants)
|
||||
def ast_to_ttir(fn, signature, specialization, constants, debug=False):
|
||||
mod, _ = build_triton_ir(fn, signature, specialization, constants, debug)
|
||||
return optimize_triton_ir(mod)
|
||||
|
||||
|
||||
@@ -991,18 +1039,15 @@ def optimize_ttgir(mod, num_stages, compute_capability):
|
||||
pm = _triton.ir.pass_manager(mod.context)
|
||||
pm.enable_debug()
|
||||
pm.add_tritongpu_coalesce_pass()
|
||||
pm.add_tritongpu_remove_layout_conversions_pass()
|
||||
pm.add_tritongpu_accelerate_matmul_pass(compute_capability)
|
||||
pm.add_tritongpu_remove_layout_conversions_pass()
|
||||
pm.add_tritongpu_fuse_transpositions_pass()
|
||||
pm.add_tritongpu_optimize_dot_operands_pass()
|
||||
pm.add_tritongpu_pipeline_pass(num_stages)
|
||||
pm.add_tritongpu_prefetch_pass()
|
||||
pm.add_tritongpu_fuse_transpositions_pass()
|
||||
pm.add_tritongpu_optimize_dot_operands_pass()
|
||||
pm.add_tritongpu_remove_layout_conversions_pass()
|
||||
pm.add_tritongpu_decompose_conversions_pass()
|
||||
if compute_capability // 10 == 7:
|
||||
# The update_mma_for_volta pass helps to compute some information for MMA encoding specifically for MMAv1
|
||||
# NOTE this pass should be placed after all the passes those modifies mma layout
|
||||
pm.add_tritongpu_update_mma_for_volta_pass()
|
||||
pm.add_tritongpu_reorder_instructions_pass()
|
||||
pm.add_cse_pass()
|
||||
pm.add_symbol_dce_pass()
|
||||
@@ -1718,6 +1763,7 @@ def compile(fn, **kwargs):
|
||||
num_warps = kwargs.get("num_warps", 4)
|
||||
num_stages = kwargs.get("num_stages", 3 if capability >= 75 else 2)
|
||||
extern_libs = kwargs.get("extern_libs", dict())
|
||||
debug = kwargs.get("debug", False)
|
||||
# build compilation stages
|
||||
if torch.version.hip is not None:
|
||||
if extern_libs is None:
|
||||
@@ -1736,7 +1782,7 @@ def compile(fn, **kwargs):
|
||||
stages = {
|
||||
"ast": (lambda path: fn, None),
|
||||
"ttir": (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
|
||||
lambda src: ast_to_ttir(src, signature, configs[0], constants, debug)),
|
||||
"ttgir": (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps), num_stages, capability)),
|
||||
"llir": (lambda path: Path(path).read_text(),
|
||||
@@ -1750,7 +1796,7 @@ def compile(fn, **kwargs):
|
||||
stages = {
|
||||
"ast": (lambda path: fn, None),
|
||||
"ttir": (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: ast_to_ttir(src, signature, configs[0], constants)),
|
||||
lambda src: ast_to_ttir(src, signature, configs[0], constants, debug)),
|
||||
"ttgir": (lambda path: parse_mlir_module(path, context),
|
||||
lambda src: optimize_ttgir(ttir_to_ttgir(src, num_warps), num_stages, capability)),
|
||||
"llir": (lambda path: Path(path).read_text(),
|
||||
@@ -1804,7 +1850,8 @@ def compile(fn, **kwargs):
|
||||
with open(fn_cache_manager._make_path(f"{name}.json")) as f:
|
||||
metadata = json.load(f)
|
||||
else:
|
||||
metadata = {"num_warps": num_warps, "num_stages": num_stages, "constants": _get_jsonable_constants(constants), "ctime": dict()}
|
||||
metadata = {"num_warps": num_warps, "num_stages": num_stages,
|
||||
"constants": _get_jsonable_constants(constants), "ctime": dict(), "debug": debug}
|
||||
if ext == "ptx":
|
||||
assert "shared" in kwargs, "ptx compilation must provide shared memory size"
|
||||
metadata["shared"] = kwargs["shared"]
|
||||
@@ -1935,7 +1982,8 @@ class CompiledKernel:
|
||||
if self.shared > max_shared:
|
||||
raise OutOfResources(self.shared, max_shared, "shared memory")
|
||||
mod, func, n_regs, n_spills = cuda_utils.load_binary(self.metadata["name"], self.asm["cubin"], self.shared, device)
|
||||
# print(self.shared, n_regs, n_spills)
|
||||
self.n_spills = n_spills
|
||||
self.n_regs = n_regs
|
||||
self.cu_module = mod
|
||||
self.cu_function = func
|
||||
|
||||
|
||||
@@ -28,6 +28,8 @@ from .core import (
|
||||
constexpr,
|
||||
cos,
|
||||
debug_barrier,
|
||||
device_assert,
|
||||
device_print,
|
||||
dot,
|
||||
dtype,
|
||||
exp,
|
||||
@@ -36,7 +38,8 @@ from .core import (
|
||||
float16,
|
||||
float32,
|
||||
float64,
|
||||
float8,
|
||||
float8e4,
|
||||
float8e5,
|
||||
function_type,
|
||||
int1,
|
||||
int16,
|
||||
@@ -54,7 +57,6 @@ from .core import (
|
||||
num_programs,
|
||||
pi32_t,
|
||||
pointer_type,
|
||||
printf,
|
||||
program_id,
|
||||
ravel,
|
||||
reshape,
|
||||
@@ -62,6 +64,8 @@ from .core import (
|
||||
sin,
|
||||
softmax,
|
||||
sqrt,
|
||||
static_assert,
|
||||
static_print,
|
||||
store,
|
||||
sum,
|
||||
swizzle2d,
|
||||
@@ -118,6 +122,8 @@ __all__ = [
|
||||
"constexpr",
|
||||
"cos",
|
||||
"debug_barrier",
|
||||
"device_assert",
|
||||
"device_print",
|
||||
"dot",
|
||||
"dtype",
|
||||
"exp",
|
||||
@@ -125,7 +131,8 @@ __all__ = [
|
||||
"float16",
|
||||
"float32",
|
||||
"float64",
|
||||
"float8",
|
||||
"float8e4",
|
||||
"float8e5",
|
||||
"full",
|
||||
"function_type",
|
||||
"int1",
|
||||
@@ -149,7 +156,6 @@ __all__ = [
|
||||
"philox_impl",
|
||||
"pi32_t",
|
||||
"pointer_type",
|
||||
"printf",
|
||||
"program_id",
|
||||
"rand",
|
||||
"rand4x",
|
||||
@@ -164,6 +170,8 @@ __all__ = [
|
||||
"softmax",
|
||||
"sqrt",
|
||||
"static_range",
|
||||
"static_assert",
|
||||
"static_print",
|
||||
"store",
|
||||
"sum",
|
||||
"swizzle2d",
|
||||
|
||||
@@ -39,8 +39,7 @@ def _to_tensor(x, builder):
|
||||
class dtype:
|
||||
SINT_TYPES = ['int1', 'int8', 'int16', 'int32', 'int64']
|
||||
UINT_TYPES = ['uint8', 'uint16', 'uint32', 'uint64']
|
||||
FP_TYPES = ['fp8', 'fp16', 'bf16', 'fp32', 'fp64']
|
||||
CUSTOMIZED_FP_TYPES = ['fp8']
|
||||
FP_TYPES = ['fp8e4', 'fp8e5', 'fp16', 'bf16', 'fp32', 'fp64']
|
||||
STANDARD_FP_TYPES = ['fp16', 'bf16', 'fp32', 'fp64']
|
||||
OTHER_TYPES = ['void']
|
||||
|
||||
@@ -60,9 +59,12 @@ class dtype:
|
||||
self.int_bitwidth = int(name.split('int')[-1])
|
||||
self.primitive_bitwidth = self.int_bitwidth
|
||||
elif name in dtype.FP_TYPES:
|
||||
if name == 'fp8':
|
||||
if name == 'fp8e4':
|
||||
self.fp_mantissa_width = 3
|
||||
self.primitive_bitwidth = 8
|
||||
elif name == 'fp8e5':
|
||||
self.fp_mantissa_width = 2
|
||||
self.primitive_bitwidth = 8
|
||||
elif name == 'fp16':
|
||||
self.fp_mantissa_width = 10
|
||||
self.primitive_bitwidth = 16
|
||||
@@ -75,11 +77,13 @@ class dtype:
|
||||
elif name == 'fp64':
|
||||
self.fp_mantissa_width = 53
|
||||
self.primitive_bitwidth = 64
|
||||
else:
|
||||
raise RuntimeError(f'Unsupported floating-point type {name}')
|
||||
elif name == 'void':
|
||||
self.primitive_bitwidth = 0
|
||||
|
||||
def is_fp8(self):
|
||||
return self.name == 'fp8'
|
||||
return 'fp8' in self.name
|
||||
|
||||
def is_fp16(self):
|
||||
return self.name == 'fp16'
|
||||
@@ -123,9 +127,6 @@ class dtype:
|
||||
def is_floating(self):
|
||||
return self.name in dtype.FP_TYPES
|
||||
|
||||
def is_customized_floating(self):
|
||||
return self.name in dtype.CUSTOMIZED_FP_TYPES
|
||||
|
||||
def is_standard_floating(self):
|
||||
return self.name in dtype.STANDARD_FP_TYPES
|
||||
|
||||
@@ -181,8 +182,10 @@ class dtype:
|
||||
return builder.get_int32_ty()
|
||||
elif self.name in ('int64', 'uint64'):
|
||||
return builder.get_int64_ty()
|
||||
elif self.name == 'fp8':
|
||||
return builder.get_fp8_ty()
|
||||
elif self.name == 'fp8e5':
|
||||
return builder.get_fp8e5_ty()
|
||||
elif self.name == 'fp8e4':
|
||||
return builder.get_fp8e4_ty()
|
||||
elif self.name == 'fp16':
|
||||
return builder.get_half_ty()
|
||||
elif self.name == 'bf16':
|
||||
@@ -314,7 +317,8 @@ uint8 = dtype('uint8')
|
||||
uint16 = dtype('uint16')
|
||||
uint32 = dtype('uint32')
|
||||
uint64 = dtype('uint64')
|
||||
float8 = dtype('fp8')
|
||||
float8e5 = dtype('fp8e5')
|
||||
float8e4 = dtype('fp8e4')
|
||||
float16 = dtype('fp16')
|
||||
bfloat16 = dtype('bf16')
|
||||
float32 = dtype('fp32')
|
||||
@@ -1047,7 +1051,7 @@ def where(condition, x, y, _builder=None):
|
||||
If you want to avoid unintended memory operations, use the :code:`mask` arguments in `triton.load` and `triton.store` instead.
|
||||
|
||||
The shape of :code:`x` and :code:`y` are both broadcast to the shape of :code:`condition`.
|
||||
:code:`x` and :code:`y` must have the data type.
|
||||
:code:`x` and :code:`y` must have the same data type.
|
||||
|
||||
:param condition: When True (nonzero), yield x, otherwise yield y.
|
||||
:type condition: Block of triton.bool
|
||||
@@ -1349,24 +1353,56 @@ def zeros(shape, dtype):
|
||||
def zeros_like(input):
|
||||
return zeros(input.shape, input.dtype)
|
||||
|
||||
# -----------------------
|
||||
# Debugging functions
|
||||
# -----------------------
|
||||
|
||||
|
||||
@builtin
|
||||
def printf(prefix, *args, _builder=None):
|
||||
def static_print(*values, sep: str = " ", end: str = "\n", file=None, flush=False, _builder=None):
|
||||
pass
|
||||
|
||||
|
||||
@builtin
|
||||
def static_assert(cond, msg="", _builder=None):
|
||||
pass
|
||||
|
||||
|
||||
@builtin
|
||||
def device_print(prefix, *args, _builder=None):
|
||||
import string
|
||||
new_prefix = prefix
|
||||
if isinstance(prefix, constexpr):
|
||||
new_prefix = prefix.value
|
||||
assert isinstance(new_prefix, str), f"{new_prefix} is not string"
|
||||
prefix = _constexpr_to_value(prefix)
|
||||
assert isinstance(prefix, str), f"{prefix} is not string"
|
||||
b_ascii = True
|
||||
for ch in new_prefix:
|
||||
for ch in prefix:
|
||||
if ch not in string.printable:
|
||||
b_ascii = False
|
||||
break
|
||||
assert b_ascii, f"{new_prefix} is not an ascii string"
|
||||
assert b_ascii, f"{prefix} is not an ascii string"
|
||||
new_args = []
|
||||
for arg in args:
|
||||
new_args.append(_to_tensor(arg, _builder))
|
||||
return semantic.printf(new_prefix, new_args, _builder)
|
||||
return semantic.device_print(prefix, new_args, _builder)
|
||||
|
||||
|
||||
@builtin
|
||||
def device_assert(cond, msg="", _builder=None):
|
||||
msg = _constexpr_to_value(msg)
|
||||
import inspect
|
||||
frame = inspect.currentframe()
|
||||
module = inspect.getmodule(frame)
|
||||
# The triton function module doesn't have the name attribute.
|
||||
# We use this trick to find the caller.
|
||||
while hasattr(module, "__name__"):
|
||||
frame = frame.f_back
|
||||
module = inspect.getmodule(frame)
|
||||
func_name = frame.f_code.co_name
|
||||
file_name = frame.f_back.f_code.co_filename
|
||||
# TODO: The line number currently indicates the line
|
||||
# where the triton function is called but not where the
|
||||
# device_assert is called. Need to enhance this.
|
||||
lineno = frame.f_back.f_lineno
|
||||
return semantic.device_assert(_to_tensor(cond, _builder), msg, file_name, func_name, lineno, _builder)
|
||||
|
||||
# -----------------------
|
||||
# Iterators
|
||||
|
||||
@@ -493,13 +493,22 @@ def arange(start: int, end: int, builder: ir.builder) -> tl.tensor:
|
||||
|
||||
|
||||
def full(shape: List[int], value, dtype: tl.dtype, builder: ir.builder) -> tl.tensor:
|
||||
if value == 0:
|
||||
_value = builder.get_null_value(dtype.to_ir(builder))
|
||||
if isinstance(value, tl.tensor):
|
||||
assert value.numel.value == 1, "only accepts size-1 tensor"
|
||||
value = cast(value, dtype, builder)
|
||||
ret_ty = tl.block_type(value.dtype, shape)
|
||||
return tl.tensor(builder.create_splat(value.handle, shape), ret_ty)
|
||||
else:
|
||||
get_value_fn = getattr(builder, f"get_{dtype.name}")
|
||||
_value = get_value_fn(value)
|
||||
ret_ty = tl.block_type(dtype, shape)
|
||||
return tl.tensor(builder.create_splat(_value, shape), ret_ty)
|
||||
# scalar
|
||||
if value == 0:
|
||||
value = builder.get_null_value(dtype.to_ir(builder))
|
||||
else:
|
||||
get_value_fn = getattr(builder, f"get_{dtype.name}")
|
||||
value = get_value_fn(value)
|
||||
if dtype is None:
|
||||
raise ValueError("dtype must be specified when value is not a tensor")
|
||||
ret_ty = tl.block_type(dtype, shape)
|
||||
return tl.tensor(builder.create_splat(value, shape), ret_ty)
|
||||
|
||||
|
||||
|
||||
@@ -640,7 +649,7 @@ def bitcast(input: tl.tensor,
|
||||
builder: ir.builder) -> tl.tensor:
|
||||
src_ty = input.type
|
||||
if src_ty.is_block():
|
||||
dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes())
|
||||
dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes())
|
||||
if src_ty == dst_ty:
|
||||
return input
|
||||
src_sca_ty = src_ty.scalar
|
||||
@@ -664,7 +673,7 @@ def cast(input: tl.tensor,
|
||||
if isinstance(dst_ty, tl.constexpr):
|
||||
dst_ty = dst_ty.value
|
||||
if src_ty.is_block():
|
||||
dst_ty = tl.block_type(dst_ty, input.type.get_block_shapes())
|
||||
dst_ty = tl.block_type(dst_ty.scalar, input.type.get_block_shapes())
|
||||
if src_ty == dst_ty:
|
||||
return input
|
||||
|
||||
@@ -672,8 +681,8 @@ def cast(input: tl.tensor,
|
||||
dst_sca_ty = dst_ty.scalar
|
||||
|
||||
# Casting with customized floating types involved: fp8 <=> bf16, fp16, fp32, fp64
|
||||
if (src_sca_ty.is_customized_floating() and dst_sca_ty.is_floating()) or \
|
||||
(src_sca_ty.is_floating() and dst_sca_ty.is_customized_floating()):
|
||||
if (src_sca_ty.is_fp8() and dst_sca_ty.is_floating()) or \
|
||||
(src_sca_ty.is_floating() and dst_sca_ty.is_fp8()):
|
||||
return tl.tensor(builder.create_fp_to_fp(input.handle, dst_ty.to_ir(builder)),
|
||||
dst_ty)
|
||||
|
||||
@@ -1147,6 +1156,10 @@ def reduce_impl(input: tl.tensor, axis: int, builder: ir.builder, name: str,
|
||||
|
||||
# get result type
|
||||
shape = input.type.shape
|
||||
|
||||
rank = len(shape)
|
||||
assert 0 <= axis < rank, f"axis (v={axis}) is out of range, should be within [0, {rank})"
|
||||
|
||||
ret_shape = []
|
||||
for i, s in enumerate(shape):
|
||||
if i != axis:
|
||||
@@ -1248,8 +1261,12 @@ def debug_barrier(builder: ir.builder) -> tl.tensor:
|
||||
return tl.tensor(builder.create_barrier(), tl.void)
|
||||
|
||||
|
||||
def printf(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor:
|
||||
def device_print(prefix: str, args: List[tl.tensor], builder: ir.builder) -> tl.tensor:
|
||||
new_args = []
|
||||
for arg in args:
|
||||
new_args.append(arg.handle)
|
||||
return tl.tensor(builder.create_printf(prefix, new_args), tl.void)
|
||||
return tl.tensor(builder.create_print(prefix, new_args), tl.void)
|
||||
|
||||
|
||||
def device_assert(cond: tl.tensor, msg: str, file_name: str, func_name, lineno: int, builder: ir.builder) -> tl.tensor:
|
||||
return tl.tensor(builder.create_assert(cond.handle, msg, file_name, func_name, lineno), tl.void)
|
||||
|
||||
@@ -63,7 +63,7 @@ def _fwd_kernel(
|
||||
p *= l_rcp
|
||||
acc *= (l_prev * l_rcp)[:, None]
|
||||
# update acc
|
||||
p = p.to(tl.float16)
|
||||
p = p.to(Q.dtype.element_ty)
|
||||
v = tl.load(v_ptrs)
|
||||
acc += tl.dot(p, v)
|
||||
# update m_i and l_i
|
||||
@@ -167,7 +167,7 @@ def _bwd_kernel(
|
||||
p = tl.exp(qk * sm_scale - m[:, None])
|
||||
# compute dv
|
||||
do = tl.load(do_ptrs)
|
||||
dv += tl.dot(tl.trans(p.to(tl.float16)), do)
|
||||
dv += tl.dot(tl.trans(p.to(Q.dtype.element_ty)), do)
|
||||
# compute dp = dot(v, do)
|
||||
Di = tl.load(D_ptrs + offs_m_curr)
|
||||
dp = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - Di[:, None]
|
||||
@@ -175,10 +175,10 @@ def _bwd_kernel(
|
||||
# compute ds = p * (dp - delta[:, None])
|
||||
ds = p * dp * sm_scale
|
||||
# compute dk = dot(ds.T, q)
|
||||
dk += tl.dot(tl.trans(ds.to(tl.float16)), q)
|
||||
dk += tl.dot(tl.trans(ds.to(Q.dtype.element_ty)), q)
|
||||
# compute dq
|
||||
dq = tl.load(dq_ptrs)
|
||||
dq += tl.dot(ds.to(tl.float16), k)
|
||||
dq += tl.dot(ds.to(Q.dtype.element_ty), k)
|
||||
tl.store(dq_ptrs, dq)
|
||||
# increment pointers
|
||||
dq_ptrs += BLOCK_M * stride_qm
|
||||
@@ -198,7 +198,7 @@ class _attention(torch.autograd.Function):
|
||||
# only support for Ampere now
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 8:
|
||||
raise RuntimeError("Flash attention currently only supported for compute capability < 80")
|
||||
raise RuntimeError("Flash attention currently only supported for compute capability >= 80")
|
||||
BLOCK = 128
|
||||
# shape constraints
|
||||
Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
|
||||
|
||||
@@ -15,7 +15,7 @@ class Autotuner(KernelInterface):
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
'top_k': number of configs to bench
|
||||
'prune_num_stages_by'(optional): a function used to prune num_stages. It take configs:List[Config] as its input, and returns pruned configs.
|
||||
'prune_num_stages_by'(optional): a function used to prune num_stages. It takes configs:List[Config] as its input, and returns pruned configs.
|
||||
'''
|
||||
if not configs:
|
||||
self.configs = [Config({}, num_warps=4, num_stages=2)]
|
||||
@@ -168,7 +168,7 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
|
||||
:note: When all the configurations are evaluated, the kernel will run multiple time.
|
||||
This means that whatever value the kernel updates will be updated multiple times.
|
||||
To avoid this undesired behavior, you can use the `reset_to_zero` argument, which
|
||||
reset the value of the provided tensor to `zero` before running any configuration.
|
||||
resets the value of the provided tensor to `zero` before running any configuration.
|
||||
:param configs: a list of :code:`triton.Config` objects
|
||||
:type configs: list[triton.Config]
|
||||
:param key: a list of argument names whose change in value will trigger the evaluation of all provided configs.
|
||||
@@ -176,7 +176,7 @@ def autotune(configs, key, prune_configs_by=None, reset_to_zero=None):
|
||||
:param prune_configs_by: a dict of functions that are used to prune configs, fields:
|
||||
'perf_model': performance model used to predicate running time with different configs, returns running time
|
||||
'top_k': number of configs to bench
|
||||
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It take configs:List[Config] as its input, and returns pruned configs.
|
||||
'early_config_prune'(optional): a function used to do early prune (eg, num_stages). It takes configs:List[Config] as its input, and returns pruned configs.
|
||||
:param reset_to_zero: a list of argument names whose value will be reset to zero before evaluating any configs.
|
||||
:type reset_to_zero: list[str]
|
||||
"""
|
||||
|
||||
@@ -125,8 +125,6 @@ class JITFunction(KernelInterface[T]):
|
||||
elif isinstance(arg, int):
|
||||
if -2**31 <= arg and arg <= 2**31 - 1:
|
||||
return "i32"
|
||||
elif 2**31 <= arg and arg <= 2**32 - 1:
|
||||
return "u32"
|
||||
elif 2**63 <= arg and arg <= 2**64 - 1:
|
||||
return "u64"
|
||||
else:
|
||||
@@ -179,7 +177,8 @@ class JITFunction(KernelInterface[T]):
|
||||
triton.language.uint16: 'u16',
|
||||
triton.language.uint32: 'u32',
|
||||
triton.language.uint64: 'u64',
|
||||
triton.language.float8: 'fp8',
|
||||
triton.language.float8e5: 'fp8e5',
|
||||
triton.language.float8e4: 'fp8e4',
|
||||
triton.language.float16: 'fp16',
|
||||
triton.language.bfloat16: 'bf16',
|
||||
triton.language.float32: 'fp32',
|
||||
@@ -243,7 +242,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
||||
sig_key = {sig_keys},
|
||||
constexpr_key = {f'{constexpr_keys},' if len(constexpr_keys) > 0 else ()}
|
||||
spec_key = {f'{spec_keys},' if len(spec_keys) > 0 else ()}
|
||||
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_stages)
|
||||
key = (version_key, sig_key, constexpr_key, spec_key, num_warps, num_stages, self.debug)
|
||||
if not extern_libs is None:
|
||||
key = (key, tuple(extern_libs.items()))
|
||||
assert num_warps > 0 and (num_warps & (num_warps - 1)) == 0, "num_warps must be a power of 2"
|
||||
@@ -278,7 +277,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
||||
if callable(arg):
|
||||
raise TypeError(f"Callable constexpr at index {{i}} is not supported")
|
||||
if not self._call_hook(key, signature, device, constants, num_warps, num_stages, extern_libs, configs):
|
||||
bin = triton.compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, configs=configs)
|
||||
bin = triton.compile(self, signature=signature, device=device, constants=constants, num_warps=num_warps, num_stages=num_stages, extern_libs=extern_libs, configs=configs, debug=self.debug)
|
||||
if not warmup:
|
||||
bin.c_wrapper(grid_0, grid_1, grid_2, bin.num_warps, bin.shared, stream, bin.cu_function, triton.compiler.CompiledKernel.launch_enter_hook, triton.compiler.CompiledKernel.launch_exit_hook, bin, *args)
|
||||
self.cache[device][key] = bin
|
||||
@@ -291,7 +290,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
||||
exec(src, scope)
|
||||
return scope[self.fn.__name__]
|
||||
|
||||
def __init__(self, fn, version=None, do_not_specialize=None):
|
||||
def __init__(self, fn, version=None, do_not_specialize=None, debug=None):
|
||||
self.fn = fn
|
||||
self.module = fn.__module__
|
||||
self.version = version
|
||||
@@ -312,6 +311,7 @@ def {self.fn.__name__}({', '.join(self.arg_names)}, grid, num_warps=4, num_stage
|
||||
# when called with a grid using __getitem__
|
||||
self.kernel_decorators = []
|
||||
self.kernel = None
|
||||
self.debug = os.environ.get("TRITON_DEBUG", "0") == "1" if debug is None else debug
|
||||
# annotations
|
||||
self.annotations = {self.arg_names.index(name): ty for name, ty in fn.__annotations__.items()}
|
||||
self.__annotations__ = fn.__annotations__
|
||||
@@ -380,6 +380,7 @@ def jit(
|
||||
*,
|
||||
version=None,
|
||||
do_not_specialize: Optional[Iterable[int]] = None,
|
||||
debug: Optional[bool] = None,
|
||||
) -> Callable[[T], JITFunction[T]]:
|
||||
...
|
||||
|
||||
@@ -389,6 +390,7 @@ def jit(
|
||||
*,
|
||||
version=None,
|
||||
do_not_specialize: Optional[Iterable[int]] = None,
|
||||
debug: Optional[bool] = None,
|
||||
) -> Union[JITFunction[T], Callable[[T], JITFunction[T]]]:
|
||||
"""
|
||||
Decorator for JIT-compiling a function using the Triton compiler.
|
||||
@@ -413,6 +415,7 @@ def jit(
|
||||
fn,
|
||||
version=version,
|
||||
do_not_specialize=do_not_specialize,
|
||||
debug=debug,
|
||||
)
|
||||
|
||||
if fn is not None:
|
||||
|
||||
@@ -454,10 +454,12 @@ def get_max_simd_tflops(dtype: torch.dtype, backend=None, device=None):
|
||||
backend = _triton.runtime.backend.CUDA
|
||||
if not device:
|
||||
device = torch.cuda.current_device()
|
||||
num_subcores = _triton.runtime.num_sm(backend, device) * 4 # on recent GPUs
|
||||
clock_rate = _triton.runtime.clock_rate(backend, device) # in kHz
|
||||
cc = _triton.runtime.cc(backend, device)
|
||||
if cc < 80:
|
||||
|
||||
triton.compiler.init_cuda_utils()
|
||||
num_subcores = triton.compiler.cuda_utils.get_device_properties(device)["multiprocessor_count"] * 4
|
||||
clock_rate = triton.compiler.cuda_utils.get_device_properties(device)["sm_clock_rate"] # in kHz
|
||||
capability = torch.cuda.get_device_capability()
|
||||
if capability[0] < 8:
|
||||
if dtype == torch.float32:
|
||||
ops_per_sub_core = 32 # 2*16
|
||||
elif dtype == torch.float16:
|
||||
|
||||
@@ -93,9 +93,9 @@ func.func @insert_slice(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
func.func @extract_slice(%A : !tt.ptr<f16>) {
|
||||
// CHECK: %cst -> %cst
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED>
|
||||
%index = arith.constant 0 : index
|
||||
// CHECK-NEXT: %extracted_slice -> %cst
|
||||
%cst1 = tensor.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A_SHARED> to tensor<16x16xf16, #A_SHARED>
|
||||
%index = arith.constant 0 : i32
|
||||
// CHECK-NEXT: %0 -> %cst
|
||||
%cst1 = triton_gpu.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A_SHARED> to tensor<16x16xf16, #A_SHARED>
|
||||
return
|
||||
}
|
||||
|
||||
@@ -169,9 +169,9 @@ func.func @for_if(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B
|
||||
// CHECK-NEXT: %0#2 -> %cst,%cst_0
|
||||
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
|
||||
scf.if %i1 {
|
||||
%index = arith.constant 8 : index
|
||||
// CHECK-NEXT: %extracted_slice -> %cst,%cst_0
|
||||
%cst0 = tensor.extract_slice %a_shared[%index, 0][1, 32][1, 1] : tensor<128x32xf16, #A_SHARED> to tensor<32xf16, #A_SHARED>
|
||||
%index = arith.constant 8 : i32
|
||||
// CHECK-NEXT: %1 -> %cst,%cst_0
|
||||
%cst0 = triton_gpu.extract_slice %a_shared[%index, 0][1, 32][1, 1] : tensor<128x32xf16, #A_SHARED> to tensor<32xf16, #A_SHARED>
|
||||
scf.yield
|
||||
}
|
||||
scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
|
||||
|
||||
@@ -200,8 +200,8 @@ func.func @insert_slice_async(%A : !tt.ptr<f16>, %i1 : i1) {
|
||||
func.func @extract_slice(%A : !tt.ptr<f16>) {
|
||||
// CHECK: offset = 0, size = 512
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED>
|
||||
%index = arith.constant 0 : index
|
||||
%cst1 = tensor.extract_slice %cst0[%index, 0, 0][1, 16, 16][1,1,1] : tensor<1x16x16xf16, #A_SHARED> to tensor<16x16xf16, #A_SHARED>
|
||||
%index = arith.constant 0 : i32
|
||||
%cst1 = triton_gpu.extract_slice %cst0[%index, 0, 0][1, 16, 16][1,1,1] : tensor<1x16x16xf16, #A_SHARED> to tensor<16x16xf16, #A_SHARED>
|
||||
return
|
||||
// CHECK-NEXT: size = 512
|
||||
}
|
||||
@@ -284,8 +284,8 @@ func.func @for_if_slice(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f1
|
||||
%c_shared_init = arith.constant dense<0.00e+00> : tensor<128x32xf16, #A_SHARED>
|
||||
%a_shared, %b_shared, %c_shared = scf.for %iv = %lb to %ub step %step iter_args(%a_shared = %a_shared_init, %b_shared = %b_shared_init, %c_shared = %c_shared_init) -> (tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>) {
|
||||
scf.if %i1 {
|
||||
%index = arith.constant 8 : index
|
||||
%cst0 = tensor.extract_slice %a_shared[%index, 0][1, 32][1, 1] : tensor<128x32xf16, #A_SHARED> to tensor<32xf16, #A_SHARED>
|
||||
%index = arith.constant 8 : i32
|
||||
%cst0 = triton_gpu.extract_slice %a_shared[%index, 0][1, 32][1, 1] : tensor<128x32xf16, #A_SHARED> to tensor<32xf16, #A_SHARED>
|
||||
scf.yield
|
||||
}
|
||||
scf.yield %b_shared, %a_shared, %a_shared : tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>, tensor<128x32xf16, #A_SHARED>
|
||||
|
||||
@@ -109,8 +109,8 @@ func.func @alloc() {
|
||||
// CHECK-LABEL: extract_slice
|
||||
func.func @extract_slice() {
|
||||
%cst0 = arith.constant dense<0.000000e+00> : tensor<1x16x16xf16, #A_SHARED>
|
||||
%index = arith.constant 0 : index
|
||||
%0 = tensor.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A_SHARED> to tensor<16x16xf16, #A_SHARED>
|
||||
%index = arith.constant 0 : i32
|
||||
%0 = triton_gpu.extract_slice %cst0[%index, 0, 0][1, 16, 16][1, 1, 1] : tensor<1x16x16xf16, #A_SHARED> to tensor<16x16xf16, #A_SHARED>
|
||||
// CHECK: gpu.barrier
|
||||
// CHECK-NEXT: triton_gpu.convert_layout
|
||||
%1 = triton_gpu.convert_layout %0 : (tensor<16x16xf16, #A_SHARED>) -> tensor<16x16xf16, #AL>
|
||||
|
||||
@@ -13,11 +13,13 @@ configure_lit_site_cfg(
|
||||
|
||||
set(TRITON_TEST_DEPENDS
|
||||
triton-opt
|
||||
FileCheck
|
||||
)
|
||||
|
||||
set(FILECHECK_PATH "${LLVM_LIBRARY_DIR}/../bin/FileCheck")
|
||||
set(LIT_ARGS "-Dfilecheck=${FILECHECK_PATH}")
|
||||
add_lit_testsuite(check-triton-lit-tests "Running the triton regression tests"
|
||||
${CMAKE_CURRENT_BINARY_DIR}
|
||||
ARGS ${LIT_ARGS}
|
||||
DEPENDS ${TRITON_TEST_DEPENDS}
|
||||
)
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
// RUN: (triton-opt %s -split-input-file --convert-triton-gpu-to-llvm --mlir-pass-pipeline-crash-reproducer=%t 2>/dev/null; true) | FileCheck --check-prefixes=CHECK,GCN %s
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK: llvm.func @test_empty_kernel(%arg0: i32, %arg1: !llvm.ptr<f16, 1>)
|
||||
// CHECK: llvm.func @test_empty_kernel(%arg0: i64, %arg1: !llvm.ptr<f16, 1>)
|
||||
// Here the 128 comes from the 4 in module attribute multiples 32
|
||||
// PTX: attributes {nvvm.kernel = 1 : ui1, nvvm.maxntid = [128 : i32]} {{.*}}
|
||||
func.func @test_empty_kernel(%lb : index, %A : !tt.ptr<f16>) {
|
||||
@@ -700,9 +700,9 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-NEXT: llvm.mul
|
||||
// CHECK-NEXT: llvm.add
|
||||
// CHECK-NEXT: llvm.getelementptr
|
||||
%index = arith.constant 1 : index
|
||||
%index = arith.constant 1 : i32
|
||||
%0 = triton_gpu.alloc_tensor : tensor<128x16x32xf32, #shared0>
|
||||
%1 = tensor.extract_slice %0[%index, 0, 0][1, 16, 32][1, 1, 1] : tensor<128x16x32xf32, #shared0> to tensor<16x32xf32, #shared0>
|
||||
%1 = triton_gpu.extract_slice %0[%index, 0, 0][1, 16, 32][1, 1, 1] : tensor<128x16x32xf32, #shared0> to tensor<16x32xf32, #shared0>
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -1394,14 +1394,67 @@ module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
func.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
|
||||
// GCN-NOT: llvm.inline_asm
|
||||
// GCN: llvm.atomicrmw fadd {{.*}} monotonic : !llvm.ptr<f32, 1>, f32
|
||||
// PTX: llvm.icmp "slt"
|
||||
// PTX: llvm.inline_asm
|
||||
// PTX-SAME: atom.global.gpu.add.f32
|
||||
// PTX-SAME: @$3 atom.global.gpu.add.f32
|
||||
// PTX: llvm.inline_asm
|
||||
// PTX-SAME: @$3 atom.global.gpu.add.f32
|
||||
%0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32} : (tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: atomic_add_f32_scalar
|
||||
func.func @atomic_add_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : i1, %arg2 : f32) {
|
||||
// GCN-NOT: llvm.inline_asm
|
||||
// GCN: llvm.atomicrmw fadd {{.*}} monotonic : !llvm.ptr<f32, 1>, f32
|
||||
// PTX: llvm.icmp "eq"
|
||||
// PTX: llvm.inline_asm
|
||||
// PTX-SAME: @$3 atom.global.gpu.add.f32
|
||||
%0 = "tt.atomic_rmw" (%arg0, %arg2, %arg1) {atomic_rmw_op = 5 : i32} : (!tt.ptr<f32>, f32, i1) -> f32
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: store_f32
|
||||
func.func @store_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xf32, #blocked0>) {
|
||||
// GCN-NOT: llvm.inline_asm
|
||||
// GCN: llvm.store {{.*}} : !llvm.ptr<f32, 1>
|
||||
// GCN: llvm.store {{.*}} : !llvm.ptr<f32, 1>
|
||||
// PTX: llvm.icmp "slt"
|
||||
// PTX: llvm.inline_asm
|
||||
// PTX-SAME: @$2 st.global.b32
|
||||
// PTX: llvm.inline_asm
|
||||
// PTX-SAME: @$2 st.global.b32
|
||||
tt.store %arg0, %arg1 : tensor<256xf32, #blocked0>
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
// CHECK-LABEL: store_f32_scalar
|
||||
func.func @store_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : f32) {
|
||||
// GCN-NOT: llvm.inline_asm
|
||||
// GCN: llvm.store {{.*}} : !llvm.ptr<f32, 1>
|
||||
// PTX: llvm.icmp "slt"
|
||||
// PTX: llvm.inline_asm
|
||||
// PTX-SAME: @$2 st.global.b32
|
||||
tt.store %arg0, %arg1 : f32
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
|
||||
module attributes {"triton_gpu.num-warps" = 4 : i32} {
|
||||
|
||||
|
||||
@@ -55,49 +55,6 @@ func.func @remat(%arg0: i32) -> tensor<1024xi32, #layout1> {
|
||||
// CHECK: return %6 : tensor<1024xi32, [[$target_layout]]>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: remat_load_store
|
||||
func.func @remat_load_store(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #layout0>
|
||||
%1 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<64x!tt.ptr<i32>, #layout0>
|
||||
%2 = tt.addptr %1, %0 : tensor<64x!tt.ptr<i32>, #layout0>, tensor<64xi32, #layout0>
|
||||
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xi32, #layout0>
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
%4 = triton_gpu.convert_layout %3 : (tensor<64xi32, #layout0>) -> tensor<64xi32, #layout1>
|
||||
%5 = triton_gpu.convert_layout %2 : (tensor<64x!tt.ptr<i32>, #layout0>) -> tensor<64x!tt.ptr<i32>, #layout1>
|
||||
tt.store %5, %4 : tensor<64xi32, #layout1>
|
||||
return
|
||||
}
|
||||
|
||||
// Don't rematerialize vectorized loads
|
||||
// CHECK-LABEL: remat_expensive
|
||||
func.func @remat_expensive(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #layout1>
|
||||
%1 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<64x!tt.ptr<i32>, #layout1>
|
||||
%2 = tt.addptr %1, %0 : tensor<64x!tt.ptr<i32>, #layout1>, tensor<64xi32, #layout1>
|
||||
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xi32, #layout1>
|
||||
// CHECK: triton_gpu.convert_layout
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
%4 = triton_gpu.convert_layout %3 : (tensor<64xi32, #layout1>) -> tensor<64xi32, #layout0>
|
||||
%5 = triton_gpu.convert_layout %2 : (tensor<64x!tt.ptr<i32>, #layout1>) -> tensor<64x!tt.ptr<i32>, #layout0>
|
||||
tt.store %5, %4 : tensor<64xi32, #layout0>
|
||||
return
|
||||
}
|
||||
|
||||
// Don't rematerialize loads when original and target layouts are different
|
||||
// CHECK-LABEL: remat_multi_layout
|
||||
func.func @remat_multi_layout(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
%0 = tt.make_range {end = 64 : i32, start = 0 : i32} : tensor<64xi32, #layout0>
|
||||
%1 = tt.splat %arg : (!tt.ptr<i32>) -> tensor<64x!tt.ptr<i32>, #layout0>
|
||||
%2 = tt.addptr %1, %0 : tensor<64x!tt.ptr<i32>, #layout0>, tensor<64xi32, #layout0>
|
||||
%3 = tt.load %2 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<64xi32, #layout0>
|
||||
// CHECK: triton_gpu.convert_layout
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
%4 = triton_gpu.convert_layout %3 : (tensor<64xi32, #layout0>) -> tensor<64xi32, #layout2>
|
||||
%5 = triton_gpu.convert_layout %2 : (tensor<64x!tt.ptr<i32>, #layout0>) -> tensor<64x!tt.ptr<i32>, #layout2>
|
||||
tt.store %5, %4 : tensor<64xi32, #layout2>
|
||||
return
|
||||
}
|
||||
|
||||
// Always rematerialize single value loads
|
||||
// CHECK-LABEL: remat_single_value
|
||||
func.func @remat_single_value(%arg: !tt.ptr<i32> {tt.divisibility = 16 : i32}) {
|
||||
@@ -998,3 +955,30 @@ func.func public @cmp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1: !
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// Just make sure it doesn't crash on non-tensor types.
|
||||
// CHECK-LABEL: if_no_tensor
|
||||
func.func public @if_no_tensor(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg2: i32 {tt.divisibility = 16 : i32}, %arg3: !tt.ptr<i64> {tt.divisibility = 16 : i32}) {
|
||||
// CHECK-NOT: triton_gpu.convert_layout
|
||||
%c-1_i64 = arith.constant -1 : i64
|
||||
%cst = arith.constant 0.000000e+00 : f32
|
||||
%c-1_i32 = arith.constant -1 : i32
|
||||
%0 = tt.get_program_id {axis = 0 : i32} : i32
|
||||
%1 = tt.addptr %arg3, %0 : !tt.ptr<i64>, i32
|
||||
%2 = tt.load %1 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : i64
|
||||
%3 = arith.cmpi eq, %2, %c-1_i64 : i64
|
||||
%4 = arith.select %3, %c-1_i32, %arg2 : i32
|
||||
%5 = scf.if %3 -> (!tt.ptr<f32>) {
|
||||
scf.yield %arg0 : !tt.ptr<f32>
|
||||
} else {
|
||||
%10 = tt.addptr %arg0, %2 : !tt.ptr<f32>, i64
|
||||
scf.yield %10 : !tt.ptr<f32>
|
||||
}
|
||||
%6 = arith.extsi %4 : i32 to i64
|
||||
%7 = arith.cmpi slt, %2, %6 : i64
|
||||
%8 = tt.load %5, %7, %cst {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : f32
|
||||
%9 = tt.addptr %arg1, %0 : !tt.ptr<f32>, i32
|
||||
tt.store %9, %8 {cache = 1 : i32, evict = 1 : i32} : f32
|
||||
return
|
||||
}
|
||||
|
||||
@@ -29,26 +29,26 @@
|
||||
// CHECK-DAG: %[[LOOP_COND_1_SPLAT_B:.*]] = tt.splat %[[LOOP_COND_1]]
|
||||
// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]], %[[LOOP_COND_1_SPLAT_B]]
|
||||
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
||||
// CHECK: %[[A0:.*]] = tensor.extract_slice %[[A1BUFFER]][0, 0, 0]
|
||||
// CHECK: %[[B0:.*]] = tensor.extract_slice %[[B1BUFFER]][0, 0, 0]
|
||||
// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]][0, 0, 0]
|
||||
// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0]
|
||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]]
|
||||
// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]]
|
||||
// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]]
|
||||
// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op]], {{.*}}
|
||||
// CHECK: %[[arg_b0_dot_op_0:.*]] = triton_gpu.convert_layout %[[arg_b0]]
|
||||
// CHECK: %[[arg_b0_dot_op_1:.*]] = arith.mulf %[[arg_b0_dot_op_0]]
|
||||
// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op_1]], {{.*}}
|
||||
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK-DAG: %[[EXTRACT_INT:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.index_cast %[[EXTRACT_INT]] : i32 to index
|
||||
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
|
||||
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
|
||||
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
||||
// CHECK: %[[NEXT_A:.*]] = tensor.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
|
||||
// CHECK: %[[NEXT_B:.*]] = tensor.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
|
||||
// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
|
||||
// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
|
||||
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
|
||||
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]]
|
||||
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]]
|
||||
func.func @matmul_loop(%lb : index, %ub : index, %step : index,
|
||||
%A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
|
||||
%B : !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
|
||||
%B : !tt.ptr<f16> {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> {
|
||||
// A ptrs
|
||||
%a_ptr_splat = tt.splat %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32, #ALs0>
|
||||
@@ -71,12 +71,15 @@ func.func @matmul_loop(%lb : index, %ub : index, %step : index,
|
||||
|
||||
%a_off = arith.constant dense<4> : tensor<128x32xi32, #AL>
|
||||
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
|
||||
|
||||
%b_scale = arith.constant dense<4.> : tensor<32x128xf16, #B>
|
||||
|
||||
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
||||
%loop:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
||||
%a_ = tt.load %a_ptr {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||
%b__ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||
%b_ = triton_gpu.convert_layout %b__ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||
%b = arith.mulf %b_, %b_scale: tensor<32x128xf16, #B>
|
||||
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
|
||||
@@ -84,10 +87,9 @@ func.func @matmul_loop(%lb : index, %ub : index, %step : index,
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
|
||||
scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
|
||||
}
|
||||
return
|
||||
return %loop#2: tensor<128x128xf32, #C>
|
||||
}
|
||||
|
||||
|
||||
// CHECK: func.func @matmul_loop_nested
|
||||
// CHECK-DAG: %[[CONSTANT_0:.*]] = arith.constant 0 : i32
|
||||
// CHECK-DAG: %[[CONSTANT_1:.*]] = arith.constant 1 : i32
|
||||
@@ -101,27 +103,28 @@ func.func @matmul_loop(%lb : index, %ub : index, %step : index,
|
||||
// CHECK: %[[A1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]]
|
||||
// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]]
|
||||
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
||||
// CHECK: %[[A0:.*]] = tensor.extract_slice %[[A1BUFFER]][0, 0, 0]
|
||||
// CHECK: %[[B0:.*]] = tensor.extract_slice %[[B1BUFFER]][0, 0, 0]
|
||||
// CHECK: %[[A0:.*]] = triton_gpu.extract_slice %[[A1BUFFER]][0, 0, 0]
|
||||
// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0]
|
||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]]
|
||||
// CHECK: %[[arg_a0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_a0]]
|
||||
// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]]
|
||||
// CHECK: tt.dot %[[arg_a0_dot_op]], %[[arg_b0_dot_op]], {{.*}}
|
||||
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK-DAG: %[[EXTRACT_INT:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.index_cast %[[EXTRACT_INT]] : i32 to index
|
||||
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK: %[[NEXT_A_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
|
||||
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
|
||||
// CHECK: triton_gpu.async_wait {num = 2 : i32}
|
||||
// CHECK: %[[NEXT_A:.*]] = tensor.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
|
||||
// CHECK: %[[NEXT_B:.*]] = tensor.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
|
||||
// CHECK: %[[NEXT_A:.*]] = triton_gpu.extract_slice %[[NEXT_A_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
|
||||
// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
|
||||
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
|
||||
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]]
|
||||
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_BUFFER]], %[[NEXT_B_BUFFER]], %[[NEXT_A]], %[[NEXT_B]], {{.*}}, {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]]
|
||||
func.func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
|
||||
%A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
|
||||
%B : !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
|
||||
scf.for %iv0 = %lb to %ub step %step {
|
||||
%B : !tt.ptr<f16> {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C>{
|
||||
|
||||
%c_start = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C>
|
||||
%loop1:1 = scf.for %iv0 = %lb to %ub step %step iter_args(%c_init = %c_start) -> (tensor<128x128xf32, #C>) {
|
||||
// A ptrs
|
||||
%a_ptr_splat = tt.splat %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32, #ALs0>
|
||||
@@ -139,12 +142,11 @@ func.func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
|
||||
%a_other = arith.constant dense<0.00e+00> : tensor<128x32xf16, #AL>
|
||||
%b_mask = arith.constant dense<true> : tensor<32x128xi1, #BL>
|
||||
%b_other = arith.constant dense<0.00e+00> : tensor<32x128xf16, #BL>
|
||||
%c_init = arith.constant dense<0.00e+00> : tensor<128x128xf32, #C>
|
||||
|
||||
%a_off = arith.constant dense<4> : tensor<128x32xi32, #AL>
|
||||
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
|
||||
|
||||
scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
||||
%loop2:3 = scf.for %iv = %lb to %ub step %step iter_args(%a_ptr = %a_ptr_init, %b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
||||
%a_ = tt.load %a_ptr, %a_mask, %a_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<128x32xf16, #AL>
|
||||
%a = triton_gpu.convert_layout %a_ : (tensor<128x32xf16, #AL>) -> tensor<128x32xf16, #A>
|
||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||
@@ -156,9 +158,11 @@ func.func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
|
||||
scf.yield %next_a_ptr, %next_b_ptr, %c : tensor<128x32x!tt.ptr<f16>, #AL>, tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
|
||||
}
|
||||
|
||||
scf.yield %loop2#2 : tensor<128x128xf32, #C>
|
||||
}
|
||||
return
|
||||
}
|
||||
return %loop1#0 : tensor<128x128xf32, #C>
|
||||
}
|
||||
|
||||
|
||||
// CHECK: func.func @matmul_loop_single_pipeline
|
||||
@@ -170,22 +174,21 @@ func.func @matmul_loop_nested(%lb : index, %ub : index, %step : index,
|
||||
// CHECK: %[[B0BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_0]]
|
||||
// CHECK: %[[B1BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[CONSTANT_1]]
|
||||
// CHECK: triton_gpu.async_wait {num = 1 : i32}
|
||||
// CHECK: %[[B0:.*]] = tensor.extract_slice %[[B1BUFFER]][0, 0, 0]
|
||||
// CHECK: %[[B0:.*]] = triton_gpu.extract_slice %[[B1BUFFER]][0, 0, 0]
|
||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, {{.*}}, %[[arg_b0:.*]] = %[[B0]], {{.*}}, {{.*}}, %[[PIPELINE_IDX:.*]] = %[[CONSTANT_2]], %[[LOOP_IDX:.*]] = %[[CONSTANT_1]]
|
||||
// CHECK: %[[arg_b0_dot_op:.*]] = triton_gpu.convert_layout %[[arg_b0]]
|
||||
// CHECK: tt.dot {{.*}}, %[[arg_b0_dot_op]], {{.*}}
|
||||
// CHECK-DAG: %[[INSERT_IDX:.*]] = arith.remsi %[[PIPELINE_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK-DAG: %[[EXTRACT_INT:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.index_cast %[[EXTRACT_INT]] : i32 to index
|
||||
// CHECK-DAG: %[[EXTRACT_IDX:.*]] = arith.remsi %[[LOOP_IDX]], %[[CONSTANT_3]]
|
||||
// CHECK: %[[NEXT_B_BUFFER:.*]] = triton_gpu.insert_slice_async {{.*}}, {{.*}}, %[[INSERT_IDX]]
|
||||
// CHECK: triton_gpu.async_wait {num = 1 : i32}
|
||||
// CHECK: %[[NEXT_B:.*]] = tensor.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
|
||||
// CHECK: %[[NEXT_B:.*]] = triton_gpu.extract_slice %[[NEXT_B_BUFFER]][%[[EXTRACT_IDX]], 0, 0]
|
||||
// CHECK-DAG: %[[NEXT_PIPELINE_IDX:.*]] = arith.addi %[[PIPELINE_IDX]], %[[CONSTANT_1]]
|
||||
// CHECK-DAG: %[[NEXT_LOOP_IDX:.*]] = arith.addi %[[LOOP_IDX]], %[[CONSTANT_1]]
|
||||
// CHECK: scf.yield {{.*}}, {{.*}}, %[[NEXT_B_BUFFER]], %[[NEXT_B]], {{.*}}, {{.*}}, %[[NEXT_PIPELINE_IDX]], %[[NEXT_LOOP_IDX]]
|
||||
func.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
|
||||
%A : !tt.ptr<f16> {tt.divisibility = 16 : i32},
|
||||
%B : !tt.ptr<f16> {tt.divisibility = 16 : i32}) {
|
||||
%B : !tt.ptr<f16> {tt.divisibility = 16 : i32}) -> tensor<128x128xf32, #C> {
|
||||
// A ptrs
|
||||
%a_ptr_splat = tt.splat %A : (!tt.ptr<f16>) -> tensor<128x32x!tt.ptr<f16>, #AL>
|
||||
%a_tmp0 = tt.make_range {end = 32: i32, start = 0: i32} : tensor<32xi32, #ALs0>
|
||||
@@ -211,12 +214,12 @@ func.func @matmul_loop_single_pipeline(%lb : index, %ub : index, %step : index,
|
||||
|
||||
%b_off = arith.constant dense<4> : tensor<32x128xi32, #BL>
|
||||
|
||||
scf.for %iv = %lb to %ub step %step iter_args(%b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
||||
%loop:2 = scf.for %iv = %lb to %ub step %step iter_args(%b_ptr = %b_ptr_init, %prev_c = %c_init) -> (tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>) {
|
||||
%b_ = tt.load %b_ptr, %b_mask, %b_other {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<32x128xf16, #BL>
|
||||
%b = triton_gpu.convert_layout %b_ : (tensor<32x128xf16, #BL>) -> tensor<32x128xf16, #B>
|
||||
%c = tt.dot %a, %b, %prev_c {allowTF32 = true, transA = false, transB = false} : tensor<128x32xf16, #A> * tensor<32x128xf16, #B> -> tensor<128x128xf32, #C>
|
||||
%next_b_ptr = tt.addptr %b_ptr, %b_off : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<32x128xi32, #BL>
|
||||
scf.yield %next_b_ptr, %c : tensor<32x128x!tt.ptr<f16>, #BL>, tensor<128x128xf32, #C>
|
||||
}
|
||||
return
|
||||
}
|
||||
return %loop#1 : tensor<128x128xf32, #C>
|
||||
}
|
||||
@@ -12,20 +12,20 @@
|
||||
|
||||
|
||||
// CHECK: func.func @matmul_loop
|
||||
// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = tensor.extract_slice %[[A0:.*]][0, 0] [128, 16]
|
||||
// CHECK-DAG: %[[A0_PREFETCH_SMEM:.*]] = triton_gpu.extract_slice %[[A0:.*]][0, 0] [128, 16]
|
||||
// CHECK-DAG: %[[A0_PREFETCH:.*]] = triton_gpu.convert_layout %[[A0_PREFETCH_SMEM]]
|
||||
// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = tensor.extract_slice %[[B0:.*]][0, 0] [16, 128]
|
||||
// CHECK-DAG: %[[B0_PREFETCH_SMEM:.*]] = triton_gpu.extract_slice %[[B0:.*]][0, 0] [16, 128]
|
||||
// CHECK-DAG: %[[B0_PREFETCH:.*]] = triton_gpu.convert_layout %[[B0_PREFETCH_SMEM]]
|
||||
// CHECK: scf.for {{.*}} iter_args({{.*}}, {{.*}}, %[[arg_a0:.*]] = %[[A0]], %[[arg_b0:.*]] = %[[B0]], {{.*}}, %[[a0_prefetch:.*]] = %[[A0_PREFETCH]], %[[b0_prefetch:.*]] = %[[B0_PREFETCH]]
|
||||
// CHECK-DAG: %[[A_REM_SMEM:.*]] = tensor.extract_slice %[[arg_a0]][0, 16] [128, 16]
|
||||
// CHECK-DAG: %[[A_REM_SMEM:.*]] = triton_gpu.extract_slice %[[arg_a0]][0, 16] [128, 16]
|
||||
// CHECK-DAG: %[[A_REM:.*]] = triton_gpu.convert_layout %[[A_REM_SMEM]]
|
||||
// CHECK-DAG: %[[B_REM_SMEM:.*]] = tensor.extract_slice %[[arg_b0]][16, 0] [16, 128]
|
||||
// CHECK-DAG: %[[B_REM_SMEM:.*]] = triton_gpu.extract_slice %[[arg_b0]][16, 0] [16, 128]
|
||||
// CHECK-DAG: %[[B_REM:.*]] = triton_gpu.convert_layout %[[B_REM_SMEM]]
|
||||
// CHECK: %[[D_FIRST:.*]] = tt.dot %[[a0_prefetch]], %[[b0_prefetch:.*]], {{.*}}
|
||||
// CHECK: tt.dot %[[A_REM]], %[[B_REM]], %[[D_FIRST:.*]]
|
||||
// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = tensor.extract_slice {{.*}}[0, 0] [128, 16]
|
||||
// CHECK-DAG: %[[NEXT_A_PREFETCH_SMEM:.*]] = triton_gpu.extract_slice {{.*}}[0, 0] [128, 16]
|
||||
// CHECK-DAG: %[[NEXT_A_PREFETCH:.*]] = triton_gpu.convert_layout %[[NEXT_A_PREFETCH_SMEM]]
|
||||
// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = tensor.extract_slice {{.*}}[0, 0] [16, 128]
|
||||
// CHECK-DAG: %[[NEXT_B_PREFETCH_SMEM:.*]] = triton_gpu.extract_slice {{.*}}[0, 0] [16, 128]
|
||||
// CHECK-DAG: %[[NEXT_B_PREFETCH:.*]] = triton_gpu.convert_layout %[[NEXT_B_PREFETCH_SMEM]]
|
||||
// CHECK: scf.yield {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, %[[NEXT_A_PREFETCH]], %[[NEXT_B_PREFETCH]]
|
||||
func.func @matmul_loop(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
|
||||
|
||||
@@ -1,73 +0,0 @@
|
||||
// RUN: triton-opt %s -split-input-file -tritongpu-fuse-transposition -tritongpu-update-mma-for-volta 2>&1 | FileCheck %s
|
||||
|
||||
// -----
|
||||
|
||||
// check the UpdateMMAVersionMinorForVolta pattern
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 4], order = [1, 0]}>
|
||||
#shared0 = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8 ,order = [1, 0]}>
|
||||
#mma0 = #triton_gpu.mma<{versionMajor=1, versionMinor=0, warpsPerCTA=[4,4]}>
|
||||
// Here, the isMMAv1Row of a and b's dot_operands mismatch #mma0's versionMinor,
|
||||
// and the pattern should update the versionMinor.
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, isMMAv1Row=true}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, isMMAv1Row=false}>
|
||||
// It creates a new MMA layout to fit with $a and $b's dot_operand, and get the right warpsPerCTA
|
||||
// The ID of this MMA instance should be 0.
|
||||
// CHECK: [[$new_mma:#mma.*]] = #triton_gpu.mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [4, 2]}>
|
||||
module attributes {"triton_gpu.num-warps" = 16 : i32} {
|
||||
// CHECK-LABEL: dot_mmav1
|
||||
func.func @dot_mmav1(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) -> tensor<64x64xf32, #blocked0> {
|
||||
%C = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked0>
|
||||
%AA = triton_gpu.convert_layout %A : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_a>
|
||||
%BB = triton_gpu.convert_layout %B : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_b>
|
||||
%CC = triton_gpu.convert_layout %C : (tensor<64x64xf32, #blocked0>) -> tensor<64x64xf32, #mma0>
|
||||
|
||||
// CHECK: {{.*}} = tt.dot {{.*}}, {{.*}}, %cst {allowTF32 = true} : tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = [[$new_mma]], isMMAv1Row = true}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = [[$new_mma]], isMMAv1Row = true}>> -> tensor<64x64xf32, [[$new_mma]]>
|
||||
%D = tt.dot %AA, %BB, %CC {allowTF32 = true} : tensor<64x64xf16, #dot_operand_a> * tensor<64x64xf16, #dot_operand_b> -> tensor<64x64xf32, #mma0>
|
||||
%res = triton_gpu.convert_layout %D : (tensor<64x64xf32, #mma0>) -> tensor<64x64xf32, #blocked0>
|
||||
|
||||
return %res : tensor<64x64xf32, #blocked0>
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
// Check id in multiple MMA layout instances
|
||||
|
||||
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 4], warpsPerCTA = [4, 4], order = [1, 0]}>
|
||||
#shared0 = #triton_gpu.shared<{vec = 1, perPhase=2, maxPhase=8 ,order = [1, 0]}>
|
||||
#mma0 = #triton_gpu.mma<{versionMajor=1, versionMinor=0, warpsPerCTA=[4,4]}>
|
||||
// mma id=1, with all other boolean flags be false, should get a versionMinor of 16(= 1 * 1<<4)
|
||||
#mma1 = #triton_gpu.mma<{versionMajor=1, versionMinor=16, warpsPerCTA=[4,4]}>
|
||||
|
||||
// Will still get two MMA layouts
|
||||
// CHECK-DAG: [[$new_mma:#mma.*]] = #triton_gpu.mma<{versionMajor = 1, versionMinor = 3, warpsPerCTA = [4, 2]}>
|
||||
// CHECK-DAG: [[$new_mma1:#mma.*]] = #triton_gpu.mma<{versionMajor = 1, versionMinor = 19, warpsPerCTA = [4, 2]}>
|
||||
|
||||
#dot_operand_a = #triton_gpu.dot_op<{opIdx=0, parent=#mma0, isMMAv1Row=true}>
|
||||
#dot_operand_b = #triton_gpu.dot_op<{opIdx=1, parent=#mma0, isMMAv1Row=false}>
|
||||
#dot_operand_a1 = #triton_gpu.dot_op<{opIdx=0, parent=#mma1, isMMAv1Row=true}>
|
||||
#dot_operand_b1 = #triton_gpu.dot_op<{opIdx=1, parent=#mma1, isMMAv1Row=false}>
|
||||
|
||||
module attributes {"triton_gpu.num-warps" = 16 : i32} {
|
||||
// CHECK-LABEL: dot_mmav1
|
||||
func.func @dot_mmav1(%A: tensor<64x64xf16, #blocked0>, %B: tensor<64x64xf16, #blocked0>) -> tensor<64x64xf32, #blocked0> {
|
||||
%C = arith.constant dense<0.000000e+00> : tensor<64x64xf32, #blocked0>
|
||||
%AA = triton_gpu.convert_layout %A : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_a>
|
||||
%BB = triton_gpu.convert_layout %B : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_b>
|
||||
%CC = triton_gpu.convert_layout %C : (tensor<64x64xf32, #blocked0>) -> tensor<64x64xf32, #mma0>
|
||||
|
||||
%AA1 = triton_gpu.convert_layout %A : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_a1>
|
||||
%BB1 = triton_gpu.convert_layout %B : (tensor<64x64xf16, #blocked0>) -> tensor<64x64xf16, #dot_operand_b1>
|
||||
%CC1 = triton_gpu.convert_layout %C : (tensor<64x64xf32, #blocked0>) -> tensor<64x64xf32, #mma1>
|
||||
|
||||
// CHECK: {{.*}} = tt.dot {{.*}}, {{.*}}, {{.*}} {allowTF32 = true} : tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = [[$new_mma]], isMMAv1Row = true}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = [[$new_mma]], isMMAv1Row = true}>> -> tensor<64x64xf32, [[$new_mma]]>
|
||||
// CHECK: {{.*}} = tt.dot {{.*}}, {{.*}}, {{.*}} {allowTF32 = true} : tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 0, parent = [[$new_mma1]], isMMAv1Row = true}>> * tensor<64x64xf16, #triton_gpu.dot_op<{opIdx = 1, parent = [[$new_mma1]], isMMAv1Row = true}>> -> tensor<64x64xf32, [[$new_mma1]]>
|
||||
%D = tt.dot %AA, %BB, %CC {allowTF32 = true} : tensor<64x64xf16, #dot_operand_a> * tensor<64x64xf16, #dot_operand_b> -> tensor<64x64xf32, #mma0>
|
||||
%D1 = tt.dot %AA1, %BB1, %CC1 {allowTF32 = true} : tensor<64x64xf16, #dot_operand_a1> * tensor<64x64xf16, #dot_operand_b1> -> tensor<64x64xf32, #mma1>
|
||||
%res = triton_gpu.convert_layout %D : (tensor<64x64xf32, #mma0>) -> tensor<64x64xf32, #blocked0>
|
||||
%res1 = triton_gpu.convert_layout %D1 : (tensor<64x64xf32, #mma1>) -> tensor<64x64xf32, #blocked0>
|
||||
%sum = arith.addf %res, %res1 : tensor<64x64xf32, #blocked0>
|
||||
|
||||
return %sum : tensor<64x64xf32, #blocked0>
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user