28#include "llvm/ADT/STLExtras.h"
29#include "llvm/Support/FormatVariadic.h"
34#include "llvm/ADT/TypeSwitch.h"
39#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
40#include "mlir/Conversion/Passes.h.inc"
48static constexpr int32_t systolicDepth{8};
49static constexpr int32_t executionSize{16};
52enum class NdTdescOffset : uint32_t {
59static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
60 switch (xeGpuMemspace) {
61 case xegpu::MemorySpace::Global:
62 return static_cast<int>(xevm::AddrSpace::GLOBAL);
63 case xegpu::MemorySpace::SLM:
64 return static_cast<int>(xevm::AddrSpace::SHARED);
66 llvm_unreachable(
"Unknown XeGPU memory space");
70static bool isSharedMemRef(
const MemRefType &memrefTy) {
71 Attribute attr = memrefTy.getMemorySpace();
74 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
75 return intAttr.getInt() ==
static_cast<int>(xevm::AddrSpace::SHARED);
76 if (
auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
77 return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
78 return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
82static VectorType encodeVectorTypeTo(VectorType currentVecType,
84 auto elemType = currentVecType.getElementType();
85 auto currentBitWidth = elemType.getIntOrFloatBitWidth();
88 currentVecType.getNumElements() * currentBitWidth / newBitWidth;
89 return VectorType::get(size, toElemType);
92static xevm::LoadCacheControl
93translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
94 std::optional<xegpu::CachePolicy> L3hint) {
95 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
96 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
98 case xegpu::CachePolicy::CACHED:
99 if (L3hintVal == xegpu::CachePolicy::CACHED)
100 return xevm::LoadCacheControl::L1C_L2UC_L3C;
101 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
102 return xevm::LoadCacheControl::L1C_L2UC_L3UC;
104 llvm_unreachable(
"Unsupported cache control.");
105 case xegpu::CachePolicy::UNCACHED:
106 if (L3hintVal == xegpu::CachePolicy::CACHED)
107 return xevm::LoadCacheControl::L1UC_L2UC_L3C;
108 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
109 return xevm::LoadCacheControl::L1UC_L2UC_L3UC;
111 llvm_unreachable(
"Unsupported cache control.");
112 case xegpu::CachePolicy::STREAMING:
113 if (L3hintVal == xegpu::CachePolicy::CACHED)
114 return xevm::LoadCacheControl::L1S_L2UC_L3C;
115 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
116 return xevm::LoadCacheControl::L1S_L2UC_L3UC;
118 llvm_unreachable(
"Unsupported cache control.");
119 case xegpu::CachePolicy::READ_INVALIDATE:
120 return xevm::LoadCacheControl::INVALIDATE_READ;
122 llvm_unreachable(
"Unsupported cache control.");
126static xevm::StoreCacheControl
127translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
128 std::optional<xegpu::CachePolicy> L3hint) {
129 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
130 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
132 case xegpu::CachePolicy::UNCACHED:
133 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
134 return xevm::StoreCacheControl::L1UC_L2UC_L3UC;
135 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
136 return xevm::StoreCacheControl::L1UC_L2UC_L3WB;
138 llvm_unreachable(
"Unsupported cache control.");
139 case xegpu::CachePolicy::STREAMING:
140 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
141 return xevm::StoreCacheControl::L1S_L2UC_L3UC;
142 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
143 return xevm::StoreCacheControl::L1S_L2UC_L3WB;
145 llvm_unreachable(
"Unsupported cache control.");
146 case xegpu::CachePolicy::WRITE_BACK:
147 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
148 return xevm::StoreCacheControl::L1WB_L2UC_L3UC;
149 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
150 return xevm::StoreCacheControl::L1WB_L2UC_L3WB;
152 llvm_unreachable(
"Unsupported cache control.");
153 case xegpu::CachePolicy::WRITE_THROUGH:
154 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
155 return xevm::StoreCacheControl::L1WT_L2UC_L3UC;
156 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
157 return xevm::StoreCacheControl::L1WT_L2UC_L3WB;
159 llvm_unreachable(
"Unsupported cache control.");
161 llvm_unreachable(
"Unsupported cache control.");
173class CreateNdDescToXeVMPattern
174 :
public OpConversionPattern<xegpu::CreateNdDescOp> {
175 using OpConversionPattern::OpConversionPattern;
177 matchAndRewrite(xegpu::CreateNdDescOp op,
178 xegpu::CreateNdDescOp::Adaptor adaptor,
179 ConversionPatternRewriter &rewriter)
const override {
180 SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
181 if (mixedOffsets.size() != 0)
182 return rewriter.notifyMatchFailure(op,
"Offsets not supported.");
183 auto loc = op.getLoc();
184 auto source = op.getSource();
188 Type payloadElemTy = rewriter.getI32Type();
189 VectorType payloadTy = VectorType::get(8, payloadElemTy);
190 Type i64Ty = rewriter.getI64Type();
192 VectorType payloadI64Ty = VectorType::get(4, i64Ty);
194 Value payload = arith::ConstantOp::create(
203 SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
204 SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides();
206 int64_t rank = mixedSizes.size();
207 auto sourceTy = source.getType();
208 auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
211 if (sourceMemrefTy) {
212 if (!sourceMemrefTy.hasRank()) {
213 return rewriter.notifyMatchFailure(op,
"Expected ranked Memref.");
217 baseAddr = adaptor.getSource();
219 baseAddr = adaptor.getSource();
220 if (baseAddr.
getType() != i64Ty) {
222 baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
227 rewriter.replaceOp(op, baseAddr);
231 auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
232 unsigned idx) -> Value {
238 baseShapeW = createOffset(mixedSizes, 1);
239 baseShapeH = createOffset(mixedSizes, 0);
241 Value basePitch = createOffset(mixedStrides, 0);
244 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
246 vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
247 static_cast<int>(NdTdescOffset::BasePtr));
248 payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
250 vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
251 static_cast<int>(NdTdescOffset::BaseShapeW));
253 vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
254 static_cast<int>(NdTdescOffset::BaseShapeH));
256 vector::InsertOp::create(rewriter, loc, basePitch, payload,
257 static_cast<int>(NdTdescOffset::BasePitch));
258 rewriter.replaceOp(op, payload);
265 typename = std::enable_if_t<llvm::is_one_of<
266 OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp>::value>>
267class LoadStorePrefetchNdToXeVMPattern :
public OpConversionPattern<OpType> {
268 using OpConversionPattern<OpType>::OpConversionPattern;
270 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
271 ConversionPatternRewriter &rewriter)
const override {
272 auto mixedOffsets = op.getMixedOffsets();
273 int64_t opOffsetsSize = mixedOffsets.size();
274 auto loc = op.getLoc();
275 auto ctxt = rewriter.getContext();
277 auto tdesc = adaptor.getTensorDesc();
278 auto tdescTy = op.getTensorDescType();
279 auto tileRank = tdescTy.getRank();
280 if (opOffsetsSize != tileRank)
281 return rewriter.notifyMatchFailure(
282 op,
"Expected offset rank to match descriptor rank.");
283 auto elemType = tdescTy.getElementType();
284 auto elemBitSize = elemType.getIntOrFloatBitWidth();
285 bool isSubByte = elemBitSize < 8;
286 uint64_t wScaleFactor = 1;
288 if (!isSubByte && (elemBitSize % 8 != 0))
289 return rewriter.notifyMatchFailure(
290 op,
"Expected element type bit width to be multiple of 8.");
291 auto tileW = tdescTy.getDimSize(tileRank - 1);
294 if (elemBitSize != 4)
295 return rewriter.notifyMatchFailure(
296 op,
"Only sub byte types of 4bits are supported.");
298 return rewriter.notifyMatchFailure(
299 op,
"Sub byte types are only supported for 2D tensor descriptors.");
300 auto subByteFactor = 8 / elemBitSize;
301 auto tileH = tdescTy.getDimSize(0);
303 if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
304 if (op.getPacked().value_or(
false)) {
306 if (tileH == systolicDepth * 4 &&
307 tileW == executionSize * subByteFactor) {
312 elemType = rewriter.getIntegerType(8);
313 tileW = executionSize;
314 wScaleFactor = subByteFactor;
319 if (wScaleFactor == 1) {
320 auto sub16BitFactor = subByteFactor * 2;
321 if (tileW == executionSize * sub16BitFactor) {
325 elemType = rewriter.getIntegerType(16);
326 tileW = executionSize;
327 wScaleFactor = sub16BitFactor;
329 return rewriter.notifyMatchFailure(
330 op,
"Unsupported tile shape for sub byte types.");
334 elemBitSize = elemType.getIntOrFloatBitWidth();
338 auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
339 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
343 rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
344 VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
346 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
348 vector::ExtractOp::create(rewriter, loc, payLoadAsI64,
349 static_cast<int>(NdTdescOffset::BasePtr));
350 Value baseShapeW = vector::ExtractOp::create(
351 rewriter, loc, tdesc,
static_cast<int>(NdTdescOffset::BaseShapeW));
352 Value baseShapeH = vector::ExtractOp::create(
353 rewriter, loc, tdesc,
static_cast<int>(NdTdescOffset::BaseShapeH));
354 Value basePitch = vector::ExtractOp::create(
355 rewriter, loc, tdesc,
static_cast<int>(NdTdescOffset::BasePitch));
361 rewriter.getI32Type(), offsetW);
365 rewriter.getI32Type(), offsetH);
368 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
372 Value baseShapeWInBytes =
373 arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
375 Value basePitchBytes =
376 arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize);
378 if (wScaleFactor > 1) {
382 rewriter, loc, rewriter.getI32Type(), llvm::Log2_64(wScaleFactor));
383 baseShapeWInBytes = arith::ShRSIOp::create(
384 rewriter, loc, baseShapeWInBytes, wScaleFactorValLog2);
385 basePitchBytes = arith::ShRSIOp::create(rewriter, loc, basePitchBytes,
386 wScaleFactorValLog2);
388 arith::ShRSIOp::create(rewriter, loc, offsetW, wScaleFactorValLog2);
391 auto tileH = tdescTy.getDimSize(0);
393 int32_t vblocks = tdescTy.getArrayLength();
394 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
395 Value src = adaptor.getValue();
401 VectorType srcVecTy = dyn_cast<VectorType>(src.
getType());
403 return rewriter.notifyMatchFailure(
404 op,
"Expected store value to be a vector type.");
406 VectorType newSrcVecTy =
407 encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
408 if (srcVecTy != newSrcVecTy)
409 src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
410 auto storeCacheControl =
411 translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
412 xevm::BlockStore2dOp::create(
413 rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
414 basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, src,
415 xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
416 rewriter.eraseOp(op);
418 auto loadCacheControl =
419 translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
420 if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
421 xevm::BlockPrefetch2dOp::create(
422 rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
423 basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH,
424 vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
425 rewriter.eraseOp(op);
427 VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
428 const bool vnni = op.getPacked().value_or(
false);
429 auto transposeValue = op.getTranspose();
431 transposeValue.has_value() && transposeValue.value()[0] == 1;
432 VectorType loadedTy = encodeVectorTypeTo(
433 dstVecTy, vnni ? rewriter.getI32Type()
434 : rewriter.getIntegerType(elemBitSize));
436 Value resultFlatVec = xevm::BlockLoad2dOp::create(
437 rewriter, loc, loadedTy, basePtrLLVM, baseShapeWInBytes,
438 baseShapeH, basePitchBytes, offsetW, offsetH, elemBitSize, tileW,
439 tileH, vblocks, transpose, vnni,
440 xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
441 resultFlatVec = vector::BitCastOp::create(
443 encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
445 rewriter.replaceOp(op, resultFlatVec);
457 rewriter.getI64Type(), offset);
460 rewriter, loc, rewriter.getI64Type(), elemBitSize / 8);
462 rewriter.createOrFold<arith::MulIOp>(loc, offset, elemByteSize);
464 Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>(
470 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64);
471 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
472 Value src = adaptor.getValue();
478 VectorType srcVecTy = dyn_cast<VectorType>(src.
getType());
480 return rewriter.notifyMatchFailure(
481 op,
"Expected store value to be a vector type.");
483 VectorType newSrcVecTy =
484 encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
485 if (srcVecTy != newSrcVecTy)
486 src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
487 auto storeCacheControl =
488 translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
489 rewriter.replaceOpWithNewOp<xevm::BlockStoreOp>(
490 op, finalPtrLLVM, src,
491 xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
492 }
else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
493 auto loadCacheControl =
494 translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
495 VectorType resTy = cast<VectorType>(op.getValue().getType());
496 VectorType loadedTy =
497 encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize));
498 Value
load = xevm::BlockLoadOp::create(
499 rewriter, loc, loadedTy, finalPtrLLVM,
500 xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
501 if (loadedTy != resTy)
502 load = vector::BitCastOp::create(rewriter, loc, resTy,
load);
503 rewriter.replaceOp(op,
load);
505 return rewriter.notifyMatchFailure(
506 op,
"Unsupported operation: xegpu.prefetch_nd with tensor "
507 "descriptor rank == 1");
516static Value addOffsetToBaseAddr(ConversionPatternRewriter &rewriter,
520 rewriter, loc, baseAddr.
getType(), elemByteSize);
521 Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
522 Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
526template <
typename OpType,
527 typename = std::enable_if_t<llvm::is_one_of<
528 OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
529class LoadStoreToXeVMPattern :
public OpConversionPattern<OpType> {
530 using OpConversionPattern<OpType>::OpConversionPattern;
532 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
533 ConversionPatternRewriter &rewriter)
const override {
534 Value offset = adaptor.getOffsets();
536 return rewriter.notifyMatchFailure(op,
"Expected offset to be provided.");
537 auto loc = op.getLoc();
538 auto ctxt = rewriter.getContext();
539 auto tdescTy = op.getTensorDescType();
543 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
545 this->getTypeConverter()->convertType(op.getResult().getType());
547 valOrResTy = adaptor.getValue().getType();
548 VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
549 bool hasScalarVal = !valOrResVecTy;
550 int64_t elemBitWidth =
552 : valOrResVecTy.getElementType().getIntOrFloatBitWidth();
554 if (elemBitWidth % 8 != 0)
555 return rewriter.notifyMatchFailure(
556 op,
"Expected element type bit width to be multiple of 8.");
557 int64_t elemByteSize = elemBitWidth / 8;
559 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
560 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
563 ptrTypeLLVM = LLVM::LLVMPointerType::get(
564 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
567 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
568 basePtrI64 = adaptor.getSource();
569 if (
auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
570 auto addrSpace = memRefTy.getMemorySpaceAsInt();
572 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
575 basePtrI64 = adaptor.getDest();
576 if (
auto memRefTy = dyn_cast<MemRefType>(op.getDest().getType())) {
577 auto addrSpace = memRefTy.getMemorySpaceAsInt();
579 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
583 if (basePtrI64.
getType() != rewriter.getI64Type()) {
584 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
587 Value mask = adaptor.getMask();
588 if (dyn_cast<VectorType>(offset.
getType())) {
591 return rewriter.notifyMatchFailure(op,
"Expected offset to be a scalar.");
597 addOffsetToBaseAddr(rewriter, loc, basePtrI64, offset, elemByteSize);
601 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
604 VectorType maskVecTy = dyn_cast<VectorType>(mask.
getType());
608 return rewriter.notifyMatchFailure(op,
"Expected mask to be a scalar.");
611 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
612 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy},
613 maskForLane,
true,
true);
615 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
617 valOrResTy = VectorType::get({valOrResVecTy.getNumElements()},
618 valOrResVecTy.getElementType());
620 LLVM::LoadOp::create(rewriter, loc, valOrResTy, basePtrLLVM);
623 "cache_control", xevm::LoadCacheControlAttr::get(
624 ctxt, translateLoadXeGPUCacheHint(
625 op.getL1Hint(), op.getL3Hint())));
626 scf::YieldOp::create(rewriter, loc,
ValueRange{loaded});
627 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
629 auto eTy = hasScalarVal ? valOrResTy : valOrResVecTy.getElementType();
632 eVal = FloatAttr::get(eTy, 0.0);
634 eVal = IntegerAttr::get(eTy, 0);
636 loaded = arith::ConstantOp::create(rewriter, loc, eVal);
638 loaded = arith::ConstantOp::create(
640 scf::YieldOp::create(rewriter, loc,
ValueRange{loaded});
641 rewriter.replaceOp(op, ifOp.getResult(0));
644 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane,
false);
645 auto body = ifOp.getBody();
646 rewriter.setInsertionPointToStart(body);
648 LLVM::StoreOp::create(rewriter, loc, adaptor.getValue(), basePtrLLVM);
650 storeOp.getOperation()->setAttr(
651 "cache_control", xevm::StoreCacheControlAttr::get(
652 ctxt, translateStoreXeGPUCacheHint(
653 op.getL1Hint(), op.getL3Hint())));
654 rewriter.eraseOp(op);
660class CreateMemDescOpPattern final
661 :
public OpConversionPattern<xegpu::CreateMemDescOp> {
663 using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern;
665 matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
666 ConversionPatternRewriter &rewriter)
const override {
668 rewriter.replaceOp(op, adaptor.getSource());
673template <
typename OpType,
674 typename = std::enable_if_t<llvm::is_one_of<
675 OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
676class LoadStoreMatrixToXeVMPattern :
public OpConversionPattern<OpType> {
677 using OpConversionPattern<OpType>::OpConversionPattern;
679 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
680 ConversionPatternRewriter &rewriter)
const override {
682 SmallVector<OpFoldResult> offsets = op.getMixedOffsets();
684 return rewriter.notifyMatchFailure(op,
"Expected offset to be provided.");
686 auto loc = op.getLoc();
687 auto ctxt = rewriter.getContext();
688 Value baseAddr32 = adaptor.getMemDesc();
689 Value mdescVal = op.getMemDesc();
692 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
693 Type resType = op.getResult().getType();
696 if (
auto vecType = dyn_cast<VectorType>(resType)) {
697 assert(llvm::count_if(vecType.getShape(),
698 [](int64_t d) { return d != 1; }) <= 1 &&
699 "Expected either 1D vector or nD with unit dimensions");
700 resType = VectorType::get({vecType.getNumElements()},
701 vecType.getElementType());
705 dataTy = adaptor.getData().getType();
706 VectorType valOrResVecTy = dyn_cast<VectorType>(dataTy);
708 valOrResVecTy = VectorType::get(1, dataTy);
710 int64_t elemBitWidth =
711 valOrResVecTy.getElementType().getIntOrFloatBitWidth();
713 if (elemBitWidth % 8 != 0)
714 return rewriter.notifyMatchFailure(
715 op,
"Expected element type bit width to be multiple of 8.");
716 int64_t elemByteSize = elemBitWidth / 8;
719 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
720 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM));
722 auto mdescTy = cast<xegpu::MemDescType>(mdescVal.
getType());
724 Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
725 linearOffset = arith::IndexCastUIOp::create(
726 rewriter, loc, rewriter.getI32Type(), linearOffset);
727 Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32,
728 linearOffset, elemByteSize);
732 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);
734 if (op.getSubgroupBlockIoAttr()) {
738 Type intElemTy = rewriter.getIntegerType(elemBitWidth);
739 VectorType intVecTy =
740 VectorType::get(valOrResVecTy.getShape(), intElemTy);
742 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
744 xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM);
745 if (intVecTy != valOrResVecTy) {
747 vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp);
749 rewriter.replaceOp(op, loadOp);
751 Value dataToStore = adaptor.getData();
752 if (valOrResVecTy != intVecTy) {
754 vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore);
756 xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore,
758 rewriter.eraseOp(op);
763 if (valOrResVecTy.getNumElements() >= 1) {
765 if (!chipOpt || (*chipOpt !=
"pvc" && *chipOpt !=
"bmg")) {
767 return rewriter.notifyMatchFailure(
768 op,
"The lowering is specific to pvc or bmg.");
772 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
776 auto scalarTy = valOrResVecTy.getElementType();
778 if (valOrResVecTy.getNumElements() == 1)
779 loadOp = LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
782 LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
783 rewriter.replaceOp(op, loadOp);
785 LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
786 rewriter.eraseOp(op);
792class PrefetchToXeVMPattern :
public OpConversionPattern<xegpu::PrefetchOp> {
793 using OpConversionPattern::OpConversionPattern;
795 matchAndRewrite(xegpu::PrefetchOp op, xegpu::PrefetchOp::Adaptor adaptor,
796 ConversionPatternRewriter &rewriter)
const override {
797 auto loc = op.getLoc();
798 auto ctxt = rewriter.getContext();
799 auto tdescTy = op.getTensorDescType();
800 Value basePtrI64 = adaptor.getSource();
802 if (basePtrI64.
getType() != rewriter.getI64Type())
803 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
805 Value offsets = adaptor.getOffsets();
807 VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.
getType());
810 return rewriter.notifyMatchFailure(op,
811 "Expected offsets to be a scalar.");
813 int64_t elemBitWidth{0};
814 int64_t elemByteSize;
819 elemBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth();
820 }
else if (
auto memRefTy = dyn_cast<MemRefType>(op.getSourceType())) {
823 elemBitWidth = memRefTy.getElementType().getIntOrFloatBitWidth();
826 elemByteSize = *op.getOffsetAlignByte();
828 if (elemBitWidth != 0) {
829 if (elemBitWidth % 8 != 0)
830 return rewriter.notifyMatchFailure(
831 op,
"Expected element type bit width to be multiple of 8.");
832 elemByteSize = elemBitWidth / 8;
834 basePtrI64 = addOffsetToBaseAddr(rewriter, loc, basePtrI64, offsets,
839 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
840 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
843 ptrTypeLLVM = LLVM::LLVMPointerType::get(
844 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
846 if (
auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
847 auto addrSpace = memRefTy.getMemorySpaceAsInt();
849 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
853 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
855 xevm::PrefetchOp::create(
856 rewriter, loc, ptrLLVM,
857 xevm::LoadCacheControlAttr::get(
858 ctxt, translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint())));
859 rewriter.eraseOp(op);
864class FenceToXeVMPattern :
public OpConversionPattern<xegpu::FenceOp> {
865 using OpConversionPattern::OpConversionPattern;
867 matchAndRewrite(xegpu::FenceOp op, xegpu::FenceOp::Adaptor adaptor,
868 ConversionPatternRewriter &rewriter)
const override {
869 auto loc = op.getLoc();
870 xevm::MemScope memScope{xevm::MemScope::WORKGROUP};
871 switch (op.getFenceScope()) {
872 case xegpu::FenceScope::Workgroup:
873 memScope = xevm::MemScope::WORKGROUP;
875 case xegpu::FenceScope::GPU:
876 memScope = xevm::MemScope::DEVICE;
879 xevm::AddrSpace addrSpace{xevm::AddrSpace::GLOBAL};
880 switch (op.getMemoryKind()) {
881 case xegpu::MemorySpace::Global:
882 addrSpace = xevm::AddrSpace::GLOBAL;
884 case xegpu::MemorySpace::SLM:
885 addrSpace = xevm::AddrSpace::SHARED;
888 xevm::MemfenceOp::create(rewriter, loc, memScope, addrSpace);
889 rewriter.eraseOp(op);
894class DpasToXeVMPattern :
public OpConversionPattern<xegpu::DpasOp> {
895 using OpConversionPattern::OpConversionPattern;
897 matchAndRewrite(xegpu::DpasOp op, xegpu::DpasOp::Adaptor adaptor,
898 ConversionPatternRewriter &rewriter)
const override {
899 auto loc = op.getLoc();
900 auto ctxt = rewriter.getContext();
901 auto aTy = cast<VectorType>(op.getLhs().getType());
902 auto bTy = cast<VectorType>(op.getRhs().getType());
903 auto resultType = cast<VectorType>(op.getResultType());
905 auto encodePrecision = [&](Type type) -> xevm::ElemType {
906 if (type == rewriter.getBF16Type())
907 return xevm::ElemType::BF16;
908 else if (type == rewriter.getF16Type())
909 return xevm::ElemType::F16;
910 else if (type == rewriter.getTF32Type())
911 return xevm::ElemType::TF32;
912 else if (type.isInteger(8)) {
913 if (type.isUnsignedInteger())
914 return xevm::ElemType::U8;
915 return xevm::ElemType::S8;
916 }
else if (type == rewriter.getF32Type())
917 return xevm::ElemType::F32;
918 else if (type.isInteger(32))
919 return xevm::ElemType::S32;
920 llvm_unreachable(
"add more support for ElemType");
922 xevm::ElemType precATy = encodePrecision(aTy.getElementType());
923 xevm::ElemType precBTy = encodePrecision(bTy.getElementType());
924 Value c = op.getAcc();
926 auto elementTy = resultType.getElementType();
927 Attribute initValueAttr;
928 if (isa<FloatType>(elementTy))
929 initValueAttr = FloatAttr::get(elementTy, 0.0);
931 initValueAttr = IntegerAttr::get(elementTy, 0);
932 c = arith::ConstantOp::create(
936 Value aVec = op.getLhs();
937 Value bVec = op.getRhs();
938 auto cvecty = cast<VectorType>(c.
getType());
939 xevm::ElemType precCTy = encodePrecision(cvecty.getElementType());
940 xevm::ElemType precDTy = encodePrecision(resultType.getElementType());
942 VectorType::get(cvecty.getNumElements(), cvecty.getElementType());
944 c = vector::ShapeCastOp::create(rewriter, loc, cNty, c);
945 Value dpasRes = xevm::MMAOp::create(
946 rewriter, loc, cNty, aVec, bVec, c,
947 xevm::MMAShapeAttr::get(ctxt, cvecty.getNumElements(), executionSize,
949 getNumOperandsPerDword(precATy)),
950 xevm::MMATypesAttr::get(ctxt, precDTy, precATy, precBTy, precCTy));
952 dpasRes = vector::ShapeCastOp::create(rewriter, loc, resultType, dpasRes);
953 rewriter.replaceOp(op, dpasRes);
958 static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
960 case xevm::ElemType::TF32:
962 case xevm::ElemType::BF16:
963 case xevm::ElemType::F16:
965 case xevm::ElemType::U8:
966 case xevm::ElemType::S8:
969 llvm_unreachable(
"unsupported xevm::ElemType");
974static std::optional<LLVM::AtomicBinOp>
975matchSimpleAtomicOp(arith::AtomicRMWKind arithKind) {
977 case arith::AtomicRMWKind::addf:
978 return LLVM::AtomicBinOp::fadd;
979 case arith::AtomicRMWKind::addi:
980 return LLVM::AtomicBinOp::add;
981 case arith::AtomicRMWKind::assign:
982 return LLVM::AtomicBinOp::xchg;
983 case arith::AtomicRMWKind::maximumf:
984 return LLVM::AtomicBinOp::fmax;
985 case arith::AtomicRMWKind::maxs:
986 return LLVM::AtomicBinOp::max;
987 case arith::AtomicRMWKind::maxu:
988 return LLVM::AtomicBinOp::umax;
989 case arith::AtomicRMWKind::minimumf:
990 return LLVM::AtomicBinOp::fmin;
991 case arith::AtomicRMWKind::mins:
992 return LLVM::AtomicBinOp::min;
993 case arith::AtomicRMWKind::minu:
994 return LLVM::AtomicBinOp::umin;
995 case arith::AtomicRMWKind::ori:
996 return LLVM::AtomicBinOp::_or;
997 case arith::AtomicRMWKind::andi:
998 return LLVM::AtomicBinOp::_and;
1000 return std::nullopt;
1004class AtomicRMWToXeVMPattern :
public OpConversionPattern<xegpu::AtomicRMWOp> {
1005 using OpConversionPattern::OpConversionPattern;
1007 matchAndRewrite(xegpu::AtomicRMWOp op, xegpu::AtomicRMWOp::Adaptor adaptor,
1008 ConversionPatternRewriter &rewriter)
const override {
1009 auto loc = op.getLoc();
1010 auto ctxt = rewriter.getContext();
1011 auto tdesc = op.getTensorDesc().getType();
1012 auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
1013 ctxt, getNumericXeVMAddrSpace(tdesc.getMemorySpace()));
1014 Value basePtrI64 = arith::IndexCastOp::create(
1015 rewriter, loc, rewriter.getI64Type(), adaptor.getTensorDesc());
1017 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
1018 VectorType srcOrDstVecTy = cast<VectorType>(op.getValue().getType());
1019 VectorType srcOrDstFlatVecTy = VectorType::get(
1020 srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
1021 Value srcFlatVec = vector::ShapeCastOp::create(
1022 rewriter, loc, srcOrDstFlatVecTy, op.getValue());
1023 auto atomicKind = matchSimpleAtomicOp(op.getKind());
1024 assert(atomicKind.has_value());
1025 Value resVec = srcFlatVec;
1026 for (
int i = 0; i < srcOrDstVecTy.getNumElements(); i++) {
1027 auto val = vector::ExtractOp::create(rewriter, loc, resVec, i);
1028 Value idx = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
1029 rewriter.getIndexAttr(i));
1031 LLVM::GEPOp::create(rewriter, loc, ptrTypeLLVM,
1032 srcOrDstVecTy.getElementType(), basePtrLLVM, idx);
1034 LLVM::AtomicRMWOp::create(rewriter, loc, atomicKind.value(), currPtr,
1035 val, LLVM::AtomicOrdering::seq_cst);
1036 resVec = vector::InsertOp::create(rewriter, loc, newVal, resVec, i);
1038 rewriter.replaceOp(op, resVec);
1047struct ConvertXeGPUToXeVMPass
1051 void runOnOperation()
override {
1052 LLVMTypeConverter typeConverter(&
getContext());
1053 typeConverter.addConversion([&](VectorType type) -> Type {
1054 unsigned rank = type.getRank();
1055 auto elemType = type.getElementType();
1057 if (llvm::isa<IndexType>(elemType))
1058 elemType = IntegerType::get(&
getContext(), 64);
1060 if (rank < 1 || type.getNumElements() == 1)
1063 int64_t sum = llvm::product_of(type.getShape());
1064 return VectorType::get(sum, elemType);
1066 typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
1068 if (type.isScattered())
1070 if (type.getRank() == 1)
1072 auto i32Type = IntegerType::get(&
getContext(), 32);
1073 return VectorType::get(8, i32Type);
1076 typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
1080 typeConverter.addConversion([&](MemRefType type) -> Type {
1081 return IntegerType::get(&
getContext(), (isSharedMemRef(type) ? 32 : 64));
1088 auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
1090 Location loc) -> Value {
1091 if (inputs.size() != 1)
1093 auto input = inputs.front();
1094 if (
auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
1095 unsigned rank = memrefTy.getRank();
1099 SmallVector<int64_t> intStrides;
1102 if (succeeded(memrefTy.getStridesAndOffset(intStrides, intOffsets)) &&
1103 ShapedType::isStatic(intOffsets)) {
1104 addr = memref::ExtractAlignedPointerAsIndexOp::create(builder, loc,
1106 offset = arith::ConstantOp::create(builder, loc,
1112 SmallVector<Type> resultTypes{
1113 MemRefType::get({}, memrefTy.getElementType(),
1114 MemRefLayoutAttrInterface(),
1115 memrefTy.getMemorySpace()),
1118 resultTypes.append(2 * rank, indexType);
1120 auto meta = memref::ExtractStridedMetadataOp::create(
1121 builder, loc, resultTypes, input);
1123 addr = memref::ExtractAlignedPointerAsIndexOp::create(
1124 builder, loc, meta.getBaseBuffer());
1125 offset = meta.getOffset();
1129 arith::IndexCastUIOp::create(builder, loc, type, addr);
1131 arith::IndexCastUIOp::create(builder, loc, type, offset);
1134 auto byteSize = arith::ConstantOp::create(
1137 memrefTy.getElementTypeBitWidth() / 8));
1139 arith::MulIOp::create(builder, loc, offsetCasted, byteSize);
1140 auto addrWithOffset =
1141 arith::AddIOp::create(builder, loc, addrCasted, byteOffset);
1143 return addrWithOffset.getResult();
1149 auto ui64MaterializationCast = [](OpBuilder &builder, Type type,
1151 Location loc) -> Value {
1152 if (inputs.size() != 1)
1154 auto input = inputs.front();
1157 index::CastUOp::create(builder, loc, builder.
getIndexType(), input)
1159 return arith::IndexCastUIOp::create(builder, loc, type, cast)
1166 auto ui32MaterializationCast = [](OpBuilder &builder, Type type,
1168 Location loc) -> Value {
1169 if (inputs.size() != 1)
1171 auto input = inputs.front();
1174 index::CastUOp::create(builder, loc, builder.
getIndexType(), input)
1176 return arith::IndexCastUIOp::create(builder, loc, type, cast)
1186 auto vectorMaterializationCast = [](OpBuilder &builder, Type type,
1188 Location loc) -> Value {
1189 if (inputs.size() != 1)
1191 auto input = inputs.front();
1192 if (
auto vecTy = dyn_cast<VectorType>(input.getType())) {
1193 if (vecTy.getNumElements() == 1) {
1196 vector::ExtractOp::create(builder, loc, input, 0).getResult();
1198 cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
1201 }
else if (
auto targetVecTy = dyn_cast<VectorType>(type)) {
1204 if (targetVecTy.getRank() == vecTy.getRank())
1205 return vector::BitCastOp::create(builder, loc, targetVecTy, input)
1207 else if (targetVecTy.getElementType() == vecTy.getElementType()) {
1210 return vector::ShapeCastOp::create(builder, loc, targetVecTy, input)
1221 auto singleElementVectorMaterializationCast =
1222 [](OpBuilder &builder, Type type,
ValueRange inputs,
1223 Location loc) -> Value {
1224 if (inputs.size() != 1)
1226 auto input = inputs.front();
1227 if (input.getType().isIntOrIndexOrFloat()) {
1230 if (
auto vecTy = dyn_cast<VectorType>(type)) {
1231 if (vecTy.getNumElements() == 1) {
1232 return vector::BroadcastOp::create(builder, loc, vecTy, input)
1239 typeConverter.addSourceMaterialization(
1240 singleElementVectorMaterializationCast);
1241 typeConverter.addSourceMaterialization(vectorMaterializationCast);
1242 typeConverter.addTargetMaterialization(memrefMaterializationCast);
1243 typeConverter.addTargetMaterialization(ui32MaterializationCast);
1244 typeConverter.addTargetMaterialization(ui64MaterializationCast);
1245 typeConverter.addTargetMaterialization(vectorMaterializationCast);
1247 target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
1248 vector::VectorDialect, arith::ArithDialect,
1249 memref::MemRefDialect, gpu::GPUDialect,
1250 index::IndexDialect>();
1251 target.addIllegalDialect<xegpu::XeGPUDialect>();
1257 if (
failed(applyPartialConversion(getOperation(),
target,
1259 signalPassFailure();
1269 patterns.add<CreateNdDescToXeVMPattern,
1270 LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
1271 LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
1272 LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>(
1273 typeConverter,
patterns.getContext());
1274 patterns.add<AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
1275 LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
1276 LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
1277 typeConverter,
patterns.getContext());
1278 patterns.add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>,
1279 LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>,
1280 CreateMemDescOpPattern>(typeConverter,
patterns.getContext());
1281 patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
const FrozenRewritePatternSet & patterns
void populateXeGPUToXeVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)