29#include "llvm/ADT/STLExtras.h"
30#include "llvm/Support/FormatVariadic.h"
35#include "llvm/ADT/TypeSwitch.h"
40#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
41#include "mlir/Conversion/Passes.h.inc"
49static constexpr int32_t systolicDepth{8};
50static constexpr int32_t executionSize{16};
53enum class NdTdescOffset : uint32_t {
60static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
61 switch (xeGpuMemspace) {
62 case xegpu::MemorySpace::Global:
63 return static_cast<int>(xevm::AddrSpace::GLOBAL);
64 case xegpu::MemorySpace::SLM:
65 return static_cast<int>(xevm::AddrSpace::SHARED);
67 llvm_unreachable(
"Unknown XeGPU memory space");
71static bool isSharedMemRef(
const MemRefType &memrefTy) {
72 Attribute attr = memrefTy.getMemorySpace();
75 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
76 return intAttr.getInt() ==
static_cast<int>(xevm::AddrSpace::SHARED);
77 if (
auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
78 return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
79 return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
83static VectorType encodeVectorTypeTo(VectorType currentVecType,
85 auto elemType = currentVecType.getElementType();
86 auto currentBitWidth = elemType.getIntOrFloatBitWidth();
89 currentVecType.getNumElements() * currentBitWidth / newBitWidth;
90 return VectorType::get(size, toElemType);
93static xevm::LoadCacheControl
94translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
95 std::optional<xegpu::CachePolicy> L3hint) {
97 if (!L1hint && !L3hint)
98 return xevm::LoadCacheControl::USE_DEFAULT;
100 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::CACHED);
101 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::CACHED);
103 case xegpu::CachePolicy::CACHED:
104 if (L3hintVal == xegpu::CachePolicy::CACHED)
105 return xevm::LoadCacheControl::L1C_L2UC_L3C;
106 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
107 return xevm::LoadCacheControl::L1C_L2UC_L3UC;
109 llvm_unreachable(
"Unsupported cache control.");
110 case xegpu::CachePolicy::UNCACHED:
111 if (L3hintVal == xegpu::CachePolicy::CACHED)
112 return xevm::LoadCacheControl::L1UC_L2UC_L3C;
113 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
114 return xevm::LoadCacheControl::L1UC_L2UC_L3UC;
116 llvm_unreachable(
"Unsupported cache control.");
117 case xegpu::CachePolicy::STREAMING:
118 if (L3hintVal == xegpu::CachePolicy::CACHED)
119 return xevm::LoadCacheControl::L1S_L2UC_L3C;
120 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
121 return xevm::LoadCacheControl::L1S_L2UC_L3UC;
123 llvm_unreachable(
"Unsupported cache control.");
124 case xegpu::CachePolicy::READ_INVALIDATE:
125 return xevm::LoadCacheControl::INVALIDATE_READ;
127 llvm_unreachable(
"Unsupported cache control.");
131static xevm::StoreCacheControl
132translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
133 std::optional<xegpu::CachePolicy> L3hint) {
135 if (!L1hint && !L3hint)
136 return xevm::StoreCacheControl::USE_DEFAULT;
138 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
139 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::WRITE_BACK);
141 case xegpu::CachePolicy::UNCACHED:
142 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
143 return xevm::StoreCacheControl::L1UC_L2UC_L3UC;
144 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
145 return xevm::StoreCacheControl::L1UC_L2UC_L3WB;
147 llvm_unreachable(
"Unsupported cache control.");
148 case xegpu::CachePolicy::STREAMING:
149 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
150 return xevm::StoreCacheControl::L1S_L2UC_L3UC;
151 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
152 return xevm::StoreCacheControl::L1S_L2UC_L3WB;
154 llvm_unreachable(
"Unsupported cache control.");
155 case xegpu::CachePolicy::WRITE_BACK:
156 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
157 return xevm::StoreCacheControl::L1WB_L2UC_L3UC;
158 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
159 return xevm::StoreCacheControl::L1WB_L2UC_L3WB;
161 llvm_unreachable(
"Unsupported cache control.");
162 case xegpu::CachePolicy::WRITE_THROUGH:
163 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
164 return xevm::StoreCacheControl::L1WT_L2UC_L3UC;
165 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
166 return xevm::StoreCacheControl::L1WT_L2UC_L3WB;
168 llvm_unreachable(
"Unsupported cache control.");
170 llvm_unreachable(
"Unsupported cache control.");
182class CreateNdDescToXeVMPattern
183 :
public OpConversionPattern<xegpu::CreateNdDescOp> {
184 using OpConversionPattern::OpConversionPattern;
186 matchAndRewrite(xegpu::CreateNdDescOp op,
187 xegpu::CreateNdDescOp::Adaptor adaptor,
188 ConversionPatternRewriter &rewriter)
const override {
189 auto loc = op.getLoc();
190 auto source = op.getSource();
194 Type payloadElemTy = rewriter.getI32Type();
195 VectorType payloadTy = VectorType::get(8, payloadElemTy);
196 Type i64Ty = rewriter.getI64Type();
198 VectorType payloadI64Ty = VectorType::get(4, i64Ty);
200 Value payload = arith::ConstantOp::create(
209 SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
210 SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides();
212 int64_t rank = mixedSizes.size();
213 auto sourceTy = source.getType();
214 auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
217 if (sourceMemrefTy) {
218 if (!sourceMemrefTy.hasRank()) {
219 return rewriter.notifyMatchFailure(op,
"Expected ranked Memref.");
223 baseAddr = adaptor.getSource();
225 baseAddr = adaptor.getSource();
226 if (baseAddr.
getType() != i64Ty) {
228 baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
233 rewriter.replaceOp(op, baseAddr);
237 auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
238 unsigned idx) -> Value {
244 baseShapeW = createOffset(mixedSizes, 1);
245 baseShapeH = createOffset(mixedSizes, 0);
247 Value basePitch = createOffset(mixedStrides, 0);
250 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
252 vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
253 static_cast<int>(NdTdescOffset::BasePtr));
254 payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
256 vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
257 static_cast<int>(NdTdescOffset::BaseShapeW));
259 vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
260 static_cast<int>(NdTdescOffset::BaseShapeH));
262 vector::InsertOp::create(rewriter, loc, basePitch, payload,
263 static_cast<int>(NdTdescOffset::BasePitch));
264 rewriter.replaceOp(op, payload);
271 typename = std::enable_if_t<llvm::is_one_of<
272 OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp>::value>>
273class LoadStorePrefetchNdToXeVMPattern :
public OpConversionPattern<OpType> {
274 using OpConversionPattern<OpType>::OpConversionPattern;
276 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
277 ConversionPatternRewriter &rewriter)
const override {
278 auto mixedOffsets = op.getMixedOffsets();
279 int64_t opOffsetsSize = mixedOffsets.size();
280 auto loc = op.getLoc();
281 auto ctxt = rewriter.getContext();
283 auto tdesc = adaptor.getTensorDesc();
284 auto tdescTy = op.getTensorDescType();
285 auto tileRank = tdescTy.getRank();
286 if (opOffsetsSize != tileRank)
287 return rewriter.notifyMatchFailure(
288 op,
"Expected offset rank to match descriptor rank.");
289 auto elemType = tdescTy.getElementType();
290 auto elemBitSize = elemType.getIntOrFloatBitWidth();
291 bool isSubByte = elemBitSize < 8;
292 uint64_t wScaleFactor = 1;
294 if (!isSubByte && (elemBitSize % 8 != 0))
295 return rewriter.notifyMatchFailure(
296 op,
"Expected element type bit width to be multiple of 8.");
297 auto tileW = tdescTy.getDimSize(tileRank - 1);
300 if (elemBitSize != 4)
301 return rewriter.notifyMatchFailure(
302 op,
"Only sub byte types of 4bits are supported.");
304 return rewriter.notifyMatchFailure(
305 op,
"Sub byte types are only supported for 2D tensor descriptors.");
306 auto subByteFactor = 8 / elemBitSize;
307 auto tileH = tdescTy.getDimSize(0);
309 if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
310 if (op.getPacked().value_or(
false)) {
312 if (tileH == systolicDepth * 4 &&
313 tileW == executionSize * subByteFactor) {
318 elemType = rewriter.getIntegerType(8);
319 tileW = executionSize;
320 wScaleFactor = subByteFactor;
325 if (wScaleFactor == 1) {
326 auto sub16BitFactor = subByteFactor * 2;
327 if (tileW == executionSize * sub16BitFactor) {
331 elemType = rewriter.getIntegerType(16);
332 tileW = executionSize;
333 wScaleFactor = sub16BitFactor;
335 return rewriter.notifyMatchFailure(
336 op,
"Unsupported tile shape for sub byte types.");
340 elemBitSize = elemType.getIntOrFloatBitWidth();
344 auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
345 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
349 rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
350 VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
352 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
354 vector::ExtractOp::create(rewriter, loc, payLoadAsI64,
355 static_cast<int>(NdTdescOffset::BasePtr));
356 Value baseShapeW = vector::ExtractOp::create(
357 rewriter, loc, tdesc,
static_cast<int>(NdTdescOffset::BaseShapeW));
358 Value baseShapeH = vector::ExtractOp::create(
359 rewriter, loc, tdesc,
static_cast<int>(NdTdescOffset::BaseShapeH));
360 Value basePitch = vector::ExtractOp::create(
361 rewriter, loc, tdesc,
static_cast<int>(NdTdescOffset::BasePitch));
367 rewriter.getI32Type(), offsetW);
371 rewriter.getI32Type(), offsetH);
374 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
378 Value baseShapeWInBytes =
379 arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
381 Value basePitchBytes =
382 arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize);
384 if (wScaleFactor > 1) {
388 rewriter, loc, rewriter.getI32Type(), llvm::Log2_64(wScaleFactor));
389 baseShapeWInBytes = arith::ShRSIOp::create(
390 rewriter, loc, baseShapeWInBytes, wScaleFactorValLog2);
391 basePitchBytes = arith::ShRSIOp::create(rewriter, loc, basePitchBytes,
392 wScaleFactorValLog2);
394 arith::ShRSIOp::create(rewriter, loc, offsetW, wScaleFactorValLog2);
397 auto tileH = tdescTy.getDimSize(0);
399 int32_t vblocks = tdescTy.getArrayLength();
400 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
401 Value src = adaptor.getValue();
407 VectorType srcVecTy = dyn_cast<VectorType>(src.
getType());
409 return rewriter.notifyMatchFailure(
410 op,
"Expected store value to be a vector type.");
412 VectorType newSrcVecTy =
413 encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
414 if (srcVecTy != newSrcVecTy)
415 src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
416 auto storeCacheControl =
417 translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
418 xevm::BlockStore2dOp::create(
419 rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
420 basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, src,
421 xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
422 rewriter.eraseOp(op);
424 auto loadCacheControl =
425 translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
426 if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
427 xevm::BlockPrefetch2dOp::create(
428 rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
429 basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH,
430 vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
431 rewriter.eraseOp(op);
433 VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
434 const bool vnni = op.getPacked().value_or(
false);
435 auto transposeValue = op.getTranspose();
437 transposeValue.has_value() && transposeValue.value()[0] == 1;
438 VectorType loadedTy = encodeVectorTypeTo(
439 dstVecTy, vnni ? rewriter.getI32Type()
440 : rewriter.getIntegerType(elemBitSize));
442 Value resultFlatVec = xevm::BlockLoad2dOp::create(
443 rewriter, loc, loadedTy, basePtrLLVM, baseShapeWInBytes,
444 baseShapeH, basePitchBytes, offsetW, offsetH, elemBitSize, tileW,
445 tileH, vblocks, transpose, vnni,
446 xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
447 resultFlatVec = vector::BitCastOp::create(
449 encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
451 rewriter.replaceOp(op, resultFlatVec);
463 rewriter.getI64Type(), offset);
466 rewriter, loc, rewriter.getI64Type(), elemBitSize / 8);
468 rewriter.createOrFold<arith::MulIOp>(loc, offset, elemByteSize);
470 Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>(
476 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64);
477 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
478 Value src = adaptor.getValue();
484 VectorType srcVecTy = dyn_cast<VectorType>(src.
getType());
486 return rewriter.notifyMatchFailure(
487 op,
"Expected store value to be a vector type.");
489 VectorType newSrcVecTy =
490 encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
491 if (srcVecTy != newSrcVecTy)
492 src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
493 auto storeCacheControl =
494 translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
495 rewriter.replaceOpWithNewOp<xevm::BlockStoreOp>(
496 op, finalPtrLLVM, src,
497 xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
498 }
else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
499 auto loadCacheControl =
500 translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
501 VectorType resTy = cast<VectorType>(op.getValue().getType());
502 VectorType loadedTy =
503 encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize));
504 Value
load = xevm::BlockLoadOp::create(
505 rewriter, loc, loadedTy, finalPtrLLVM,
506 xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
507 if (loadedTy != resTy)
508 load = vector::BitCastOp::create(rewriter, loc, resTy,
load);
509 rewriter.replaceOp(op,
load);
511 return rewriter.notifyMatchFailure(
512 op,
"Unsupported operation: xegpu.prefetch_nd with tensor "
513 "descriptor rank == 1");
522static Value addOffsetToBaseAddr(ConversionPatternRewriter &rewriter,
526 rewriter, loc, baseAddr.
getType(), elemByteSize);
527 Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
528 Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
532template <
typename OpType,
533 typename = std::enable_if_t<llvm::is_one_of<
534 OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
535class LoadStoreToXeVMPattern :
public OpConversionPattern<OpType> {
536 using OpConversionPattern<OpType>::OpConversionPattern;
538 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
539 ConversionPatternRewriter &rewriter)
const override {
540 Value offset = adaptor.getOffsets();
542 return rewriter.notifyMatchFailure(op,
"Expected offset to be provided.");
543 auto loc = op.getLoc();
544 auto ctxt = rewriter.getContext();
548 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
550 this->getTypeConverter()->convertType(op.getResult().getType());
552 valOrResTy = adaptor.getValue().getType();
553 VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
554 bool hasScalarVal = !valOrResVecTy;
555 int64_t elemBitWidth =
557 : valOrResVecTy.getElementType().getIntOrFloatBitWidth();
559 if (elemBitWidth % 8 != 0)
560 return rewriter.notifyMatchFailure(
561 op,
"Expected element type bit width to be multiple of 8.");
562 int64_t elemByteSize = elemBitWidth / 8;
564 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
565 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
568 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
569 basePtrI64 = adaptor.getSource();
570 if (
auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
571 auto addrSpace = memRefTy.getMemorySpaceAsInt();
573 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
576 basePtrI64 = adaptor.getDest();
577 if (
auto memRefTy = dyn_cast<MemRefType>(op.getDest().getType())) {
578 auto addrSpace = memRefTy.getMemorySpaceAsInt();
580 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
584 if (basePtrI64.
getType() != rewriter.getI64Type()) {
585 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
588 Value mask = adaptor.getMask();
589 if (dyn_cast<VectorType>(offset.
getType())) {
592 return rewriter.notifyMatchFailure(op,
"Expected offset to be a scalar.");
598 addOffsetToBaseAddr(rewriter, loc, basePtrI64, offset, elemByteSize);
602 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
605 VectorType maskVecTy = dyn_cast<VectorType>(mask.
getType());
609 return rewriter.notifyMatchFailure(op,
"Expected mask to be a scalar.");
612 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
613 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy},
614 maskForLane,
true,
true);
616 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
618 valOrResTy = VectorType::get({valOrResVecTy.getNumElements()},
619 valOrResVecTy.getElementType());
621 LLVM::LoadOp::create(rewriter, loc, valOrResTy, basePtrLLVM);
624 "cache_control", xevm::LoadCacheControlAttr::get(
625 ctxt, translateLoadXeGPUCacheHint(
626 op.getL1Hint(), op.getL3Hint())));
627 scf::YieldOp::create(rewriter, loc,
ValueRange{loaded});
628 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
630 auto eTy = hasScalarVal ? valOrResTy : valOrResVecTy.getElementType();
633 eVal = FloatAttr::get(eTy, 0.0);
635 eVal = IntegerAttr::get(eTy, 0);
637 loaded = arith::ConstantOp::create(rewriter, loc, eVal);
639 loaded = arith::ConstantOp::create(
641 scf::YieldOp::create(rewriter, loc,
ValueRange{loaded});
642 rewriter.replaceOp(op, ifOp.getResult(0));
645 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane,
false);
646 auto body = ifOp.getBody();
647 rewriter.setInsertionPointToStart(body);
649 LLVM::StoreOp::create(rewriter, loc, adaptor.getValue(), basePtrLLVM);
651 storeOp.getOperation()->setAttr(
652 "cache_control", xevm::StoreCacheControlAttr::get(
653 ctxt, translateStoreXeGPUCacheHint(
654 op.getL1Hint(), op.getL3Hint())));
655 rewriter.eraseOp(op);
661class CreateMemDescOpPattern final
662 :
public OpConversionPattern<xegpu::CreateMemDescOp> {
664 using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern;
666 matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
667 ConversionPatternRewriter &rewriter)
const override {
669 rewriter.replaceOp(op, adaptor.getSource());
674template <
typename OpType,
675 typename = std::enable_if_t<llvm::is_one_of<
676 OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
677class LoadStoreMatrixToXeVMPattern :
public OpConversionPattern<OpType> {
678 using OpConversionPattern<OpType>::OpConversionPattern;
680 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
681 ConversionPatternRewriter &rewriter)
const override {
683 SmallVector<OpFoldResult> offsets = op.getMixedOffsets();
685 return rewriter.notifyMatchFailure(op,
"Expected offset to be provided.");
687 auto loc = op.getLoc();
688 auto ctxt = rewriter.getContext();
689 Value baseAddr32 = adaptor.getMemDesc();
690 Value mdescVal = op.getMemDesc();
693 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
694 Type resType = op.getResult().getType();
697 if (
auto vecType = dyn_cast<VectorType>(resType)) {
698 assert(llvm::count_if(vecType.getShape(),
699 [](int64_t d) { return d != 1; }) <= 1 &&
700 "Expected either 1D vector or nD with unit dimensions");
701 resType = VectorType::get({vecType.getNumElements()},
702 vecType.getElementType());
706 dataTy = adaptor.getData().getType();
707 VectorType valOrResVecTy = dyn_cast<VectorType>(dataTy);
709 valOrResVecTy = VectorType::get(1, dataTy);
711 int64_t elemBitWidth =
712 valOrResVecTy.getElementType().getIntOrFloatBitWidth();
714 if (elemBitWidth % 8 != 0)
715 return rewriter.notifyMatchFailure(
716 op,
"Expected element type bit width to be multiple of 8.");
717 int64_t elemByteSize = elemBitWidth / 8;
720 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
721 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM));
723 auto mdescTy = cast<xegpu::MemDescType>(mdescVal.
getType());
725 Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
726 linearOffset = arith::IndexCastUIOp::create(
727 rewriter, loc, rewriter.getI32Type(), linearOffset);
728 Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32,
729 linearOffset, elemByteSize);
733 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);
735 if (op.getSubgroupBlockIoAttr()) {
739 Type intElemTy = rewriter.getIntegerType(elemBitWidth);
740 VectorType intVecTy =
741 VectorType::get(valOrResVecTy.getShape(), intElemTy);
743 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
745 xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM);
746 if (intVecTy != valOrResVecTy) {
748 vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp);
750 rewriter.replaceOp(op, loadOp);
752 Value dataToStore = adaptor.getData();
753 if (valOrResVecTy != intVecTy) {
755 vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore);
757 xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore,
759 rewriter.eraseOp(op);
764 if (valOrResVecTy.getNumElements() >= 1) {
766 if (!chipOpt || (*chipOpt !=
"pvc" && *chipOpt !=
"bmg")) {
768 return rewriter.notifyMatchFailure(
769 op,
"The lowering is specific to pvc or bmg.");
773 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
777 auto scalarTy = valOrResVecTy.getElementType();
779 if (valOrResVecTy.getNumElements() == 1)
780 loadOp = LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
783 LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
784 rewriter.replaceOp(op, loadOp);
786 LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
787 rewriter.eraseOp(op);
793class PrefetchToXeVMPattern :
public OpConversionPattern<xegpu::PrefetchOp> {
794 using OpConversionPattern::OpConversionPattern;
796 matchAndRewrite(xegpu::PrefetchOp op, xegpu::PrefetchOp::Adaptor adaptor,
797 ConversionPatternRewriter &rewriter)
const override {
798 auto loc = op.getLoc();
799 auto ctxt = rewriter.getContext();
800 Value basePtrI64 = adaptor.getSource();
802 if (basePtrI64.
getType() != rewriter.getI64Type())
803 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
805 Value offsets = adaptor.getOffsets();
807 VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.
getType());
810 return rewriter.notifyMatchFailure(op,
811 "Expected offsets to be a scalar.");
813 int64_t elemBitWidth{0};
814 int64_t elemByteSize;
816 if (
auto memRefTy = dyn_cast<MemRefType>(op.getSourceType())) {
819 elemBitWidth = memRefTy.getElementType().getIntOrFloatBitWidth();
822 elemByteSize = *op.getOffsetAlignByte();
824 if (elemBitWidth != 0) {
825 if (elemBitWidth % 8 != 0)
826 return rewriter.notifyMatchFailure(
827 op,
"Expected element type bit width to be multiple of 8.");
828 elemByteSize = elemBitWidth / 8;
830 basePtrI64 = addOffsetToBaseAddr(rewriter, loc, basePtrI64, offsets,
835 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
836 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
838 if (
auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
839 auto addrSpace = memRefTy.getMemorySpaceAsInt();
841 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
845 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
847 xevm::PrefetchOp::create(
848 rewriter, loc, ptrLLVM,
849 xevm::LoadCacheControlAttr::get(
850 ctxt, translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint())));
851 rewriter.eraseOp(op);
856class FenceToXeVMPattern :
public OpConversionPattern<xegpu::FenceOp> {
857 using OpConversionPattern::OpConversionPattern;
859 matchAndRewrite(xegpu::FenceOp op, xegpu::FenceOp::Adaptor adaptor,
860 ConversionPatternRewriter &rewriter)
const override {
861 auto loc = op.getLoc();
862 xevm::MemScope memScope{xevm::MemScope::WORKGROUP};
863 switch (op.getFenceScope()) {
864 case xegpu::FenceScope::Workgroup:
865 memScope = xevm::MemScope::WORKGROUP;
867 case xegpu::FenceScope::GPU:
868 memScope = xevm::MemScope::DEVICE;
871 xevm::AddrSpace addrSpace{xevm::AddrSpace::GLOBAL};
872 switch (op.getMemoryKind()) {
873 case xegpu::MemorySpace::Global:
874 addrSpace = xevm::AddrSpace::GLOBAL;
876 case xegpu::MemorySpace::SLM:
877 addrSpace = xevm::AddrSpace::SHARED;
880 xevm::MemfenceOp::create(rewriter, loc, memScope, addrSpace);
881 rewriter.eraseOp(op);
886class DpasToXeVMPattern :
public OpConversionPattern<xegpu::DpasOp> {
887 using OpConversionPattern::OpConversionPattern;
889 matchAndRewrite(xegpu::DpasOp op, xegpu::DpasOp::Adaptor adaptor,
890 ConversionPatternRewriter &rewriter)
const override {
891 auto loc = op.getLoc();
892 auto ctxt = rewriter.getContext();
893 auto aTy = cast<VectorType>(op.getLhs().getType());
894 auto bTy = cast<VectorType>(op.getRhs().getType());
895 auto resultType = cast<VectorType>(op.getResultType());
900 return rewriter.notifyMatchFailure(op,
"cannot determine target chip");
904 return rewriter.notifyMatchFailure(op,
"unsupported target uArch");
907 llvm::dyn_cast_or_null<xegpu::uArch::SubgroupMatrixMultiplyAcc>(
908 uArch->getInstruction(
909 xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)));
911 return rewriter.notifyMatchFailure(op,
912 "DPAS not supported by target uArch");
914 auto checkSupportedTypes = [&](VectorType vecTy,
916 auto supported = dpasInst->getSupportedTypes(*ctxt, kind);
917 return llvm::find(supported, vecTy.getElementType()) != supported.end();
920 if (!checkSupportedTypes(aTy, xegpu::uArch::MMAOpndKind::MatrixA))
921 return rewriter.notifyMatchFailure(
922 op,
"A-matrix element type not supported by target uArch");
923 if (!checkSupportedTypes(bTy, xegpu::uArch::MMAOpndKind::MatrixB))
924 return rewriter.notifyMatchFailure(
925 op,
"B-matrix element type not supported by target uArch");
927 if (!checkSupportedTypes(resultType, xegpu::uArch::MMAOpndKind::MatrixD))
928 return rewriter.notifyMatchFailure(
929 op,
"result/accumulator element type not supported by target uArch");
931 auto encodePrecision = [&](Type type) -> xevm::ElemType {
932 if (type == rewriter.getBF16Type())
933 return xevm::ElemType::BF16;
934 else if (type == rewriter.getF16Type())
935 return xevm::ElemType::F16;
936 else if (type == rewriter.getTF32Type())
937 return xevm::ElemType::TF32;
938 else if (type.isInteger(8)) {
939 if (type.isUnsignedInteger())
940 return xevm::ElemType::U8;
941 return xevm::ElemType::S8;
942 }
else if (type == rewriter.getF32Type())
943 return xevm::ElemType::F32;
944 else if (type.isInteger(32))
945 return xevm::ElemType::S32;
946 llvm_unreachable(
"add more support for ElemType");
948 xevm::ElemType precATy = encodePrecision(aTy.getElementType());
949 xevm::ElemType precBTy = encodePrecision(bTy.getElementType());
950 Value c = op.getAcc();
952 auto elementTy = resultType.getElementType();
953 Attribute initValueAttr;
954 if (isa<FloatType>(elementTy))
955 initValueAttr = FloatAttr::get(elementTy, 0.0);
957 initValueAttr = IntegerAttr::get(elementTy, 0);
958 c = arith::ConstantOp::create(
962 Value aVec = op.getLhs();
963 Value bVec = op.getRhs();
964 auto cvecty = cast<VectorType>(c.
getType());
965 xevm::ElemType precCTy = encodePrecision(cvecty.getElementType());
966 xevm::ElemType precDTy = encodePrecision(resultType.getElementType());
968 VectorType::get(cvecty.getNumElements(), cvecty.getElementType());
970 c = vector::ShapeCastOp::create(rewriter, loc, cNty, c);
971 Value dpasRes = xevm::MMAOp::create(
972 rewriter, loc, cNty, aVec, bVec, c,
973 xevm::MMAShapeAttr::get(ctxt, cvecty.getNumElements(), executionSize,
975 getNumOperandsPerDword(precATy)),
976 xevm::MMATypesAttr::get(ctxt, precDTy, precATy, precBTy, precCTy));
978 dpasRes = vector::ShapeCastOp::create(rewriter, loc, resultType, dpasRes);
979 rewriter.replaceOp(op, dpasRes);
984 static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
986 case xevm::ElemType::TF32:
988 case xevm::ElemType::BF16:
989 case xevm::ElemType::F16:
991 case xevm::ElemType::U8:
992 case xevm::ElemType::S8:
995 llvm_unreachable(
"unsupported xevm::ElemType");
1000static std::optional<LLVM::AtomicBinOp>
1001matchSimpleAtomicOp(arith::AtomicRMWKind arithKind) {
1002 switch (arithKind) {
1003 case arith::AtomicRMWKind::addf:
1004 return LLVM::AtomicBinOp::fadd;
1005 case arith::AtomicRMWKind::addi:
1006 return LLVM::AtomicBinOp::add;
1007 case arith::AtomicRMWKind::assign:
1008 return LLVM::AtomicBinOp::xchg;
1009 case arith::AtomicRMWKind::maximumf:
1010 return LLVM::AtomicBinOp::fmax;
1011 case arith::AtomicRMWKind::maxs:
1012 return LLVM::AtomicBinOp::max;
1013 case arith::AtomicRMWKind::maxu:
1014 return LLVM::AtomicBinOp::umax;
1015 case arith::AtomicRMWKind::minimumf:
1016 return LLVM::AtomicBinOp::fmin;
1017 case arith::AtomicRMWKind::mins:
1018 return LLVM::AtomicBinOp::min;
1019 case arith::AtomicRMWKind::minu:
1020 return LLVM::AtomicBinOp::umin;
1021 case arith::AtomicRMWKind::ori:
1022 return LLVM::AtomicBinOp::_or;
1023 case arith::AtomicRMWKind::andi:
1024 return LLVM::AtomicBinOp::_and;
1026 return std::nullopt;
1030class AtomicRMWToXeVMPattern :
public OpConversionPattern<xegpu::AtomicRMWOp> {
1031 using OpConversionPattern::OpConversionPattern;
1033 matchAndRewrite(xegpu::AtomicRMWOp op, xegpu::AtomicRMWOp::Adaptor adaptor,
1034 ConversionPatternRewriter &rewriter)
const override {
1035 auto loc = op.getLoc();
1036 auto ctxt = rewriter.getContext();
1037 auto tdesc = op.getTensorDesc().getType();
1038 auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
1039 ctxt, getNumericXeVMAddrSpace(tdesc.getMemorySpace()));
1040 Value basePtrI64 = arith::IndexCastOp::create(
1041 rewriter, loc, rewriter.getI64Type(), adaptor.getTensorDesc());
1043 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
1044 VectorType srcOrDstVecTy = cast<VectorType>(op.getValue().getType());
1045 VectorType srcOrDstFlatVecTy = VectorType::get(
1046 srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
1047 Value srcFlatVec = vector::ShapeCastOp::create(
1048 rewriter, loc, srcOrDstFlatVecTy, op.getValue());
1049 auto atomicKind = matchSimpleAtomicOp(op.getKind());
1050 assert(atomicKind.has_value());
1051 Value resVec = srcFlatVec;
1052 for (
int i = 0; i < srcOrDstVecTy.getNumElements(); i++) {
1053 auto val = vector::ExtractOp::create(rewriter, loc, resVec, i);
1054 Value idx = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
1055 rewriter.getIndexAttr(i));
1057 LLVM::GEPOp::create(rewriter, loc, ptrTypeLLVM,
1058 srcOrDstVecTy.getElementType(), basePtrLLVM, idx);
1060 LLVM::AtomicRMWOp::create(rewriter, loc, atomicKind.value(), currPtr,
1061 val, LLVM::AtomicOrdering::seq_cst);
1062 resVec = vector::InsertOp::create(rewriter, loc, newVal, resVec, i);
1064 rewriter.replaceOp(op, resVec);
1073struct ConvertXeGPUToXeVMPass
1077 void runOnOperation()
override {
1087 LowerToLLVMOptions
options(context);
1088 options.overrideIndexBitwidth(this->use64bitIndex ? 64 : 32);
1089 LLVMTypeConverter typeConverter(context,
options);
1091 Type xevmIndexType = typeConverter.convertType(IndexType::get(context));
1092 Type i32Type = IntegerType::get(context, 32);
1093 typeConverter.addConversion([&](VectorType type) -> Type {
1094 auto elemType = typeConverter.convertType(type.getElementType());
1096 unsigned rank = type.getRank();
1097 if (rank == 0 || type.getNumElements() == 1)
1100 int64_t sum = llvm::product_of(type.getShape());
1101 return VectorType::get(sum, elemType);
1103 typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
1104 if (type.getRank() == 1)
1105 return xevmIndexType;
1106 return VectorType::get(8, i32Type);
1115 typeConverter.addConversion(
1116 [&](xegpu::MemDescType type) -> Type {
return i32Type; });
1118 typeConverter.addConversion([&](MemRefType type) -> Type {
1119 return isSharedMemRef(type) ? i32Type : xevmIndexType;
1129 auto memrefToIntMaterializationCast = [](OpBuilder &builder, Type type,
1131 Location loc) -> Value {
1132 if (inputs.size() != 1)
1134 auto input = inputs.front();
1135 if (
auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
1136 unsigned rank = memrefTy.getRank();
1140 SmallVector<int64_t> intStrides;
1143 if (succeeded(memrefTy.getStridesAndOffset(intStrides, intOffsets)) &&
1144 ShapedType::isStatic(intOffsets)) {
1145 addr = memref::ExtractAlignedPointerAsIndexOp::create(builder, loc,
1147 offset = arith::ConstantOp::create(builder, loc,
1153 SmallVector<Type> resultTypes{
1154 MemRefType::get({}, memrefTy.getElementType(),
1155 MemRefLayoutAttrInterface(),
1156 memrefTy.getMemorySpace()),
1159 resultTypes.append(2 * rank, indexType);
1161 auto meta = memref::ExtractStridedMetadataOp::create(
1162 builder, loc, resultTypes, input);
1164 addr = memref::ExtractAlignedPointerAsIndexOp::create(
1165 builder, loc, meta.getBaseBuffer());
1166 offset = meta.getOffset();
1170 arith::IndexCastUIOp::create(builder, loc, type, addr);
1172 arith::IndexCastUIOp::create(builder, loc, type, offset);
1175 auto byteSize = arith::ConstantOp::create(
1178 memrefTy.getElementTypeBitWidth() / 8));
1180 arith::MulIOp::create(builder, loc, offsetCasted, byteSize);
1181 auto addrWithOffset =
1182 arith::AddIOp::create(builder, loc, addrCasted, byteOffset);
1184 return addrWithOffset.getResult();
1193 auto ui64ToI64MaterializationCast = [](OpBuilder &builder, Type type,
1195 Location loc) -> Value {
1196 if (inputs.size() != 1)
1198 auto input = inputs.front();
1201 index::CastUOp::create(builder, loc, builder.
getIndexType(), input)
1203 return arith::IndexCastUIOp::create(builder, loc, type, cast)
1213 auto ui32ToI32MaterializationCast = [](OpBuilder &builder, Type type,
1215 Location loc) -> Value {
1216 if (inputs.size() != 1)
1218 auto input = inputs.front();
1221 index::CastUOp::create(builder, loc, builder.
getIndexType(), input)
1223 return arith::IndexCastUIOp::create(builder, loc, type, cast)
1233 auto vectorToVectorMaterializationCast = [](OpBuilder &builder, Type type,
1235 Location loc) -> Value {
1236 if (inputs.size() != 1)
1238 auto input = inputs.front();
1239 if (
auto vecTy = dyn_cast<VectorType>(input.getType())) {
1240 if (
auto targetVecTy = dyn_cast<VectorType>(type)) {
1244 if (targetVecTy.getShape() != vecTy.getShape()) {
1245 cast = vector::ShapeCastOp::create(
1247 VectorType::get(targetVecTy.getShape(),
1248 vecTy.getElementType()),
1252 if (targetVecTy.getElementType() != vecTy.getElementType()) {
1253 cast = vector::BitCastOp::create(builder, loc, targetVecTy, cast)
1265 auto vectorToSingleElementMaterializationCast =
1266 [](OpBuilder &builder, Type type,
ValueRange inputs,
1267 Location loc) -> Value {
1268 if (inputs.size() != 1)
1270 auto input = inputs.front();
1271 if (
auto vecTy = dyn_cast<VectorType>(input.getType())) {
1273 auto rank = vecTy.getRank();
1274 if (rank != 0 && vecTy.getNumElements() != 1)
1276 auto inElemTy = vecTy.getElementType();
1280 cast = vector::ExtractOp::create(builder, loc, cast, {}).getResult();
1282 cast = vector::ExtractOp::create(builder, loc, cast,
1283 SmallVector<int64_t>(rank, 0))
1290 if (inElemTy.isIndex()) {
1291 cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
1293 }
else if (inElemTy != type) {
1294 cast = arith::BitcastOp::create(builder, loc, type, cast).getResult();
1308 auto singleElementToVectorMaterializationCast =
1309 [](OpBuilder &builder, Type type,
ValueRange inputs,
1310 Location loc) -> Value {
1311 if (inputs.size() != 1)
1313 auto input = inputs.front();
1314 auto inTy = input.getType();
1315 if (!inTy.isIntOrFloat())
1319 if (
auto vecTy = dyn_cast<VectorType>(type)) {
1320 if (vecTy.getRank() != 0 && vecTy.getNumElements() != 1)
1322 auto outElemTy = vecTy.getElementType();
1324 if (outElemTy.isIndex()) {
1325 cast = arith::IndexCastUIOp::create(builder, loc,
1328 }
else if (inTy != outElemTy) {
1329 cast = arith::BitcastOp::create(builder, loc, outElemTy, cast)
1332 return vector::BroadcastOp::create(builder, loc, vecTy, cast)
1337 typeConverter.addSourceMaterialization(
1338 singleElementToVectorMaterializationCast);
1339 typeConverter.addSourceMaterialization(vectorToVectorMaterializationCast);
1340 typeConverter.addTargetMaterialization(memrefToIntMaterializationCast);
1341 typeConverter.addTargetMaterialization(ui32ToI32MaterializationCast);
1342 typeConverter.addTargetMaterialization(ui64ToI64MaterializationCast);
1343 typeConverter.addTargetMaterialization(
1344 vectorToSingleElementMaterializationCast);
1345 typeConverter.addTargetMaterialization(vectorToVectorMaterializationCast);
1346 ConversionTarget
target(*context);
1347 target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
1348 vector::VectorDialect, arith::ArithDialect,
1349 memref::MemRefDialect, gpu::GPUDialect,
1350 index::IndexDialect>();
1351 target.addIllegalDialect<xegpu::XeGPUDialect>();
1353 RewritePatternSet patterns(context);
1357 if (
failed(applyPartialConversion(getOperation(),
target,
1358 std::move(patterns))))
1359 signalPassFailure();
1369 patterns.
add<CreateNdDescToXeVMPattern,
1370 LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
1371 LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
1372 LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>(
1374 patterns.
add<AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
1375 LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
1376 LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
1378 patterns.
add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>,
1379 LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>,
1380 CreateMemDescOpPattern>(typeConverter, patterns.
getContext());
1381 patterns.
add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
@ SubgroupMatrixMultiplyAcc
const uArch * getUArch(llvm::StringRef archName)
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
void populateXeGPUToXeVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)