MLIR 23.0.0git
IntelGpuXe2.h
Go to the documentation of this file.
1//===--- IntelGpuXe2.h ------------------------------------------*- C++ -*-===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// \file
10// Xe2 uArch definition. Xe2 is the second generation of Intel Xe GPUs.
11// This file defines the uArch details for Xe2 and its derived architectures.
12// This includes Ponte Vecchio (PVC) and Battlemage (BMG) architectures.
13//
14//===----------------------------------------------------------------------===//
15#ifndef MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
16#define MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
17
21#include "llvm/ADT/SmallVector.h"
22#include "llvm/Support/DebugLog.h"
23#include <map>
24#include <string>
25
26using namespace mlir;
27using namespace mlir::xegpu::uArch;
28
29namespace mlir {
30namespace xegpu {
31namespace uArch {
32
33struct Xe2Plus : public uArch {
34 Xe2Plus(StringRef archName, StringRef archDescription,
36 const XeCoreInfo &xeCore)
37 : uArch(archName, archDescription, instructionRegistry), xeCore(xeCore) {}
38 int getSubgroupSize() const override { return 16; }
39 unsigned getGeneralPackedFormatBitSize() const override { return 32; }
40
41protected:
43};
44
45//===----------------------------------------------------------------------===//
46// uArch instructions
47//===----------------------------------------------------------------------===//
52 static bool classof(const Instruction *B) {
53 return B->getInstructionKind() == InstructionKind::Subgroup2DBlockStore;
54 }
55 // Source :
56 // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_2d_block_io.html#_add_a_new_section_5_2_x_cl_intel_subgroup_2d_block_io
57 std::optional<
58 std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
60 const static int kHeight[] = {1, 2, 4, 8};
61 const static int kWidth16[] = {16};
62 const static int kWidth32[] = {16};
63 const static int kCount[] = {1};
64 const int elemByteSize = elemTy.getIntOrFloatBitWidth() / 8;
65 if (elemByteSize == 1)
66 return std::make_tuple(llvm::ArrayRef<int>(kWidth32),
67 llvm::ArrayRef<int>(kHeight),
68 llvm::ArrayRef<int>(kCount));
69 else if (elemByteSize == 2 || elemByteSize == 4)
70 return std::make_tuple(llvm::ArrayRef<int>(kWidth16),
71 llvm::ArrayRef<int>(kHeight),
72 llvm::ArrayRef<int>(kCount));
73 return std::nullopt;
74 }
75
76 int32_t getPackedFormatBitSize() const { return 16; }
77};
78
83 static bool classof(const Instruction *B) {
84 return B->getInstructionKind() == InstructionKind::Subgroup2DBlockLoad;
85 }
86
87 // Source :
88 // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_2d_block_io.html#_add_a_new_section_5_2_x_cl_intel_subgroup_2d_block_io
89 std::optional<
90 std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
91 getBlockWidthHeightCount(Type elemTy, bool hasTransform, bool hasTranspose,
92 bool upConv = false) const {
93 static const int kHeightAtLeast1[] = {1, 2, 4, 8, 16, 32};
94 static const int kHeightAtLeast8[] = {8, 16, 32};
95 static const int kHeightAtLeast16[] = {16, 32};
96 static const int kHeightAtLeast32[] = {32};
97
98 static const int kWidth32[] = {32};
99 static const int kWidth16[] = {16};
100 static const int kWidth8[] = {8};
101
102 static const int32_t kCount1[] = {1};
103 static const int32_t kCount2[] = {1, 2};
104 static const int32_t kCount4[] = {1, 2, 4};
105 static const int32_t kCount4Only[] = {4};
106 // (elemBytes, transform, transpose, upConvert)
107 using Key = std::tuple<int, uint8_t, uint8_t, uint8_t>;
108 // (widths, heights, counts)
109 using Value = std::tuple<llvm::ArrayRef<int32_t>, llvm::ArrayRef<int32_t>,
111 static const llvm::DenseMap<Key, Value> kMap = {
112 {{1, false, false, false}, {kWidth32, kHeightAtLeast1, kCount2}},
113 {{1, false, false, true}, {kWidth16, kHeightAtLeast8, kCount4Only}},
114 {{2, false, false, false}, {kWidth16, kHeightAtLeast1, kCount2}},
115 {{4, false, false, false}, {kWidth16, kHeightAtLeast1, kCount1}},
116 // Block Loads with Transform:
117 {{1, true, false, false}, {kWidth16, kHeightAtLeast32, kCount4}},
118 {{2, true, false, false}, {kWidth16, kHeightAtLeast16, kCount2}},
119 // Block Loads with Transpose:
120 {{4, false, true, false}, {kWidth8, kHeightAtLeast16, kCount1}},
121 };
122 const int elemByteSize = elemTy.getIntOrFloatBitWidth() / 8;
123 auto it = kMap.find({elemByteSize, hasTransform, hasTranspose, upConv});
124 if (it != kMap.end())
125 return it->second;
126 return std::nullopt;
127 }
128
129 int32_t getPackedFormatBitSize() const { return 16; }
130};
131
136 static bool classof(const Instruction *B) {
137 return B->getInstructionKind() == InstructionKind::Subgroup2DBlockPrefetch;
138 }
139 // Source :
140 // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_buffer_prefetch.html#_add_a_new_section_6_15_x_sub_group_prefetch_functions
141 std::optional<
142 std::tuple<llvm::ArrayRef<int>, llvm::ArrayRef<int>, llvm::ArrayRef<int>>>
144 static const int kHeightAtLeast1[] = {1, 2, 4, 8, 16, 32};
145
146 static const int kWidth32[] = {32};
147 static const int kWidth16[] = {16};
148
149 static const int32_t kCount1[] = {1};
150 static const int32_t kCount2[] = {1, 2};
151 // elemBytes
152 using Key = int;
153 // (widths, heights, counts)
154 using Value = std::tuple<llvm::ArrayRef<int32_t>, llvm::ArrayRef<int32_t>,
156 static const llvm::DenseMap<Key, Value> kMap = {
157 {1, {kWidth32, kHeightAtLeast1, kCount2}},
158 {2, {kWidth16, kHeightAtLeast1, kCount2}},
159 {4, {kWidth16, kHeightAtLeast1, kCount1}},
160 };
161 const int elemByteSize = elemTy.getIntOrFloatBitWidth() / 8;
162 auto it = kMap.find(elemByteSize);
163 if (it != kMap.end())
164 return it->second;
165 return std::nullopt;
166 }
167 int32_t getPackedFormatBitSize() const { return 16; }
168};
169
178 static bool classof(const Instruction *B) {
179 return B->getInstructionKind() ==
181 }
182 // Source:
183 // https://registry.khronos.org/OpenCL/extensions/intel/cl_intel_subgroup_matrix_multiply_accumulate.html
184
185 // Override all virtuals from MatrixOpInterface
187 getSupportedShapes(Type dataType, MMAOpndKind matrixType) override;
189 getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override;
190 virtual bool
191 checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
192 std::pair<uint32_t, uint32_t> BShape,
193 std::pair<uint32_t, uint32_t> CShape,
194 std::pair<uint32_t, uint32_t> DShape, Type AType,
195 Type BType, Type CType, Type DType) override;
196 virtual bool checkSupportedTypes(Type AType, Type BType, Type CType,
197 Type DType) override;
198 virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
199 std::pair<uint32_t, uint32_t> BShape,
200 std::pair<uint32_t, uint32_t> CShape,
201 std::pair<uint32_t, uint32_t> DShape, Type AType,
202 Type BType, Type CType, Type DType) override;
204 getSupportedM(Type type) const override;
206 getSupportedK(Type type) const override;
208 getSupportedN(Type type) const override;
209
212 bool isLaneLayoutRowMajorOrder() const override { return true; }
213
214protected:
215 const unsigned packedFormatBitSizeA;
216 const unsigned packedFormatBitSizeB;
217};
218
227 static bool classof(const Instruction *B) {
228 return B->getInstructionKind() ==
230 }
231 // Source:
232 // https://github.com/intel/llvm/blob/sycl/sycl/doc/design/spirv-extensions/SPV_INTEL_subgroup_scaled_matrix_multiply_accumulate.asciidoc
233
234 // Override all virtuals from MatrixOpInterface
236 getSupportedShapes(Type dataType, MMAOpndKind matrixType) override;
238 getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override;
239 virtual bool
240 checkSupportedShapesAndTypes(std::pair<uint32_t, uint32_t> AShape,
241 std::pair<uint32_t, uint32_t> BShape,
242 std::pair<uint32_t, uint32_t> CShape,
243 std::pair<uint32_t, uint32_t> DShape, Type AType,
244 Type BType, Type CType, Type DType) override;
245 virtual bool checkSupportedTypes(Type AType, Type BType, Type CType,
246 Type DType) override;
247 virtual bool validate(std::pair<uint32_t, uint32_t> AShape,
248 std::pair<uint32_t, uint32_t> BShape,
249 std::pair<uint32_t, uint32_t> CShape,
250 std::pair<uint32_t, uint32_t> DShape, Type AType,
251 Type BType, Type CType, Type DType) override;
253 getSupportedM(Type type) const override;
255 getSupportedK(Type type) const override;
257 getSupportedN(Type type) const override;
258
261 bool isLaneLayoutRowMajorOrder() const override { return true; }
262
263protected:
264 const unsigned packedFormatBitSizeA;
265 const unsigned packedFormatBitSizeB;
266};
267
269 int32_t getMaxLaneLoadSize(int32_t bitWidth) const override { return 16; }
270};
271
273 int32_t getMaxLaneStoreSize(int32_t bitWidth) const override { return 16; }
274};
275
276//===----------------------------------------------------------------------===//
277// uArch instances
278//===----------------------------------------------------------------------===//
279
280struct PVCuArch final : public Xe2Plus {
282 static const SubgroupMatrixMultiplyAcc dpasInst{16, 32};
283 static const Subgroup2DBlockLoadInstruction loadNdInst;
284 static const Subgroup2DBlockStoreInstruction storeNdInst;
285 static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
286 static const SpirvStoreScatterInstruction storeScatterInst;
287 static const SpirvLoadGatherInstruction loadGatherInst;
288 static const Instruction *arr[] = {&dpasInst, &loadNdInst,
289 &storeNdInst, &prefetchNdInst,
290 &storeScatterInst, &loadGatherInst};
291 return arr;
292 }
293
295 : Xe2Plus("pvc", // archName
296 "Ponte Vecchio Architecture", // archDescription
298 XeCoreInfo(8, SharedMemory(512 * 1024, 4), 8, 8) // xeCore
299 ) {}
300 static const uArch *getInstance() {
301 static const PVCuArch instance;
302 return reinterpret_cast<const uArch *>(&instance);
303 }
304};
305
306struct BMGuArch : public Xe2Plus {
308 static const SubgroupMatrixMultiplyAcc dpasInst{16, 32};
309 static const Subgroup2DBlockLoadInstruction loadNdInst;
310 static const Subgroup2DBlockStoreInstruction storeNdInst;
311 static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
312 static const SpirvStoreScatterInstruction storeScatterInst;
313 static const SpirvLoadGatherInstruction loadGatherInst;
314 static const Instruction *arr[] = {&dpasInst, &loadNdInst,
315 &storeNdInst, &prefetchNdInst,
316 &storeScatterInst, &loadGatherInst};
317 return arr;
318 }
319
321 : Xe2Plus("bmg", // archName
322 "Battlemage Architecture", // archDescription
324 XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8) // xeCore
325 ) {}
326 static const uArch *getInstance() {
327 static const BMGuArch instance;
328 return reinterpret_cast<const uArch *>(&instance);
329 }
330};
331
332struct CRIuArch : public Xe2Plus {
334 static const SubgroupMatrixMultiplyAcc dpasInst{16, 32};
335 static const SubgroupScaledMatrixMultiplyAcc dpasMxInst{16, 32};
336 static const Subgroup2DBlockLoadInstruction loadNdInst;
337 static const Subgroup2DBlockStoreInstruction storeNdInst;
338 static const Subgroup2DBlockPrefetchInstruction prefetchNdInst;
339 static const SpirvStoreScatterInstruction storeScatterInst;
340 static const SpirvLoadGatherInstruction loadGatherInst;
341 static const Instruction *arr[] = {
342 &dpasInst, &dpasMxInst, &loadNdInst, &storeNdInst,
343 &prefetchNdInst, &storeScatterInst, &loadGatherInst};
344 return arr;
345 }
346
348 : Xe2Plus("cri", // archName
349 "Crescent Island Architecture", // archDescription
351 // Using bmg config as placeholder
352 // TODO: Update to actual XeCore and SharedMemory config
353 XeCoreInfo(8, SharedMemory(256 * 1024, 4), 8, 8) // xeCore
354 ) {}
355 static const uArch *getInstance() {
356 static const CRIuArch instance;
357 return reinterpret_cast<const uArch *>(&instance);
358 }
359};
360
361inline const uArch *getUArch(llvm::StringRef archName) {
362 if (archName.equals_insensitive("pvc"))
363 return PVCuArch::getInstance();
364 if (archName.equals_insensitive("bmg"))
365 return BMGuArch::getInstance();
366 if (archName.equals_insensitive("cri"))
367 return CRIuArch::getInstance();
368 return nullptr;
369}
370
371} // namespace uArch
372} // namespace xegpu
373} // namespace mlir
374
375//===----------------------------------------------------------------------===//
376// Instruction implementations
377//===----------------------------------------------------------------------===//
378
381 MMAOpndKind matrixType) {
382 auto combineVectors = [](const llvm::SmallVector<uint32_t, 8> &a,
384 -> llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16> {
386 for (unsigned x : a) {
387 for (unsigned y : b) {
388 result.emplace_back(x, y);
389 }
390 }
391 return result;
392 };
393
394 auto M = getSupportedM(dataType);
395 auto K = getSupportedK(dataType);
396 auto N = getSupportedN(dataType);
398
399 switch (matrixType) {
401 resultMatrix = combineVectors(M, K);
402 break;
404 resultMatrix = combineVectors(K, N);
405 break;
407 resultMatrix = combineVectors(M, N);
408 break;
410 resultMatrix = combineVectors(M, N);
411 break;
412 }
413 return resultMatrix;
414}
415
418 MMAOpndKind matrixType) {
419 Type bf16Type = BFloat16Type::get(&context);
420 Type f16Type = Float16Type::get(&context);
421 Type tf32Type = FloatTF32Type::get(&context);
422 Type f32Type = Float32Type::get(&context);
423
424 switch (matrixType) {
426 return {bf16Type, f16Type, tf32Type};
428 return {bf16Type, f16Type, tf32Type};
430 return {bf16Type, f16Type, f32Type};
432 return {bf16Type, f16Type, f32Type};
433 }
434 return {};
435}
436
438 Type BType,
439 Type CType,
440 Type DType) {
441 if (AType.isF16() || BType.isF16()) {
442 if (AType != BType || (CType && (!CType.isF32() && !CType.isF16())) ||
443 (!DType.isF32() && !DType.isF16())) {
444 LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
445 return false;
446 }
447 } else if (AType.isBF16() || BType.isBF16()) {
448 if (AType != BType || (CType && (!CType.isF32() && !CType.isBF16())) ||
449 (!DType.isF32() && !DType.isBF16())) {
450 LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
451 return false;
452 }
453 } else if (AType.isTF32() || BType.isTF32()) {
454 if (AType != BType || (CType && (!CType.isF32() && !DType.isF32())) ||
455 (!DType.isF32())) {
456 LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
457 return false;
458 }
459 } else if (!(AType.isInteger(2) || AType.isInteger(4) ||
460 AType.isInteger(8)) &&
461 !(BType.isInteger(2) || BType.isInteger(4) ||
462 BType.isInteger(8))) {
463 LDBG() << "Unsupported dpas combinations of Dst, Acc, A and B matrices.";
464 return false;
465 }
466
467 return true;
468}
469
471 std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
472 std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
473 Type AType, Type BType, Type CType, Type DType) {
474 auto supportedAShapes = getSupportedShapes(AType, MMAOpndKind::MatrixA);
475 auto supportedBShapes = getSupportedShapes(BType, MMAOpndKind::MatrixB);
476 auto supportedCShapes = getSupportedShapes(CType, MMAOpndKind::MatrixC);
477 auto supportedDShapes = getSupportedShapes(DType, MMAOpndKind::MatrixD);
478 return llvm::is_contained(supportedAShapes, AShape) &&
479 llvm::is_contained(supportedBShapes, BShape) &&
480 llvm::is_contained(supportedCShapes, CShape) &&
481 llvm::is_contained(supportedDShapes, DShape) &&
482 checkSupportedTypes(AType, BType, CType, DType);
483}
484
486 std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
487 std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
488 Type AType, Type BType, Type CType, Type DType) {
489 return checkSupportedShapesAndTypes(AShape, BShape, CShape, DShape, AType,
490 BType, CType, DType);
491}
492
495 return {1, 2, 3, 4, 5, 6, 7, 8};
496}
497
500 // assert if data type is not int or float type
501 assert(type.isIntOrFloat() && "Matrix type must be int or float");
502 auto bitWidth = type.getIntOrFloatBitWidth();
503 uint32_t kSize = 0;
504 switch (bitWidth) {
505 case 2:
506 kSize = 64;
507 break;
508 case 4:
509 kSize = 64;
510 break;
511 case 8:
512 kSize = 32;
513 break;
514 case 16:
515 kSize = 16;
516 break;
517 case 32:
518 kSize = 8;
519 break;
520 default:
521 llvm_unreachable("Invalid int or float");
522 }
523 return {kSize};
524}
525
528 return {16};
529}
530
531//===----------------------------------------------------------------------===//
532// SubgroupScaledMatrixMultiplyAcc implementations
533//===----------------------------------------------------------------------===//
534
537 MMAOpndKind matrixType) {
538 auto combineVectors = [](const llvm::SmallVector<uint32_t, 8> &a,
540 -> llvm::SmallVector<std::pair<uint32_t, uint32_t>, 16> {
542 for (unsigned x : a) {
543 for (unsigned y : b) {
544 result.emplace_back(x, y);
545 }
546 }
547 return result;
548 };
549
550 // Avoid calling getSupportedK for C/D types (which are f32/bf16
551 // and not valid for the K-dimension bit-width calculation).
552 switch (matrixType) {
554 return combineVectors(getSupportedM(dataType), getSupportedK(dataType));
556 return combineVectors(getSupportedK(dataType), getSupportedN(dataType));
559 return combineVectors(getSupportedM(dataType), getSupportedN(dataType));
560 }
561 return {};
562}
563
566 MMAOpndKind matrixType) {
567 Type f8E4M3FNType = Float8E4M3FNType::get(&context);
568 Type f8E5M2Type = Float8E5M2Type::get(&context);
569 Type f4E2M1FNType = Float4E2M1FNType::get(&context);
570 Type bf16Type = BFloat16Type::get(&context);
571 Type f32Type = Float32Type::get(&context);
572
573 switch (matrixType) {
575 return {f8E4M3FNType, f8E5M2Type, f4E2M1FNType};
577 return {f8E4M3FNType, f8E5M2Type, f4E2M1FNType};
579 return {bf16Type, f32Type};
581 return {bf16Type, f32Type};
582 }
583 return {};
584}
585
587 Type BType,
588 Type CType,
589 Type DType) {
590 auto isSupportedLowPrecision = [](Type t) {
591 return t.isF8E4M3FN() || t.isF8E5M2() || llvm::isa<Float4E2M1FNType>(t);
592 };
593 auto isSupportedAccum = [](Type t) { return t.isF32() || t.isBF16(); };
594
595 if (!isSupportedLowPrecision(AType) || !isSupportedLowPrecision(BType)) {
596 LDBG() << "Unsupported scaled dpas: A and B must be FP8 or FP4 types.";
597 return false;
598 }
599
600 // A and B must have the same bit width for K dimension compatibility.
601 if (AType.getIntOrFloatBitWidth() != BType.getIntOrFloatBitWidth()) {
602 LDBG() << "Unsupported scaled dpas: A and B must have the same bit width.";
603 return false;
604 }
605
606 if (CType && !isSupportedAccum(CType)) {
607 LDBG() << "Unsupported scaled dpas: C must be f32 or bf16.";
608 return false;
609 }
610
611 if (!isSupportedAccum(DType)) {
612 LDBG() << "Unsupported scaled dpas: D must be f32 or bf16.";
613 return false;
614 }
615
616 return true;
617}
618
620 std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
621 std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
622 Type AType, Type BType, Type CType, Type DType) {
623 auto supportedAShapes = getSupportedShapes(AType, MMAOpndKind::MatrixA);
624 auto supportedBShapes = getSupportedShapes(BType, MMAOpndKind::MatrixB);
625 auto supportedCShapes = getSupportedShapes(CType, MMAOpndKind::MatrixC);
626 auto supportedDShapes = getSupportedShapes(DType, MMAOpndKind::MatrixD);
627 return llvm::is_contained(supportedAShapes, AShape) &&
628 llvm::is_contained(supportedBShapes, BShape) &&
629 llvm::is_contained(supportedCShapes, CShape) &&
630 llvm::is_contained(supportedDShapes, DShape) &&
631 checkSupportedTypes(AType, BType, CType, DType);
632}
633
635 std::pair<uint32_t, uint32_t> AShape, std::pair<uint32_t, uint32_t> BShape,
636 std::pair<uint32_t, uint32_t> CShape, std::pair<uint32_t, uint32_t> DShape,
637 Type AType, Type BType, Type CType, Type DType) {
638 return checkSupportedShapesAndTypes(AShape, BShape, CShape, DShape, AType,
639 BType, CType, DType);
640}
641
644 return {8};
645}
646
649 assert(type.isIntOrFloat() && "Matrix type must be int or float");
650 auto bitWidth = type.getIntOrFloatBitWidth();
651 uint32_t kSize = 0;
652 switch (bitWidth) {
653 case 4:
654 kSize = 64; // FP4: scale K by 4 (base 16-bit K=16 -> 64)
655 break;
656 case 8:
657 kSize = 32; // FP8: scale K by 2 (base 16-bit K=16 -> 32)
658 break;
659 default:
660 // Scaled dpas only supports FP8 (8-bit) and FP4 (4-bit) types for A/B
661 // matrices. Return empty so callers can gracefully reject unsupported
662 // types instead of aborting.
663 return {};
664 }
665 return {kSize};
666}
667
670 return {16};
671}
672
673#endif // MLIR_DIALECT_XEGPU_UARCH_INTELGPUXE2_H
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isTF32() const
Definition Types.cpp:39
bool isF32() const
Definition Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition Types.cpp:58
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition Types.cpp:118
bool isF16() const
Definition Types.cpp:38
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition Types.cpp:124
bool isBF16() const
Definition Types.cpp:37
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
const uArch * getUArch(llvm::StringRef archName)
Include the generated interface declarations.
static llvm::ArrayRef< const Instruction * > getInstructionRegistryArr()
static const uArch * getInstance()
static const uArch * getInstance()
static llvm::ArrayRef< const Instruction * > getInstructionRegistryArr()
Instruction(InstructionKind kind, InstructionScope scope)
Definition uArchBase.h:56
static llvm::ArrayRef< const Instruction * > getInstructionRegistryArr()
static const uArch * getInstance()
int32_t getMaxLaneLoadSize(int32_t bitWidth) const override
int32_t getMaxLaneStoreSize(int32_t bitWidth) const override
std::optional< std::tuple< llvm::ArrayRef< int >, llvm::ArrayRef< int >, llvm::ArrayRef< int > > > getBlockWidthHeightCount(Type elemTy, bool hasTransform, bool hasTranspose, bool upConv=false) const
Definition IntelGpuXe2.h:91
static bool classof(const Instruction *B)
Definition IntelGpuXe2.h:83
std::optional< std::tuple< llvm::ArrayRef< int >, llvm::ArrayRef< int >, llvm::ArrayRef< int > > > getBlockWidthHeightCount(Type elemTy) const
static bool classof(const Instruction *B)
Definition IntelGpuXe2.h:52
std::optional< std::tuple< llvm::ArrayRef< int >, llvm::ArrayRef< int >, llvm::ArrayRef< int > > > getBlockWidthHeightCount(Type elemTy) const
Definition IntelGpuXe2.h:59
virtual llvm::SmallVector< std::pair< uint32_t, uint32_t >, 16 > getSupportedShapes(Type dataType, MMAOpndKind matrixType) override
virtual llvm::SmallVector< uint32_t, 8 > getSupportedN(Type type) const override
virtual bool validate(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType) override
virtual llvm::SmallVector< uint32_t, 8 > getSupportedM(Type type) const override
virtual llvm::SmallVector< uint32_t, 8 > getSupportedK(Type type) const override
SubgroupMatrixMultiplyAcc(unsigned packedFormatBitSizeA, unsigned packedFormatBitSizeB)
static bool classof(const Instruction *B)
virtual llvm::SmallVector< Type, 8 > getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override
virtual bool checkSupportedShapesAndTypes(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType) override
virtual bool checkSupportedTypes(Type AType, Type BType, Type CType, Type DType) override
virtual llvm::SmallVector< Type, 8 > getSupportedTypes(MLIRContext &context, MMAOpndKind matrixType) override
virtual llvm::SmallVector< uint32_t, 8 > getSupportedK(Type type) const override
virtual llvm::SmallVector< uint32_t, 8 > getSupportedN(Type type) const override
virtual bool checkSupportedTypes(Type AType, Type BType, Type CType, Type DType) override
virtual llvm::SmallVector< uint32_t, 8 > getSupportedM(Type type) const override
SubgroupScaledMatrixMultiplyAcc(unsigned packedFormatBitSizeA, unsigned packedFormatBitSizeB)
virtual bool validate(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType) override
static bool classof(const Instruction *B)
virtual llvm::SmallVector< std::pair< uint32_t, uint32_t >, 16 > getSupportedShapes(Type dataType, MMAOpndKind matrixType) override
virtual bool checkSupportedShapesAndTypes(std::pair< uint32_t, uint32_t > AShape, std::pair< uint32_t, uint32_t > BShape, std::pair< uint32_t, uint32_t > CShape, std::pair< uint32_t, uint32_t > DShape, Type AType, Type BType, Type CType, Type DType) override
unsigned getGeneralPackedFormatBitSize() const override
Definition IntelGpuXe2.h:39
int getSubgroupSize() const override
Definition IntelGpuXe2.h:38
Xe2Plus(StringRef archName, StringRef archDescription, llvm::ArrayRef< const Instruction * > instructionRegistry, const XeCoreInfo &xeCore)
Definition IntelGpuXe2.h:34
llvm::SmallDenseMap< InstructionKind, const Instruction *, 32 > instructionRegistry
Definition uArchBase.h:183
uArch(StringRef name, StringRef description, llvm::ArrayRef< const Instruction * > instructionRegistry)
Definition uArchBase.h:156