29#include "llvm/ADT/STLExtras.h"
30#include "llvm/Support/FormatVariadic.h"
35#include "llvm/ADT/TypeSwitch.h"
40#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
41#include "mlir/Conversion/Passes.h.inc"
49static constexpr int32_t systolicDepth{8};
50static constexpr int32_t executionSize{16};
53enum class NdTdescOffset : uint32_t {
60static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
61 switch (xeGpuMemspace) {
62 case xegpu::MemorySpace::Global:
63 return static_cast<int>(xevm::AddrSpace::GLOBAL);
64 case xegpu::MemorySpace::SLM:
65 return static_cast<int>(xevm::AddrSpace::SHARED);
67 llvm_unreachable(
"Unknown XeGPU memory space");
71static bool isSharedMemRef(
const MemRefType &memrefTy) {
72 Attribute attr = memrefTy.getMemorySpace();
75 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
76 return intAttr.getInt() ==
static_cast<int>(xevm::AddrSpace::SHARED);
77 if (
auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
78 return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
79 return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
83static VectorType encodeVectorTypeTo(VectorType currentVecType,
85 auto elemType = currentVecType.getElementType();
86 auto currentBitWidth = elemType.getIntOrFloatBitWidth();
89 currentVecType.getNumElements() * currentBitWidth / newBitWidth;
90 return VectorType::get(size, toElemType);
93static xevm::LoadCacheControl
94translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
95 std::optional<xegpu::CachePolicy> L3hint) {
97 if (!L1hint && !L3hint)
98 return xevm::LoadCacheControl::USE_DEFAULT;
100 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::CACHED);
101 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::CACHED);
103 case xegpu::CachePolicy::CACHED:
104 if (L3hintVal == xegpu::CachePolicy::CACHED)
105 return xevm::LoadCacheControl::L1C_L2UC_L3C;
106 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
107 return xevm::LoadCacheControl::L1C_L2UC_L3UC;
109 llvm_unreachable(
"Unsupported cache control.");
110 case xegpu::CachePolicy::UNCACHED:
111 if (L3hintVal == xegpu::CachePolicy::CACHED)
112 return xevm::LoadCacheControl::L1UC_L2UC_L3C;
113 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
114 return xevm::LoadCacheControl::L1UC_L2UC_L3UC;
116 llvm_unreachable(
"Unsupported cache control.");
117 case xegpu::CachePolicy::STREAMING:
118 if (L3hintVal == xegpu::CachePolicy::CACHED)
119 return xevm::LoadCacheControl::L1S_L2UC_L3C;
120 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
121 return xevm::LoadCacheControl::L1S_L2UC_L3UC;
123 llvm_unreachable(
"Unsupported cache control.");
124 case xegpu::CachePolicy::READ_INVALIDATE:
125 return xevm::LoadCacheControl::INVALIDATE_READ;
127 llvm_unreachable(
"Unsupported cache control.");
131static xevm::StoreCacheControl
132translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
133 std::optional<xegpu::CachePolicy> L3hint) {
135 if (!L1hint && !L3hint)
136 return xevm::StoreCacheControl::USE_DEFAULT;
138 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
139 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::WRITE_BACK);
141 case xegpu::CachePolicy::UNCACHED:
142 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
143 return xevm::StoreCacheControl::L1UC_L2UC_L3UC;
144 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
145 return xevm::StoreCacheControl::L1UC_L2UC_L3WB;
147 llvm_unreachable(
"Unsupported cache control.");
148 case xegpu::CachePolicy::STREAMING:
149 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
150 return xevm::StoreCacheControl::L1S_L2UC_L3UC;
151 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
152 return xevm::StoreCacheControl::L1S_L2UC_L3WB;
154 llvm_unreachable(
"Unsupported cache control.");
155 case xegpu::CachePolicy::WRITE_BACK:
156 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
157 return xevm::StoreCacheControl::L1WB_L2UC_L3UC;
158 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
159 return xevm::StoreCacheControl::L1WB_L2UC_L3WB;
161 llvm_unreachable(
"Unsupported cache control.");
162 case xegpu::CachePolicy::WRITE_THROUGH:
163 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
164 return xevm::StoreCacheControl::L1WT_L2UC_L3UC;
165 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
166 return xevm::StoreCacheControl::L1WT_L2UC_L3WB;
168 llvm_unreachable(
"Unsupported cache control.");
170 llvm_unreachable(
"Unsupported cache control.");
182class CreateNdDescToXeVMPattern
183 :
public OpConversionPattern<xegpu::CreateNdDescOp> {
184 using OpConversionPattern::OpConversionPattern;
186 matchAndRewrite(xegpu::CreateNdDescOp op,
187 xegpu::CreateNdDescOp::Adaptor adaptor,
188 ConversionPatternRewriter &rewriter)
const override {
189 SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
190 if (mixedOffsets.size() != 0)
191 return rewriter.notifyMatchFailure(op,
"Offsets not supported.");
192 auto loc = op.getLoc();
193 auto source = op.getSource();
197 Type payloadElemTy = rewriter.getI32Type();
198 VectorType payloadTy = VectorType::get(8, payloadElemTy);
199 Type i64Ty = rewriter.getI64Type();
201 VectorType payloadI64Ty = VectorType::get(4, i64Ty);
203 Value payload = arith::ConstantOp::create(
212 SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
213 SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides();
215 int64_t rank = mixedSizes.size();
216 auto sourceTy = source.getType();
217 auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
220 if (sourceMemrefTy) {
221 if (!sourceMemrefTy.hasRank()) {
222 return rewriter.notifyMatchFailure(op,
"Expected ranked Memref.");
226 baseAddr = adaptor.getSource();
228 baseAddr = adaptor.getSource();
229 if (baseAddr.
getType() != i64Ty) {
231 baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
236 rewriter.replaceOp(op, baseAddr);
240 auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
241 unsigned idx) -> Value {
247 baseShapeW = createOffset(mixedSizes, 1);
248 baseShapeH = createOffset(mixedSizes, 0);
250 Value basePitch = createOffset(mixedStrides, 0);
253 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
255 vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
256 static_cast<int>(NdTdescOffset::BasePtr));
257 payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
259 vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
260 static_cast<int>(NdTdescOffset::BaseShapeW));
262 vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
263 static_cast<int>(NdTdescOffset::BaseShapeH));
265 vector::InsertOp::create(rewriter, loc, basePitch, payload,
266 static_cast<int>(NdTdescOffset::BasePitch));
267 rewriter.replaceOp(op, payload);
274 typename = std::enable_if_t<llvm::is_one_of<
275 OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp>::value>>
276class LoadStorePrefetchNdToXeVMPattern :
public OpConversionPattern<OpType> {
277 using OpConversionPattern<OpType>::OpConversionPattern;
279 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
280 ConversionPatternRewriter &rewriter)
const override {
281 auto mixedOffsets = op.getMixedOffsets();
282 int64_t opOffsetsSize = mixedOffsets.size();
283 auto loc = op.getLoc();
284 auto ctxt = rewriter.getContext();
286 auto tdesc = adaptor.getTensorDesc();
287 auto tdescTy = op.getTensorDescType();
288 auto tileRank = tdescTy.getRank();
289 if (opOffsetsSize != tileRank)
290 return rewriter.notifyMatchFailure(
291 op,
"Expected offset rank to match descriptor rank.");
292 auto elemType = tdescTy.getElementType();
293 auto elemBitSize = elemType.getIntOrFloatBitWidth();
294 bool isSubByte = elemBitSize < 8;
295 uint64_t wScaleFactor = 1;
297 if (!isSubByte && (elemBitSize % 8 != 0))
298 return rewriter.notifyMatchFailure(
299 op,
"Expected element type bit width to be multiple of 8.");
300 auto tileW = tdescTy.getDimSize(tileRank - 1);
303 if (elemBitSize != 4)
304 return rewriter.notifyMatchFailure(
305 op,
"Only sub byte types of 4bits are supported.");
307 return rewriter.notifyMatchFailure(
308 op,
"Sub byte types are only supported for 2D tensor descriptors.");
309 auto subByteFactor = 8 / elemBitSize;
310 auto tileH = tdescTy.getDimSize(0);
312 if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
313 if (op.getPacked().value_or(
false)) {
315 if (tileH == systolicDepth * 4 &&
316 tileW == executionSize * subByteFactor) {
321 elemType = rewriter.getIntegerType(8);
322 tileW = executionSize;
323 wScaleFactor = subByteFactor;
328 if (wScaleFactor == 1) {
329 auto sub16BitFactor = subByteFactor * 2;
330 if (tileW == executionSize * sub16BitFactor) {
334 elemType = rewriter.getIntegerType(16);
335 tileW = executionSize;
336 wScaleFactor = sub16BitFactor;
338 return rewriter.notifyMatchFailure(
339 op,
"Unsupported tile shape for sub byte types.");
343 elemBitSize = elemType.getIntOrFloatBitWidth();
347 auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
348 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
352 rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
353 VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
355 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
357 vector::ExtractOp::create(rewriter, loc, payLoadAsI64,
358 static_cast<int>(NdTdescOffset::BasePtr));
359 Value baseShapeW = vector::ExtractOp::create(
360 rewriter, loc, tdesc,
static_cast<int>(NdTdescOffset::BaseShapeW));
361 Value baseShapeH = vector::ExtractOp::create(
362 rewriter, loc, tdesc,
static_cast<int>(NdTdescOffset::BaseShapeH));
363 Value basePitch = vector::ExtractOp::create(
364 rewriter, loc, tdesc,
static_cast<int>(NdTdescOffset::BasePitch));
370 rewriter.getI32Type(), offsetW);
374 rewriter.getI32Type(), offsetH);
377 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
381 Value baseShapeWInBytes =
382 arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
384 Value basePitchBytes =
385 arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize);
387 if (wScaleFactor > 1) {
391 rewriter, loc, rewriter.getI32Type(), llvm::Log2_64(wScaleFactor));
392 baseShapeWInBytes = arith::ShRSIOp::create(
393 rewriter, loc, baseShapeWInBytes, wScaleFactorValLog2);
394 basePitchBytes = arith::ShRSIOp::create(rewriter, loc, basePitchBytes,
395 wScaleFactorValLog2);
397 arith::ShRSIOp::create(rewriter, loc, offsetW, wScaleFactorValLog2);
400 auto tileH = tdescTy.getDimSize(0);
402 int32_t vblocks = tdescTy.getArrayLength();
403 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
404 Value src = adaptor.getValue();
410 VectorType srcVecTy = dyn_cast<VectorType>(src.
getType());
412 return rewriter.notifyMatchFailure(
413 op,
"Expected store value to be a vector type.");
415 VectorType newSrcVecTy =
416 encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
417 if (srcVecTy != newSrcVecTy)
418 src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
419 auto storeCacheControl =
420 translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
421 xevm::BlockStore2dOp::create(
422 rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
423 basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, src,
424 xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
425 rewriter.eraseOp(op);
427 auto loadCacheControl =
428 translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
429 if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
430 xevm::BlockPrefetch2dOp::create(
431 rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
432 basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH,
433 vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
434 rewriter.eraseOp(op);
436 VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
437 const bool vnni = op.getPacked().value_or(
false);
438 auto transposeValue = op.getTranspose();
440 transposeValue.has_value() && transposeValue.value()[0] == 1;
441 VectorType loadedTy = encodeVectorTypeTo(
442 dstVecTy, vnni ? rewriter.getI32Type()
443 : rewriter.getIntegerType(elemBitSize));
445 Value resultFlatVec = xevm::BlockLoad2dOp::create(
446 rewriter, loc, loadedTy, basePtrLLVM, baseShapeWInBytes,
447 baseShapeH, basePitchBytes, offsetW, offsetH, elemBitSize, tileW,
448 tileH, vblocks, transpose, vnni,
449 xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
450 resultFlatVec = vector::BitCastOp::create(
452 encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
454 rewriter.replaceOp(op, resultFlatVec);
466 rewriter.getI64Type(), offset);
469 rewriter, loc, rewriter.getI64Type(), elemBitSize / 8);
471 rewriter.createOrFold<arith::MulIOp>(loc, offset, elemByteSize);
473 Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>(
479 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64);
480 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
481 Value src = adaptor.getValue();
487 VectorType srcVecTy = dyn_cast<VectorType>(src.
getType());
489 return rewriter.notifyMatchFailure(
490 op,
"Expected store value to be a vector type.");
492 VectorType newSrcVecTy =
493 encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
494 if (srcVecTy != newSrcVecTy)
495 src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
496 auto storeCacheControl =
497 translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
498 rewriter.replaceOpWithNewOp<xevm::BlockStoreOp>(
499 op, finalPtrLLVM, src,
500 xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
501 }
else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
502 auto loadCacheControl =
503 translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
504 VectorType resTy = cast<VectorType>(op.getValue().getType());
505 VectorType loadedTy =
506 encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize));
507 Value
load = xevm::BlockLoadOp::create(
508 rewriter, loc, loadedTy, finalPtrLLVM,
509 xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
510 if (loadedTy != resTy)
511 load = vector::BitCastOp::create(rewriter, loc, resTy,
load);
512 rewriter.replaceOp(op,
load);
514 return rewriter.notifyMatchFailure(
515 op,
"Unsupported operation: xegpu.prefetch_nd with tensor "
516 "descriptor rank == 1");
525static Value addOffsetToBaseAddr(ConversionPatternRewriter &rewriter,
529 rewriter, loc, baseAddr.
getType(), elemByteSize);
530 Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
531 Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
535template <
typename OpType,
536 typename = std::enable_if_t<llvm::is_one_of<
537 OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
538class LoadStoreToXeVMPattern :
public OpConversionPattern<OpType> {
539 using OpConversionPattern<OpType>::OpConversionPattern;
541 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
542 ConversionPatternRewriter &rewriter)
const override {
543 Value offset = adaptor.getOffsets();
545 return rewriter.notifyMatchFailure(op,
"Expected offset to be provided.");
546 auto loc = op.getLoc();
547 auto ctxt = rewriter.getContext();
548 auto tdescTy = op.getTensorDescType();
552 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
554 this->getTypeConverter()->convertType(op.getResult().getType());
556 valOrResTy = adaptor.getValue().getType();
557 VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
558 bool hasScalarVal = !valOrResVecTy;
559 int64_t elemBitWidth =
561 : valOrResVecTy.getElementType().getIntOrFloatBitWidth();
563 if (elemBitWidth % 8 != 0)
564 return rewriter.notifyMatchFailure(
565 op,
"Expected element type bit width to be multiple of 8.");
566 int64_t elemByteSize = elemBitWidth / 8;
568 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
569 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
572 ptrTypeLLVM = LLVM::LLVMPointerType::get(
573 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
576 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
577 basePtrI64 = adaptor.getSource();
578 if (
auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
579 auto addrSpace = memRefTy.getMemorySpaceAsInt();
581 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
584 basePtrI64 = adaptor.getDest();
585 if (
auto memRefTy = dyn_cast<MemRefType>(op.getDest().getType())) {
586 auto addrSpace = memRefTy.getMemorySpaceAsInt();
588 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
592 if (basePtrI64.
getType() != rewriter.getI64Type()) {
593 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
596 Value mask = adaptor.getMask();
597 if (dyn_cast<VectorType>(offset.
getType())) {
600 return rewriter.notifyMatchFailure(op,
"Expected offset to be a scalar.");
606 addOffsetToBaseAddr(rewriter, loc, basePtrI64, offset, elemByteSize);
610 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
613 VectorType maskVecTy = dyn_cast<VectorType>(mask.
getType());
617 return rewriter.notifyMatchFailure(op,
"Expected mask to be a scalar.");
620 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
621 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy},
622 maskForLane,
true,
true);
624 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
626 valOrResTy = VectorType::get({valOrResVecTy.getNumElements()},
627 valOrResVecTy.getElementType());
629 LLVM::LoadOp::create(rewriter, loc, valOrResTy, basePtrLLVM);
632 "cache_control", xevm::LoadCacheControlAttr::get(
633 ctxt, translateLoadXeGPUCacheHint(
634 op.getL1Hint(), op.getL3Hint())));
635 scf::YieldOp::create(rewriter, loc,
ValueRange{loaded});
636 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
638 auto eTy = hasScalarVal ? valOrResTy : valOrResVecTy.getElementType();
641 eVal = FloatAttr::get(eTy, 0.0);
643 eVal = IntegerAttr::get(eTy, 0);
645 loaded = arith::ConstantOp::create(rewriter, loc, eVal);
647 loaded = arith::ConstantOp::create(
649 scf::YieldOp::create(rewriter, loc,
ValueRange{loaded});
650 rewriter.replaceOp(op, ifOp.getResult(0));
653 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane,
false);
654 auto body = ifOp.getBody();
655 rewriter.setInsertionPointToStart(body);
657 LLVM::StoreOp::create(rewriter, loc, adaptor.getValue(), basePtrLLVM);
659 storeOp.getOperation()->setAttr(
660 "cache_control", xevm::StoreCacheControlAttr::get(
661 ctxt, translateStoreXeGPUCacheHint(
662 op.getL1Hint(), op.getL3Hint())));
663 rewriter.eraseOp(op);
669class CreateMemDescOpPattern final
670 :
public OpConversionPattern<xegpu::CreateMemDescOp> {
672 using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern;
674 matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
675 ConversionPatternRewriter &rewriter)
const override {
677 rewriter.replaceOp(op, adaptor.getSource());
682template <
typename OpType,
683 typename = std::enable_if_t<llvm::is_one_of<
684 OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
685class LoadStoreMatrixToXeVMPattern :
public OpConversionPattern<OpType> {
686 using OpConversionPattern<OpType>::OpConversionPattern;
688 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
689 ConversionPatternRewriter &rewriter)
const override {
691 SmallVector<OpFoldResult> offsets = op.getMixedOffsets();
693 return rewriter.notifyMatchFailure(op,
"Expected offset to be provided.");
695 auto loc = op.getLoc();
696 auto ctxt = rewriter.getContext();
697 Value baseAddr32 = adaptor.getMemDesc();
698 Value mdescVal = op.getMemDesc();
701 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
702 Type resType = op.getResult().getType();
705 if (
auto vecType = dyn_cast<VectorType>(resType)) {
706 assert(llvm::count_if(vecType.getShape(),
707 [](int64_t d) { return d != 1; }) <= 1 &&
708 "Expected either 1D vector or nD with unit dimensions");
709 resType = VectorType::get({vecType.getNumElements()},
710 vecType.getElementType());
714 dataTy = adaptor.getData().getType();
715 VectorType valOrResVecTy = dyn_cast<VectorType>(dataTy);
717 valOrResVecTy = VectorType::get(1, dataTy);
719 int64_t elemBitWidth =
720 valOrResVecTy.getElementType().getIntOrFloatBitWidth();
722 if (elemBitWidth % 8 != 0)
723 return rewriter.notifyMatchFailure(
724 op,
"Expected element type bit width to be multiple of 8.");
725 int64_t elemByteSize = elemBitWidth / 8;
728 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
729 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM));
731 auto mdescTy = cast<xegpu::MemDescType>(mdescVal.
getType());
733 Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
734 linearOffset = arith::IndexCastUIOp::create(
735 rewriter, loc, rewriter.getI32Type(), linearOffset);
736 Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32,
737 linearOffset, elemByteSize);
741 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);
743 if (op.getSubgroupBlockIoAttr()) {
747 Type intElemTy = rewriter.getIntegerType(elemBitWidth);
748 VectorType intVecTy =
749 VectorType::get(valOrResVecTy.getShape(), intElemTy);
751 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
753 xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM);
754 if (intVecTy != valOrResVecTy) {
756 vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp);
758 rewriter.replaceOp(op, loadOp);
760 Value dataToStore = adaptor.getData();
761 if (valOrResVecTy != intVecTy) {
763 vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore);
765 xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore,
767 rewriter.eraseOp(op);
772 if (valOrResVecTy.getNumElements() >= 1) {
774 if (!chipOpt || (*chipOpt !=
"pvc" && *chipOpt !=
"bmg")) {
776 return rewriter.notifyMatchFailure(
777 op,
"The lowering is specific to pvc or bmg.");
781 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
785 auto scalarTy = valOrResVecTy.getElementType();
787 if (valOrResVecTy.getNumElements() == 1)
788 loadOp = LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
791 LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
792 rewriter.replaceOp(op, loadOp);
794 LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
795 rewriter.eraseOp(op);
801class PrefetchToXeVMPattern :
public OpConversionPattern<xegpu::PrefetchOp> {
802 using OpConversionPattern::OpConversionPattern;
804 matchAndRewrite(xegpu::PrefetchOp op, xegpu::PrefetchOp::Adaptor adaptor,
805 ConversionPatternRewriter &rewriter)
const override {
806 auto loc = op.getLoc();
807 auto ctxt = rewriter.getContext();
808 auto tdescTy = op.getTensorDescType();
809 Value basePtrI64 = adaptor.getSource();
811 if (basePtrI64.
getType() != rewriter.getI64Type())
812 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
814 Value offsets = adaptor.getOffsets();
816 VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.
getType());
819 return rewriter.notifyMatchFailure(op,
820 "Expected offsets to be a scalar.");
822 int64_t elemBitWidth{0};
823 int64_t elemByteSize;
828 elemBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth();
829 }
else if (
auto memRefTy = dyn_cast<MemRefType>(op.getSourceType())) {
832 elemBitWidth = memRefTy.getElementType().getIntOrFloatBitWidth();
835 elemByteSize = *op.getOffsetAlignByte();
837 if (elemBitWidth != 0) {
838 if (elemBitWidth % 8 != 0)
839 return rewriter.notifyMatchFailure(
840 op,
"Expected element type bit width to be multiple of 8.");
841 elemByteSize = elemBitWidth / 8;
843 basePtrI64 = addOffsetToBaseAddr(rewriter, loc, basePtrI64, offsets,
848 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
849 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
852 ptrTypeLLVM = LLVM::LLVMPointerType::get(
853 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
855 if (
auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
856 auto addrSpace = memRefTy.getMemorySpaceAsInt();
858 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
862 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
864 xevm::PrefetchOp::create(
865 rewriter, loc, ptrLLVM,
866 xevm::LoadCacheControlAttr::get(
867 ctxt, translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint())));
868 rewriter.eraseOp(op);
873class FenceToXeVMPattern :
public OpConversionPattern<xegpu::FenceOp> {
874 using OpConversionPattern::OpConversionPattern;
876 matchAndRewrite(xegpu::FenceOp op, xegpu::FenceOp::Adaptor adaptor,
877 ConversionPatternRewriter &rewriter)
const override {
878 auto loc = op.getLoc();
879 xevm::MemScope memScope{xevm::MemScope::WORKGROUP};
880 switch (op.getFenceScope()) {
881 case xegpu::FenceScope::Workgroup:
882 memScope = xevm::MemScope::WORKGROUP;
884 case xegpu::FenceScope::GPU:
885 memScope = xevm::MemScope::DEVICE;
888 xevm::AddrSpace addrSpace{xevm::AddrSpace::GLOBAL};
889 switch (op.getMemoryKind()) {
890 case xegpu::MemorySpace::Global:
891 addrSpace = xevm::AddrSpace::GLOBAL;
893 case xegpu::MemorySpace::SLM:
894 addrSpace = xevm::AddrSpace::SHARED;
897 xevm::MemfenceOp::create(rewriter, loc, memScope, addrSpace);
898 rewriter.eraseOp(op);
903class DpasToXeVMPattern :
public OpConversionPattern<xegpu::DpasOp> {
904 using OpConversionPattern::OpConversionPattern;
906 matchAndRewrite(xegpu::DpasOp op, xegpu::DpasOp::Adaptor adaptor,
907 ConversionPatternRewriter &rewriter)
const override {
908 auto loc = op.getLoc();
909 auto ctxt = rewriter.getContext();
910 auto aTy = cast<VectorType>(op.getLhs().getType());
911 auto bTy = cast<VectorType>(op.getRhs().getType());
912 auto resultType = cast<VectorType>(op.getResultType());
917 return rewriter.notifyMatchFailure(op,
"cannot determine target chip");
921 return rewriter.notifyMatchFailure(op,
"unsupported target uArch");
924 llvm::dyn_cast_or_null<xegpu::uArch::SubgroupMatrixMultiplyAcc>(
925 uArch->getInstruction(
926 xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)));
928 return rewriter.notifyMatchFailure(op,
929 "DPAS not supported by target uArch");
931 auto checkSupportedTypes = [&](VectorType vecTy,
933 auto supported = dpasInst->getSupportedTypes(*ctxt, kind);
934 return llvm::find(supported, vecTy.getElementType()) != supported.end();
937 if (!checkSupportedTypes(aTy, xegpu::uArch::MMAOpndKind::MatrixA))
938 return rewriter.notifyMatchFailure(
939 op,
"A-matrix element type not supported by target uArch");
940 if (!checkSupportedTypes(bTy, xegpu::uArch::MMAOpndKind::MatrixB))
941 return rewriter.notifyMatchFailure(
942 op,
"B-matrix element type not supported by target uArch");
944 if (!checkSupportedTypes(resultType, xegpu::uArch::MMAOpndKind::MatrixD))
945 return rewriter.notifyMatchFailure(
946 op,
"result/accumulator element type not supported by target uArch");
948 auto encodePrecision = [&](Type type) -> xevm::ElemType {
949 if (type == rewriter.getBF16Type())
950 return xevm::ElemType::BF16;
951 else if (type == rewriter.getF16Type())
952 return xevm::ElemType::F16;
953 else if (type == rewriter.getTF32Type())
954 return xevm::ElemType::TF32;
955 else if (type.isInteger(8)) {
956 if (type.isUnsignedInteger())
957 return xevm::ElemType::U8;
958 return xevm::ElemType::S8;
959 }
else if (type == rewriter.getF32Type())
960 return xevm::ElemType::F32;
961 else if (type.isInteger(32))
962 return xevm::ElemType::S32;
963 llvm_unreachable(
"add more support for ElemType");
965 xevm::ElemType precATy = encodePrecision(aTy.getElementType());
966 xevm::ElemType precBTy = encodePrecision(bTy.getElementType());
967 Value c = op.getAcc();
969 auto elementTy = resultType.getElementType();
970 Attribute initValueAttr;
971 if (isa<FloatType>(elementTy))
972 initValueAttr = FloatAttr::get(elementTy, 0.0);
974 initValueAttr = IntegerAttr::get(elementTy, 0);
975 c = arith::ConstantOp::create(
979 Value aVec = op.getLhs();
980 Value bVec = op.getRhs();
981 auto cvecty = cast<VectorType>(c.
getType());
982 xevm::ElemType precCTy = encodePrecision(cvecty.getElementType());
983 xevm::ElemType precDTy = encodePrecision(resultType.getElementType());
985 VectorType::get(cvecty.getNumElements(), cvecty.getElementType());
987 c = vector::ShapeCastOp::create(rewriter, loc, cNty, c);
988 Value dpasRes = xevm::MMAOp::create(
989 rewriter, loc, cNty, aVec, bVec, c,
990 xevm::MMAShapeAttr::get(ctxt, cvecty.getNumElements(), executionSize,
992 getNumOperandsPerDword(precATy)),
993 xevm::MMATypesAttr::get(ctxt, precDTy, precATy, precBTy, precCTy));
995 dpasRes = vector::ShapeCastOp::create(rewriter, loc, resultType, dpasRes);
996 rewriter.replaceOp(op, dpasRes);
1001 static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
1003 case xevm::ElemType::TF32:
1005 case xevm::ElemType::BF16:
1006 case xevm::ElemType::F16:
1008 case xevm::ElemType::U8:
1009 case xevm::ElemType::S8:
1012 llvm_unreachable(
"unsupported xevm::ElemType");
1017static std::optional<LLVM::AtomicBinOp>
1018matchSimpleAtomicOp(arith::AtomicRMWKind arithKind) {
1019 switch (arithKind) {
1020 case arith::AtomicRMWKind::addf:
1021 return LLVM::AtomicBinOp::fadd;
1022 case arith::AtomicRMWKind::addi:
1023 return LLVM::AtomicBinOp::add;
1024 case arith::AtomicRMWKind::assign:
1025 return LLVM::AtomicBinOp::xchg;
1026 case arith::AtomicRMWKind::maximumf:
1027 return LLVM::AtomicBinOp::fmax;
1028 case arith::AtomicRMWKind::maxs:
1029 return LLVM::AtomicBinOp::max;
1030 case arith::AtomicRMWKind::maxu:
1031 return LLVM::AtomicBinOp::umax;
1032 case arith::AtomicRMWKind::minimumf:
1033 return LLVM::AtomicBinOp::fmin;
1034 case arith::AtomicRMWKind::mins:
1035 return LLVM::AtomicBinOp::min;
1036 case arith::AtomicRMWKind::minu:
1037 return LLVM::AtomicBinOp::umin;
1038 case arith::AtomicRMWKind::ori:
1039 return LLVM::AtomicBinOp::_or;
1040 case arith::AtomicRMWKind::andi:
1041 return LLVM::AtomicBinOp::_and;
1043 return std::nullopt;
1047class AtomicRMWToXeVMPattern :
public OpConversionPattern<xegpu::AtomicRMWOp> {
1048 using OpConversionPattern::OpConversionPattern;
1050 matchAndRewrite(xegpu::AtomicRMWOp op, xegpu::AtomicRMWOp::Adaptor adaptor,
1051 ConversionPatternRewriter &rewriter)
const override {
1052 auto loc = op.getLoc();
1053 auto ctxt = rewriter.getContext();
1054 auto tdesc = op.getTensorDesc().getType();
1055 auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
1056 ctxt, getNumericXeVMAddrSpace(tdesc.getMemorySpace()));
1057 Value basePtrI64 = arith::IndexCastOp::create(
1058 rewriter, loc, rewriter.getI64Type(), adaptor.getTensorDesc());
1060 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
1061 VectorType srcOrDstVecTy = cast<VectorType>(op.getValue().getType());
1062 VectorType srcOrDstFlatVecTy = VectorType::get(
1063 srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
1064 Value srcFlatVec = vector::ShapeCastOp::create(
1065 rewriter, loc, srcOrDstFlatVecTy, op.getValue());
1066 auto atomicKind = matchSimpleAtomicOp(op.getKind());
1067 assert(atomicKind.has_value());
1068 Value resVec = srcFlatVec;
1069 for (
int i = 0; i < srcOrDstVecTy.getNumElements(); i++) {
1070 auto val = vector::ExtractOp::create(rewriter, loc, resVec, i);
1071 Value idx = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
1072 rewriter.getIndexAttr(i));
1074 LLVM::GEPOp::create(rewriter, loc, ptrTypeLLVM,
1075 srcOrDstVecTy.getElementType(), basePtrLLVM, idx);
1077 LLVM::AtomicRMWOp::create(rewriter, loc, atomicKind.value(), currPtr,
1078 val, LLVM::AtomicOrdering::seq_cst);
1079 resVec = vector::InsertOp::create(rewriter, loc, newVal, resVec, i);
1081 rewriter.replaceOp(op, resVec);
1090struct ConvertXeGPUToXeVMPass
1094 void runOnOperation()
override {
1104 LowerToLLVMOptions
options(context);
1105 options.overrideIndexBitwidth(this->use64bitIndex ? 64 : 32);
1106 LLVMTypeConverter typeConverter(context,
options);
1108 Type xevmIndexType = typeConverter.convertType(IndexType::get(context));
1109 Type i32Type = IntegerType::get(context, 32);
1110 typeConverter.addConversion([&](VectorType type) -> Type {
1111 auto elemType = typeConverter.convertType(type.getElementType());
1113 unsigned rank = type.getRank();
1114 if (rank == 0 || type.getNumElements() == 1)
1117 int64_t sum = llvm::product_of(type.getShape());
1118 return VectorType::get(sum, elemType);
1120 typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
1122 if (type.isScattered())
1124 if (type.getRank() == 1)
1125 return xevmIndexType;
1126 return VectorType::get(8, i32Type);
1135 typeConverter.addConversion(
1136 [&](xegpu::MemDescType type) -> Type {
return i32Type; });
1138 typeConverter.addConversion([&](MemRefType type) -> Type {
1139 return isSharedMemRef(type) ? i32Type : xevmIndexType;
1149 auto memrefToIntMaterializationCast = [](OpBuilder &builder, Type type,
1151 Location loc) -> Value {
1152 if (inputs.size() != 1)
1154 auto input = inputs.front();
1155 if (
auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
1156 unsigned rank = memrefTy.getRank();
1160 SmallVector<int64_t> intStrides;
1163 if (succeeded(memrefTy.getStridesAndOffset(intStrides, intOffsets)) &&
1164 ShapedType::isStatic(intOffsets)) {
1165 addr = memref::ExtractAlignedPointerAsIndexOp::create(builder, loc,
1167 offset = arith::ConstantOp::create(builder, loc,
1173 SmallVector<Type> resultTypes{
1174 MemRefType::get({}, memrefTy.getElementType(),
1175 MemRefLayoutAttrInterface(),
1176 memrefTy.getMemorySpace()),
1179 resultTypes.append(2 * rank, indexType);
1181 auto meta = memref::ExtractStridedMetadataOp::create(
1182 builder, loc, resultTypes, input);
1184 addr = memref::ExtractAlignedPointerAsIndexOp::create(
1185 builder, loc, meta.getBaseBuffer());
1186 offset = meta.getOffset();
1190 arith::IndexCastUIOp::create(builder, loc, type, addr);
1192 arith::IndexCastUIOp::create(builder, loc, type, offset);
1195 auto byteSize = arith::ConstantOp::create(
1198 memrefTy.getElementTypeBitWidth() / 8));
1200 arith::MulIOp::create(builder, loc, offsetCasted, byteSize);
1201 auto addrWithOffset =
1202 arith::AddIOp::create(builder, loc, addrCasted, byteOffset);
1204 return addrWithOffset.getResult();
1213 auto ui64ToI64MaterializationCast = [](OpBuilder &builder, Type type,
1215 Location loc) -> Value {
1216 if (inputs.size() != 1)
1218 auto input = inputs.front();
1221 index::CastUOp::create(builder, loc, builder.
getIndexType(), input)
1223 return arith::IndexCastUIOp::create(builder, loc, type, cast)
1233 auto ui32ToI32MaterializationCast = [](OpBuilder &builder, Type type,
1235 Location loc) -> Value {
1236 if (inputs.size() != 1)
1238 auto input = inputs.front();
1241 index::CastUOp::create(builder, loc, builder.
getIndexType(), input)
1243 return arith::IndexCastUIOp::create(builder, loc, type, cast)
1253 auto vectorToVectorMaterializationCast = [](OpBuilder &builder, Type type,
1255 Location loc) -> Value {
1256 if (inputs.size() != 1)
1258 auto input = inputs.front();
1259 if (
auto vecTy = dyn_cast<VectorType>(input.getType())) {
1260 if (
auto targetVecTy = dyn_cast<VectorType>(type)) {
1264 if (targetVecTy.getShape() != vecTy.getShape()) {
1265 cast = vector::ShapeCastOp::create(
1267 VectorType::get(targetVecTy.getShape(),
1268 vecTy.getElementType()),
1272 if (targetVecTy.getElementType() != vecTy.getElementType()) {
1273 cast = vector::BitCastOp::create(builder, loc, targetVecTy, cast)
1285 auto vectorToSingleElementMaterializationCast =
1286 [](OpBuilder &builder, Type type,
ValueRange inputs,
1287 Location loc) -> Value {
1288 if (inputs.size() != 1)
1290 auto input = inputs.front();
1291 if (
auto vecTy = dyn_cast<VectorType>(input.getType())) {
1293 auto rank = vecTy.getRank();
1294 if (rank != 0 && vecTy.getNumElements() != 1)
1296 auto inElemTy = vecTy.getElementType();
1300 cast = vector::ExtractOp::create(builder, loc, cast, {}).getResult();
1302 cast = vector::ExtractOp::create(builder, loc, cast,
1303 SmallVector<int64_t>(rank, 0))
1310 if (inElemTy.isIndex()) {
1311 cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
1313 }
else if (inElemTy != type) {
1314 cast = arith::BitcastOp::create(builder, loc, type, cast).getResult();
1328 auto singleElementToVectorMaterializationCast =
1329 [](OpBuilder &builder, Type type,
ValueRange inputs,
1330 Location loc) -> Value {
1331 if (inputs.size() != 1)
1333 auto input = inputs.front();
1334 auto inTy = input.getType();
1335 if (!inTy.isIntOrFloat())
1339 if (
auto vecTy = dyn_cast<VectorType>(type)) {
1340 if (vecTy.getRank() != 0 && vecTy.getNumElements() != 1)
1342 auto outElemTy = vecTy.getElementType();
1344 if (outElemTy.isIndex()) {
1345 cast = arith::IndexCastUIOp::create(builder, loc,
1348 }
else if (inTy != outElemTy) {
1349 cast = arith::BitcastOp::create(builder, loc, outElemTy, cast)
1352 return vector::BroadcastOp::create(builder, loc, vecTy, cast)
1357 typeConverter.addSourceMaterialization(
1358 singleElementToVectorMaterializationCast);
1359 typeConverter.addSourceMaterialization(vectorToVectorMaterializationCast);
1360 typeConverter.addTargetMaterialization(memrefToIntMaterializationCast);
1361 typeConverter.addTargetMaterialization(ui32ToI32MaterializationCast);
1362 typeConverter.addTargetMaterialization(ui64ToI64MaterializationCast);
1363 typeConverter.addTargetMaterialization(
1364 vectorToSingleElementMaterializationCast);
1365 typeConverter.addTargetMaterialization(vectorToVectorMaterializationCast);
1366 ConversionTarget
target(*context);
1367 target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
1368 vector::VectorDialect, arith::ArithDialect,
1369 memref::MemRefDialect, gpu::GPUDialect,
1370 index::IndexDialect>();
1371 target.addIllegalDialect<xegpu::XeGPUDialect>();
1373 RewritePatternSet patterns(context);
1377 if (
failed(applyPartialConversion(getOperation(),
target,
1378 std::move(patterns))))
1379 signalPassFailure();
1389 patterns.
add<CreateNdDescToXeVMPattern,
1390 LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
1391 LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
1392 LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>(
1394 patterns.
add<AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
1395 LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
1396 LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
1398 patterns.
add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>,
1399 LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>,
1400 CreateMemDescOpPattern>(typeConverter, patterns.
getContext());
1401 patterns.
add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
static llvm::ManagedStatic< PassManagerOptions > options
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
@ SubgroupMatrixMultiplyAcc
const uArch * getUArch(llvm::StringRef archName)
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
void populateXeGPUToXeVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)