26 #include "llvm/Support/FormatVariadic.h"
31 #include "llvm/ADT/TypeSwitch.h"
36 #define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
37 #include "mlir/Conversion/Passes.h.inc"
45 static constexpr int32_t systolicDepth{8};
46 static constexpr int32_t executionSize{16};
49 enum class NdTdescOffset : uint32_t {
57 static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
58 switch (xeGpuMemspace) {
59 case xegpu::MemorySpace::Global:
60 return static_cast<int>(xevm::AddrSpace::GLOBAL);
61 case xegpu::MemorySpace::SLM:
62 return static_cast<int>(xevm::AddrSpace::SHARED);
67 static VectorType encodeVectorTypeTo(VectorType currentVecType,
69 auto elemType = currentVecType.getElementType();
70 auto currentBitWidth = elemType.getIntOrFloatBitWidth();
73 currentVecType.getNumElements() * currentBitWidth / newBitWidth;
77 static xevm::LoadCacheControl
78 translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
79 std::optional<xegpu::CachePolicy> L3hint) {
80 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
81 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
83 case xegpu::CachePolicy::CACHED:
84 if (L3hintVal == xegpu::CachePolicy::CACHED)
85 return xevm::LoadCacheControl::L1C_L2UC_L3C;
86 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
87 return xevm::LoadCacheControl::L1C_L2UC_L3UC;
89 llvm_unreachable(
"Unsupported cache control.");
90 case xegpu::CachePolicy::UNCACHED:
91 if (L3hintVal == xegpu::CachePolicy::CACHED)
92 return xevm::LoadCacheControl::L1UC_L2UC_L3C;
93 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
94 return xevm::LoadCacheControl::L1UC_L2UC_L3UC;
96 llvm_unreachable(
"Unsupported cache control.");
97 case xegpu::CachePolicy::STREAMING:
98 if (L3hintVal == xegpu::CachePolicy::CACHED)
99 return xevm::LoadCacheControl::L1S_L2UC_L3C;
100 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
101 return xevm::LoadCacheControl::L1S_L2UC_L3UC;
103 llvm_unreachable(
"Unsupported cache control.");
104 case xegpu::CachePolicy::READ_INVALIDATE:
105 return xevm::LoadCacheControl::INVALIDATE_READ;
107 llvm_unreachable(
"Unsupported cache control.");
111 static xevm::StoreCacheControl
112 translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
113 std::optional<xegpu::CachePolicy> L3hint) {
114 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
115 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
117 case xegpu::CachePolicy::UNCACHED:
118 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
119 return xevm::StoreCacheControl::L1UC_L2UC_L3UC;
120 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
121 return xevm::StoreCacheControl::L1UC_L2UC_L3WB;
123 llvm_unreachable(
"Unsupported cache control.");
124 case xegpu::CachePolicy::STREAMING:
125 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
126 return xevm::StoreCacheControl::L1S_L2UC_L3UC;
127 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
128 return xevm::StoreCacheControl::L1S_L2UC_L3WB;
130 llvm_unreachable(
"Unsupported cache control.");
131 case xegpu::CachePolicy::WRITE_BACK:
132 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
133 return xevm::StoreCacheControl::L1WB_L2UC_L3UC;
134 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
135 return xevm::StoreCacheControl::L1WB_L2UC_L3WB;
137 llvm_unreachable(
"Unsupported cache control.");
138 case xegpu::CachePolicy::WRITE_THROUGH:
139 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
140 return xevm::StoreCacheControl::L1WT_L2UC_L3UC;
141 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
142 return xevm::StoreCacheControl::L1WT_L2UC_L3WB;
144 llvm_unreachable(
"Unsupported cache control.");
146 llvm_unreachable(
"Unsupported cache control.");
150 class CreateNdDescToXeVMPattern
154 matchAndRewrite(xegpu::CreateNdDescOp op,
155 xegpu::CreateNdDescOp::Adaptor adaptor,
158 if (mixedOffsets.size() != 0)
160 auto loc = op.getLoc();
161 auto source = op.getSource();
171 Value payload = arith::ConstantOp::create(
184 int64_t rank = mixedSizes.size();
187 auto sourceTy = source.getType();
188 auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
191 if (sourceMemrefTy) {
192 if (!sourceMemrefTy.hasStaticShape()) {
196 memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
198 baseAddr = adaptor.getSource();
202 unsigned idx) ->
Value {
211 baseShapeW = createOffset(mixedSizes, 1);
212 baseShapeH = createOffset(mixedSizes, 0);
213 if (sourceMemrefTy) {
215 baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
216 }
else if (baseAddr.
getType() != i64Ty) {
218 baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
222 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
224 vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
225 static_cast<int>(NdTdescOffset::BasePtr));
226 payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
228 vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
229 static_cast<int>(NdTdescOffset::BaseShapeW));
231 vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
232 static_cast<int>(NdTdescOffset::BaseShapeH));
233 payload = vector::InsertOp::create(
234 rewriter, loc, offsetW, payload,
235 static_cast<int>(NdTdescOffset::TensorOffsetW));
236 payload = vector::InsertOp::create(
237 rewriter, loc, offsetH, payload,
238 static_cast<int>(NdTdescOffset::TensorOffsetH));
246 typename = std::enable_if_t<llvm::is_one_of<
247 OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp>::value>>
251 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
253 auto mixedOffsets = op.getMixedOffsets();
254 int64_t opOffsetsSize = mixedOffsets.size();
255 if (opOffsetsSize != 2)
257 auto loc = op.getLoc();
260 auto tdesc = adaptor.getTensorDesc();
261 auto tdescTy = op.getTensorDescType();
262 if (tdescTy.getRank() != 2)
264 auto elemType = tdescTy.getElementType();
265 auto elemBitSize = elemType.getIntOrFloatBitWidth();
266 if (elemBitSize % 8 != 0)
268 op,
"Expected element type bit width to be multiple of 8.");
272 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
273 Value basePtr = vector::ExtractOp::create(
274 rewriter, loc, payLoadAsI64,
static_cast<int>(NdTdescOffset::BasePtr));
275 Value baseShapeW = vector::ExtractOp::create(
276 rewriter, loc, tdesc,
static_cast<int>(NdTdescOffset::BaseShapeW));
277 Value baseShapeH = vector::ExtractOp::create(
278 rewriter, loc, tdesc,
static_cast<int>(NdTdescOffset::BaseShapeH));
291 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
294 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
297 rewriter, loc, rewriter.
getI32Type(), elemBitSize / 8);
299 arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
302 auto tileW = tdescTy.getDimSize(1);
303 auto tileH = tdescTy.getDimSize(0);
304 int32_t vblocks = tdescTy.getArrayLength();
305 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
306 Value src = adaptor.getValue();
312 VectorType srcVecTy = dyn_cast<VectorType>(src.
getType());
315 op,
"Expected store value to be a vector type.");
317 VectorType newSrcVecTy =
318 encodeVectorTypeTo(srcVecTy, rewriter.
getIntegerType(elemBitSize));
319 if (srcVecTy != newSrcVecTy)
320 src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
321 auto storeCacheControl =
322 translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
323 xevm::BlockStore2dOp::create(
324 rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
325 offsetH, elemBitSize, tileW, tileH, src,
329 auto loadCacheControl =
330 translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
331 if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
332 xevm::BlockPrefetch2dOp::create(
333 rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
334 offsetH, elemBitSize, tileW, tileH, vblocks,
338 VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
339 const bool vnni = op.getPacked().value_or(
false);
340 auto transposeValue = op.getTranspose();
342 transposeValue.has_value() && transposeValue.value()[0] == 1;
343 VectorType loadedTy = encodeVectorTypeTo(
347 Value resultFlatVec = xevm::BlockLoad2dOp::create(
348 rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
349 surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
352 resultFlatVec = vector::BitCastOp::create(
354 encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
366 Value baseAddr,
Value offset, int64_t elemByteSize) {
368 rewriter, loc, rewriter.
getI64Type(), elemByteSize);
369 Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
370 Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
374 template <
typename OpType,
375 typename = std::enable_if_t<llvm::is_one_of<
376 OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
380 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
382 Value offset = adaptor.getOffsets();
385 auto loc = op.getLoc();
387 auto tdescTy = op.getTensorDescType();
391 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
392 valOrResTy = op.getResult().getType();
394 valOrResTy = adaptor.getValue().getType();
395 VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
396 bool hasScalarVal = !valOrResVecTy;
397 int64_t elemBitWidth =
399 : valOrResVecTy.getElementType().getIntOrFloatBitWidth();
401 if (elemBitWidth % 8 != 0)
403 op,
"Expected element type bit width to be multiple of 8.");
404 int64_t elemByteSize = elemBitWidth / 8;
407 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
411 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
414 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
415 basePtrI64 = adaptor.getSource();
416 if (
auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
417 auto addrSpace = memRefTy.getMemorySpaceAsInt();
422 basePtrI64 = adaptor.getDest();
423 if (
auto memRefTy = dyn_cast<MemRefType>(op.getDest().getType())) {
424 auto addrSpace = memRefTy.getMemorySpaceAsInt();
431 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.
getI64Type(),
434 Value mask = adaptor.getMask();
435 if (dyn_cast<VectorType>(offset.
getType())) {
443 basePtrI64 = addOffset(rewriter, loc, basePtrI64, offset, elemByteSize);
447 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
450 VectorType maskVecTy = dyn_cast<VectorType>(mask.
getType());
457 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
458 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy},
459 maskForLane,
true,
true);
464 valOrResVecTy.getElementType());
466 LLVM::LoadOp::create(rewriter, loc, valOrResTy, basePtrLLVM);
470 ctxt, translateLoadXeGPUCacheHint(
471 op.getL1Hint(), op.getL3Hint())));
472 scf::YieldOp::create(rewriter, loc,
ValueRange{loaded});
475 auto eTy = hasScalarVal ? valOrResTy : valOrResVecTy.getElementType();
482 loaded = arith::ConstantOp::create(rewriter, loc, eVal);
484 loaded = arith::ConstantOp::create(
486 scf::YieldOp::create(rewriter, loc,
ValueRange{loaded});
487 rewriter.
replaceOp(op, ifOp.getResult(0));
490 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane,
false);
491 auto body = ifOp.getBody();
494 LLVM::StoreOp::create(rewriter, loc, adaptor.getValue(), basePtrLLVM);
496 storeOp.getOperation()->setAttr(
498 ctxt, translateStoreXeGPUCacheHint(
499 op.getL1Hint(), op.getL3Hint())));
509 matchAndRewrite(xegpu::PrefetchOp op, xegpu::PrefetchOp::Adaptor adaptor,
511 auto loc = op.getLoc();
513 auto tdescTy = op.getTensorDescType();
514 Value basePtrI64 = adaptor.getSource();
517 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.
getI64Type(),
519 Value offsets = adaptor.getOffsets();
521 VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.
getType());
525 "Expected offsets to be a scalar.");
527 int64_t elemBitWidth{0};
528 int64_t elemByteSize;
533 elemBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth();
534 }
else if (
auto memRefTy = dyn_cast<MemRefType>(op.getSourceType())) {
537 elemBitWidth = memRefTy.getElementType().getIntOrFloatBitWidth();
540 elemByteSize = *op.getOffsetAlignByte();
542 if (elemBitWidth != 0) {
543 if (elemBitWidth % 8 != 0)
545 op,
"Expected element type bit width to be multiple of 8.");
546 elemByteSize = elemBitWidth / 8;
549 addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
554 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
558 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
560 if (
auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
561 auto addrSpace = memRefTy.getMemorySpaceAsInt();
567 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
569 xevm::PrefetchOp::create(
570 rewriter, loc, ptrLLVM,
572 ctxt, translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint())));
581 matchAndRewrite(xegpu::FenceOp op, xegpu::FenceOp::Adaptor adaptor,
583 auto loc = op.getLoc();
584 xevm::MemScope memScope{xevm::MemScope::WORKGROUP};
585 switch (op.getFenceScope()) {
586 case xegpu::FenceScope::Workgroup:
587 memScope = xevm::MemScope::WORKGROUP;
589 case xegpu::FenceScope::GPU:
590 memScope = xevm::MemScope::DEVICE;
593 xevm::AddrSpace addrSpace{xevm::AddrSpace::GLOBAL};
594 switch (op.getMemoryKind()) {
595 case xegpu::MemorySpace::Global:
596 addrSpace = xevm::AddrSpace::GLOBAL;
598 case xegpu::MemorySpace::SLM:
599 addrSpace = xevm::AddrSpace::SHARED;
602 xevm::MemfenceOp::create(rewriter, loc, memScope, addrSpace);
611 matchAndRewrite(xegpu::DpasOp op, xegpu::DpasOp::Adaptor adaptor,
613 auto loc = op.getLoc();
615 auto aTy = cast<VectorType>(op.getLhs().getType());
616 auto bTy = cast<VectorType>(op.getRhs().getType());
617 auto resultType = cast<VectorType>(op.getResultType());
619 auto encodePrecision = [&](
Type type) -> xevm::ElemType {
621 return xevm::ElemType::BF16;
623 return xevm::ElemType::F16;
625 return xevm::ElemType::TF32;
626 else if (type.isInteger(8)) {
627 if (type.isUnsignedInteger())
628 return xevm::ElemType::U8;
629 return xevm::ElemType::S8;
631 return xevm::ElemType::F32;
632 else if (type.isInteger(32))
633 return xevm::ElemType::S32;
634 llvm_unreachable(
"add more support for ElemType");
636 xevm::ElemType precATy = encodePrecision(aTy.getElementType());
637 xevm::ElemType precBTy = encodePrecision(bTy.getElementType());
638 Value c = op.getAcc();
640 auto elementTy = resultType.getElementType();
642 if (isa<FloatType>(elementTy))
646 c = arith::ConstantOp::create(
650 Value aVec = op.getLhs();
651 Value bVec = op.getRhs();
652 auto cvecty = cast<VectorType>(c.
getType());
653 xevm::ElemType precCTy = encodePrecision(cvecty.getElementType());
654 xevm::ElemType precDTy = encodePrecision(resultType.getElementType());
658 c = vector::ShapeCastOp::create(rewriter, loc, cNty, c);
659 Value dpasRes = xevm::MMAOp::create(
660 rewriter, loc, cNty, aVec, bVec, c,
663 getNumOperandsPerDword(precATy)),
666 dpasRes = vector::ShapeCastOp::create(rewriter, loc, resultType, dpasRes);
672 static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
674 case xevm::ElemType::TF32:
676 case xevm::ElemType::BF16:
677 case xevm::ElemType::F16:
679 case xevm::ElemType::U8:
680 case xevm::ElemType::S8:
683 llvm_unreachable(
"unsupported xevm::ElemType");
688 static std::optional<LLVM::AtomicBinOp>
689 matchSimpleAtomicOp(arith::AtomicRMWKind arithKind) {
691 case arith::AtomicRMWKind::addf:
692 return LLVM::AtomicBinOp::fadd;
693 case arith::AtomicRMWKind::addi:
695 case arith::AtomicRMWKind::assign:
696 return LLVM::AtomicBinOp::xchg;
697 case arith::AtomicRMWKind::maximumf:
698 return LLVM::AtomicBinOp::fmax;
699 case arith::AtomicRMWKind::maxs:
701 case arith::AtomicRMWKind::maxu:
702 return LLVM::AtomicBinOp::umax;
703 case arith::AtomicRMWKind::minimumf:
704 return LLVM::AtomicBinOp::fmin;
705 case arith::AtomicRMWKind::mins:
707 case arith::AtomicRMWKind::minu:
708 return LLVM::AtomicBinOp::umin;
709 case arith::AtomicRMWKind::ori:
710 return LLVM::AtomicBinOp::_or;
711 case arith::AtomicRMWKind::andi:
712 return LLVM::AtomicBinOp::_and;
721 matchAndRewrite(xegpu::AtomicRMWOp op, xegpu::AtomicRMWOp::Adaptor adaptor,
723 auto loc = op.getLoc();
725 auto tdesc = op.getTensorDesc().getType();
727 ctxt, getNumericXeVMAddrSpace(tdesc.getMemorySpace()));
728 Value basePtrI64 = arith::IndexCastOp::create(
729 rewriter, loc, rewriter.
getI64Type(), adaptor.getTensorDesc());
731 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
732 VectorType srcOrDstVecTy = cast<VectorType>(op.getValue().getType());
734 srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
735 Value srcFlatVec = vector::ShapeCastOp::create(
736 rewriter, loc, srcOrDstFlatVecTy, op.getValue());
737 auto atomicKind = matchSimpleAtomicOp(op.getKind());
738 assert(atomicKind.has_value());
739 Value resVec = srcFlatVec;
740 for (
int i = 0; i < srcOrDstVecTy.getNumElements(); i++) {
741 auto val = vector::ExtractOp::create(rewriter, loc, resVec, i);
742 Value idx = LLVM::ConstantOp::create(rewriter, loc, rewriter.
getI64Type(),
745 LLVM::GEPOp::create(rewriter, loc, ptrTypeLLVM,
746 srcOrDstVecTy.getElementType(), basePtrLLVM, idx);
748 LLVM::AtomicRMWOp::create(rewriter, loc, atomicKind.value(), currPtr,
749 val, LLVM::AtomicOrdering::seq_cst);
750 resVec = vector::InsertOp::create(rewriter, loc, newVal, resVec, i);
761 struct ConvertXeGPUToXeVMPass
762 :
public impl::ConvertXeGPUToXeVMPassBase<ConvertXeGPUToXeVMPass> {
765 void runOnOperation()
override {
767 typeConverter.addConversion([&](VectorType type) ->
Type {
768 unsigned rank = type.getRank();
769 auto elemType = type.getElementType();
771 if (llvm::isa<IndexType>(elemType))
774 if (rank < 1 || type.getNumElements() == 1)
778 std::accumulate(type.getShape().begin(), type.getShape().end(),
779 int64_t{1}, std::multiplies<int64_t>());
782 typeConverter.addConversion([&](xegpu::TensorDescType type) ->
Type {
783 if (type.isScattered())
788 typeConverter.addConversion([&](MemRefType type) ->
Type {
797 auto memrefMaterializationCast = [](
OpBuilder &builder,
Type type,
800 if (inputs.size() != 1)
802 auto input = inputs.front();
803 if (
auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
806 memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, input);
807 return arith::IndexCastUIOp::create(builder, loc, type, addr)
814 auto ui64MaterializationCast = [](
OpBuilder &builder,
Type type,
817 if (inputs.size() != 1)
819 auto input = inputs.front();
822 index::CastUOp::create(builder, loc, builder.
getIndexType(), input)
824 return arith::IndexCastUIOp::create(builder, loc, type, cast)
831 auto ui32MaterializationCast = [](
OpBuilder &builder,
Type type,
834 if (inputs.size() != 1)
836 auto input = inputs.front();
839 index::CastUOp::create(builder, loc, builder.
getIndexType(), input)
841 return arith::IndexCastUIOp::create(builder, loc, type, cast)
851 auto vectorMaterializationCast = [](
OpBuilder &builder,
Type type,
854 if (inputs.size() != 1)
856 auto input = inputs.front();
857 if (
auto vecTy = dyn_cast<VectorType>(input.getType())) {
858 if (vecTy.getNumElements() == 1) {
861 vector::ExtractOp::create(builder, loc, input, 0).getResult();
863 cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
866 }
else if (
auto targetVecTy = dyn_cast<VectorType>(type)) {
869 if (targetVecTy.getRank() == vecTy.getRank())
870 return vector::BitCastOp::create(builder, loc, targetVecTy, input)
872 else if (targetVecTy.getElementType() == vecTy.getElementType()) {
875 return vector::ShapeCastOp::create(builder, loc, targetVecTy, input)
882 typeConverter.addSourceMaterialization(memrefMaterializationCast);
883 typeConverter.addSourceMaterialization(ui64MaterializationCast);
884 typeConverter.addSourceMaterialization(ui32MaterializationCast);
885 typeConverter.addSourceMaterialization(vectorMaterializationCast);
886 typeConverter.addTargetMaterialization(memrefMaterializationCast);
887 typeConverter.addTargetMaterialization(ui32MaterializationCast);
888 typeConverter.addTargetMaterialization(ui64MaterializationCast);
889 typeConverter.addTargetMaterialization(vectorMaterializationCast);
891 target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
892 vector::VectorDialect, arith::ArithDialect,
893 memref::MemRefDialect, gpu::GPUDialect,
894 index::IndexDialect>();
895 target.addIllegalDialect<xegpu::XeGPUDialect>();
913 patterns.add<CreateNdDescToXeVMPattern,
914 LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
915 LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
916 LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>(
917 typeConverter,
patterns.getContext());
918 patterns.add<AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
919 LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
920 LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
921 typeConverter,
patterns.getContext());
922 patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
This class helps build Operations.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Include the generated interface declarations.
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
const FrozenRewritePatternSet & patterns
void populateXeGPUToXeVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.