29#include "llvm/ADT/STLExtras.h"
30#include "llvm/Support/FormatVariadic.h"
35#include "llvm/ADT/TypeSwitch.h"
40#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
41#include "mlir/Conversion/Passes.h.inc"
49static constexpr int32_t systolicDepth{8};
50static constexpr int32_t executionSize{16};
53enum class NdTdescOffset : uint32_t {
60static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
61 switch (xeGpuMemspace) {
62 case xegpu::MemorySpace::Global:
63 return static_cast<int>(xevm::AddrSpace::GLOBAL);
64 case xegpu::MemorySpace::SLM:
65 return static_cast<int>(xevm::AddrSpace::SHARED);
67 llvm_unreachable(
"Unknown XeGPU memory space");
71static bool isSharedMemRef(
const MemRefType &memrefTy) {
72 Attribute attr = memrefTy.getMemorySpace();
75 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
76 return intAttr.getInt() ==
static_cast<int>(xevm::AddrSpace::SHARED);
77 if (
auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
78 return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
79 return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
83static VectorType encodeVectorTypeTo(VectorType currentVecType,
85 auto elemType = currentVecType.getElementType();
86 auto currentBitWidth = elemType.getIntOrFloatBitWidth();
89 currentVecType.getNumElements() * currentBitWidth / newBitWidth;
90 return VectorType::get(size, toElemType);
93static xevm::LoadCacheControl
94translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
95 std::optional<xegpu::CachePolicy> L3hint) {
97 if (!L1hint && !L3hint)
98 return xevm::LoadCacheControl::USE_DEFAULT;
100 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::CACHED);
101 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::CACHED);
103 case xegpu::CachePolicy::CACHED:
104 if (L3hintVal == xegpu::CachePolicy::CACHED)
105 return xevm::LoadCacheControl::L1C_L2UC_L3C;
106 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
107 return xevm::LoadCacheControl::L1C_L2UC_L3UC;
109 llvm_unreachable(
"Unsupported cache control.");
110 case xegpu::CachePolicy::UNCACHED:
111 if (L3hintVal == xegpu::CachePolicy::CACHED)
112 return xevm::LoadCacheControl::L1UC_L2UC_L3C;
113 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
114 return xevm::LoadCacheControl::L1UC_L2UC_L3UC;
116 llvm_unreachable(
"Unsupported cache control.");
117 case xegpu::CachePolicy::STREAMING:
118 if (L3hintVal == xegpu::CachePolicy::CACHED)
119 return xevm::LoadCacheControl::L1S_L2UC_L3C;
120 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
121 return xevm::LoadCacheControl::L1S_L2UC_L3UC;
123 llvm_unreachable(
"Unsupported cache control.");
124 case xegpu::CachePolicy::READ_INVALIDATE:
125 return xevm::LoadCacheControl::INVALIDATE_READ;
127 llvm_unreachable(
"Unsupported cache control.");
131static xevm::StoreCacheControl
132translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
133 std::optional<xegpu::CachePolicy> L3hint) {
135 if (!L1hint && !L3hint)
136 return xevm::StoreCacheControl::USE_DEFAULT;
138 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
139 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::WRITE_BACK);
141 case xegpu::CachePolicy::UNCACHED:
142 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
143 return xevm::StoreCacheControl::L1UC_L2UC_L3UC;
144 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
145 return xevm::StoreCacheControl::L1UC_L2UC_L3WB;
147 llvm_unreachable(
"Unsupported cache control.");
148 case xegpu::CachePolicy::STREAMING:
149 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
150 return xevm::StoreCacheControl::L1S_L2UC_L3UC;
151 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
152 return xevm::StoreCacheControl::L1S_L2UC_L3WB;
154 llvm_unreachable(
"Unsupported cache control.");
155 case xegpu::CachePolicy::WRITE_BACK:
156 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
157 return xevm::StoreCacheControl::L1WB_L2UC_L3UC;
158 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
159 return xevm::StoreCacheControl::L1WB_L2UC_L3WB;
161 llvm_unreachable(
"Unsupported cache control.");
162 case xegpu::CachePolicy::WRITE_THROUGH:
163 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
164 return xevm::StoreCacheControl::L1WT_L2UC_L3UC;
165 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
166 return xevm::StoreCacheControl::L1WT_L2UC_L3WB;
168 llvm_unreachable(
"Unsupported cache control.");
170 llvm_unreachable(
"Unsupported cache control.");
182class CreateNdDescToXeVMPattern
183 :
public OpConversionPattern<xegpu::CreateNdDescOp> {
184 using OpConversionPattern::OpConversionPattern;
186 matchAndRewrite(xegpu::CreateNdDescOp op,
187 xegpu::CreateNdDescOp::Adaptor adaptor,
188 ConversionPatternRewriter &rewriter)
const override {
189 auto loc = op.getLoc();
190 auto source = op.getSource();
194 Type payloadElemTy = rewriter.getI32Type();
195 VectorType payloadTy = VectorType::get(8, payloadElemTy);
196 Type i64Ty = rewriter.getI64Type();
198 VectorType payloadI64Ty = VectorType::get(4, i64Ty);
200 Value payload = arith::ConstantOp::create(
209 SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
210 SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides();
212 int64_t rank = mixedSizes.size();
213 auto sourceTy = source.getType();
214 auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
217 if (sourceMemrefTy) {
218 if (!sourceMemrefTy.hasRank()) {
219 return rewriter.notifyMatchFailure(op,
"Expected ranked Memref.");
223 baseAddr = adaptor.getSource();
225 baseAddr = adaptor.getSource();
226 if (baseAddr.
getType() != i64Ty) {
228 baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
233 rewriter.replaceOp(op, baseAddr);
237 auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
238 unsigned idx) -> Value {
244 baseShapeW = createOffset(mixedSizes, 1);
245 baseShapeH = createOffset(mixedSizes, 0);
247 Value basePitch = createOffset(mixedStrides, 0);
250 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
252 vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
253 static_cast<int>(NdTdescOffset::BasePtr));
254 payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
256 vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
257 static_cast<int>(NdTdescOffset::BaseShapeW));
259 vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
260 static_cast<int>(NdTdescOffset::BaseShapeH));
262 vector::InsertOp::create(rewriter, loc, basePitch, payload,
263 static_cast<int>(NdTdescOffset::BasePitch));
264 rewriter.replaceOp(op, payload);
271 typename = std::enable_if_t<llvm::is_one_of<
272 OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp>::value>>
273class LoadStorePrefetchNdToXeVMPattern :
public OpConversionPattern<OpType> {
274 using OpConversionPattern<OpType>::OpConversionPattern;
276 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
277 ConversionPatternRewriter &rewriter)
const override {
278 auto mixedOffsets = op.getMixedOffsets();
279 int64_t opOffsetsSize = mixedOffsets.size();
280 auto loc = op.getLoc();
281 auto ctxt = rewriter.getContext();
283 auto tdesc = adaptor.getTensorDesc();
284 auto tdescTy = op.getTensorDescType();
285 auto tileRank = tdescTy.getRank();
286 if (opOffsetsSize != tileRank)
287 return rewriter.notifyMatchFailure(
288 op,
"Expected offset rank to match descriptor rank.");
289 auto elemType = tdescTy.getElementType();
290 auto elemBitSize = elemType.getIntOrFloatBitWidth();
291 bool isSubByte = elemBitSize < 8;
292 uint64_t wScaleFactor = 1;
294 if (!isSubByte && (elemBitSize % 8 != 0))
295 return rewriter.notifyMatchFailure(
296 op,
"Expected element type bit width to be multiple of 8.");
297 auto tileW = tdescTy.getDimSize(tileRank - 1);
300 if (elemBitSize != 4)
301 return rewriter.notifyMatchFailure(
302 op,
"Only sub byte types of 4bits are supported.");
304 return rewriter.notifyMatchFailure(
305 op,
"Sub byte types are only supported for 2D tensor descriptors.");
306 auto subByteFactor = 8 / elemBitSize;
307 auto tileH = tdescTy.getDimSize(0);
309 if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
310 if (op.getPacked().value_or(
false)) {
312 if (tileH == systolicDepth * 4 &&
313 tileW == executionSize * subByteFactor) {
318 elemType = rewriter.getIntegerType(8);
319 tileW = executionSize;
320 wScaleFactor = subByteFactor;
325 if (wScaleFactor == 1) {
326 auto sub16BitFactor = subByteFactor * 2;
327 if (tileW == executionSize * sub16BitFactor) {
331 elemType = rewriter.getIntegerType(16);
332 tileW = executionSize;
333 wScaleFactor = sub16BitFactor;
335 return rewriter.notifyMatchFailure(
336 op,
"Unsupported tile shape for sub byte types.");
340 elemBitSize = elemType.getIntOrFloatBitWidth();
344 auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
345 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
349 rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
350 VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
352 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
354 vector::ExtractOp::create(rewriter, loc, payLoadAsI64,
355 static_cast<int>(NdTdescOffset::BasePtr));
356 Value baseShapeW = vector::ExtractOp::create(
357 rewriter, loc, tdesc,
static_cast<int>(NdTdescOffset::BaseShapeW));
358 Value baseShapeH = vector::ExtractOp::create(
359 rewriter, loc, tdesc,
static_cast<int>(NdTdescOffset::BaseShapeH));
360 Value basePitch = vector::ExtractOp::create(
361 rewriter, loc, tdesc,
static_cast<int>(NdTdescOffset::BasePitch));
367 rewriter.getI32Type(), offsetW);
371 rewriter.getI32Type(), offsetH);
374 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
378 Value baseShapeWInBytes =
379 arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
381 Value basePitchBytes =
382 arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize);
384 if (wScaleFactor > 1) {
388 rewriter, loc, rewriter.getI32Type(), llvm::Log2_64(wScaleFactor));
389 baseShapeWInBytes = arith::ShRSIOp::create(
390 rewriter, loc, baseShapeWInBytes, wScaleFactorValLog2);
391 basePitchBytes = arith::ShRSIOp::create(rewriter, loc, basePitchBytes,
392 wScaleFactorValLog2);
394 arith::ShRSIOp::create(rewriter, loc, offsetW, wScaleFactorValLog2);
397 auto tileH = tdescTy.getDimSize(0);
399 int32_t vblocks = tdescTy.getArrayLength();
400 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
401 Value src = adaptor.getValue();
407 VectorType srcVecTy = dyn_cast<VectorType>(src.
getType());
409 return rewriter.notifyMatchFailure(
410 op,
"Expected store value to be a vector type.");
412 VectorType newSrcVecTy =
413 encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
414 if (srcVecTy != newSrcVecTy)
415 src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
416 auto storeCacheControl =
417 translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
418 xevm::BlockStore2dOp::create(
419 rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
420 basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, src,
421 xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
422 rewriter.eraseOp(op);
424 auto loadCacheControl =
425 translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
426 if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
427 xevm::BlockPrefetch2dOp::create(
428 rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
429 basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH,
430 vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
431 rewriter.eraseOp(op);
433 VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
434 const bool vnni = op.getPacked().value_or(
false);
435 auto transposeValue = op.getTranspose();
437 transposeValue.has_value() && transposeValue.value()[0] == 1;
438 VectorType loadedTy = encodeVectorTypeTo(
439 dstVecTy, vnni ? rewriter.getI32Type()
440 : rewriter.getIntegerType(elemBitSize));
442 Value resultFlatVec = xevm::BlockLoad2dOp::create(
443 rewriter, loc, loadedTy, basePtrLLVM, baseShapeWInBytes,
444 baseShapeH, basePitchBytes, offsetW, offsetH, elemBitSize, tileW,
445 tileH, vblocks, transpose, vnni,
446 xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
447 resultFlatVec = vector::BitCastOp::create(
449 encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
451 rewriter.replaceOp(op, resultFlatVec);
463 rewriter.getI64Type(), offset);
466 rewriter, loc, rewriter.getI64Type(), elemBitSize / 8);
468 rewriter.createOrFold<arith::MulIOp>(loc, offset, elemByteSize);
470 Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>(
476 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64);
477 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
478 Value src = adaptor.getValue();
484 VectorType srcVecTy = dyn_cast<VectorType>(src.
getType());
486 return rewriter.notifyMatchFailure(
487 op,
"Expected store value to be a vector type.");
489 VectorType newSrcVecTy =
490 encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
491 if (srcVecTy != newSrcVecTy)
492 src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
493 auto storeCacheControl =
494 translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
495 rewriter.replaceOpWithNewOp<xevm::BlockStoreOp>(
496 op, finalPtrLLVM, src,
497 xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
498 }
else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
499 auto loadCacheControl =
500 translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
501 VectorType resTy = cast<VectorType>(op.getValue().getType());
502 VectorType loadedTy =
503 encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize));
504 Value
load = xevm::BlockLoadOp::create(
505 rewriter, loc, loadedTy, finalPtrLLVM,
506 xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
507 if (loadedTy != resTy)
508 load = vector::BitCastOp::create(rewriter, loc, resTy,
load);
509 rewriter.replaceOp(op,
load);
511 return rewriter.notifyMatchFailure(
512 op,
"Unsupported operation: xegpu.prefetch_nd with tensor "
513 "descriptor rank == 1");
522static Value addOffsetToBaseAddr(ConversionPatternRewriter &rewriter,
526 rewriter, loc, baseAddr.
getType(), elemByteSize);
527 Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
528 Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
532template <
typename OpType,
533 typename = std::enable_if_t<llvm::is_one_of<
534 OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
535class LoadStoreToXeVMPattern :
public OpConversionPattern<OpType> {
536 using OpConversionPattern<OpType>::OpConversionPattern;
538 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
539 ConversionPatternRewriter &rewriter)
const override {
540 Value offset = adaptor.getOffsets();
542 return rewriter.notifyMatchFailure(op,
"Expected offset to be provided.");
543 auto loc = op.getLoc();
544 auto ctxt = rewriter.getContext();
548 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
550 this->getTypeConverter()->convertType(op.getResult().getType());
552 valOrResTy = adaptor.getValue().getType();
553 VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
554 bool hasScalarVal = !valOrResVecTy;
555 int64_t elemBitWidth =
557 : valOrResVecTy.getElementType().getIntOrFloatBitWidth();
559 if (elemBitWidth % 8 != 0)
560 return rewriter.notifyMatchFailure(
561 op,
"Expected element type bit width to be multiple of 8.");
562 int64_t elemByteSize = elemBitWidth / 8;
564 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
565 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
568 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
569 basePtrI64 = adaptor.getSource();
570 if (
auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
571 auto addrSpace = memRefTy.getMemorySpaceAsInt();
573 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
576 basePtrI64 = adaptor.getDest();
577 if (
auto memRefTy = dyn_cast<MemRefType>(op.getDest().getType())) {
578 auto addrSpace = memRefTy.getMemorySpaceAsInt();
580 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
584 if (basePtrI64.
getType() != rewriter.getI64Type()) {
585 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
588 Value mask = adaptor.getMask();
589 if (dyn_cast<VectorType>(offset.
getType())) {
592 return rewriter.notifyMatchFailure(op,
"Expected offset to be a scalar.");
598 addOffsetToBaseAddr(rewriter, loc, basePtrI64, offset, elemByteSize);
602 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
605 VectorType maskVecTy = dyn_cast<VectorType>(mask.
getType());
609 return rewriter.notifyMatchFailure(op,
"Expected mask to be a scalar.");
612 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
613 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy},
614 maskForLane,
true,
true);
616 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
618 valOrResTy = VectorType::get({valOrResVecTy.getNumElements()},
619 valOrResVecTy.getElementType());
621 LLVM::LoadOp::create(rewriter, loc, valOrResTy, basePtrLLVM);
624 "cache_control", xevm::LoadCacheControlAttr::get(
625 ctxt, translateLoadXeGPUCacheHint(
626 op.getL1Hint(), op.getL3Hint())));
627 scf::YieldOp::create(rewriter, loc,
ValueRange{loaded});
628 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
630 auto eTy = hasScalarVal ? valOrResTy : valOrResVecTy.getElementType();
633 eVal = FloatAttr::get(eTy, 0.0);
635 eVal = IntegerAttr::get(eTy, 0);
637 loaded = arith::ConstantOp::create(rewriter, loc, eVal);
639 loaded = arith::ConstantOp::create(
641 scf::YieldOp::create(rewriter, loc,
ValueRange{loaded});
642 rewriter.replaceOp(op, ifOp.getResult(0));
645 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane,
false);
646 auto body = ifOp.getBody();
647 rewriter.setInsertionPointToStart(body);
649 LLVM::StoreOp::create(rewriter, loc, adaptor.getValue(), basePtrLLVM);
651 storeOp.getOperation()->setAttr(
652 "cache_control", xevm::StoreCacheControlAttr::get(
653 ctxt, translateStoreXeGPUCacheHint(
654 op.getL1Hint(), op.getL3Hint())));
655 rewriter.eraseOp(op);
661class CreateMemDescOpPattern final
662 :
public OpConversionPattern<xegpu::CreateMemDescOp> {
664 using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern;
666 matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
667 ConversionPatternRewriter &rewriter)
const override {
669 rewriter.replaceOp(op, adaptor.getSource());
674template <
typename OpType,
675 typename = std::enable_if_t<llvm::is_one_of<
676 OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
677class LoadStoreMatrixToXeVMPattern :
public OpConversionPattern<OpType> {
678 using OpConversionPattern<OpType>::OpConversionPattern;
680 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
681 ConversionPatternRewriter &rewriter)
const override {
683 SmallVector<OpFoldResult> offsets = op.getMixedOffsets();
685 return rewriter.notifyMatchFailure(op,
"Expected offset to be provided.");
687 auto loc = op.getLoc();
688 auto ctxt = rewriter.getContext();
689 Value baseAddr32 = adaptor.getMemDesc();
690 Value mdescVal = op.getMemDesc();
693 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
694 Type resType = op.getResult().getType();
697 if (
auto vecType = dyn_cast<VectorType>(resType)) {
698 assert(llvm::count_if(vecType.getShape(),
699 [](int64_t d) { return d != 1; }) <= 1 &&
700 "Expected either 1D vector or nD with unit dimensions");
701 resType = VectorType::get({vecType.getNumElements()},
702 vecType.getElementType());
706 dataTy = adaptor.getData().getType();
707 VectorType valOrResVecTy = dyn_cast<VectorType>(dataTy);
709 valOrResVecTy = VectorType::get(1, dataTy);
711 int64_t elemBitWidth =
712 valOrResVecTy.getElementType().getIntOrFloatBitWidth();
714 if (elemBitWidth % 8 != 0)
715 return rewriter.notifyMatchFailure(
716 op,
"Expected element type bit width to be multiple of 8.");
717 int64_t elemByteSize = elemBitWidth / 8;
720 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
721 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM));
723 auto mdescTy = cast<xegpu::MemDescType>(mdescVal.
getType());
725 Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
726 linearOffset = arith::IndexCastUIOp::create(
727 rewriter, loc, rewriter.getI32Type(), linearOffset);
728 Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32,
729 linearOffset, elemByteSize);
733 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);
735 if (op.getSubgroupBlockIoAttr()) {
739 Type intElemTy = rewriter.getIntegerType(elemBitWidth);
740 VectorType intVecTy =
741 VectorType::get(valOrResVecTy.getShape(), intElemTy);
743 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
745 xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM);
746 if (intVecTy != valOrResVecTy) {
748 vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp);
750 rewriter.replaceOp(op, loadOp);
752 Value dataToStore = adaptor.getData();
753 if (valOrResVecTy != intVecTy) {
755 vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore);
757 xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore,
759 rewriter.eraseOp(op);
764 if (valOrResVecTy.getNumElements() >= 1) {
767 (*chipOpt !=
"pvc" && *chipOpt !=
"bmg" && *chipOpt !=
"cri")) {
769 return rewriter.notifyMatchFailure(
770 op,
"The lowering is specific to pvc, bmg or cri.");
774 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
778 auto scalarTy = valOrResVecTy.getElementType();
780 if (valOrResVecTy.getNumElements() == 1)
781 loadOp = LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
784 LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
785 rewriter.replaceOp(op, loadOp);
787 LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
788 rewriter.eraseOp(op);
794class PrefetchToXeVMPattern :
public OpConversionPattern<xegpu::PrefetchOp> {
795 using OpConversionPattern::OpConversionPattern;
797 matchAndRewrite(xegpu::PrefetchOp op, xegpu::PrefetchOp::Adaptor adaptor,
798 ConversionPatternRewriter &rewriter)
const override {
799 auto loc = op.getLoc();
800 auto ctxt = rewriter.getContext();
801 Value basePtrI64 = adaptor.getSource();
803 if (basePtrI64.
getType() != rewriter.getI64Type())
804 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
806 Value offsets = adaptor.getOffsets();
808 VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.
getType());
811 return rewriter.notifyMatchFailure(op,
812 "Expected offsets to be a scalar.");
814 int64_t elemBitWidth{0};
815 int64_t elemByteSize;
817 if (
auto memRefTy = dyn_cast<MemRefType>(op.getSourceType())) {
820 elemBitWidth = memRefTy.getElementType().getIntOrFloatBitWidth();
823 elemByteSize = *op.getOffsetAlignByte();
825 if (elemBitWidth != 0) {
826 if (elemBitWidth % 8 != 0)
827 return rewriter.notifyMatchFailure(
828 op,
"Expected element type bit width to be multiple of 8.");
829 elemByteSize = elemBitWidth / 8;
831 basePtrI64 = addOffsetToBaseAddr(rewriter, loc, basePtrI64, offsets,
836 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
837 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
839 if (
auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
840 auto addrSpace = memRefTy.getMemorySpaceAsInt();
842 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
846 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
848 xevm::PrefetchOp::create(
849 rewriter, loc, ptrLLVM,
850 xevm::LoadCacheControlAttr::get(
851 ctxt, translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint())));
852 rewriter.eraseOp(op);
857class FenceToXeVMPattern :
public OpConversionPattern<xegpu::FenceOp> {
858 using OpConversionPattern::OpConversionPattern;
860 matchAndRewrite(xegpu::FenceOp op, xegpu::FenceOp::Adaptor adaptor,
861 ConversionPatternRewriter &rewriter)
const override {
862 auto loc = op.getLoc();
863 xevm::MemScope memScope{xevm::MemScope::WORKGROUP};
864 switch (op.getFenceScope()) {
865 case xegpu::FenceScope::Workgroup:
866 memScope = xevm::MemScope::WORKGROUP;
868 case xegpu::FenceScope::GPU:
869 memScope = xevm::MemScope::DEVICE;
872 xevm::AddrSpace addrSpace{xevm::AddrSpace::GLOBAL};
873 switch (op.getMemoryKind()) {
874 case xegpu::MemorySpace::Global:
875 addrSpace = xevm::AddrSpace::GLOBAL;
877 case xegpu::MemorySpace::SLM:
878 addrSpace = xevm::AddrSpace::SHARED;
881 xevm::MemfenceOp::create(rewriter, loc, memScope, addrSpace);
882 rewriter.eraseOp(op);
887static auto encodePrecision = [](
Type type) -> xevm::ElemType {
889 return xevm::ElemType::BF16;
890 else if (type.isF16())
891 return xevm::ElemType::F16;
892 else if (type.isTF32())
893 return xevm::ElemType::TF32;
894 else if (type.isInteger(8)) {
895 if (type.isUnsignedInteger())
896 return xevm::ElemType::U8;
897 return xevm::ElemType::S8;
898 }
else if (type.isF32())
899 return xevm::ElemType::F32;
900 else if (type.isInteger(32))
901 return xevm::ElemType::S32;
902 else if (type.isF8E5M2())
903 return xevm::ElemType::BF8;
904 else if (type.isF8E4M3FN())
905 return xevm::ElemType::F8;
906 else if (mlir::isa<Float4E2M1FNType>(type))
907 return xevm::ElemType::E2M1;
908 llvm_unreachable(
"add more support for ElemType");
911static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
913 case xevm::ElemType::TF32:
915 case xevm::ElemType::BF16:
916 case xevm::ElemType::F16:
918 case xevm::ElemType::U8:
919 case xevm::ElemType::S8:
920 case xevm::ElemType::F8:
921 case xevm::ElemType::BF8:
923 case xevm::ElemType::E2M1:
926 llvm_unreachable(
"unsupported xevm::ElemType");
930class DpasToXeVMPattern :
public OpConversionPattern<xegpu::DpasOp> {
931 using OpConversionPattern::OpConversionPattern;
933 matchAndRewrite(xegpu::DpasOp op, xegpu::DpasOp::Adaptor adaptor,
934 ConversionPatternRewriter &rewriter)
const override {
935 auto loc = op.getLoc();
936 auto ctxt = rewriter.getContext();
937 auto aTy = cast<VectorType>(op.getLhs().getType());
938 auto bTy = cast<VectorType>(op.getRhs().getType());
939 auto resultType = cast<VectorType>(op.getResultType());
944 return rewriter.notifyMatchFailure(op,
"cannot determine target chip");
948 return rewriter.notifyMatchFailure(op,
"unsupported target uArch");
951 llvm::dyn_cast_or_null<xegpu::uArch::SubgroupMatrixMultiplyAcc>(
952 uArch->getInstruction(
953 xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)));
955 return rewriter.notifyMatchFailure(op,
956 "DPAS not supported by target uArch");
958 auto checkSupportedTypes = [&](VectorType vecTy,
960 auto supported = dpasInst->getSupportedTypes(*ctxt, kind);
961 return llvm::find(supported, vecTy.getElementType()) != supported.end();
964 if (!checkSupportedTypes(aTy, xegpu::uArch::MMAOpndKind::MatrixA))
965 return rewriter.notifyMatchFailure(
966 op,
"A-matrix element type not supported by target uArch");
967 if (!checkSupportedTypes(bTy, xegpu::uArch::MMAOpndKind::MatrixB))
968 return rewriter.notifyMatchFailure(
969 op,
"B-matrix element type not supported by target uArch");
971 if (!checkSupportedTypes(resultType, xegpu::uArch::MMAOpndKind::MatrixD))
972 return rewriter.notifyMatchFailure(
973 op,
"result/accumulator element type not supported by target uArch");
975 xevm::ElemType precATy = encodePrecision(aTy.getElementType());
976 xevm::ElemType precBTy = encodePrecision(bTy.getElementType());
977 Value c = op.getAcc();
979 auto elementTy = resultType.getElementType();
980 Attribute initValueAttr;
981 if (isa<FloatType>(elementTy))
982 initValueAttr = FloatAttr::get(elementTy, 0.0);
984 initValueAttr = IntegerAttr::get(elementTy, 0);
985 c = arith::ConstantOp::create(
989 Value aVec = op.getLhs();
990 Value bVec = op.getRhs();
991 auto cvecty = cast<VectorType>(c.
getType());
992 xevm::ElemType precCTy = encodePrecision(cvecty.getElementType());
993 xevm::ElemType precDTy = encodePrecision(resultType.getElementType());
995 VectorType::get(cvecty.getNumElements(), cvecty.getElementType());
997 c = vector::ShapeCastOp::create(rewriter, loc, cNty, c);
998 Value dpasRes = xevm::MMAOp::create(
999 rewriter, loc, cNty, aVec, bVec, c,
1000 xevm::MMAShapeAttr::get(ctxt, cvecty.getNumElements(), executionSize,
1002 getNumOperandsPerDword(precATy)),
1003 xevm::MMATypesAttr::get(ctxt, precDTy, precATy, precBTy, precCTy));
1005 dpasRes = vector::ShapeCastOp::create(rewriter, loc, resultType, dpasRes);
1006 rewriter.replaceOp(op, dpasRes);
1011static std::optional<LLVM::AtomicBinOp>
1012matchSimpleAtomicOp(arith::AtomicRMWKind arithKind) {
1013 switch (arithKind) {
1014 case arith::AtomicRMWKind::addf:
1015 return LLVM::AtomicBinOp::fadd;
1016 case arith::AtomicRMWKind::addi:
1017 return LLVM::AtomicBinOp::add;
1018 case arith::AtomicRMWKind::assign:
1019 return LLVM::AtomicBinOp::xchg;
1020 case arith::AtomicRMWKind::maximumf:
1021 return LLVM::AtomicBinOp::fmax;
1022 case arith::AtomicRMWKind::maxs:
1023 return LLVM::AtomicBinOp::max;
1024 case arith::AtomicRMWKind::maxu:
1025 return LLVM::AtomicBinOp::umax;
1026 case arith::AtomicRMWKind::minimumf:
1027 return LLVM::AtomicBinOp::fmin;
1028 case arith::AtomicRMWKind::mins:
1029 return LLVM::AtomicBinOp::min;
1030 case arith::AtomicRMWKind::minu:
1031 return LLVM::AtomicBinOp::umin;
1032 case arith::AtomicRMWKind::ori:
1033 return LLVM::AtomicBinOp::_or;
1034 case arith::AtomicRMWKind::andi:
1035 return LLVM::AtomicBinOp::_and;
1037 return std::nullopt;
1041class AtomicRMWToXeVMPattern :
public OpConversionPattern<xegpu::AtomicRMWOp> {
1042 using OpConversionPattern::OpConversionPattern;
1044 matchAndRewrite(xegpu::AtomicRMWOp op, xegpu::AtomicRMWOp::Adaptor adaptor,
1045 ConversionPatternRewriter &rewriter)
const override {
1046 auto loc = op.getLoc();
1047 auto ctxt = rewriter.getContext();
1048 auto tdesc = op.getTensorDesc().getType();
1049 auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
1050 ctxt, getNumericXeVMAddrSpace(tdesc.getMemorySpace()));
1051 Value basePtrI64 = arith::IndexCastOp::create(
1052 rewriter, loc, rewriter.getI64Type(), adaptor.getTensorDesc());
1054 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
1055 VectorType srcOrDstVecTy = cast<VectorType>(op.getValue().getType());
1056 VectorType srcOrDstFlatVecTy = VectorType::get(
1057 srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
1058 Value srcFlatVec = vector::ShapeCastOp::create(
1059 rewriter, loc, srcOrDstFlatVecTy, op.getValue());
1060 auto atomicKind = matchSimpleAtomicOp(op.getKind());
1061 assert(atomicKind.has_value());
1062 Value resVec = srcFlatVec;
1063 for (
int i = 0; i < srcOrDstVecTy.getNumElements(); i++) {
1064 auto val = vector::ExtractOp::create(rewriter, loc, resVec, i);
1065 Value idx = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
1066 rewriter.getIndexAttr(i));
1068 LLVM::GEPOp::create(rewriter, loc, ptrTypeLLVM,
1069 srcOrDstVecTy.getElementType(), basePtrLLVM, idx);
1071 LLVM::AtomicRMWOp::create(rewriter, loc, atomicKind.value(), currPtr,
1072 val, LLVM::AtomicOrdering::seq_cst);
1073 resVec = vector::InsertOp::create(rewriter, loc, newVal, resVec, i);
1075 rewriter.replaceOp(op, resVec);
1080class DpasMxToXeVMPattern :
public OpConversionPattern<xegpu::DpasMxOp> {
1081 using OpConversionPattern::OpConversionPattern;
1083 matchAndRewrite(xegpu::DpasMxOp op, xegpu::DpasMxOp::Adaptor adaptor,
1084 ConversionPatternRewriter &rewriter)
const override {
1085 auto loc = op.getLoc();
1086 auto ctxt = rewriter.getContext();
1087 auto aTy = op.getA().getType();
1088 auto bTy = op.getB().getType();
1090 cast<VectorType>(getTypeConverter()->convertType(op.getType()));
1094 return rewriter.notifyMatchFailure(op,
"cannot determine target chip");
1098 return rewriter.notifyMatchFailure(op,
"unsupported target uArch");
1102 xevm::ElemType precATy = encodePrecision(aTy.getElementType());
1103 xevm::ElemType precBTy = encodePrecision(bTy.getElementType());
1104 Value c = adaptor.getAcc();
1106 auto elementTy = resVecTy.getElementType();
1107 Attribute initValueAttr;
1108 if (isa<FloatType>(elementTy))
1109 initValueAttr = FloatAttr::get(elementTy, 0.0);
1111 initValueAttr = IntegerAttr::get(elementTy, 0);
1112 c = arith::ConstantOp::create(
1116 Value aVec = adaptor.getA();
1117 Value bVec = adaptor.getB();
1118 auto aVecTy = cast<VectorType>(aVec.
getType());
1119 auto bVecTy = cast<VectorType>(bVec.
getType());
1120 if (aVecTy.getElementTypeBitWidth() == 4)
1121 aVec = vector::BitCastOp::create(
1123 VectorType::get(aVecTy.getNumElements() / 2, rewriter.getI8Type()),
1125 if (bVecTy.getElementTypeBitWidth() == 4)
1126 bVec = vector::BitCastOp::create(
1128 VectorType::get(bVecTy.getNumElements() / 2, rewriter.getI8Type()),
1130 auto cVecTy = cast<VectorType>(c.
getType());
1131 xevm::ElemType precCTy = encodePrecision(cVecTy.getElementType());
1132 xevm::ElemType precDTy = encodePrecision(resVecTy.getElementType());
1133 Value scaleA = adaptor.getScaleA();
1134 Value scaleB = adaptor.getScaleB();
1135 Value dpasMxRes = xevm::MMAMxOp::create(
1136 rewriter, loc, resVecTy, aVec, bVec, scaleA, scaleB, c,
1137 xevm::MMAShapeAttr::get(ctxt, cVecTy.getNumElements(), executionSize,
1139 getNumOperandsPerDword(precATy)),
1140 xevm::MMATypesAttr::get(ctxt, precDTy, precATy, precBTy, precCTy));
1141 rewriter.replaceOp(op, dpasMxRes);
1150struct ConvertXeGPUToXeVMPass
1151 :
public impl::ConvertXeGPUToXeVMPassBase<ConvertXeGPUToXeVMPass> {
1154 void runOnOperation()
override {
1164 LowerToLLVMOptions
options(context);
1165 options.overrideIndexBitwidth(this->use64bitIndex ? 64 : 32);
1166 LLVMTypeConverter typeConverter(context,
options);
1168 Type xevmIndexType = typeConverter.convertType(IndexType::get(context));
1169 Type i32Type = IntegerType::get(context, 32);
1170 typeConverter.addConversion([&](VectorType type) -> Type {
1171 auto elemType = typeConverter.convertType(type.getElementType());
1173 unsigned rank = type.getRank();
1174 if (rank == 0 || type.getNumElements() == 1)
1177 int64_t sum = llvm::product_of(type.getShape());
1178 return VectorType::get(sum, elemType);
1180 typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
1181 if (type.getRank() == 1)
1182 return xevmIndexType;
1183 return VectorType::get(8, i32Type);
1192 typeConverter.addConversion(
1193 [&](xegpu::MemDescType type) -> Type {
return i32Type; });
1195 typeConverter.addConversion([&](MemRefType type) -> Type {
1196 return isSharedMemRef(type) ? i32Type : xevmIndexType;
1206 auto memrefToIntMaterializationCast = [](OpBuilder &builder, Type type,
1208 Location loc) -> Value {
1209 if (inputs.size() != 1)
1211 auto input = inputs.front();
1212 if (
auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
1213 unsigned rank = memrefTy.getRank();
1217 SmallVector<int64_t> intStrides;
1220 if (succeeded(memrefTy.getStridesAndOffset(intStrides, intOffsets)) &&
1221 ShapedType::isStatic(intOffsets)) {
1222 addr = memref::ExtractAlignedPointerAsIndexOp::create(builder, loc,
1224 offset = arith::ConstantOp::create(builder, loc,
1230 SmallVector<Type> resultTypes{
1231 MemRefType::get({}, memrefTy.getElementType(),
1232 MemRefLayoutAttrInterface(),
1233 memrefTy.getMemorySpace()),
1236 resultTypes.append(2 * rank, indexType);
1238 auto meta = memref::ExtractStridedMetadataOp::create(
1239 builder, loc, resultTypes, input);
1241 addr = memref::ExtractAlignedPointerAsIndexOp::create(
1242 builder, loc, meta.getBaseBuffer());
1243 offset = meta.getOffset();
1247 arith::IndexCastUIOp::create(builder, loc, type, addr);
1249 arith::IndexCastUIOp::create(builder, loc, type, offset);
1252 auto byteSize = arith::ConstantOp::create(
1255 memrefTy.getElementTypeBitWidth() / 8));
1257 arith::MulIOp::create(builder, loc, offsetCasted, byteSize);
1258 auto addrWithOffset =
1259 arith::AddIOp::create(builder, loc, addrCasted, byteOffset);
1261 return addrWithOffset.getResult();
1270 auto ui64ToI64MaterializationCast = [](OpBuilder &builder, Type type,
1272 Location loc) -> Value {
1273 if (inputs.size() != 1)
1275 auto input = inputs.front();
1278 index::CastUOp::create(builder, loc, builder.
getIndexType(), input)
1280 return arith::IndexCastUIOp::create(builder, loc, type, cast)
1290 auto ui32ToI32MaterializationCast = [](OpBuilder &builder, Type type,
1292 Location loc) -> Value {
1293 if (inputs.size() != 1)
1295 auto input = inputs.front();
1298 index::CastUOp::create(builder, loc, builder.
getIndexType(), input)
1300 return arith::IndexCastUIOp::create(builder, loc, type, cast)
1310 auto vectorToVectorMaterializationCast = [](OpBuilder &builder, Type type,
1312 Location loc) -> Value {
1313 if (inputs.size() != 1)
1315 auto input = inputs.front();
1316 if (
auto vecTy = dyn_cast<VectorType>(input.getType())) {
1317 if (
auto targetVecTy = dyn_cast<VectorType>(type)) {
1321 if (targetVecTy.getShape() != vecTy.getShape()) {
1322 cast = vector::ShapeCastOp::create(
1324 VectorType::get(targetVecTy.getShape(),
1325 vecTy.getElementType()),
1329 if (targetVecTy.getElementType() != vecTy.getElementType()) {
1330 cast = vector::BitCastOp::create(builder, loc, targetVecTy, cast)
1342 auto vectorToSingleElementMaterializationCast =
1343 [](OpBuilder &builder, Type type,
ValueRange inputs,
1344 Location loc) -> Value {
1345 if (inputs.size() != 1)
1347 auto input = inputs.front();
1348 if (
auto vecTy = dyn_cast<VectorType>(input.getType())) {
1350 auto rank = vecTy.getRank();
1351 if (rank != 0 && vecTy.getNumElements() != 1)
1353 auto inElemTy = vecTy.getElementType();
1357 cast = vector::ExtractOp::create(builder, loc, cast, {}).getResult();
1359 cast = vector::ExtractOp::create(builder, loc, cast,
1360 SmallVector<int64_t>(rank, 0))
1367 if (inElemTy.isIndex()) {
1368 cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
1370 }
else if (inElemTy != type) {
1371 cast = arith::BitcastOp::create(builder, loc, type, cast).getResult();
1385 auto singleElementToVectorMaterializationCast =
1386 [](OpBuilder &builder, Type type,
ValueRange inputs,
1387 Location loc) -> Value {
1388 if (inputs.size() != 1)
1390 auto input = inputs.front();
1391 auto inTy = input.getType();
1392 if (!inTy.isIntOrFloat())
1396 if (
auto vecTy = dyn_cast<VectorType>(type)) {
1397 if (vecTy.getRank() != 0 && vecTy.getNumElements() != 1)
1399 auto outElemTy = vecTy.getElementType();
1401 if (outElemTy.isIndex()) {
1402 cast = arith::IndexCastUIOp::create(builder, loc,
1405 }
else if (inTy != outElemTy) {
1406 cast = arith::BitcastOp::create(builder, loc, outElemTy, cast)
1409 return vector::BroadcastOp::create(builder, loc, vecTy, cast)
1414 typeConverter.addSourceMaterialization(
1415 singleElementToVectorMaterializationCast);
1416 typeConverter.addSourceMaterialization(vectorToVectorMaterializationCast);
1417 typeConverter.addTargetMaterialization(memrefToIntMaterializationCast);
1418 typeConverter.addTargetMaterialization(ui32ToI32MaterializationCast);
1419 typeConverter.addTargetMaterialization(ui64ToI64MaterializationCast);
1420 typeConverter.addTargetMaterialization(
1421 vectorToSingleElementMaterializationCast);
1422 typeConverter.addTargetMaterialization(vectorToVectorMaterializationCast);
1423 ConversionTarget
target(*context);
1424 target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
1425 vector::VectorDialect, arith::ArithDialect,
1426 memref::MemRefDialect, gpu::GPUDialect,
1427 index::IndexDialect>();
1428 target.addIllegalDialect<xegpu::XeGPUDialect>();
1430 RewritePatternSet patterns(context);
1434 if (
failed(applyPartialConversion(getOperation(),
target,
1435 std::move(patterns))))
1436 signalPassFailure();
1446 patterns.
add<CreateNdDescToXeVMPattern,
1447 LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
1448 LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
1449 LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>(
1451 patterns.
add<AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
1452 LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
1453 LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
1455 patterns.
add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>,
1456 LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>,
1457 CreateMemDescOpPattern>(typeConverter, patterns.
getContext());
1458 patterns.
add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
1460 patterns.
add<DpasMxToXeVMPattern>(typeConverter, patterns.
getContext());
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
@ SubgroupMatrixMultiplyAcc
const uArch * getUArch(llvm::StringRef archName)
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
void populateXeGPUToXeVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)