26 #include "llvm/Support/FormatVariadic.h"
31 #include "llvm/ADT/TypeSwitch.h"
36 #define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
37 #include "mlir/Conversion/Passes.h.inc"
45 static constexpr int32_t systolicDepth{8};
46 static constexpr int32_t executionSize{16};
49 enum class NdTdescOffset : uint32_t {
57 static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
58 switch (xeGpuMemspace) {
59 case xegpu::MemorySpace::Global:
60 return static_cast<int>(xevm::AddrSpace::GLOBAL);
61 case xegpu::MemorySpace::SLM:
62 return static_cast<int>(xevm::AddrSpace::SHARED);
67 static VectorType encodeVectorTypeTo(VectorType currentVecType,
69 auto elemType = currentVecType.getElementType();
70 auto currentBitWidth = elemType.getIntOrFloatBitWidth();
73 currentVecType.getNumElements() * currentBitWidth / newBitWidth;
77 static xevm::LoadCacheControl
78 translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
79 std::optional<xegpu::CachePolicy> L3hint) {
80 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
81 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
83 case xegpu::CachePolicy::CACHED:
84 if (L3hintVal == xegpu::CachePolicy::CACHED)
85 return xevm::LoadCacheControl::L1C_L2UC_L3C;
86 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
87 return xevm::LoadCacheControl::L1C_L2UC_L3UC;
89 llvm_unreachable(
"Unsupported cache control.");
90 case xegpu::CachePolicy::UNCACHED:
91 if (L3hintVal == xegpu::CachePolicy::CACHED)
92 return xevm::LoadCacheControl::L1UC_L2UC_L3C;
93 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
94 return xevm::LoadCacheControl::L1UC_L2UC_L3UC;
96 llvm_unreachable(
"Unsupported cache control.");
97 case xegpu::CachePolicy::STREAMING:
98 if (L3hintVal == xegpu::CachePolicy::CACHED)
99 return xevm::LoadCacheControl::L1S_L2UC_L3C;
100 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
101 return xevm::LoadCacheControl::L1S_L2UC_L3UC;
103 llvm_unreachable(
"Unsupported cache control.");
104 case xegpu::CachePolicy::READ_INVALIDATE:
105 return xevm::LoadCacheControl::INVALIDATE_READ;
107 llvm_unreachable(
"Unsupported cache control.");
111 static xevm::StoreCacheControl
112 translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
113 std::optional<xegpu::CachePolicy> L3hint) {
114 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
115 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
117 case xegpu::CachePolicy::UNCACHED:
118 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
119 return xevm::StoreCacheControl::L1UC_L2UC_L3UC;
120 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
121 return xevm::StoreCacheControl::L1UC_L2UC_L3WB;
123 llvm_unreachable(
"Unsupported cache control.");
124 case xegpu::CachePolicy::STREAMING:
125 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
126 return xevm::StoreCacheControl::L1S_L2UC_L3UC;
127 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
128 return xevm::StoreCacheControl::L1S_L2UC_L3WB;
130 llvm_unreachable(
"Unsupported cache control.");
131 case xegpu::CachePolicy::WRITE_BACK:
132 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
133 return xevm::StoreCacheControl::L1WB_L2UC_L3UC;
134 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
135 return xevm::StoreCacheControl::L1WB_L2UC_L3WB;
137 llvm_unreachable(
"Unsupported cache control.");
138 case xegpu::CachePolicy::WRITE_THROUGH:
139 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
140 return xevm::StoreCacheControl::L1WT_L2UC_L3UC;
141 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
142 return xevm::StoreCacheControl::L1WT_L2UC_L3WB;
144 llvm_unreachable(
"Unsupported cache control.");
146 llvm_unreachable(
"Unsupported cache control.");
150 class CreateNdDescToXeVMPattern
154 matchAndRewrite(xegpu::CreateNdDescOp op,
155 xegpu::CreateNdDescOp::Adaptor adaptor,
157 auto loc = op.getLoc();
158 auto source = op.getSource();
168 Value payload = arith::ConstantOp::create(
182 int64_t rank = mixedSizes.size();
185 auto sourceTy = source.getType();
186 auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
189 if (sourceMemrefTy) {
190 if (!sourceMemrefTy.hasStaticShape()) {
194 memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
196 baseAddr = adaptor.getSource();
200 unsigned idx) ->
Value {
206 if (mixedOffsets.size() == 2) {
207 offsetW = createOffset(mixedOffsets, 1);
208 offsetH = createOffset(mixedOffsets, 0);
209 }
else if (mixedOffsets.size() == 0) {
214 "Expected 2D offsets or no offsets.");
217 baseShapeW = createOffset(mixedSizes, 1);
218 baseShapeH = createOffset(mixedSizes, 0);
219 if (sourceMemrefTy) {
221 baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
222 }
else if (baseAddr.
getType() != i64Ty) {
224 baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
228 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
230 vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
231 static_cast<int>(NdTdescOffset::BasePtr));
232 payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
234 vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
235 static_cast<int>(NdTdescOffset::BaseShapeW));
237 vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
238 static_cast<int>(NdTdescOffset::BaseShapeH));
239 payload = vector::InsertOp::create(
240 rewriter, loc, offsetW, payload,
241 static_cast<int>(NdTdescOffset::TensorOffsetW));
242 payload = vector::InsertOp::create(
243 rewriter, loc, offsetH, payload,
244 static_cast<int>(NdTdescOffset::TensorOffsetH));
250 class UpdateNdOffsetToXeVMPattern
254 matchAndRewrite(xegpu::UpdateNdOffsetOp op,
255 xegpu::UpdateNdOffsetOp::Adaptor adaptor,
257 auto loc = op.getLoc();
258 auto mixedOffsets = op.getMixedOffsets();
260 if (mixedOffsets.size() != 2)
262 auto payload = adaptor.getTensorDesc();
264 auto updateOffset = [&](
unsigned idx,
int payloadPos) ->
Value {
270 vector::ExtractOp::create(rewriter, loc, payload, payloadPos);
271 Value newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset);
272 return vector::InsertOp::create(rewriter, loc, newOffset, payload,
276 payload = updateOffset(0,
static_cast<int>(NdTdescOffset::TensorOffsetH));
277 payload = updateOffset(1,
static_cast<int>(NdTdescOffset::TensorOffsetW));
285 typename = std::enable_if_t<llvm::is_one_of<
286 OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp>::value>>
290 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
292 auto loc = op.getLoc();
295 auto tdesc = adaptor.getTensorDesc();
296 auto tdescTy = op.getTensorDescType();
297 if (tdescTy.getRank() != 2)
299 auto elemType = tdescTy.getElementType();
300 auto elemBitSize = elemType.getIntOrFloatBitWidth();
301 if (elemBitSize % 8 != 0)
303 op,
"Expected element type bit width to be multiple of 8.");
307 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
308 Value basePtr = vector::ExtractOp::create(
309 rewriter, loc, payLoadAsI64,
static_cast<int>(NdTdescOffset::BasePtr));
310 Value baseShapeW = vector::ExtractOp::create(
311 rewriter, loc, tdesc,
static_cast<int>(NdTdescOffset::BaseShapeW));
312 Value baseShapeH = vector::ExtractOp::create(
313 rewriter, loc, tdesc,
static_cast<int>(NdTdescOffset::BaseShapeH));
319 auto mixedOffsets = op.getMixedOffsets();
320 int64_t opOffsetsSize = mixedOffsets.size();
321 if (opOffsetsSize != 0 && opOffsetsSize != 2)
323 "Expected 2D offsets or no offsets.");
335 offsetW = vector::ExtractOp::create(
336 rewriter, loc, tdesc,
static_cast<int>(NdTdescOffset::TensorOffsetW));
337 offsetH = vector::ExtractOp::create(
338 rewriter, loc, tdesc,
static_cast<int>(NdTdescOffset::TensorOffsetH));
342 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
345 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
348 rewriter, loc, rewriter.
getI32Type(), elemBitSize / 8);
350 arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
353 auto tileW = tdescTy.getDimSize(1);
354 auto tileH = tdescTy.getDimSize(0);
355 int32_t vblocks = tdescTy.getArrayLength();
356 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
357 Value src = adaptor.getValue();
363 VectorType srcVecTy = dyn_cast<VectorType>(src.
getType());
366 op,
"Expected store value to be a vector type.");
368 VectorType newSrcVecTy =
369 encodeVectorTypeTo(srcVecTy, rewriter.
getIntegerType(elemBitSize));
370 if (srcVecTy != newSrcVecTy)
371 src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
372 auto storeCacheControl =
373 translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
374 xevm::BlockStore2dOp::create(
375 rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
376 offsetH, elemBitSize, tileW, tileH, src,
380 auto loadCacheControl =
381 translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
382 if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
383 xevm::BlockPrefetch2dOp::create(
384 rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
385 offsetH, elemBitSize, tileW, tileH, vblocks,
389 VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
390 const bool vnni = op.getPacked().value_or(
false);
391 auto transposeValue = op.getTranspose();
393 transposeValue.has_value() && transposeValue.value()[0] == 1;
394 VectorType loadedTy = encodeVectorTypeTo(
398 Value resultFlatVec = xevm::BlockLoad2dOp::create(
399 rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
400 surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
403 resultFlatVec = vector::BitCastOp::create(
405 encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
417 Value baseAddr,
Value offset, int64_t elemByteSize) {
419 rewriter, loc, rewriter.
getI64Type(), elemByteSize);
420 Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
421 Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
425 class CreateDescToXeVMPattern
429 matchAndRewrite(xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor,
431 auto eTy = op.getTensorDescType().getElementType();
432 auto eBw = eTy.getIntOrFloatBitWidth();
435 op,
"Expected element type bit width to be multiple of 8.");
436 auto loc = op.getLoc();
438 auto offsets = adaptor.getOffsets();
441 Value addr = adaptor.getSource();
444 addr = arith::ExtUIOp::create(rewriter, loc, rewriter.
getI64Type(), addr);
445 auto laneAddr = addOffset(rewriter, loc, addr, offsets, eBw / 8);
451 class UpdateOffsetToXeVMPattern
455 matchAndRewrite(xegpu::UpdateOffsetOp op,
456 xegpu::UpdateOffsetOp::Adaptor adaptor,
458 auto eTy = op.getTensorDescType().getElementType();
459 auto eBw = eTy.getIntOrFloatBitWidth();
462 op,
"Expected element type bit width to be multiple of 8.");
463 auto loc = op.getLoc();
466 Value newOffset = addOffset(rewriter, loc, adaptor.getTensorDesc(),
467 adaptor.getOffsets(), eBw / 8);
473 template <
typename OpType,
474 typename = std::enable_if_t<llvm::is_one_of<
475 OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
479 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
481 auto loc = op.getLoc();
483 auto tdescTy = op.getTensorDescType();
487 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
488 valOrResTy = op.getResult().getType();
490 valOrResTy = adaptor.getValue().getType();
491 VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
492 bool hasScalarVal = !valOrResVecTy;
493 int64_t elemBitWidth =
495 : valOrResVecTy.getElementType().getIntOrFloatBitWidth();
497 if (elemBitWidth % 8 != 0)
499 op,
"Expected element type bit width to be multiple of 8.");
500 int64_t elemByteSize = elemBitWidth / 8;
503 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
507 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
510 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
511 basePtrI64 = adaptor.getSource();
512 if (
auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
513 auto addrSpace = memRefTy.getMemorySpaceAsInt();
518 basePtrI64 = adaptor.getDest();
519 if (
auto memRefTy = dyn_cast<MemRefType>(op.getDest().getType())) {
520 auto addrSpace = memRefTy.getMemorySpaceAsInt();
527 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.
getI64Type(),
530 Value offsets = adaptor.getOffsets();
531 Value mask = adaptor.getMask();
533 if (dyn_cast<VectorType>(offsets.
getType())) {
537 "Expected offsets to be a scalar.");
543 addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
548 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
551 VectorType maskVecTy = dyn_cast<VectorType>(mask.
getType());
558 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
559 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy},
560 maskForLane,
true,
true);
565 valOrResVecTy.getElementType());
567 LLVM::LoadOp::create(rewriter, loc, valOrResTy, basePtrLLVM);
571 ctxt, translateLoadXeGPUCacheHint(
572 op.getL1Hint(), op.getL3Hint())));
573 scf::YieldOp::create(rewriter, loc,
ValueRange{loaded});
576 auto eTy = hasScalarVal ? valOrResTy : valOrResVecTy.getElementType();
583 loaded = arith::ConstantOp::create(rewriter, loc, eVal);
585 loaded = arith::ConstantOp::create(
587 scf::YieldOp::create(rewriter, loc,
ValueRange{loaded});
588 rewriter.
replaceOp(op, ifOp.getResult(0));
591 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane,
false);
592 auto body = ifOp.getBody();
595 LLVM::StoreOp::create(rewriter, loc, adaptor.getValue(), basePtrLLVM);
597 storeOp.getOperation()->setAttr(
599 ctxt, translateStoreXeGPUCacheHint(
600 op.getL1Hint(), op.getL3Hint())));
610 matchAndRewrite(xegpu::PrefetchOp op, xegpu::PrefetchOp::Adaptor adaptor,
612 auto loc = op.getLoc();
614 auto tdescTy = op.getTensorDescType();
615 Value basePtrI64 = adaptor.getSource();
618 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.
getI64Type(),
620 Value offsets = adaptor.getOffsets();
622 VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.
getType());
626 "Expected offsets to be a scalar.");
628 int64_t elemBitWidth{0};
629 int64_t elemByteSize;
634 elemBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth();
635 }
else if (
auto memRefTy = dyn_cast<MemRefType>(op.getSourceType())) {
638 elemBitWidth = memRefTy.getElementType().getIntOrFloatBitWidth();
641 elemByteSize = *op.getOffsetAlignByte();
643 if (elemBitWidth != 0) {
644 if (elemBitWidth % 8 != 0)
646 op,
"Expected element type bit width to be multiple of 8.");
647 elemByteSize = elemBitWidth / 8;
650 addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
655 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
659 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
661 if (
auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
662 auto addrSpace = memRefTy.getMemorySpaceAsInt();
668 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
670 xevm::PrefetchOp::create(
671 rewriter, loc, ptrLLVM,
673 ctxt, translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint())));
682 matchAndRewrite(xegpu::FenceOp op, xegpu::FenceOp::Adaptor adaptor,
684 auto loc = op.getLoc();
685 xevm::MemScope memScope{xevm::MemScope::WORKGROUP};
686 switch (op.getFenceScope()) {
687 case xegpu::FenceScope::Workgroup:
688 memScope = xevm::MemScope::WORKGROUP;
690 case xegpu::FenceScope::GPU:
691 memScope = xevm::MemScope::DEVICE;
694 xevm::AddrSpace addrSpace{xevm::AddrSpace::GLOBAL};
695 switch (op.getMemoryKind()) {
696 case xegpu::MemorySpace::Global:
697 addrSpace = xevm::AddrSpace::GLOBAL;
699 case xegpu::MemorySpace::SLM:
700 addrSpace = xevm::AddrSpace::SHARED;
703 xevm::MemfenceOp::create(rewriter, loc, memScope, addrSpace);
712 matchAndRewrite(xegpu::DpasOp op, xegpu::DpasOp::Adaptor adaptor,
714 auto loc = op.getLoc();
716 auto aTy = cast<VectorType>(op.getLhs().getType());
717 auto bTy = cast<VectorType>(op.getRhs().getType());
718 auto resultType = cast<VectorType>(op.getResultType());
720 auto encodePrecision = [&](
Type type) -> xevm::ElemType {
722 return xevm::ElemType::BF16;
724 return xevm::ElemType::F16;
726 return xevm::ElemType::TF32;
727 else if (type.isInteger(8)) {
728 if (type.isUnsignedInteger())
729 return xevm::ElemType::U8;
730 return xevm::ElemType::S8;
732 return xevm::ElemType::F32;
733 else if (type.isInteger(32))
734 return xevm::ElemType::S32;
735 llvm_unreachable(
"add more support for ElemType");
737 xevm::ElemType precATy = encodePrecision(aTy.getElementType());
738 xevm::ElemType precBTy = encodePrecision(bTy.getElementType());
739 Value c = op.getAcc();
741 auto elementTy = resultType.getElementType();
743 if (isa<FloatType>(elementTy))
747 c = arith::ConstantOp::create(
751 Value aVec = op.getLhs();
752 Value bVec = op.getRhs();
753 auto cvecty = cast<VectorType>(c.
getType());
754 xevm::ElemType precCTy = encodePrecision(cvecty.getElementType());
755 xevm::ElemType precDTy = encodePrecision(resultType.getElementType());
759 c = vector::ShapeCastOp::create(rewriter, loc, cNty, c);
760 Value dpasRes = xevm::MMAOp::create(
761 rewriter, loc, cNty, aVec, bVec, c,
764 getNumOperandsPerDword(precATy)),
767 dpasRes = vector::ShapeCastOp::create(rewriter, loc, resultType, dpasRes);
773 static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
775 case xevm::ElemType::TF32:
777 case xevm::ElemType::BF16:
778 case xevm::ElemType::F16:
780 case xevm::ElemType::U8:
781 case xevm::ElemType::S8:
784 llvm_unreachable(
"unsupported xevm::ElemType");
789 static std::optional<LLVM::AtomicBinOp>
790 matchSimpleAtomicOp(arith::AtomicRMWKind arithKind) {
792 case arith::AtomicRMWKind::addf:
793 return LLVM::AtomicBinOp::fadd;
794 case arith::AtomicRMWKind::addi:
796 case arith::AtomicRMWKind::assign:
797 return LLVM::AtomicBinOp::xchg;
798 case arith::AtomicRMWKind::maximumf:
799 return LLVM::AtomicBinOp::fmax;
800 case arith::AtomicRMWKind::maxs:
802 case arith::AtomicRMWKind::maxu:
803 return LLVM::AtomicBinOp::umax;
804 case arith::AtomicRMWKind::minimumf:
805 return LLVM::AtomicBinOp::fmin;
806 case arith::AtomicRMWKind::mins:
808 case arith::AtomicRMWKind::minu:
809 return LLVM::AtomicBinOp::umin;
810 case arith::AtomicRMWKind::ori:
811 return LLVM::AtomicBinOp::_or;
812 case arith::AtomicRMWKind::andi:
813 return LLVM::AtomicBinOp::_and;
822 matchAndRewrite(xegpu::AtomicRMWOp op, xegpu::AtomicRMWOp::Adaptor adaptor,
824 auto loc = op.getLoc();
826 auto tdesc = op.getTensorDesc().getType();
828 ctxt, getNumericXeVMAddrSpace(tdesc.getMemorySpace()));
829 Value basePtrI64 = arith::IndexCastOp::create(
830 rewriter, loc, rewriter.
getI64Type(), adaptor.getTensorDesc());
832 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
833 VectorType srcOrDstVecTy = cast<VectorType>(op.getValue().getType());
835 srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
836 Value srcFlatVec = vector::ShapeCastOp::create(
837 rewriter, loc, srcOrDstFlatVecTy, op.getValue());
838 auto atomicKind = matchSimpleAtomicOp(op.getKind());
839 assert(atomicKind.has_value());
840 Value resVec = srcFlatVec;
841 for (
int i = 0; i < srcOrDstVecTy.getNumElements(); i++) {
842 auto val = vector::ExtractOp::create(rewriter, loc, resVec, i);
843 Value idx = LLVM::ConstantOp::create(rewriter, loc, rewriter.
getI64Type(),
846 LLVM::GEPOp::create(rewriter, loc, ptrTypeLLVM,
847 srcOrDstVecTy.getElementType(), basePtrLLVM, idx);
849 LLVM::AtomicRMWOp::create(rewriter, loc, atomicKind.value(), currPtr,
850 val, LLVM::AtomicOrdering::seq_cst);
851 resVec = vector::InsertOp::create(rewriter, loc, newVal, resVec, i);
862 struct ConvertXeGPUToXeVMPass
863 :
public impl::ConvertXeGPUToXeVMPassBase<ConvertXeGPUToXeVMPass> {
866 void runOnOperation()
override {
868 typeConverter.addConversion([&](VectorType type) ->
Type {
869 unsigned rank = type.getRank();
870 auto elemType = type.getElementType();
872 if (llvm::isa<IndexType>(elemType))
875 if (rank < 1 || type.getNumElements() == 1)
879 std::accumulate(type.getShape().begin(), type.getShape().end(),
880 int64_t{1}, std::multiplies<int64_t>());
883 typeConverter.addConversion([&](xegpu::TensorDescType type) ->
Type {
884 if (type.isScattered())
889 typeConverter.addConversion([&](MemRefType type) ->
Type {
898 auto memrefMaterializationCast = [](
OpBuilder &builder,
Type type,
901 if (inputs.size() != 1)
903 auto input = inputs.front();
904 if (
auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
907 memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, input);
908 return arith::IndexCastUIOp::create(builder, loc, type, addr)
915 auto ui64MaterializationCast = [](
OpBuilder &builder,
Type type,
918 if (inputs.size() != 1)
920 auto input = inputs.front();
923 index::CastUOp::create(builder, loc, builder.
getIndexType(), input)
925 return arith::IndexCastUIOp::create(builder, loc, type, cast)
932 auto ui32MaterializationCast = [](
OpBuilder &builder,
Type type,
935 if (inputs.size() != 1)
937 auto input = inputs.front();
940 index::CastUOp::create(builder, loc, builder.
getIndexType(), input)
942 return arith::IndexCastUIOp::create(builder, loc, type, cast)
952 auto vectorMaterializationCast = [](
OpBuilder &builder,
Type type,
955 if (inputs.size() != 1)
957 auto input = inputs.front();
958 if (
auto vecTy = dyn_cast<VectorType>(input.getType())) {
959 if (vecTy.getNumElements() == 1) {
962 vector::ExtractOp::create(builder, loc, input, 0).getResult();
964 cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
967 }
else if (
auto targetVecTy = dyn_cast<VectorType>(type)) {
970 if (targetVecTy.getRank() == vecTy.getRank())
971 return vector::BitCastOp::create(builder, loc, targetVecTy, input)
973 else if (targetVecTy.getElementType() == vecTy.getElementType()) {
976 return vector::ShapeCastOp::create(builder, loc, targetVecTy, input)
983 typeConverter.addSourceMaterialization(memrefMaterializationCast);
984 typeConverter.addSourceMaterialization(ui64MaterializationCast);
985 typeConverter.addSourceMaterialization(ui32MaterializationCast);
986 typeConverter.addSourceMaterialization(vectorMaterializationCast);
987 typeConverter.addTargetMaterialization(memrefMaterializationCast);
988 typeConverter.addTargetMaterialization(ui32MaterializationCast);
989 typeConverter.addTargetMaterialization(ui64MaterializationCast);
990 typeConverter.addTargetMaterialization(vectorMaterializationCast);
992 target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
993 vector::VectorDialect, arith::ArithDialect,
994 memref::MemRefDialect, gpu::GPUDialect,
995 index::IndexDialect>();
996 target.addIllegalDialect<xegpu::XeGPUDialect>();
1004 signalPassFailure();
1014 patterns.add<CreateNdDescToXeVMPattern, UpdateNdOffsetToXeVMPattern,
1015 LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
1016 LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
1017 LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>(
1018 typeConverter,
patterns.getContext());
1019 patterns.add<CreateDescToXeVMPattern, UpdateOffsetToXeVMPattern,
1020 AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
1021 LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
1022 LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
1023 typeConverter,
patterns.getContext());
1024 patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Include the generated interface declarations.
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
const FrozenRewritePatternSet & patterns
void populateXeGPUToXeVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.