28#include "llvm/ADT/STLExtras.h"
29#include "llvm/Support/FormatVariadic.h"
34#include "llvm/ADT/TypeSwitch.h"
39#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
40#include "mlir/Conversion/Passes.h.inc"
48static constexpr int32_t systolicDepth{8};
49static constexpr int32_t executionSize{16};
52enum class NdTdescOffset : uint32_t {
60static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
61 switch (xeGpuMemspace) {
62 case xegpu::MemorySpace::Global:
63 return static_cast<int>(xevm::AddrSpace::GLOBAL);
64 case xegpu::MemorySpace::SLM:
65 return static_cast<int>(xevm::AddrSpace::SHARED);
67 llvm_unreachable(
"Unknown XeGPU memory space");
71static VectorType encodeVectorTypeTo(VectorType currentVecType,
73 auto elemType = currentVecType.getElementType();
74 auto currentBitWidth = elemType.getIntOrFloatBitWidth();
77 currentVecType.getNumElements() * currentBitWidth / newBitWidth;
78 return VectorType::get(size, toElemType);
81static xevm::LoadCacheControl
82translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
83 std::optional<xegpu::CachePolicy> L3hint) {
84 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
85 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
87 case xegpu::CachePolicy::CACHED:
88 if (L3hintVal == xegpu::CachePolicy::CACHED)
89 return xevm::LoadCacheControl::L1C_L2UC_L3C;
90 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
91 return xevm::LoadCacheControl::L1C_L2UC_L3UC;
93 llvm_unreachable(
"Unsupported cache control.");
94 case xegpu::CachePolicy::UNCACHED:
95 if (L3hintVal == xegpu::CachePolicy::CACHED)
96 return xevm::LoadCacheControl::L1UC_L2UC_L3C;
97 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
98 return xevm::LoadCacheControl::L1UC_L2UC_L3UC;
100 llvm_unreachable(
"Unsupported cache control.");
101 case xegpu::CachePolicy::STREAMING:
102 if (L3hintVal == xegpu::CachePolicy::CACHED)
103 return xevm::LoadCacheControl::L1S_L2UC_L3C;
104 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
105 return xevm::LoadCacheControl::L1S_L2UC_L3UC;
107 llvm_unreachable(
"Unsupported cache control.");
108 case xegpu::CachePolicy::READ_INVALIDATE:
109 return xevm::LoadCacheControl::INVALIDATE_READ;
111 llvm_unreachable(
"Unsupported cache control.");
115static xevm::StoreCacheControl
116translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
117 std::optional<xegpu::CachePolicy> L3hint) {
118 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
119 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
121 case xegpu::CachePolicy::UNCACHED:
122 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
123 return xevm::StoreCacheControl::L1UC_L2UC_L3UC;
124 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
125 return xevm::StoreCacheControl::L1UC_L2UC_L3WB;
127 llvm_unreachable(
"Unsupported cache control.");
128 case xegpu::CachePolicy::STREAMING:
129 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
130 return xevm::StoreCacheControl::L1S_L2UC_L3UC;
131 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
132 return xevm::StoreCacheControl::L1S_L2UC_L3WB;
134 llvm_unreachable(
"Unsupported cache control.");
135 case xegpu::CachePolicy::WRITE_BACK:
136 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
137 return xevm::StoreCacheControl::L1WB_L2UC_L3UC;
138 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
139 return xevm::StoreCacheControl::L1WB_L2UC_L3WB;
141 llvm_unreachable(
"Unsupported cache control.");
142 case xegpu::CachePolicy::WRITE_THROUGH:
143 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
144 return xevm::StoreCacheControl::L1WT_L2UC_L3UC;
145 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
146 return xevm::StoreCacheControl::L1WT_L2UC_L3WB;
148 llvm_unreachable(
"Unsupported cache control.");
150 llvm_unreachable(
"Unsupported cache control.");
154class CreateNdDescToXeVMPattern
155 :
public OpConversionPattern<xegpu::CreateNdDescOp> {
156 using OpConversionPattern::OpConversionPattern;
158 matchAndRewrite(xegpu::CreateNdDescOp op,
159 xegpu::CreateNdDescOp::Adaptor adaptor,
160 ConversionPatternRewriter &rewriter)
const override {
161 SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
162 if (mixedOffsets.size() != 0)
163 return rewriter.notifyMatchFailure(op,
"Offsets not supported.");
164 auto loc = op.getLoc();
165 auto source = op.getSource();
169 Type payloadElemTy = rewriter.getI32Type();
170 VectorType payloadTy = VectorType::get(8, payloadElemTy);
171 Type i64Ty = rewriter.getI64Type();
173 VectorType payloadI64Ty = VectorType::get(4, i64Ty);
175 Value payload = arith::ConstantOp::create(
186 SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
188 int64_t rank = mixedSizes.size();
189 auto sourceTy = source.getType();
190 auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
193 if (sourceMemrefTy) {
194 if (!sourceMemrefTy.hasRank()) {
195 return rewriter.notifyMatchFailure(op,
"Expected ranked Memref.");
199 baseAddr = adaptor.getSource();
201 baseAddr = adaptor.getSource();
202 if (baseAddr.
getType() != i64Ty) {
204 baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
209 rewriter.replaceOp(op, baseAddr);
213 auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
214 unsigned idx) -> Value {
223 baseShapeW = createOffset(mixedSizes, 1);
224 baseShapeH = createOffset(mixedSizes, 0);
227 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
229 vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
230 static_cast<int>(NdTdescOffset::BasePtr));
231 payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
233 vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
234 static_cast<int>(NdTdescOffset::BaseShapeW));
236 vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
237 static_cast<int>(NdTdescOffset::BaseShapeH));
238 payload = vector::InsertOp::create(
239 rewriter, loc, offsetW, payload,
240 static_cast<int>(NdTdescOffset::TensorOffsetW));
241 payload = vector::InsertOp::create(
242 rewriter, loc, offsetH, payload,
243 static_cast<int>(NdTdescOffset::TensorOffsetH));
244 rewriter.replaceOp(op, payload);
251 typename = std::enable_if_t<llvm::is_one_of<
252 OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp>::value>>
253class LoadStorePrefetchNdToXeVMPattern :
public OpConversionPattern<OpType> {
254 using OpConversionPattern<OpType>::OpConversionPattern;
256 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
257 ConversionPatternRewriter &rewriter)
const override {
258 auto mixedOffsets = op.getMixedOffsets();
259 int64_t opOffsetsSize = mixedOffsets.size();
260 auto loc = op.getLoc();
261 auto ctxt = rewriter.getContext();
263 auto tdesc = adaptor.getTensorDesc();
264 auto tdescTy = op.getTensorDescType();
265 auto tileRank = tdescTy.getRank();
266 if (opOffsetsSize != tileRank)
267 return rewriter.notifyMatchFailure(
268 op,
"Expected offset rank to match descriptor rank.");
269 auto elemType = tdescTy.getElementType();
270 auto elemBitSize = elemType.getIntOrFloatBitWidth();
271 if (elemBitSize % 8 != 0)
272 return rewriter.notifyMatchFailure(
273 op,
"Expected element type bit width to be multiple of 8.");
276 auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
277 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
281 rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
282 VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
284 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
286 vector::ExtractOp::create(rewriter, loc, payLoadAsI64,
287 static_cast<int>(NdTdescOffset::BasePtr));
288 Value baseShapeW = vector::ExtractOp::create(
289 rewriter, loc, tdesc,
static_cast<int>(NdTdescOffset::BaseShapeW));
290 Value baseShapeH = vector::ExtractOp::create(
291 rewriter, loc, tdesc,
static_cast<int>(NdTdescOffset::BaseShapeH));
297 rewriter.getI32Type(), offsetW);
301 rewriter.getI32Type(), offsetH);
304 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
307 arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
310 auto tileW = tdescTy.getDimSize(tileRank - 1);
312 auto tileH = tdescTy.getDimSize(0);
314 int32_t vblocks = tdescTy.getArrayLength();
315 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
316 Value src = adaptor.getValue();
322 VectorType srcVecTy = dyn_cast<VectorType>(src.
getType());
324 return rewriter.notifyMatchFailure(
325 op,
"Expected store value to be a vector type.");
327 VectorType newSrcVecTy =
328 encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
329 if (srcVecTy != newSrcVecTy)
330 src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
331 auto storeCacheControl =
332 translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
333 xevm::BlockStore2dOp::create(
334 rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
335 offsetH, elemBitSize, tileW, tileH, src,
336 xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
337 rewriter.eraseOp(op);
339 auto loadCacheControl =
340 translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
341 if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
342 xevm::BlockPrefetch2dOp::create(
343 rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW,
344 offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
345 xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
346 rewriter.eraseOp(op);
348 VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
349 const bool vnni = op.getPacked().value_or(
false);
350 auto transposeValue = op.getTranspose();
352 transposeValue.has_value() && transposeValue.value()[0] == 1;
353 VectorType loadedTy = encodeVectorTypeTo(
354 dstVecTy, vnni ? rewriter.getI32Type()
355 : rewriter.getIntegerType(elemBitSize));
357 Value resultFlatVec = xevm::BlockLoad2dOp::create(
358 rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
359 surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
361 xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
362 resultFlatVec = vector::BitCastOp::create(
364 encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
366 rewriter.replaceOp(op, resultFlatVec);
378 rewriter.getI64Type(), offset);
381 rewriter, loc, rewriter.getI64Type(), elemBitSize / 8);
383 rewriter.createOrFold<arith::MulIOp>(loc, offset, elemByteSize);
385 Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>(
391 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64);
392 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
393 Value src = adaptor.getValue();
399 VectorType srcVecTy = dyn_cast<VectorType>(src.
getType());
401 return rewriter.notifyMatchFailure(
402 op,
"Expected store value to be a vector type.");
404 VectorType newSrcVecTy =
405 encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
406 if (srcVecTy != newSrcVecTy)
407 src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
408 auto storeCacheControl =
409 translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
410 rewriter.replaceOpWithNewOp<xevm::BlockStoreOp>(
411 op, finalPtrLLVM, src,
412 xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
413 }
else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
414 auto loadCacheControl =
415 translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
416 VectorType resTy = cast<VectorType>(op.getValue().getType());
417 VectorType loadedTy =
418 encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize));
419 Value
load = xevm::BlockLoadOp::create(
420 rewriter, loc, loadedTy, finalPtrLLVM,
421 xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
422 if (loadedTy != resTy)
423 load = vector::BitCastOp::create(rewriter, loc, resTy,
load);
424 rewriter.replaceOp(op,
load);
426 return rewriter.notifyMatchFailure(
427 op,
"Unsupported operation: xegpu.prefetch_nd with tensor "
428 "descriptor rank == 1");
437static Value addOffsetToBaseAddr(ConversionPatternRewriter &rewriter,
441 rewriter, loc, baseAddr.
getType(), elemByteSize);
442 Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
443 Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
447template <
typename OpType,
448 typename = std::enable_if_t<llvm::is_one_of<
449 OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
450class LoadStoreToXeVMPattern :
public OpConversionPattern<OpType> {
451 using OpConversionPattern<OpType>::OpConversionPattern;
453 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
454 ConversionPatternRewriter &rewriter)
const override {
455 Value offset = adaptor.getOffsets();
457 return rewriter.notifyMatchFailure(op,
"Expected offset to be provided.");
458 auto loc = op.getLoc();
459 auto ctxt = rewriter.getContext();
460 auto tdescTy = op.getTensorDescType();
464 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
466 this->getTypeConverter()->convertType(op.getResult().getType());
468 valOrResTy = adaptor.getValue().getType();
469 VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
470 bool hasScalarVal = !valOrResVecTy;
471 int64_t elemBitWidth =
473 : valOrResVecTy.getElementType().getIntOrFloatBitWidth();
475 if (elemBitWidth % 8 != 0)
476 return rewriter.notifyMatchFailure(
477 op,
"Expected element type bit width to be multiple of 8.");
478 int64_t elemByteSize = elemBitWidth / 8;
480 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
481 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
484 ptrTypeLLVM = LLVM::LLVMPointerType::get(
485 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
488 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
489 basePtrI64 = adaptor.getSource();
490 if (
auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
491 auto addrSpace = memRefTy.getMemorySpaceAsInt();
493 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
496 basePtrI64 = adaptor.getDest();
497 if (
auto memRefTy = dyn_cast<MemRefType>(op.getDest().getType())) {
498 auto addrSpace = memRefTy.getMemorySpaceAsInt();
500 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
504 if (basePtrI64.
getType() != rewriter.getI64Type()) {
505 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
508 Value mask = adaptor.getMask();
509 if (dyn_cast<VectorType>(offset.
getType())) {
512 return rewriter.notifyMatchFailure(op,
"Expected offset to be a scalar.");
518 addOffsetToBaseAddr(rewriter, loc, basePtrI64, offset, elemByteSize);
522 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
525 VectorType maskVecTy = dyn_cast<VectorType>(mask.
getType());
529 return rewriter.notifyMatchFailure(op,
"Expected mask to be a scalar.");
532 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
533 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy},
534 maskForLane,
true,
true);
536 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
538 valOrResTy = VectorType::get({valOrResVecTy.getNumElements()},
539 valOrResVecTy.getElementType());
541 LLVM::LoadOp::create(rewriter, loc, valOrResTy, basePtrLLVM);
544 "cache_control", xevm::LoadCacheControlAttr::get(
545 ctxt, translateLoadXeGPUCacheHint(
546 op.getL1Hint(), op.getL3Hint())));
547 scf::YieldOp::create(rewriter, loc,
ValueRange{loaded});
548 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
550 auto eTy = hasScalarVal ? valOrResTy : valOrResVecTy.getElementType();
553 eVal = FloatAttr::get(eTy, 0.0);
555 eVal = IntegerAttr::get(eTy, 0);
557 loaded = arith::ConstantOp::create(rewriter, loc, eVal);
559 loaded = arith::ConstantOp::create(
561 scf::YieldOp::create(rewriter, loc,
ValueRange{loaded});
562 rewriter.replaceOp(op, ifOp.getResult(0));
565 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane,
false);
566 auto body = ifOp.getBody();
567 rewriter.setInsertionPointToStart(body);
569 LLVM::StoreOp::create(rewriter, loc, adaptor.getValue(), basePtrLLVM);
571 storeOp.getOperation()->setAttr(
572 "cache_control", xevm::StoreCacheControlAttr::get(
573 ctxt, translateStoreXeGPUCacheHint(
574 op.getL1Hint(), op.getL3Hint())));
575 rewriter.eraseOp(op);
581class CreateMemDescOpPattern final
582 :
public OpConversionPattern<xegpu::CreateMemDescOp> {
584 using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern;
586 matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
587 ConversionPatternRewriter &rewriter)
const override {
589 rewriter.replaceOp(op, adaptor.getSource());
594template <
typename OpType,
595 typename = std::enable_if_t<llvm::is_one_of<
596 OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
597class LoadStoreMatrixToXeVMPattern :
public OpConversionPattern<OpType> {
598 using OpConversionPattern<OpType>::OpConversionPattern;
600 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
601 ConversionPatternRewriter &rewriter)
const override {
603 SmallVector<OpFoldResult> offsets = op.getMixedOffsets();
605 return rewriter.notifyMatchFailure(op,
"Expected offset to be provided.");
607 auto loc = op.getLoc();
608 auto ctxt = rewriter.getContext();
609 Value baseAddr32 = adaptor.getMemDesc();
610 Value mdescVal = op.getMemDesc();
613 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>)
614 data = op.getResult();
616 data = adaptor.getData();
617 VectorType valOrResVecTy = dyn_cast<VectorType>(data.
getType());
619 valOrResVecTy = VectorType::get(1, data.
getType());
620 if (valOrResVecTy.getShape().size() != 1)
621 return rewriter.notifyMatchFailure(op,
"Expected 1D data vector.");
623 int64_t elemBitWidth =
624 valOrResVecTy.getElementType().getIntOrFloatBitWidth();
626 if (elemBitWidth % 8 != 0)
627 return rewriter.notifyMatchFailure(
628 op,
"Expected element type bit width to be multiple of 8.");
629 int64_t elemByteSize = elemBitWidth / 8;
632 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
633 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM));
635 auto mdescTy = cast<xegpu::MemDescType>(mdescVal.
getType());
637 Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
638 linearOffset = arith::IndexCastUIOp::create(
639 rewriter, loc, rewriter.getI32Type(), linearOffset);
640 Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32,
641 linearOffset, elemByteSize);
645 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);
647 if (op.getSubgroupBlockIoAttr()) {
651 Type intElemTy = rewriter.getIntegerType(elemBitWidth);
652 VectorType intVecTy =
653 VectorType::get(valOrResVecTy.getShape(), intElemTy);
655 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
657 xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM);
658 if (intVecTy != valOrResVecTy) {
660 vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp);
662 rewriter.replaceOp(op, loadOp);
664 Value dataToStore = adaptor.getData();
665 if (valOrResVecTy != intVecTy) {
667 vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore);
669 xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore,
671 rewriter.eraseOp(op);
676 if (valOrResVecTy.getNumElements() >= 1) {
678 if (!chipOpt || (*chipOpt !=
"pvc" && *chipOpt !=
"bmg")) {
680 return rewriter.notifyMatchFailure(
681 op,
"The lowering is specific to pvc or bmg.");
685 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
689 auto scalarTy = valOrResVecTy.getElementType();
691 if (valOrResVecTy.getNumElements() == 1)
692 loadOp = LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
695 LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
696 rewriter.replaceOp(op, loadOp);
698 LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
699 rewriter.eraseOp(op);
705class PrefetchToXeVMPattern :
public OpConversionPattern<xegpu::PrefetchOp> {
706 using OpConversionPattern::OpConversionPattern;
708 matchAndRewrite(xegpu::PrefetchOp op, xegpu::PrefetchOp::Adaptor adaptor,
709 ConversionPatternRewriter &rewriter)
const override {
710 auto loc = op.getLoc();
711 auto ctxt = rewriter.getContext();
712 auto tdescTy = op.getTensorDescType();
713 Value basePtrI64 = adaptor.getSource();
715 if (basePtrI64.
getType() != rewriter.getI64Type())
716 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
718 Value offsets = adaptor.getOffsets();
720 VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.
getType());
723 return rewriter.notifyMatchFailure(op,
724 "Expected offsets to be a scalar.");
726 int64_t elemBitWidth{0};
727 int64_t elemByteSize;
732 elemBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth();
733 }
else if (
auto memRefTy = dyn_cast<MemRefType>(op.getSourceType())) {
736 elemBitWidth = memRefTy.getElementType().getIntOrFloatBitWidth();
739 elemByteSize = *op.getOffsetAlignByte();
741 if (elemBitWidth != 0) {
742 if (elemBitWidth % 8 != 0)
743 return rewriter.notifyMatchFailure(
744 op,
"Expected element type bit width to be multiple of 8.");
745 elemByteSize = elemBitWidth / 8;
747 basePtrI64 = addOffsetToBaseAddr(rewriter, loc, basePtrI64, offsets,
752 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
753 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
756 ptrTypeLLVM = LLVM::LLVMPointerType::get(
757 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
759 if (
auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
760 auto addrSpace = memRefTy.getMemorySpaceAsInt();
762 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
766 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
768 xevm::PrefetchOp::create(
769 rewriter, loc, ptrLLVM,
770 xevm::LoadCacheControlAttr::get(
771 ctxt, translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint())));
772 rewriter.eraseOp(op);
777class FenceToXeVMPattern :
public OpConversionPattern<xegpu::FenceOp> {
778 using OpConversionPattern::OpConversionPattern;
780 matchAndRewrite(xegpu::FenceOp op, xegpu::FenceOp::Adaptor adaptor,
781 ConversionPatternRewriter &rewriter)
const override {
782 auto loc = op.getLoc();
783 xevm::MemScope memScope{xevm::MemScope::WORKGROUP};
784 switch (op.getFenceScope()) {
785 case xegpu::FenceScope::Workgroup:
786 memScope = xevm::MemScope::WORKGROUP;
788 case xegpu::FenceScope::GPU:
789 memScope = xevm::MemScope::DEVICE;
792 xevm::AddrSpace addrSpace{xevm::AddrSpace::GLOBAL};
793 switch (op.getMemoryKind()) {
794 case xegpu::MemorySpace::Global:
795 addrSpace = xevm::AddrSpace::GLOBAL;
797 case xegpu::MemorySpace::SLM:
798 addrSpace = xevm::AddrSpace::SHARED;
801 xevm::MemfenceOp::create(rewriter, loc, memScope, addrSpace);
802 rewriter.eraseOp(op);
807class DpasToXeVMPattern :
public OpConversionPattern<xegpu::DpasOp> {
808 using OpConversionPattern::OpConversionPattern;
810 matchAndRewrite(xegpu::DpasOp op, xegpu::DpasOp::Adaptor adaptor,
811 ConversionPatternRewriter &rewriter)
const override {
812 auto loc = op.getLoc();
813 auto ctxt = rewriter.getContext();
814 auto aTy = cast<VectorType>(op.getLhs().getType());
815 auto bTy = cast<VectorType>(op.getRhs().getType());
816 auto resultType = cast<VectorType>(op.getResultType());
818 auto encodePrecision = [&](Type type) -> xevm::ElemType {
819 if (type == rewriter.getBF16Type())
820 return xevm::ElemType::BF16;
821 else if (type == rewriter.getF16Type())
822 return xevm::ElemType::F16;
823 else if (type == rewriter.getTF32Type())
824 return xevm::ElemType::TF32;
825 else if (type.isInteger(8)) {
826 if (type.isUnsignedInteger())
827 return xevm::ElemType::U8;
828 return xevm::ElemType::S8;
829 }
else if (type == rewriter.getF32Type())
830 return xevm::ElemType::F32;
831 else if (type.isInteger(32))
832 return xevm::ElemType::S32;
833 llvm_unreachable(
"add more support for ElemType");
835 xevm::ElemType precATy = encodePrecision(aTy.getElementType());
836 xevm::ElemType precBTy = encodePrecision(bTy.getElementType());
837 Value c = op.getAcc();
839 auto elementTy = resultType.getElementType();
840 Attribute initValueAttr;
841 if (isa<FloatType>(elementTy))
842 initValueAttr = FloatAttr::get(elementTy, 0.0);
844 initValueAttr = IntegerAttr::get(elementTy, 0);
845 c = arith::ConstantOp::create(
849 Value aVec = op.getLhs();
850 Value bVec = op.getRhs();
851 auto cvecty = cast<VectorType>(c.
getType());
852 xevm::ElemType precCTy = encodePrecision(cvecty.getElementType());
853 xevm::ElemType precDTy = encodePrecision(resultType.getElementType());
855 VectorType::get(cvecty.getNumElements(), cvecty.getElementType());
857 c = vector::ShapeCastOp::create(rewriter, loc, cNty, c);
858 Value dpasRes = xevm::MMAOp::create(
859 rewriter, loc, cNty, aVec, bVec, c,
860 xevm::MMAShapeAttr::get(ctxt, cvecty.getNumElements(), executionSize,
862 getNumOperandsPerDword(precATy)),
863 xevm::MMATypesAttr::get(ctxt, precDTy, precATy, precBTy, precCTy));
865 dpasRes = vector::ShapeCastOp::create(rewriter, loc, resultType, dpasRes);
866 rewriter.replaceOp(op, dpasRes);
871 static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
873 case xevm::ElemType::TF32:
875 case xevm::ElemType::BF16:
876 case xevm::ElemType::F16:
878 case xevm::ElemType::U8:
879 case xevm::ElemType::S8:
882 llvm_unreachable(
"unsupported xevm::ElemType");
887static std::optional<LLVM::AtomicBinOp>
888matchSimpleAtomicOp(arith::AtomicRMWKind arithKind) {
890 case arith::AtomicRMWKind::addf:
891 return LLVM::AtomicBinOp::fadd;
892 case arith::AtomicRMWKind::addi:
893 return LLVM::AtomicBinOp::add;
894 case arith::AtomicRMWKind::assign:
895 return LLVM::AtomicBinOp::xchg;
896 case arith::AtomicRMWKind::maximumf:
897 return LLVM::AtomicBinOp::fmax;
898 case arith::AtomicRMWKind::maxs:
899 return LLVM::AtomicBinOp::max;
900 case arith::AtomicRMWKind::maxu:
901 return LLVM::AtomicBinOp::umax;
902 case arith::AtomicRMWKind::minimumf:
903 return LLVM::AtomicBinOp::fmin;
904 case arith::AtomicRMWKind::mins:
905 return LLVM::AtomicBinOp::min;
906 case arith::AtomicRMWKind::minu:
907 return LLVM::AtomicBinOp::umin;
908 case arith::AtomicRMWKind::ori:
909 return LLVM::AtomicBinOp::_or;
910 case arith::AtomicRMWKind::andi:
911 return LLVM::AtomicBinOp::_and;
917class AtomicRMWToXeVMPattern :
public OpConversionPattern<xegpu::AtomicRMWOp> {
918 using OpConversionPattern::OpConversionPattern;
920 matchAndRewrite(xegpu::AtomicRMWOp op, xegpu::AtomicRMWOp::Adaptor adaptor,
921 ConversionPatternRewriter &rewriter)
const override {
922 auto loc = op.getLoc();
923 auto ctxt = rewriter.getContext();
924 auto tdesc = op.getTensorDesc().getType();
925 auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
926 ctxt, getNumericXeVMAddrSpace(tdesc.getMemorySpace()));
927 Value basePtrI64 = arith::IndexCastOp::create(
928 rewriter, loc, rewriter.getI64Type(), adaptor.getTensorDesc());
930 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
931 VectorType srcOrDstVecTy = cast<VectorType>(op.getValue().getType());
932 VectorType srcOrDstFlatVecTy = VectorType::get(
933 srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
934 Value srcFlatVec = vector::ShapeCastOp::create(
935 rewriter, loc, srcOrDstFlatVecTy, op.getValue());
936 auto atomicKind = matchSimpleAtomicOp(op.getKind());
937 assert(atomicKind.has_value());
938 Value resVec = srcFlatVec;
939 for (
int i = 0; i < srcOrDstVecTy.getNumElements(); i++) {
940 auto val = vector::ExtractOp::create(rewriter, loc, resVec, i);
941 Value idx = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
942 rewriter.getIndexAttr(i));
944 LLVM::GEPOp::create(rewriter, loc, ptrTypeLLVM,
945 srcOrDstVecTy.getElementType(), basePtrLLVM, idx);
947 LLVM::AtomicRMWOp::create(rewriter, loc, atomicKind.value(), currPtr,
948 val, LLVM::AtomicOrdering::seq_cst);
949 resVec = vector::InsertOp::create(rewriter, loc, newVal, resVec, i);
951 rewriter.replaceOp(op, resVec);
960struct ConvertXeGPUToXeVMPass
964 void runOnOperation()
override {
965 LLVMTypeConverter typeConverter(&
getContext());
966 typeConverter.addConversion([&](VectorType type) -> Type {
967 unsigned rank = type.getRank();
968 auto elemType = type.getElementType();
970 if (llvm::isa<IndexType>(elemType))
971 elemType = IntegerType::get(&
getContext(), 64);
973 if (rank < 1 || type.getNumElements() == 1)
976 int64_t sum = llvm::product_of(type.getShape());
977 return VectorType::get(sum, elemType);
979 typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
981 if (type.isScattered())
983 if (type.getRank() == 1)
985 auto i32Type = IntegerType::get(&
getContext(), 32);
986 return VectorType::get(8, i32Type);
989 typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
993 typeConverter.addConversion([&](MemRefType type) -> Type {
994 if (type.getMemorySpaceAsInt() == 3)
1003 auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
1005 Location loc) -> Value {
1006 if (inputs.size() != 1)
1008 auto input = inputs.front();
1009 if (
auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
1012 memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, input);
1013 return arith::IndexCastUIOp::create(builder, loc, type, addr)
1020 auto ui64MaterializationCast = [](OpBuilder &builder, Type type,
1022 Location loc) -> Value {
1023 if (inputs.size() != 1)
1025 auto input = inputs.front();
1028 index::CastUOp::create(builder, loc, builder.
getIndexType(), input)
1030 return arith::IndexCastUIOp::create(builder, loc, type, cast)
1037 auto ui32MaterializationCast = [](OpBuilder &builder, Type type,
1039 Location loc) -> Value {
1040 if (inputs.size() != 1)
1042 auto input = inputs.front();
1045 index::CastUOp::create(builder, loc, builder.
getIndexType(), input)
1047 return arith::IndexCastUIOp::create(builder, loc, type, cast)
1057 auto vectorMaterializationCast = [](OpBuilder &builder, Type type,
1059 Location loc) -> Value {
1060 if (inputs.size() != 1)
1062 auto input = inputs.front();
1063 if (
auto vecTy = dyn_cast<VectorType>(input.getType())) {
1064 if (vecTy.getNumElements() == 1) {
1067 vector::ExtractOp::create(builder, loc, input, 0).getResult();
1069 cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
1072 }
else if (
auto targetVecTy = dyn_cast<VectorType>(type)) {
1075 if (targetVecTy.getRank() == vecTy.getRank())
1076 return vector::BitCastOp::create(builder, loc, targetVecTy, input)
1078 else if (targetVecTy.getElementType() == vecTy.getElementType()) {
1081 return vector::ShapeCastOp::create(builder, loc, targetVecTy, input)
1092 auto singleElementVectorMaterializationCast =
1093 [](OpBuilder &builder, Type type,
ValueRange inputs,
1094 Location loc) -> Value {
1095 if (inputs.size() != 1)
1097 auto input = inputs.front();
1098 if (input.getType().isIntOrIndexOrFloat()) {
1101 if (
auto vecTy = dyn_cast<VectorType>(type)) {
1102 if (vecTy.getNumElements() == 1) {
1103 return vector::BroadcastOp::create(builder, loc, vecTy, input)
1110 typeConverter.addSourceMaterialization(
1111 singleElementVectorMaterializationCast);
1112 typeConverter.addTargetMaterialization(memrefMaterializationCast);
1113 typeConverter.addTargetMaterialization(ui32MaterializationCast);
1114 typeConverter.addTargetMaterialization(ui64MaterializationCast);
1115 typeConverter.addTargetMaterialization(vectorMaterializationCast);
1117 target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
1118 vector::VectorDialect, arith::ArithDialect,
1119 memref::MemRefDialect, gpu::GPUDialect,
1120 index::IndexDialect>();
1121 target.addIllegalDialect<xegpu::XeGPUDialect>();
1127 if (
failed(applyPartialConversion(getOperation(),
target,
1129 signalPassFailure();
1139 patterns.add<CreateNdDescToXeVMPattern,
1140 LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
1141 LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
1142 LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>(
1143 typeConverter,
patterns.getContext());
1144 patterns.add<AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
1145 LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
1146 LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
1147 typeConverter,
patterns.getContext());
1148 patterns.add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>,
1149 LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>,
1150 CreateMemDescOpPattern>(typeConverter,
patterns.getContext());
1151 patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
IntegerType getIntegerType(unsigned width)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
const FrozenRewritePatternSet & patterns
void populateXeGPUToXeVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)