29#include "llvm/ADT/STLExtras.h"
30#include "llvm/Support/FormatVariadic.h"
35#include "llvm/ADT/TypeSwitch.h"
40#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
41#include "mlir/Conversion/Passes.h.inc"
49static constexpr int32_t systolicDepth{8};
50static constexpr int32_t executionSize{16};
53enum class NdTdescOffset : uint32_t {
60static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
61 switch (xeGpuMemspace) {
62 case xegpu::MemorySpace::Global:
63 return static_cast<int>(xevm::AddrSpace::GLOBAL);
64 case xegpu::MemorySpace::SLM:
65 return static_cast<int>(xevm::AddrSpace::SHARED);
67 llvm_unreachable(
"Unknown XeGPU memory space");
71static bool isSharedMemRef(
const MemRefType &memrefTy) {
72 Attribute attr = memrefTy.getMemorySpace();
75 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
76 return intAttr.getInt() ==
static_cast<int>(xevm::AddrSpace::SHARED);
77 if (
auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
78 return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
79 return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
83static VectorType encodeVectorTypeTo(VectorType currentVecType,
85 auto elemType = currentVecType.getElementType();
86 auto currentBitWidth = elemType.getIntOrFloatBitWidth();
89 currentVecType.getNumElements() * currentBitWidth / newBitWidth;
90 return VectorType::get(size, toElemType);
93static xevm::LoadCacheControl
94translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
95 std::optional<xegpu::CachePolicy> L3hint) {
97 if (!L1hint && !L3hint)
98 return xevm::LoadCacheControl::USE_DEFAULT;
100 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::CACHED);
101 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::CACHED);
103 case xegpu::CachePolicy::CACHED:
104 if (L3hintVal == xegpu::CachePolicy::CACHED)
105 return xevm::LoadCacheControl::L1C_L2UC_L3C;
106 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
107 return xevm::LoadCacheControl::L1C_L2UC_L3UC;
109 llvm_unreachable(
"Unsupported cache control.");
110 case xegpu::CachePolicy::UNCACHED:
111 if (L3hintVal == xegpu::CachePolicy::CACHED)
112 return xevm::LoadCacheControl::L1UC_L2UC_L3C;
113 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
114 return xevm::LoadCacheControl::L1UC_L2UC_L3UC;
116 llvm_unreachable(
"Unsupported cache control.");
117 case xegpu::CachePolicy::STREAMING:
118 if (L3hintVal == xegpu::CachePolicy::CACHED)
119 return xevm::LoadCacheControl::L1S_L2UC_L3C;
120 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
121 return xevm::LoadCacheControl::L1S_L2UC_L3UC;
123 llvm_unreachable(
"Unsupported cache control.");
124 case xegpu::CachePolicy::READ_INVALIDATE:
125 return xevm::LoadCacheControl::INVALIDATE_READ;
127 llvm_unreachable(
"Unsupported cache control.");
131static xevm::StoreCacheControl
132translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
133 std::optional<xegpu::CachePolicy> L3hint) {
135 if (!L1hint && !L3hint)
136 return xevm::StoreCacheControl::USE_DEFAULT;
138 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
139 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::WRITE_BACK);
141 case xegpu::CachePolicy::UNCACHED:
142 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
143 return xevm::StoreCacheControl::L1UC_L2UC_L3UC;
144 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
145 return xevm::StoreCacheControl::L1UC_L2UC_L3WB;
147 llvm_unreachable(
"Unsupported cache control.");
148 case xegpu::CachePolicy::STREAMING:
149 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
150 return xevm::StoreCacheControl::L1S_L2UC_L3UC;
151 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
152 return xevm::StoreCacheControl::L1S_L2UC_L3WB;
154 llvm_unreachable(
"Unsupported cache control.");
155 case xegpu::CachePolicy::WRITE_BACK:
156 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
157 return xevm::StoreCacheControl::L1WB_L2UC_L3UC;
158 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
159 return xevm::StoreCacheControl::L1WB_L2UC_L3WB;
161 llvm_unreachable(
"Unsupported cache control.");
162 case xegpu::CachePolicy::WRITE_THROUGH:
163 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
164 return xevm::StoreCacheControl::L1WT_L2UC_L3UC;
165 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
166 return xevm::StoreCacheControl::L1WT_L2UC_L3WB;
168 llvm_unreachable(
"Unsupported cache control.");
170 llvm_unreachable(
"Unsupported cache control.");
182class CreateNdDescToXeVMPattern
183 :
public OpConversionPattern<xegpu::CreateNdDescOp> {
184 using OpConversionPattern::OpConversionPattern;
186 matchAndRewrite(xegpu::CreateNdDescOp op,
187 xegpu::CreateNdDescOp::Adaptor adaptor,
188 ConversionPatternRewriter &rewriter)
const override {
189 auto loc = op.getLoc();
190 auto source = op.getSource();
194 Type payloadElemTy = rewriter.getI32Type();
195 VectorType payloadTy = VectorType::get(8, payloadElemTy);
196 Type i64Ty = rewriter.getI64Type();
198 VectorType payloadI64Ty = VectorType::get(4, i64Ty);
200 Value payload = arith::ConstantOp::create(
209 SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
210 SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides();
212 int64_t rank = mixedSizes.size();
213 auto sourceTy = source.getType();
214 auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
217 if (sourceMemrefTy) {
218 if (!sourceMemrefTy.hasRank()) {
219 return rewriter.notifyMatchFailure(op,
"Expected ranked Memref.");
223 baseAddr = adaptor.getSource();
225 baseAddr = adaptor.getSource();
226 if (baseAddr.
getType() != i64Ty) {
228 baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
233 rewriter.replaceOp(op, baseAddr);
237 auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
238 unsigned idx) -> Value {
245 baseShapeW = createOffset(mixedSizes, rank - 1);
246 baseShapeH = createOffset(mixedSizes, rank - 2);
248 Value basePitch = createOffset(mixedStrides, rank - 2);
251 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
253 vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
254 static_cast<int>(NdTdescOffset::BasePtr));
255 payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
257 vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
258 static_cast<int>(NdTdescOffset::BaseShapeW));
260 vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
261 static_cast<int>(NdTdescOffset::BaseShapeH));
263 vector::InsertOp::create(rewriter, loc, basePitch, payload,
264 static_cast<int>(NdTdescOffset::BasePitch));
265 rewriter.replaceOp(op, payload);
272 typename = std::enable_if_t<llvm::is_one_of<
273 OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp>::value>>
274class LoadStorePrefetchNdToXeVMPattern :
public OpConversionPattern<OpType> {
275 using OpConversionPattern<OpType>::OpConversionPattern;
277 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
278 ConversionPatternRewriter &rewriter)
const override {
279 auto mixedOffsets = op.getMixedOffsets();
280 int64_t opOffsetsSize = mixedOffsets.size();
281 auto loc = op.getLoc();
282 auto ctxt = rewriter.getContext();
284 auto tdesc = adaptor.getTensorDesc();
285 auto tdescTy = op.getTensorDescType();
286 auto tileRank = tdescTy.getRank();
287 if (opOffsetsSize != tileRank)
288 return rewriter.notifyMatchFailure(
289 op,
"Expected offset rank to match descriptor rank.");
290 auto elemType = tdescTy.getElementType();
291 auto elemBitSize = elemType.getIntOrFloatBitWidth();
292 bool isSubByte = elemBitSize < 8;
293 uint64_t wScaleFactor = 1;
295 if (!isSubByte && (elemBitSize % 8 != 0))
296 return rewriter.notifyMatchFailure(
297 op,
"Expected element type bit width to be multiple of 8.");
298 auto tileW = tdescTy.getDimSize(tileRank - 1);
301 if (elemBitSize != 4)
302 return rewriter.notifyMatchFailure(
303 op,
"Only sub byte types of 4bits are supported.");
305 return rewriter.notifyMatchFailure(
306 op,
"Sub byte types are only supported for 2D tensor descriptors.");
307 auto subByteFactor = 8 / elemBitSize;
308 auto tileH = tdescTy.getDimSize(0);
310 if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
311 if (op.getPacked().value_or(
false)) {
313 if (tileH == systolicDepth * 4 &&
314 tileW == executionSize * subByteFactor) {
319 elemType = rewriter.getIntegerType(8);
320 tileW = executionSize;
321 wScaleFactor = subByteFactor;
326 if (wScaleFactor == 1) {
327 auto sub16BitFactor = subByteFactor * 2;
328 if (tileW == executionSize * sub16BitFactor) {
332 elemType = rewriter.getIntegerType(16);
333 tileW = executionSize;
334 wScaleFactor = sub16BitFactor;
336 return rewriter.notifyMatchFailure(
337 op,
"Unsupported tile shape for sub byte types.");
341 elemBitSize = elemType.getIntOrFloatBitWidth();
345 auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
346 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
350 rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
351 VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
353 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
355 vector::ExtractOp::create(rewriter, loc, payLoadAsI64,
356 static_cast<int>(NdTdescOffset::BasePtr));
357 Value baseShapeW = vector::ExtractOp::create(
358 rewriter, loc, tdesc,
static_cast<int>(NdTdescOffset::BaseShapeW));
359 Value baseShapeH = vector::ExtractOp::create(
360 rewriter, loc, tdesc,
static_cast<int>(NdTdescOffset::BaseShapeH));
361 Value basePitch = vector::ExtractOp::create(
362 rewriter, loc, tdesc,
static_cast<int>(NdTdescOffset::BasePitch));
368 mixedOffsets[tileRank - 1]);
370 rewriter.getI32Type(), offsetW);
372 mixedOffsets[tileRank - 2]);
374 rewriter.getI32Type(), offsetH);
377 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
381 Value baseShapeWInBytes =
382 arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
384 Value basePitchBytes =
385 arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize);
387 if (wScaleFactor > 1) {
391 rewriter, loc, rewriter.getI32Type(), llvm::Log2_64(wScaleFactor));
392 baseShapeWInBytes = arith::ShRSIOp::create(
393 rewriter, loc, baseShapeWInBytes, wScaleFactorValLog2);
394 basePitchBytes = arith::ShRSIOp::create(rewriter, loc, basePitchBytes,
395 wScaleFactorValLog2);
397 arith::ShRSIOp::create(rewriter, loc, offsetW, wScaleFactorValLog2);
400 auto tileH = tdescTy.getDimSize(tileRank - 2);
402 int32_t vblocks = tdescTy.getArrayLength();
403 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
404 Value src = adaptor.getValue();
410 VectorType srcVecTy = dyn_cast<VectorType>(src.
getType());
412 return rewriter.notifyMatchFailure(
413 op,
"Expected store value to be a vector type.");
415 VectorType newSrcVecTy =
416 encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
417 if (srcVecTy != newSrcVecTy)
418 src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
419 auto storeCacheControl =
420 translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
421 xevm::BlockStore2dOp::create(
422 rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
423 basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, src,
424 xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
425 rewriter.eraseOp(op);
427 auto loadCacheControl =
428 translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
429 if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
430 xevm::BlockPrefetch2dOp::create(
431 rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
432 basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH,
433 vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
434 rewriter.eraseOp(op);
436 VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
437 bool vnni = op.getPacked().value_or(
false);
438 auto transposeValue = op.getTranspose();
440 transposeValue.has_value() && transposeValue.value()[0] == 1;
446 if (elemBitSize == 8 && tileW == 16 && tileH == 32 && !vnni &&
454 if (transpose && elemBitSize < 32) {
455 int32_t scale = 32 / elemBitSize;
457 rewriter, loc, rewriter.getI32Type(), llvm::Log2_64(scale));
458 offsetW = arith::ShRSIOp::create(rewriter, loc, offsetW, scaleLog2);
459 tileW = tileW * elemBitSize / 32;
462 VectorType loadedTy = encodeVectorTypeTo(
463 dstVecTy, vnni ? rewriter.getI32Type()
464 : rewriter.getIntegerType(elemBitSize));
466 Value resultFlatVec = xevm::BlockLoad2dOp::create(
467 rewriter, loc, loadedTy, basePtrLLVM, baseShapeWInBytes,
468 baseShapeH, basePitchBytes, offsetW, offsetH, elemBitSize, tileW,
469 tileH, vblocks, transpose, vnni,
470 xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
471 resultFlatVec = vector::BitCastOp::create(
473 encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
475 rewriter.replaceOp(op, resultFlatVec);
487 rewriter.getI64Type(), offset);
490 rewriter, loc, rewriter.getI64Type(), elemBitSize / 8);
492 rewriter.createOrFold<arith::MulIOp>(loc, offset, elemByteSize);
494 Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>(
500 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64);
501 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
502 Value src = adaptor.getValue();
508 VectorType srcVecTy = dyn_cast<VectorType>(src.
getType());
510 return rewriter.notifyMatchFailure(
511 op,
"Expected store value to be a vector type.");
513 VectorType newSrcVecTy =
514 encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
515 if (srcVecTy != newSrcVecTy)
516 src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
517 auto storeCacheControl =
518 translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
519 rewriter.replaceOpWithNewOp<xevm::BlockStoreOp>(
520 op, finalPtrLLVM, src,
521 xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
522 }
else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
523 auto loadCacheControl =
524 translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
525 VectorType resTy = cast<VectorType>(op.getValue().getType());
526 VectorType loadedTy =
527 encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize));
528 Value
load = xevm::BlockLoadOp::create(
529 rewriter, loc, loadedTy, finalPtrLLVM,
530 xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
531 if (loadedTy != resTy)
532 load = vector::BitCastOp::create(rewriter, loc, resTy,
load);
533 rewriter.replaceOp(op,
load);
535 return rewriter.notifyMatchFailure(
536 op,
"Unsupported operation: xegpu.prefetch_nd with tensor "
537 "descriptor rank == 1");
546static Value addOffsetToBaseAddr(ConversionPatternRewriter &rewriter,
550 rewriter, loc, baseAddr.
getType(), elemByteSize);
551 Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
552 Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
556template <
typename OpType,
557 typename = std::enable_if_t<llvm::is_one_of<
558 OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
559class LoadStoreToXeVMPattern :
public OpConversionPattern<OpType> {
560 using OpConversionPattern<OpType>::OpConversionPattern;
562 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
563 ConversionPatternRewriter &rewriter)
const override {
564 Value offset = adaptor.getOffsets();
566 return rewriter.notifyMatchFailure(op,
"Expected offset to be provided.");
567 auto loc = op.getLoc();
568 auto ctxt = rewriter.getContext();
572 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
574 this->getTypeConverter()->convertType(op.getResult().getType());
576 valOrResTy = adaptor.getValue().getType();
577 VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
578 bool hasScalarVal = !valOrResVecTy;
579 int64_t elemBitWidth =
581 : valOrResVecTy.getElementType().getIntOrFloatBitWidth();
583 if (elemBitWidth % 8 != 0)
584 return rewriter.notifyMatchFailure(
585 op,
"Expected element type bit width to be multiple of 8.");
586 int64_t elemByteSize = elemBitWidth / 8;
588 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
589 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
592 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
593 basePtrI64 = adaptor.getSource();
594 if (
auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
595 auto addrSpace = memRefTy.getMemorySpaceAsInt();
597 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
600 basePtrI64 = adaptor.getDest();
601 if (
auto memRefTy = dyn_cast<MemRefType>(op.getDest().getType())) {
602 auto addrSpace = memRefTy.getMemorySpaceAsInt();
604 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
608 if (basePtrI64.
getType() != rewriter.getI64Type()) {
609 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
612 Value mask = adaptor.getMask();
613 if (dyn_cast<VectorType>(offset.
getType())) {
616 return rewriter.notifyMatchFailure(op,
"Expected offset to be a scalar.");
622 addOffsetToBaseAddr(rewriter, loc, basePtrI64, offset, elemByteSize);
626 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
629 VectorType maskVecTy = dyn_cast<VectorType>(mask.
getType());
633 return rewriter.notifyMatchFailure(op,
"Expected mask to be a scalar.");
636 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
637 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy},
638 maskForLane,
true,
true);
640 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
642 valOrResTy = VectorType::get({valOrResVecTy.getNumElements()},
643 valOrResVecTy.getElementType());
645 LLVM::LoadOp::create(rewriter, loc, valOrResTy, basePtrLLVM);
648 "cache_control", xevm::LoadCacheControlAttr::get(
649 ctxt, translateLoadXeGPUCacheHint(
650 op.getL1Hint(), op.getL3Hint())));
651 scf::YieldOp::create(rewriter, loc,
ValueRange{loaded});
652 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
654 auto eTy = hasScalarVal ? valOrResTy : valOrResVecTy.getElementType();
657 eVal = FloatAttr::get(eTy, 0.0);
659 eVal = IntegerAttr::get(eTy, 0);
661 loaded = arith::ConstantOp::create(rewriter, loc, eVal);
663 loaded = arith::ConstantOp::create(
665 scf::YieldOp::create(rewriter, loc,
ValueRange{loaded});
666 rewriter.replaceOp(op, ifOp.getResult(0));
669 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane,
false);
670 auto body = ifOp.getBody();
671 rewriter.setInsertionPointToStart(body);
673 LLVM::StoreOp::create(rewriter, loc, adaptor.getValue(), basePtrLLVM);
675 storeOp.getOperation()->setAttr(
676 "cache_control", xevm::StoreCacheControlAttr::get(
677 ctxt, translateStoreXeGPUCacheHint(
678 op.getL1Hint(), op.getL3Hint())));
679 rewriter.eraseOp(op);
685class CreateMemDescOpPattern final
686 :
public OpConversionPattern<xegpu::CreateMemDescOp> {
688 using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern;
690 matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
691 ConversionPatternRewriter &rewriter)
const override {
693 rewriter.replaceOp(op, adaptor.getSource());
698template <
typename OpType,
699 typename = std::enable_if_t<llvm::is_one_of<
700 OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
701class LoadStoreMatrixToXeVMPattern :
public OpConversionPattern<OpType> {
702 using OpConversionPattern<OpType>::OpConversionPattern;
704 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
705 ConversionPatternRewriter &rewriter)
const override {
707 SmallVector<OpFoldResult> offsets = op.getMixedOffsets();
709 return rewriter.notifyMatchFailure(op,
"Expected offset to be provided.");
711 auto loc = op.getLoc();
712 auto ctxt = rewriter.getContext();
713 Value baseAddr32 = adaptor.getMemDesc();
714 Value mdescVal = op.getMemDesc();
717 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
718 Type resType = op.getResult().getType();
721 if (
auto vecType = dyn_cast<VectorType>(resType)) {
722 assert(llvm::count_if(vecType.getShape(),
723 [](int64_t d) { return d != 1; }) <= 1 &&
724 "Expected either 1D vector or nD with unit dimensions");
725 resType = VectorType::get({vecType.getNumElements()},
726 vecType.getElementType());
730 dataTy = adaptor.getData().getType();
731 VectorType valOrResVecTy = dyn_cast<VectorType>(dataTy);
733 valOrResVecTy = VectorType::get(1, dataTy);
735 int64_t elemBitWidth =
736 valOrResVecTy.getElementType().getIntOrFloatBitWidth();
738 if (elemBitWidth % 8 != 0)
739 return rewriter.notifyMatchFailure(
740 op,
"Expected element type bit width to be multiple of 8.");
741 int64_t elemByteSize = elemBitWidth / 8;
744 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
745 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM));
747 auto mdescTy = cast<xegpu::MemDescType>(mdescVal.
getType());
749 Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
750 linearOffset = arith::IndexCastUIOp::create(
751 rewriter, loc, rewriter.getI32Type(), linearOffset);
752 Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32,
753 linearOffset, elemByteSize);
757 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);
759 if (op.getSubgroupBlockIoAttr()) {
763 Type intElemTy = rewriter.getIntegerType(elemBitWidth);
764 VectorType intVecTy =
765 VectorType::get(valOrResVecTy.getShape(), intElemTy);
767 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
769 xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM);
770 if (intVecTy != valOrResVecTy) {
772 vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp);
774 rewriter.replaceOp(op, loadOp);
776 Value dataToStore = adaptor.getData();
777 if (valOrResVecTy != intVecTy) {
779 vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore);
781 xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore,
783 rewriter.eraseOp(op);
788 if (valOrResVecTy.getNumElements() >= 1) {
791 (*chipOpt !=
"pvc" && *chipOpt !=
"bmg" && *chipOpt !=
"cri")) {
793 return rewriter.notifyMatchFailure(
794 op,
"The lowering is specific to pvc, bmg or cri.");
798 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
802 auto scalarTy = valOrResVecTy.getElementType();
804 if (valOrResVecTy.getNumElements() == 1)
805 loadOp = LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
808 LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
809 rewriter.replaceOp(op, loadOp);
811 LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
812 rewriter.eraseOp(op);
818class PrefetchToXeVMPattern :
public OpConversionPattern<xegpu::PrefetchOp> {
819 using OpConversionPattern::OpConversionPattern;
821 matchAndRewrite(xegpu::PrefetchOp op, xegpu::PrefetchOp::Adaptor adaptor,
822 ConversionPatternRewriter &rewriter)
const override {
823 auto loc = op.getLoc();
824 auto ctxt = rewriter.getContext();
825 Value basePtrI64 = adaptor.getSource();
827 if (basePtrI64.
getType() != rewriter.getI64Type())
828 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
830 Value offsets = adaptor.getOffsets();
832 VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.
getType());
835 return rewriter.notifyMatchFailure(op,
836 "Expected offsets to be a scalar.");
838 int64_t elemBitWidth{0};
839 int64_t elemByteSize;
841 if (
auto memRefTy = dyn_cast<MemRefType>(op.getSourceType())) {
844 elemBitWidth = memRefTy.getElementType().getIntOrFloatBitWidth();
847 elemByteSize = *op.getOffsetAlignByte();
849 if (elemBitWidth != 0) {
850 if (elemBitWidth % 8 != 0)
851 return rewriter.notifyMatchFailure(
852 op,
"Expected element type bit width to be multiple of 8.");
853 elemByteSize = elemBitWidth / 8;
855 basePtrI64 = addOffsetToBaseAddr(rewriter, loc, basePtrI64, offsets,
860 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
861 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
863 if (
auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
864 auto addrSpace = memRefTy.getMemorySpaceAsInt();
866 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
870 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
872 xevm::PrefetchOp::create(
873 rewriter, loc, ptrLLVM,
874 xevm::LoadCacheControlAttr::get(
875 ctxt, translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint())));
876 rewriter.eraseOp(op);
881class FenceToXeVMPattern :
public OpConversionPattern<xegpu::FenceOp> {
882 using OpConversionPattern::OpConversionPattern;
884 matchAndRewrite(xegpu::FenceOp op, xegpu::FenceOp::Adaptor adaptor,
885 ConversionPatternRewriter &rewriter)
const override {
886 auto loc = op.getLoc();
887 xevm::MemScope memScope{xevm::MemScope::WORKGROUP};
888 switch (op.getFenceScope()) {
889 case xegpu::FenceScope::Workgroup:
890 memScope = xevm::MemScope::WORKGROUP;
892 case xegpu::FenceScope::GPU:
893 memScope = xevm::MemScope::DEVICE;
896 xevm::AddrSpace addrSpace{xevm::AddrSpace::GLOBAL};
897 switch (op.getMemoryKind()) {
898 case xegpu::MemorySpace::Global:
899 addrSpace = xevm::AddrSpace::GLOBAL;
901 case xegpu::MemorySpace::SLM:
902 addrSpace = xevm::AddrSpace::SHARED;
905 xevm::MemfenceOp::create(rewriter, loc, memScope, addrSpace);
906 rewriter.eraseOp(op);
911static auto encodePrecision = [](
Type type) -> xevm::ElemType {
913 return xevm::ElemType::BF16;
914 else if (type.isF16())
915 return xevm::ElemType::F16;
916 else if (type.isTF32())
917 return xevm::ElemType::TF32;
918 else if (type.isInteger(8)) {
919 if (type.isUnsignedInteger())
920 return xevm::ElemType::U8;
921 return xevm::ElemType::S8;
922 }
else if (type.isF32())
923 return xevm::ElemType::F32;
924 else if (type.isInteger(32))
925 return xevm::ElemType::S32;
926 else if (type.isF8E5M2())
927 return xevm::ElemType::BF8;
928 else if (type.isF8E4M3FN())
929 return xevm::ElemType::F8;
930 else if (mlir::isa<Float4E2M1FNType>(type))
931 return xevm::ElemType::E2M1;
932 llvm_unreachable(
"add more support for ElemType");
935static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
937 case xevm::ElemType::TF32:
939 case xevm::ElemType::BF16:
940 case xevm::ElemType::F16:
942 case xevm::ElemType::U8:
943 case xevm::ElemType::S8:
944 case xevm::ElemType::F8:
945 case xevm::ElemType::BF8:
947 case xevm::ElemType::E2M1:
950 llvm_unreachable(
"unsupported xevm::ElemType");
954class DpasToXeVMPattern :
public OpConversionPattern<xegpu::DpasOp> {
955 using OpConversionPattern::OpConversionPattern;
957 matchAndRewrite(xegpu::DpasOp op, xegpu::DpasOp::Adaptor adaptor,
958 ConversionPatternRewriter &rewriter)
const override {
959 auto loc = op.getLoc();
960 auto ctxt = rewriter.getContext();
961 auto aTy = cast<VectorType>(op.getLhs().getType());
962 auto bTy = cast<VectorType>(op.getRhs().getType());
963 auto resultType = cast<VectorType>(op.getResultType());
968 return rewriter.notifyMatchFailure(op,
"cannot determine target chip");
972 return rewriter.notifyMatchFailure(op,
"unsupported target uArch");
975 llvm::dyn_cast_or_null<xegpu::uArch::SubgroupMatrixMultiplyAcc>(
976 uArch->getInstruction(
977 xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)));
979 return rewriter.notifyMatchFailure(op,
980 "DPAS not supported by target uArch");
982 auto checkSupportedTypes = [&](VectorType vecTy,
984 auto supported = dpasInst->getSupportedTypes(*ctxt, kind);
985 return llvm::find(supported, vecTy.getElementType()) != supported.end();
988 if (!checkSupportedTypes(aTy, xegpu::uArch::MMAOpndKind::MatrixA))
989 return rewriter.notifyMatchFailure(
990 op,
"A-matrix element type not supported by target uArch");
991 if (!checkSupportedTypes(bTy, xegpu::uArch::MMAOpndKind::MatrixB))
992 return rewriter.notifyMatchFailure(
993 op,
"B-matrix element type not supported by target uArch");
995 if (!checkSupportedTypes(resultType, xegpu::uArch::MMAOpndKind::MatrixD))
996 return rewriter.notifyMatchFailure(
997 op,
"result/accumulator element type not supported by target uArch");
999 xevm::ElemType precATy = encodePrecision(aTy.getElementType());
1000 xevm::ElemType precBTy = encodePrecision(bTy.getElementType());
1001 Value c = op.getAcc();
1003 auto elementTy = resultType.getElementType();
1004 Attribute initValueAttr;
1005 if (isa<FloatType>(elementTy))
1006 initValueAttr = FloatAttr::get(elementTy, 0.0);
1008 initValueAttr = IntegerAttr::get(elementTy, 0);
1009 c = arith::ConstantOp::create(
1013 Value aVec = op.getLhs();
1014 Value bVec = op.getRhs();
1015 auto cvecty = cast<VectorType>(c.
getType());
1016 xevm::ElemType precCTy = encodePrecision(cvecty.getElementType());
1017 xevm::ElemType precDTy = encodePrecision(resultType.getElementType());
1019 VectorType::get(cvecty.getNumElements(), cvecty.getElementType());
1021 c = vector::ShapeCastOp::create(rewriter, loc, cNty, c);
1022 Value dpasRes = xevm::MMAOp::create(
1023 rewriter, loc, cNty, aVec, bVec, c,
1024 xevm::MMAShapeAttr::get(ctxt, cvecty.getNumElements(), executionSize,
1026 getNumOperandsPerDword(precATy)),
1027 xevm::MMATypesAttr::get(ctxt, precDTy, precATy, precBTy, precCTy));
1029 dpasRes = vector::ShapeCastOp::create(rewriter, loc, resultType, dpasRes);
1030 rewriter.replaceOp(op, dpasRes);
1035static std::optional<LLVM::AtomicBinOp>
1036matchSimpleAtomicOp(arith::AtomicRMWKind arithKind) {
1037 switch (arithKind) {
1038 case arith::AtomicRMWKind::addf:
1039 return LLVM::AtomicBinOp::fadd;
1040 case arith::AtomicRMWKind::addi:
1041 return LLVM::AtomicBinOp::add;
1042 case arith::AtomicRMWKind::assign:
1043 return LLVM::AtomicBinOp::xchg;
1044 case arith::AtomicRMWKind::maximumf:
1045 return LLVM::AtomicBinOp::fmax;
1046 case arith::AtomicRMWKind::maxs:
1047 return LLVM::AtomicBinOp::max;
1048 case arith::AtomicRMWKind::maxu:
1049 return LLVM::AtomicBinOp::umax;
1050 case arith::AtomicRMWKind::minimumf:
1051 return LLVM::AtomicBinOp::fmin;
1052 case arith::AtomicRMWKind::mins:
1053 return LLVM::AtomicBinOp::min;
1054 case arith::AtomicRMWKind::minu:
1055 return LLVM::AtomicBinOp::umin;
1056 case arith::AtomicRMWKind::ori:
1057 return LLVM::AtomicBinOp::_or;
1058 case arith::AtomicRMWKind::andi:
1059 return LLVM::AtomicBinOp::_and;
1061 return std::nullopt;
1065class AtomicRMWToXeVMPattern :
public OpConversionPattern<xegpu::AtomicRMWOp> {
1066 using OpConversionPattern::OpConversionPattern;
1068 matchAndRewrite(xegpu::AtomicRMWOp op, xegpu::AtomicRMWOp::Adaptor adaptor,
1069 ConversionPatternRewriter &rewriter)
const override {
1070 auto loc = op.getLoc();
1071 auto ctxt = rewriter.getContext();
1072 auto tdesc = op.getTensorDesc().getType();
1073 auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
1074 ctxt, getNumericXeVMAddrSpace(tdesc.getMemorySpace()));
1075 Value basePtrI64 = arith::IndexCastOp::create(
1076 rewriter, loc, rewriter.getI64Type(), adaptor.getTensorDesc());
1078 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
1079 VectorType srcOrDstVecTy = cast<VectorType>(op.getValue().getType());
1080 VectorType srcOrDstFlatVecTy = VectorType::get(
1081 srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
1082 Value srcFlatVec = vector::ShapeCastOp::create(
1083 rewriter, loc, srcOrDstFlatVecTy, op.getValue());
1084 auto atomicKind = matchSimpleAtomicOp(op.getKind());
1085 assert(atomicKind.has_value());
1086 Value resVec = srcFlatVec;
1087 for (
int i = 0; i < srcOrDstVecTy.getNumElements(); i++) {
1088 auto val = vector::ExtractOp::create(rewriter, loc, resVec, i);
1089 Value idx = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
1090 rewriter.getIndexAttr(i));
1092 LLVM::GEPOp::create(rewriter, loc, ptrTypeLLVM,
1093 srcOrDstVecTy.getElementType(), basePtrLLVM, idx);
1095 LLVM::AtomicRMWOp::create(rewriter, loc, atomicKind.value(), currPtr,
1096 val, LLVM::AtomicOrdering::seq_cst);
1097 resVec = vector::InsertOp::create(rewriter, loc, newVal, resVec, i);
1099 rewriter.replaceOp(op, resVec);
1104class DpasMxToXeVMPattern :
public OpConversionPattern<xegpu::DpasMxOp> {
1105 using OpConversionPattern::OpConversionPattern;
1107 matchAndRewrite(xegpu::DpasMxOp op, xegpu::DpasMxOp::Adaptor adaptor,
1108 ConversionPatternRewriter &rewriter)
const override {
1109 auto loc = op.getLoc();
1110 auto ctxt = rewriter.getContext();
1111 auto aTy = op.getA().getType();
1112 auto bTy = op.getB().getType();
1114 cast<VectorType>(getTypeConverter()->convertType(op.getType()));
1118 return rewriter.notifyMatchFailure(op,
"cannot determine target chip");
1122 return rewriter.notifyMatchFailure(op,
"unsupported target uArch");
1126 xevm::ElemType precATy = encodePrecision(aTy.getElementType());
1127 xevm::ElemType precBTy = encodePrecision(bTy.getElementType());
1128 Value c = adaptor.getAcc();
1130 auto elementTy = resVecTy.getElementType();
1131 Attribute initValueAttr;
1132 if (isa<FloatType>(elementTy))
1133 initValueAttr = FloatAttr::get(elementTy, 0.0);
1135 initValueAttr = IntegerAttr::get(elementTy, 0);
1136 c = arith::ConstantOp::create(
1140 Value aVec = adaptor.getA();
1141 Value bVec = adaptor.getB();
1142 auto aVecTy = cast<VectorType>(aVec.
getType());
1143 auto bVecTy = cast<VectorType>(bVec.
getType());
1144 if (aVecTy.getElementTypeBitWidth() == 4)
1145 aVec = vector::BitCastOp::create(
1147 VectorType::get(aVecTy.getNumElements() / 2, rewriter.getI8Type()),
1149 if (bVecTy.getElementTypeBitWidth() == 4)
1150 bVec = vector::BitCastOp::create(
1152 VectorType::get(bVecTy.getNumElements() / 2, rewriter.getI8Type()),
1154 auto cVecTy = cast<VectorType>(c.
getType());
1155 xevm::ElemType precCTy = encodePrecision(cVecTy.getElementType());
1156 xevm::ElemType precDTy = encodePrecision(resVecTy.getElementType());
1157 Value scaleA = adaptor.getScaleA();
1158 Value scaleB = adaptor.getScaleB();
1159 Value dpasMxRes = xevm::MMAMxOp::create(
1160 rewriter, loc, resVecTy, aVec, bVec, scaleA, scaleB, c,
1161 xevm::MMAShapeAttr::get(ctxt, cVecTy.getNumElements(), executionSize,
1163 getNumOperandsPerDword(precATy)),
1164 xevm::MMATypesAttr::get(ctxt, precDTy, precATy, precBTy, precCTy));
1165 rewriter.replaceOp(op, dpasMxRes);
1184static constexpr int64_t kXeVMExtfTruncfNumElems = 16;
1187static std::optional<xevm::ExtfSrcElemTypes> getExtfNarrowType(
Type etype) {
1188 if (isa<Float8E5M2Type>(etype))
1189 return xevm::ExtfSrcElemTypes::BF8;
1190 if (isa<Float8E4M3FNType>(etype))
1191 return xevm::ExtfSrcElemTypes::F8;
1192 if (isa<Float4E2M1FNType>(etype))
1193 return xevm::ExtfSrcElemTypes::E2M1;
1194 return std::nullopt;
1198static std::optional<xevm::TruncfDstElemTypes> getTruncfNarrowType(
Type etype) {
1199 if (isa<Float8E5M2Type>(etype))
1200 return xevm::TruncfDstElemTypes::BF8;
1201 if (isa<Float8E4M3FNType>(etype))
1202 return xevm::TruncfDstElemTypes::F8;
1203 if (isa<Float4E2M1FNType>(etype))
1204 return xevm::TruncfDstElemTypes::E2M1;
1205 return std::nullopt;
1210static bool isXeVMExtf(arith::ExtFOp op) {
1211 auto srcTy = dyn_cast<VectorType>(op.getIn().getType());
1212 auto dstTy = dyn_cast<VectorType>(op.getType());
1213 if (!srcTy || !dstTy || srcTy.getRank() != 1 || dstTy.getRank() != 1)
1215 if (dstTy.getNumElements() != kXeVMExtfTruncfNumElems)
1217 Type dstETy = dstTy.getElementType();
1220 return getExtfNarrowType(srcTy.getElementType()).has_value();
1226static bool isXeVMTruncf(arith::TruncFOp op) {
1227 auto srcTy = dyn_cast<VectorType>(op.getIn().getType());
1228 auto dstTy = dyn_cast<VectorType>(op.getType());
1229 if (!srcTy || !dstTy || srcTy.getRank() != 1 || dstTy.getRank() != 1)
1231 if (srcTy.getNumElements() != kXeVMExtfTruncfNumElems)
1233 Type srcETy = srcTy.getElementType();
1236 return getTruncfNarrowType(dstTy.getElementType()).has_value();
1239class ExtfToXeVMPattern :
public OpConversionPattern<arith::ExtFOp> {
1240 using OpConversionPattern::OpConversionPattern;
1242 matchAndRewrite(arith::ExtFOp op, OpAdaptor adaptor,
1243 ConversionPatternRewriter &rewriter)
const override {
1244 if (!isXeVMExtf(op))
1245 return rewriter.notifyMatchFailure(op,
"not a xevm.extf compatible extf");
1246 Location loc = op.getLoc();
1247 MLIRContext *ctx = op.getContext();
1248 auto srcVecTy = cast<VectorType>(op.getIn().getType());
1249 auto dstVecTy = cast<VectorType>(op.getType());
1250 xevm::ExtfSrcElemTypes srcEnum =
1251 *getExtfNarrowType(srcVecTy.getElementType());
1252 xevm::ExtfDstElemTypes dstEnum = dstVecTy.getElementType().isF16()
1253 ? xevm::ExtfDstElemTypes::F16
1254 : xevm::ExtfDstElemTypes::BF16;
1258 Value src = adaptor.getIn();
1259 auto convSrcTy = cast<VectorType>(src.
getType());
1260 if (convSrcTy.getElementTypeBitWidth() == 4)
1261 src = vector::BitCastOp::create(
1263 VectorType::get(convSrcTy.getNumElements() / 2, rewriter.getI8Type()),
1265 Type resTy = getTypeConverter()->convertType(dstVecTy);
1266 Value res = xevm::ExtfOp::create(
1267 rewriter, loc, resTy, src, xevm::ExtfSrcElemTypeAttr::get(ctx, srcEnum),
1268 xevm::ExtfDstElemTypeAttr::get(ctx, dstEnum));
1269 rewriter.replaceOp(op, res);
1274class TruncfToXeVMPattern :
public OpConversionPattern<arith::TruncFOp> {
1275 using OpConversionPattern::OpConversionPattern;
1277 matchAndRewrite(arith::TruncFOp op, OpAdaptor adaptor,
1278 ConversionPatternRewriter &rewriter)
const override {
1279 if (!isXeVMTruncf(op))
1280 return rewriter.notifyMatchFailure(op,
1281 "not a xevm.truncf compatible truncf");
1282 Location loc = op.getLoc();
1283 MLIRContext *ctx = op.getContext();
1284 auto srcVecTy = cast<VectorType>(op.getIn().getType());
1285 auto dstVecTy = cast<VectorType>(op.getType());
1286 xevm::TruncfSrcElemTypes srcEnum = srcVecTy.getElementType().isF16()
1287 ? xevm::TruncfSrcElemTypes::F16
1288 : xevm::TruncfSrcElemTypes::BF16;
1289 xevm::TruncfDstElemTypes dstEnum =
1290 *getTruncfNarrowType(dstVecTy.getElementType());
1292 int64_t numNarrowBits =
1293 dstVecTy.getNumElements() * dstVecTy.getElementTypeBitWidth();
1294 Type packedTy = VectorType::get(numNarrowBits / 8, rewriter.getI8Type());
1296 xevm::TruncfOp::create(rewriter, loc, packedTy, adaptor.getIn(),
1297 xevm::TruncfSrcElemTypeAttr::get(ctx, srcEnum),
1298 xevm::TruncfDstElemTypeAttr::get(ctx, dstEnum));
1300 Type resTy = getTypeConverter()->convertType(dstVecTy);
1302 res = vector::BitCastOp::create(rewriter, loc, resTy, res);
1303 rewriter.replaceOp(op, res);
1312struct ConvertXeGPUToXeVMPass
1313 :
public impl::ConvertXeGPUToXeVMPassBase<ConvertXeGPUToXeVMPass> {
1316 void runOnOperation()
override {
1326 LowerToLLVMOptions
options(context);
1327 options.overrideIndexBitwidth(this->use64bitIndex ? 64 : 32);
1328 LLVMTypeConverter typeConverter(context,
options);
1330 Type xevmIndexType = typeConverter.convertType(IndexType::get(context));
1331 Type i32Type = IntegerType::get(context, 32);
1332 typeConverter.addConversion([&](VectorType type) -> Type {
1333 auto elemType = typeConverter.convertType(type.getElementType());
1335 unsigned rank = type.getRank();
1336 if (rank == 0 || type.getNumElements() == 1)
1339 int64_t sum = llvm::product_of(type.getShape());
1340 return VectorType::get(sum, elemType);
1342 typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
1343 if (type.getRank() == 1)
1344 return xevmIndexType;
1345 return VectorType::get(8, i32Type);
1354 typeConverter.addConversion(
1355 [&](xegpu::MemDescType type) -> Type {
return i32Type; });
1357 typeConverter.addConversion([&](MemRefType type) -> Type {
1358 return isSharedMemRef(type) ? i32Type : xevmIndexType;
1368 auto memrefToIntMaterializationCast = [](OpBuilder &builder, Type type,
1370 Location loc) -> Value {
1371 if (inputs.size() != 1)
1373 auto input = inputs.front();
1374 if (
auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
1375 unsigned rank = memrefTy.getRank();
1379 SmallVector<int64_t> intStrides;
1382 if (succeeded(memrefTy.getStridesAndOffset(intStrides, intOffsets)) &&
1383 ShapedType::isStatic(intOffsets)) {
1384 addr = memref::ExtractAlignedPointerAsIndexOp::create(builder, loc,
1386 offset = arith::ConstantOp::create(builder, loc,
1392 SmallVector<Type> resultTypes{
1393 MemRefType::get({}, memrefTy.getElementType(),
1394 MemRefLayoutAttrInterface(),
1395 memrefTy.getMemorySpace()),
1398 resultTypes.append(2 * rank, indexType);
1400 auto meta = memref::ExtractStridedMetadataOp::create(
1401 builder, loc, resultTypes, input);
1403 addr = memref::ExtractAlignedPointerAsIndexOp::create(
1404 builder, loc, meta.getBaseBuffer());
1405 offset = meta.getOffset();
1409 arith::IndexCastUIOp::create(builder, loc, type, addr);
1411 arith::IndexCastUIOp::create(builder, loc, type, offset);
1414 auto byteSize = arith::ConstantOp::create(
1417 memrefTy.getElementTypeBitWidth() / 8));
1419 arith::MulIOp::create(builder, loc, offsetCasted, byteSize);
1420 auto addrWithOffset =
1421 arith::AddIOp::create(builder, loc, addrCasted, byteOffset);
1423 return addrWithOffset.getResult();
1432 auto ui64ToI64MaterializationCast = [](OpBuilder &builder, Type type,
1434 Location loc) -> Value {
1435 if (inputs.size() != 1)
1437 auto input = inputs.front();
1440 index::CastUOp::create(builder, loc, builder.
getIndexType(), input)
1442 return arith::IndexCastUIOp::create(builder, loc, type, cast)
1452 auto ui32ToI32MaterializationCast = [](OpBuilder &builder, Type type,
1454 Location loc) -> Value {
1455 if (inputs.size() != 1)
1457 auto input = inputs.front();
1460 index::CastUOp::create(builder, loc, builder.
getIndexType(), input)
1462 return arith::IndexCastUIOp::create(builder, loc, type, cast)
1472 auto vectorToVectorMaterializationCast = [](OpBuilder &builder, Type type,
1474 Location loc) -> Value {
1475 if (inputs.size() != 1)
1477 auto input = inputs.front();
1478 if (
auto vecTy = dyn_cast<VectorType>(input.getType())) {
1479 if (
auto targetVecTy = dyn_cast<VectorType>(type)) {
1483 if (targetVecTy.getShape() != vecTy.getShape()) {
1484 cast = vector::ShapeCastOp::create(
1486 VectorType::get(targetVecTy.getShape(),
1487 vecTy.getElementType()),
1491 if (targetVecTy.getElementType() != vecTy.getElementType()) {
1492 cast = vector::BitCastOp::create(builder, loc, targetVecTy, cast)
1504 auto vectorToSingleElementMaterializationCast =
1505 [](OpBuilder &builder, Type type,
ValueRange inputs,
1506 Location loc) -> Value {
1507 if (inputs.size() != 1)
1509 auto input = inputs.front();
1510 if (
auto vecTy = dyn_cast<VectorType>(input.getType())) {
1512 auto rank = vecTy.getRank();
1513 if (rank != 0 && vecTy.getNumElements() != 1)
1515 auto inElemTy = vecTy.getElementType();
1519 cast = vector::ExtractOp::create(builder, loc, cast, {}).getResult();
1521 cast = vector::ExtractOp::create(builder, loc, cast,
1522 SmallVector<int64_t>(rank, 0))
1529 if (inElemTy.isIndex()) {
1530 cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
1532 }
else if (inElemTy != type) {
1533 cast = arith::BitcastOp::create(builder, loc, type, cast).getResult();
1547 auto singleElementToVectorMaterializationCast =
1548 [](OpBuilder &builder, Type type,
ValueRange inputs,
1549 Location loc) -> Value {
1550 if (inputs.size() != 1)
1552 auto input = inputs.front();
1553 auto inTy = input.getType();
1554 if (!inTy.isIntOrFloat())
1558 if (
auto vecTy = dyn_cast<VectorType>(type)) {
1559 if (vecTy.getRank() != 0 && vecTy.getNumElements() != 1)
1561 auto outElemTy = vecTy.getElementType();
1563 if (outElemTy.isIndex()) {
1564 cast = arith::IndexCastUIOp::create(builder, loc,
1567 }
else if (inTy != outElemTy) {
1568 cast = arith::BitcastOp::create(builder, loc, outElemTy, cast)
1571 return vector::BroadcastOp::create(builder, loc, vecTy, cast)
1576 typeConverter.addSourceMaterialization(
1577 singleElementToVectorMaterializationCast);
1578 typeConverter.addSourceMaterialization(vectorToVectorMaterializationCast);
1579 typeConverter.addTargetMaterialization(memrefToIntMaterializationCast);
1580 typeConverter.addTargetMaterialization(ui32ToI32MaterializationCast);
1581 typeConverter.addTargetMaterialization(ui64ToI64MaterializationCast);
1582 typeConverter.addTargetMaterialization(
1583 vectorToSingleElementMaterializationCast);
1584 typeConverter.addTargetMaterialization(vectorToVectorMaterializationCast);
1585 ConversionTarget
target(*context);
1586 target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
1587 vector::VectorDialect, arith::ArithDialect,
1588 memref::MemRefDialect, gpu::GPUDialect,
1589 index::IndexDialect>();
1590 target.addIllegalDialect<xegpu::XeGPUDialect>();
1593 target.addDynamicallyLegalOp<arith::ExtFOp>(
1594 [](arith::ExtFOp op) {
return !isXeVMExtf(op); });
1595 target.addDynamicallyLegalOp<arith::TruncFOp>(
1596 [](arith::TruncFOp op) {
return !isXeVMTruncf(op); });
1598 RewritePatternSet patterns(context);
1602 if (
failed(applyPartialConversion(getOperation(),
target,
1603 std::move(patterns))))
1604 signalPassFailure();
1614 patterns.
add<CreateNdDescToXeVMPattern,
1615 LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
1616 LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
1617 LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>(
1619 patterns.
add<AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
1620 LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
1621 LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
1623 patterns.
add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>,
1624 LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>,
1625 CreateMemDescOpPattern>(typeConverter, patterns.
getContext());
1626 patterns.
add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
1628 patterns.
add<DpasMxToXeVMPattern>(typeConverter, patterns.
getContext());
1629 patterns.
add<ExtfToXeVMPattern, TruncfToXeVMPattern>(typeConverter,
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
@ SubgroupMatrixMultiplyAcc
const uArch * getUArch(llvm::StringRef archName)
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
void populateXeGPUToXeVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)