29#include "llvm/ADT/STLExtras.h"
30#include "llvm/Support/FormatVariadic.h"
35#include "llvm/ADT/TypeSwitch.h"
40#define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
41#include "mlir/Conversion/Passes.h.inc"
49static constexpr int32_t systolicDepth{8};
50static constexpr int32_t executionSize{16};
53enum class NdTdescOffset : uint32_t {
60static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
61 switch (xeGpuMemspace) {
62 case xegpu::MemorySpace::Global:
63 return static_cast<int>(xevm::AddrSpace::GLOBAL);
64 case xegpu::MemorySpace::SLM:
65 return static_cast<int>(xevm::AddrSpace::SHARED);
67 llvm_unreachable(
"Unknown XeGPU memory space");
71static bool isSharedMemRef(
const MemRefType &memrefTy) {
72 Attribute attr = memrefTy.getMemorySpace();
75 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>(attr))
76 return intAttr.getInt() ==
static_cast<int>(xevm::AddrSpace::SHARED);
77 if (
auto xevmSpace = llvm::dyn_cast<xevm::AddrSpaceAttr>(attr))
78 return xevmSpace.getValue() == xevm::AddrSpace::SHARED;
79 return gpu::GPUDialect::isWorkgroupMemoryAddressSpace(attr);
83static VectorType encodeVectorTypeTo(VectorType currentVecType,
85 auto elemType = currentVecType.getElementType();
86 auto currentBitWidth = elemType.getIntOrFloatBitWidth();
89 currentVecType.getNumElements() * currentBitWidth / newBitWidth;
90 return VectorType::get(size, toElemType);
93static xevm::LoadCacheControl
94translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
95 std::optional<xegpu::CachePolicy> L3hint) {
96 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
97 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
99 case xegpu::CachePolicy::CACHED:
100 if (L3hintVal == xegpu::CachePolicy::CACHED)
101 return xevm::LoadCacheControl::L1C_L2UC_L3C;
102 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
103 return xevm::LoadCacheControl::L1C_L2UC_L3UC;
105 llvm_unreachable(
"Unsupported cache control.");
106 case xegpu::CachePolicy::UNCACHED:
107 if (L3hintVal == xegpu::CachePolicy::CACHED)
108 return xevm::LoadCacheControl::L1UC_L2UC_L3C;
109 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
110 return xevm::LoadCacheControl::L1UC_L2UC_L3UC;
112 llvm_unreachable(
"Unsupported cache control.");
113 case xegpu::CachePolicy::STREAMING:
114 if (L3hintVal == xegpu::CachePolicy::CACHED)
115 return xevm::LoadCacheControl::L1S_L2UC_L3C;
116 else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
117 return xevm::LoadCacheControl::L1S_L2UC_L3UC;
119 llvm_unreachable(
"Unsupported cache control.");
120 case xegpu::CachePolicy::READ_INVALIDATE:
121 return xevm::LoadCacheControl::INVALIDATE_READ;
123 llvm_unreachable(
"Unsupported cache control.");
127static xevm::StoreCacheControl
128translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
129 std::optional<xegpu::CachePolicy> L3hint) {
130 auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
131 auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
133 case xegpu::CachePolicy::UNCACHED:
134 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
135 return xevm::StoreCacheControl::L1UC_L2UC_L3UC;
136 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
137 return xevm::StoreCacheControl::L1UC_L2UC_L3WB;
139 llvm_unreachable(
"Unsupported cache control.");
140 case xegpu::CachePolicy::STREAMING:
141 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
142 return xevm::StoreCacheControl::L1S_L2UC_L3UC;
143 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
144 return xevm::StoreCacheControl::L1S_L2UC_L3WB;
146 llvm_unreachable(
"Unsupported cache control.");
147 case xegpu::CachePolicy::WRITE_BACK:
148 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
149 return xevm::StoreCacheControl::L1WB_L2UC_L3UC;
150 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
151 return xevm::StoreCacheControl::L1WB_L2UC_L3WB;
153 llvm_unreachable(
"Unsupported cache control.");
154 case xegpu::CachePolicy::WRITE_THROUGH:
155 if (L3hintVal == xegpu::CachePolicy::UNCACHED)
156 return xevm::StoreCacheControl::L1WT_L2UC_L3UC;
157 else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
158 return xevm::StoreCacheControl::L1WT_L2UC_L3WB;
160 llvm_unreachable(
"Unsupported cache control.");
162 llvm_unreachable(
"Unsupported cache control.");
174class CreateNdDescToXeVMPattern
175 :
public OpConversionPattern<xegpu::CreateNdDescOp> {
176 using OpConversionPattern::OpConversionPattern;
178 matchAndRewrite(xegpu::CreateNdDescOp op,
179 xegpu::CreateNdDescOp::Adaptor adaptor,
180 ConversionPatternRewriter &rewriter)
const override {
181 SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
182 if (mixedOffsets.size() != 0)
183 return rewriter.notifyMatchFailure(op,
"Offsets not supported.");
184 auto loc = op.getLoc();
185 auto source = op.getSource();
189 Type payloadElemTy = rewriter.getI32Type();
190 VectorType payloadTy = VectorType::get(8, payloadElemTy);
191 Type i64Ty = rewriter.getI64Type();
193 VectorType payloadI64Ty = VectorType::get(4, i64Ty);
195 Value payload = arith::ConstantOp::create(
204 SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
205 SmallVector<OpFoldResult> mixedStrides = op.getMixedStrides();
207 int64_t rank = mixedSizes.size();
208 auto sourceTy = source.getType();
209 auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
212 if (sourceMemrefTy) {
213 if (!sourceMemrefTy.hasRank()) {
214 return rewriter.notifyMatchFailure(op,
"Expected ranked Memref.");
218 baseAddr = adaptor.getSource();
220 baseAddr = adaptor.getSource();
221 if (baseAddr.
getType() != i64Ty) {
223 baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
228 rewriter.replaceOp(op, baseAddr);
232 auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
233 unsigned idx) -> Value {
239 baseShapeW = createOffset(mixedSizes, 1);
240 baseShapeH = createOffset(mixedSizes, 0);
242 Value basePitch = createOffset(mixedStrides, 0);
245 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
247 vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
248 static_cast<int>(NdTdescOffset::BasePtr));
249 payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
251 vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
252 static_cast<int>(NdTdescOffset::BaseShapeW));
254 vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
255 static_cast<int>(NdTdescOffset::BaseShapeH));
257 vector::InsertOp::create(rewriter, loc, basePitch, payload,
258 static_cast<int>(NdTdescOffset::BasePitch));
259 rewriter.replaceOp(op, payload);
266 typename = std::enable_if_t<llvm::is_one_of<
267 OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp>::value>>
268class LoadStorePrefetchNdToXeVMPattern :
public OpConversionPattern<OpType> {
269 using OpConversionPattern<OpType>::OpConversionPattern;
271 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
272 ConversionPatternRewriter &rewriter)
const override {
273 auto mixedOffsets = op.getMixedOffsets();
274 int64_t opOffsetsSize = mixedOffsets.size();
275 auto loc = op.getLoc();
276 auto ctxt = rewriter.getContext();
278 auto tdesc = adaptor.getTensorDesc();
279 auto tdescTy = op.getTensorDescType();
280 auto tileRank = tdescTy.getRank();
281 if (opOffsetsSize != tileRank)
282 return rewriter.notifyMatchFailure(
283 op,
"Expected offset rank to match descriptor rank.");
284 auto elemType = tdescTy.getElementType();
285 auto elemBitSize = elemType.getIntOrFloatBitWidth();
286 bool isSubByte = elemBitSize < 8;
287 uint64_t wScaleFactor = 1;
289 if (!isSubByte && (elemBitSize % 8 != 0))
290 return rewriter.notifyMatchFailure(
291 op,
"Expected element type bit width to be multiple of 8.");
292 auto tileW = tdescTy.getDimSize(tileRank - 1);
295 if (elemBitSize != 4)
296 return rewriter.notifyMatchFailure(
297 op,
"Only sub byte types of 4bits are supported.");
299 return rewriter.notifyMatchFailure(
300 op,
"Sub byte types are only supported for 2D tensor descriptors.");
301 auto subByteFactor = 8 / elemBitSize;
302 auto tileH = tdescTy.getDimSize(0);
304 if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
305 if (op.getPacked().value_or(
false)) {
307 if (tileH == systolicDepth * 4 &&
308 tileW == executionSize * subByteFactor) {
313 elemType = rewriter.getIntegerType(8);
314 tileW = executionSize;
315 wScaleFactor = subByteFactor;
320 if (wScaleFactor == 1) {
321 auto sub16BitFactor = subByteFactor * 2;
322 if (tileW == executionSize * sub16BitFactor) {
326 elemType = rewriter.getIntegerType(16);
327 tileW = executionSize;
328 wScaleFactor = sub16BitFactor;
330 return rewriter.notifyMatchFailure(
331 op,
"Unsupported tile shape for sub byte types.");
335 elemBitSize = elemType.getIntOrFloatBitWidth();
339 auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
340 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
344 rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
345 VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
347 vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
349 vector::ExtractOp::create(rewriter, loc, payLoadAsI64,
350 static_cast<int>(NdTdescOffset::BasePtr));
351 Value baseShapeW = vector::ExtractOp::create(
352 rewriter, loc, tdesc,
static_cast<int>(NdTdescOffset::BaseShapeW));
353 Value baseShapeH = vector::ExtractOp::create(
354 rewriter, loc, tdesc,
static_cast<int>(NdTdescOffset::BaseShapeH));
355 Value basePitch = vector::ExtractOp::create(
356 rewriter, loc, tdesc,
static_cast<int>(NdTdescOffset::BasePitch));
362 rewriter.getI32Type(), offsetW);
366 rewriter.getI32Type(), offsetH);
369 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
373 Value baseShapeWInBytes =
374 arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
376 Value basePitchBytes =
377 arith::MulIOp::create(rewriter, loc, basePitch, elemByteSize);
379 if (wScaleFactor > 1) {
383 rewriter, loc, rewriter.getI32Type(), llvm::Log2_64(wScaleFactor));
384 baseShapeWInBytes = arith::ShRSIOp::create(
385 rewriter, loc, baseShapeWInBytes, wScaleFactorValLog2);
386 basePitchBytes = arith::ShRSIOp::create(rewriter, loc, basePitchBytes,
387 wScaleFactorValLog2);
389 arith::ShRSIOp::create(rewriter, loc, offsetW, wScaleFactorValLog2);
392 auto tileH = tdescTy.getDimSize(0);
394 int32_t vblocks = tdescTy.getArrayLength();
395 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
396 Value src = adaptor.getValue();
402 VectorType srcVecTy = dyn_cast<VectorType>(src.
getType());
404 return rewriter.notifyMatchFailure(
405 op,
"Expected store value to be a vector type.");
407 VectorType newSrcVecTy =
408 encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
409 if (srcVecTy != newSrcVecTy)
410 src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
411 auto storeCacheControl =
412 translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
413 xevm::BlockStore2dOp::create(
414 rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
415 basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH, src,
416 xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
417 rewriter.eraseOp(op);
419 auto loadCacheControl =
420 translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
421 if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
422 xevm::BlockPrefetch2dOp::create(
423 rewriter, loc, basePtrLLVM, baseShapeWInBytes, baseShapeH,
424 basePitchBytes, offsetW, offsetH, elemBitSize, tileW, tileH,
425 vblocks, xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
426 rewriter.eraseOp(op);
428 VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
429 const bool vnni = op.getPacked().value_or(
false);
430 auto transposeValue = op.getTranspose();
432 transposeValue.has_value() && transposeValue.value()[0] == 1;
433 VectorType loadedTy = encodeVectorTypeTo(
434 dstVecTy, vnni ? rewriter.getI32Type()
435 : rewriter.getIntegerType(elemBitSize));
437 Value resultFlatVec = xevm::BlockLoad2dOp::create(
438 rewriter, loc, loadedTy, basePtrLLVM, baseShapeWInBytes,
439 baseShapeH, basePitchBytes, offsetW, offsetH, elemBitSize, tileW,
440 tileH, vblocks, transpose, vnni,
441 xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
442 resultFlatVec = vector::BitCastOp::create(
444 encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
446 rewriter.replaceOp(op, resultFlatVec);
458 rewriter.getI64Type(), offset);
461 rewriter, loc, rewriter.getI64Type(), elemBitSize / 8);
463 rewriter.createOrFold<arith::MulIOp>(loc, offset, elemByteSize);
465 Value finalAddrI64 = rewriter.createOrFold<arith::AddIOp>(
471 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, finalAddrI64);
472 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
473 Value src = adaptor.getValue();
479 VectorType srcVecTy = dyn_cast<VectorType>(src.
getType());
481 return rewriter.notifyMatchFailure(
482 op,
"Expected store value to be a vector type.");
484 VectorType newSrcVecTy =
485 encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
486 if (srcVecTy != newSrcVecTy)
487 src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
488 auto storeCacheControl =
489 translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
490 rewriter.replaceOpWithNewOp<xevm::BlockStoreOp>(
491 op, finalPtrLLVM, src,
492 xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
493 }
else if constexpr (std::is_same_v<OpType, xegpu::LoadNdOp>) {
494 auto loadCacheControl =
495 translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
496 VectorType resTy = cast<VectorType>(op.getValue().getType());
497 VectorType loadedTy =
498 encodeVectorTypeTo(resTy, rewriter.getIntegerType(elemBitSize));
499 Value
load = xevm::BlockLoadOp::create(
500 rewriter, loc, loadedTy, finalPtrLLVM,
501 xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
502 if (loadedTy != resTy)
503 load = vector::BitCastOp::create(rewriter, loc, resTy,
load);
504 rewriter.replaceOp(op,
load);
506 return rewriter.notifyMatchFailure(
507 op,
"Unsupported operation: xegpu.prefetch_nd with tensor "
508 "descriptor rank == 1");
517static Value addOffsetToBaseAddr(ConversionPatternRewriter &rewriter,
521 rewriter, loc, baseAddr.
getType(), elemByteSize);
522 Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
523 Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
527template <
typename OpType,
528 typename = std::enable_if_t<llvm::is_one_of<
529 OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
530class LoadStoreToXeVMPattern :
public OpConversionPattern<OpType> {
531 using OpConversionPattern<OpType>::OpConversionPattern;
533 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
534 ConversionPatternRewriter &rewriter)
const override {
535 Value offset = adaptor.getOffsets();
537 return rewriter.notifyMatchFailure(op,
"Expected offset to be provided.");
538 auto loc = op.getLoc();
539 auto ctxt = rewriter.getContext();
540 auto tdescTy = op.getTensorDescType();
544 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
546 this->getTypeConverter()->convertType(op.getResult().getType());
548 valOrResTy = adaptor.getValue().getType();
549 VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
550 bool hasScalarVal = !valOrResVecTy;
551 int64_t elemBitWidth =
553 : valOrResVecTy.getElementType().getIntOrFloatBitWidth();
555 if (elemBitWidth % 8 != 0)
556 return rewriter.notifyMatchFailure(
557 op,
"Expected element type bit width to be multiple of 8.");
558 int64_t elemByteSize = elemBitWidth / 8;
560 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
561 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
564 ptrTypeLLVM = LLVM::LLVMPointerType::get(
565 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
568 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
569 basePtrI64 = adaptor.getSource();
570 if (
auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
571 auto addrSpace = memRefTy.getMemorySpaceAsInt();
573 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
576 basePtrI64 = adaptor.getDest();
577 if (
auto memRefTy = dyn_cast<MemRefType>(op.getDest().getType())) {
578 auto addrSpace = memRefTy.getMemorySpaceAsInt();
580 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
584 if (basePtrI64.
getType() != rewriter.getI64Type()) {
585 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
588 Value mask = adaptor.getMask();
589 if (dyn_cast<VectorType>(offset.
getType())) {
592 return rewriter.notifyMatchFailure(op,
"Expected offset to be a scalar.");
598 addOffsetToBaseAddr(rewriter, loc, basePtrI64, offset, elemByteSize);
602 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
605 VectorType maskVecTy = dyn_cast<VectorType>(mask.
getType());
609 return rewriter.notifyMatchFailure(op,
"Expected mask to be a scalar.");
612 if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
613 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy},
614 maskForLane,
true,
true);
616 rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
618 valOrResTy = VectorType::get({valOrResVecTy.getNumElements()},
619 valOrResVecTy.getElementType());
621 LLVM::LoadOp::create(rewriter, loc, valOrResTy, basePtrLLVM);
624 "cache_control", xevm::LoadCacheControlAttr::get(
625 ctxt, translateLoadXeGPUCacheHint(
626 op.getL1Hint(), op.getL3Hint())));
627 scf::YieldOp::create(rewriter, loc,
ValueRange{loaded});
628 rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
630 auto eTy = hasScalarVal ? valOrResTy : valOrResVecTy.getElementType();
633 eVal = FloatAttr::get(eTy, 0.0);
635 eVal = IntegerAttr::get(eTy, 0);
637 loaded = arith::ConstantOp::create(rewriter, loc, eVal);
639 loaded = arith::ConstantOp::create(
641 scf::YieldOp::create(rewriter, loc,
ValueRange{loaded});
642 rewriter.replaceOp(op, ifOp.getResult(0));
645 scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane,
false);
646 auto body = ifOp.getBody();
647 rewriter.setInsertionPointToStart(body);
649 LLVM::StoreOp::create(rewriter, loc, adaptor.getValue(), basePtrLLVM);
651 storeOp.getOperation()->setAttr(
652 "cache_control", xevm::StoreCacheControlAttr::get(
653 ctxt, translateStoreXeGPUCacheHint(
654 op.getL1Hint(), op.getL3Hint())));
655 rewriter.eraseOp(op);
661class CreateMemDescOpPattern final
662 :
public OpConversionPattern<xegpu::CreateMemDescOp> {
664 using OpConversionPattern<xegpu::CreateMemDescOp>::OpConversionPattern;
666 matchAndRewrite(xegpu::CreateMemDescOp op, OpAdaptor adaptor,
667 ConversionPatternRewriter &rewriter)
const override {
669 rewriter.replaceOp(op, adaptor.getSource());
674template <
typename OpType,
675 typename = std::enable_if_t<llvm::is_one_of<
676 OpType, xegpu::LoadMatrixOp, xegpu::StoreMatrixOp>::value>>
677class LoadStoreMatrixToXeVMPattern :
public OpConversionPattern<OpType> {
678 using OpConversionPattern<OpType>::OpConversionPattern;
680 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
681 ConversionPatternRewriter &rewriter)
const override {
683 SmallVector<OpFoldResult> offsets = op.getMixedOffsets();
685 return rewriter.notifyMatchFailure(op,
"Expected offset to be provided.");
687 auto loc = op.getLoc();
688 auto ctxt = rewriter.getContext();
689 Value baseAddr32 = adaptor.getMemDesc();
690 Value mdescVal = op.getMemDesc();
693 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
694 Type resType = op.getResult().getType();
697 if (
auto vecType = dyn_cast<VectorType>(resType)) {
698 assert(llvm::count_if(vecType.getShape(),
699 [](int64_t d) { return d != 1; }) <= 1 &&
700 "Expected either 1D vector or nD with unit dimensions");
701 resType = VectorType::get({vecType.getNumElements()},
702 vecType.getElementType());
706 dataTy = adaptor.getData().getType();
707 VectorType valOrResVecTy = dyn_cast<VectorType>(dataTy);
709 valOrResVecTy = VectorType::get(1, dataTy);
711 int64_t elemBitWidth =
712 valOrResVecTy.getElementType().getIntOrFloatBitWidth();
714 if (elemBitWidth % 8 != 0)
715 return rewriter.notifyMatchFailure(
716 op,
"Expected element type bit width to be multiple of 8.");
717 int64_t elemByteSize = elemBitWidth / 8;
720 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
721 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::SLM));
723 auto mdescTy = cast<xegpu::MemDescType>(mdescVal.
getType());
725 Value linearOffset = mdescTy.getLinearOffsets(rewriter, loc, offsets);
726 linearOffset = arith::IndexCastUIOp::create(
727 rewriter, loc, rewriter.getI32Type(), linearOffset);
728 Value basePtrI32 = addOffsetToBaseAddr(rewriter, loc, baseAddr32,
729 linearOffset, elemByteSize);
733 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI32);
735 if (op.getSubgroupBlockIoAttr()) {
739 Type intElemTy = rewriter.getIntegerType(elemBitWidth);
740 VectorType intVecTy =
741 VectorType::get(valOrResVecTy.getShape(), intElemTy);
743 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
745 xevm::BlockLoadOp::create(rewriter, loc, intVecTy, basePtrLLVM);
746 if (intVecTy != valOrResVecTy) {
748 vector::BitCastOp::create(rewriter, loc, valOrResVecTy, loadOp);
750 rewriter.replaceOp(op, loadOp);
752 Value dataToStore = adaptor.getData();
753 if (valOrResVecTy != intVecTy) {
755 vector::BitCastOp::create(rewriter, loc, intVecTy, dataToStore);
757 xevm::BlockStoreOp::create(rewriter, loc, basePtrLLVM, dataToStore,
759 rewriter.eraseOp(op);
764 if (valOrResVecTy.getNumElements() >= 1) {
766 if (!chipOpt || (*chipOpt !=
"pvc" && *chipOpt !=
"bmg")) {
768 return rewriter.notifyMatchFailure(
769 op,
"The lowering is specific to pvc or bmg.");
773 if constexpr (std::is_same_v<OpType, xegpu::LoadMatrixOp>) {
777 auto scalarTy = valOrResVecTy.getElementType();
779 if (valOrResVecTy.getNumElements() == 1)
780 loadOp = LLVM::LoadOp::create(rewriter, loc, scalarTy, basePtrLLVM);
783 LLVM::LoadOp::create(rewriter, loc, valOrResVecTy, basePtrLLVM);
784 rewriter.replaceOp(op, loadOp);
786 LLVM::StoreOp::create(rewriter, loc, adaptor.getData(), basePtrLLVM);
787 rewriter.eraseOp(op);
793class PrefetchToXeVMPattern :
public OpConversionPattern<xegpu::PrefetchOp> {
794 using OpConversionPattern::OpConversionPattern;
796 matchAndRewrite(xegpu::PrefetchOp op, xegpu::PrefetchOp::Adaptor adaptor,
797 ConversionPatternRewriter &rewriter)
const override {
798 auto loc = op.getLoc();
799 auto ctxt = rewriter.getContext();
800 auto tdescTy = op.getTensorDescType();
801 Value basePtrI64 = adaptor.getSource();
803 if (basePtrI64.
getType() != rewriter.getI64Type())
804 basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
806 Value offsets = adaptor.getOffsets();
808 VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.
getType());
811 return rewriter.notifyMatchFailure(op,
812 "Expected offsets to be a scalar.");
814 int64_t elemBitWidth{0};
815 int64_t elemByteSize;
820 elemBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth();
821 }
else if (
auto memRefTy = dyn_cast<MemRefType>(op.getSourceType())) {
824 elemBitWidth = memRefTy.getElementType().getIntOrFloatBitWidth();
827 elemByteSize = *op.getOffsetAlignByte();
829 if (elemBitWidth != 0) {
830 if (elemBitWidth % 8 != 0)
831 return rewriter.notifyMatchFailure(
832 op,
"Expected element type bit width to be multiple of 8.");
833 elemByteSize = elemBitWidth / 8;
835 basePtrI64 = addOffsetToBaseAddr(rewriter, loc, basePtrI64, offsets,
840 LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
841 ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
844 ptrTypeLLVM = LLVM::LLVMPointerType::get(
845 ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
847 if (
auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
848 auto addrSpace = memRefTy.getMemorySpaceAsInt();
850 ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
854 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
856 xevm::PrefetchOp::create(
857 rewriter, loc, ptrLLVM,
858 xevm::LoadCacheControlAttr::get(
859 ctxt, translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint())));
860 rewriter.eraseOp(op);
865class FenceToXeVMPattern :
public OpConversionPattern<xegpu::FenceOp> {
866 using OpConversionPattern::OpConversionPattern;
868 matchAndRewrite(xegpu::FenceOp op, xegpu::FenceOp::Adaptor adaptor,
869 ConversionPatternRewriter &rewriter)
const override {
870 auto loc = op.getLoc();
871 xevm::MemScope memScope{xevm::MemScope::WORKGROUP};
872 switch (op.getFenceScope()) {
873 case xegpu::FenceScope::Workgroup:
874 memScope = xevm::MemScope::WORKGROUP;
876 case xegpu::FenceScope::GPU:
877 memScope = xevm::MemScope::DEVICE;
880 xevm::AddrSpace addrSpace{xevm::AddrSpace::GLOBAL};
881 switch (op.getMemoryKind()) {
882 case xegpu::MemorySpace::Global:
883 addrSpace = xevm::AddrSpace::GLOBAL;
885 case xegpu::MemorySpace::SLM:
886 addrSpace = xevm::AddrSpace::SHARED;
889 xevm::MemfenceOp::create(rewriter, loc, memScope, addrSpace);
890 rewriter.eraseOp(op);
895class DpasToXeVMPattern :
public OpConversionPattern<xegpu::DpasOp> {
896 using OpConversionPattern::OpConversionPattern;
898 matchAndRewrite(xegpu::DpasOp op, xegpu::DpasOp::Adaptor adaptor,
899 ConversionPatternRewriter &rewriter)
const override {
900 auto loc = op.getLoc();
901 auto ctxt = rewriter.getContext();
902 auto aTy = cast<VectorType>(op.getLhs().getType());
903 auto bTy = cast<VectorType>(op.getRhs().getType());
904 auto resultType = cast<VectorType>(op.getResultType());
909 return rewriter.notifyMatchFailure(op,
"cannot determine target chip");
913 return rewriter.notifyMatchFailure(op,
"unsupported target uArch");
916 llvm::dyn_cast_or_null<xegpu::uArch::SubgroupMatrixMultiplyAcc>(
917 uArch->getInstruction(
918 xegpu::uArch::InstructionKind::SubgroupMatrixMultiplyAcc)));
920 return rewriter.notifyMatchFailure(op,
921 "DPAS not supported by target uArch");
923 auto checkSupportedTypes = [&](VectorType vecTy,
925 auto supported = dpasInst->getSupportedTypes(*ctxt, kind);
926 return llvm::find(supported, vecTy.getElementType()) != supported.end();
929 if (!checkSupportedTypes(aTy, xegpu::uArch::MMAOpndKind::MatrixA))
930 return rewriter.notifyMatchFailure(
931 op,
"A-matrix element type not supported by target uArch");
932 if (!checkSupportedTypes(bTy, xegpu::uArch::MMAOpndKind::MatrixB))
933 return rewriter.notifyMatchFailure(
934 op,
"B-matrix element type not supported by target uArch");
936 if (!checkSupportedTypes(resultType, xegpu::uArch::MMAOpndKind::MatrixD))
937 return rewriter.notifyMatchFailure(
938 op,
"result/accumulator element type not supported by target uArch");
940 auto encodePrecision = [&](Type type) -> xevm::ElemType {
941 if (type == rewriter.getBF16Type())
942 return xevm::ElemType::BF16;
943 else if (type == rewriter.getF16Type())
944 return xevm::ElemType::F16;
945 else if (type == rewriter.getTF32Type())
946 return xevm::ElemType::TF32;
947 else if (type.isInteger(8)) {
948 if (type.isUnsignedInteger())
949 return xevm::ElemType::U8;
950 return xevm::ElemType::S8;
951 }
else if (type == rewriter.getF32Type())
952 return xevm::ElemType::F32;
953 else if (type.isInteger(32))
954 return xevm::ElemType::S32;
955 llvm_unreachable(
"add more support for ElemType");
957 xevm::ElemType precATy = encodePrecision(aTy.getElementType());
958 xevm::ElemType precBTy = encodePrecision(bTy.getElementType());
959 Value c = op.getAcc();
961 auto elementTy = resultType.getElementType();
962 Attribute initValueAttr;
963 if (isa<FloatType>(elementTy))
964 initValueAttr = FloatAttr::get(elementTy, 0.0);
966 initValueAttr = IntegerAttr::get(elementTy, 0);
967 c = arith::ConstantOp::create(
971 Value aVec = op.getLhs();
972 Value bVec = op.getRhs();
973 auto cvecty = cast<VectorType>(c.
getType());
974 xevm::ElemType precCTy = encodePrecision(cvecty.getElementType());
975 xevm::ElemType precDTy = encodePrecision(resultType.getElementType());
977 VectorType::get(cvecty.getNumElements(), cvecty.getElementType());
979 c = vector::ShapeCastOp::create(rewriter, loc, cNty, c);
980 Value dpasRes = xevm::MMAOp::create(
981 rewriter, loc, cNty, aVec, bVec, c,
982 xevm::MMAShapeAttr::get(ctxt, cvecty.getNumElements(), executionSize,
984 getNumOperandsPerDword(precATy)),
985 xevm::MMATypesAttr::get(ctxt, precDTy, precATy, precBTy, precCTy));
987 dpasRes = vector::ShapeCastOp::create(rewriter, loc, resultType, dpasRes);
988 rewriter.replaceOp(op, dpasRes);
993 static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
995 case xevm::ElemType::TF32:
997 case xevm::ElemType::BF16:
998 case xevm::ElemType::F16:
1000 case xevm::ElemType::U8:
1001 case xevm::ElemType::S8:
1004 llvm_unreachable(
"unsupported xevm::ElemType");
1009static std::optional<LLVM::AtomicBinOp>
1010matchSimpleAtomicOp(arith::AtomicRMWKind arithKind) {
1011 switch (arithKind) {
1012 case arith::AtomicRMWKind::addf:
1013 return LLVM::AtomicBinOp::fadd;
1014 case arith::AtomicRMWKind::addi:
1015 return LLVM::AtomicBinOp::add;
1016 case arith::AtomicRMWKind::assign:
1017 return LLVM::AtomicBinOp::xchg;
1018 case arith::AtomicRMWKind::maximumf:
1019 return LLVM::AtomicBinOp::fmax;
1020 case arith::AtomicRMWKind::maxs:
1021 return LLVM::AtomicBinOp::max;
1022 case arith::AtomicRMWKind::maxu:
1023 return LLVM::AtomicBinOp::umax;
1024 case arith::AtomicRMWKind::minimumf:
1025 return LLVM::AtomicBinOp::fmin;
1026 case arith::AtomicRMWKind::mins:
1027 return LLVM::AtomicBinOp::min;
1028 case arith::AtomicRMWKind::minu:
1029 return LLVM::AtomicBinOp::umin;
1030 case arith::AtomicRMWKind::ori:
1031 return LLVM::AtomicBinOp::_or;
1032 case arith::AtomicRMWKind::andi:
1033 return LLVM::AtomicBinOp::_and;
1035 return std::nullopt;
1039class AtomicRMWToXeVMPattern :
public OpConversionPattern<xegpu::AtomicRMWOp> {
1040 using OpConversionPattern::OpConversionPattern;
1042 matchAndRewrite(xegpu::AtomicRMWOp op, xegpu::AtomicRMWOp::Adaptor adaptor,
1043 ConversionPatternRewriter &rewriter)
const override {
1044 auto loc = op.getLoc();
1045 auto ctxt = rewriter.getContext();
1046 auto tdesc = op.getTensorDesc().getType();
1047 auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
1048 ctxt, getNumericXeVMAddrSpace(tdesc.getMemorySpace()));
1049 Value basePtrI64 = arith::IndexCastOp::create(
1050 rewriter, loc, rewriter.getI64Type(), adaptor.getTensorDesc());
1052 LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
1053 VectorType srcOrDstVecTy = cast<VectorType>(op.getValue().getType());
1054 VectorType srcOrDstFlatVecTy = VectorType::get(
1055 srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
1056 Value srcFlatVec = vector::ShapeCastOp::create(
1057 rewriter, loc, srcOrDstFlatVecTy, op.getValue());
1058 auto atomicKind = matchSimpleAtomicOp(op.getKind());
1059 assert(atomicKind.has_value());
1060 Value resVec = srcFlatVec;
1061 for (
int i = 0; i < srcOrDstVecTy.getNumElements(); i++) {
1062 auto val = vector::ExtractOp::create(rewriter, loc, resVec, i);
1063 Value idx = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
1064 rewriter.getIndexAttr(i));
1066 LLVM::GEPOp::create(rewriter, loc, ptrTypeLLVM,
1067 srcOrDstVecTy.getElementType(), basePtrLLVM, idx);
1069 LLVM::AtomicRMWOp::create(rewriter, loc, atomicKind.value(), currPtr,
1070 val, LLVM::AtomicOrdering::seq_cst);
1071 resVec = vector::InsertOp::create(rewriter, loc, newVal, resVec, i);
1073 rewriter.replaceOp(op, resVec);
1082struct ConvertXeGPUToXeVMPass
1086 void runOnOperation()
override {
1087 LLVMTypeConverter typeConverter(&
getContext());
1088 typeConverter.addConversion([&](VectorType type) -> Type {
1089 unsigned rank = type.getRank();
1090 auto elemType = type.getElementType();
1092 if (llvm::isa<IndexType>(elemType))
1093 elemType = IntegerType::get(&
getContext(), 64);
1095 if (rank == 0 || type.getNumElements() == 1)
1098 int64_t sum = llvm::product_of(type.getShape());
1099 return VectorType::get(sum, elemType);
1101 typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
1103 if (type.isScattered())
1105 if (type.getRank() == 1)
1107 auto i32Type = IntegerType::get(&
getContext(), 32);
1108 return VectorType::get(8, i32Type);
1111 typeConverter.addConversion([&](xegpu::MemDescType type) -> Type {
1115 typeConverter.addConversion([&](MemRefType type) -> Type {
1116 return IntegerType::get(&
getContext(), (isSharedMemRef(type) ? 32 : 64));
1126 auto memrefToIntMaterializationCast = [](OpBuilder &builder, Type type,
1128 Location loc) -> Value {
1129 if (inputs.size() != 1)
1131 auto input = inputs.front();
1132 if (
auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
1133 unsigned rank = memrefTy.getRank();
1137 SmallVector<int64_t> intStrides;
1140 if (succeeded(memrefTy.getStridesAndOffset(intStrides, intOffsets)) &&
1141 ShapedType::isStatic(intOffsets)) {
1142 addr = memref::ExtractAlignedPointerAsIndexOp::create(builder, loc,
1144 offset = arith::ConstantOp::create(builder, loc,
1150 SmallVector<Type> resultTypes{
1151 MemRefType::get({}, memrefTy.getElementType(),
1152 MemRefLayoutAttrInterface(),
1153 memrefTy.getMemorySpace()),
1156 resultTypes.append(2 * rank, indexType);
1158 auto meta = memref::ExtractStridedMetadataOp::create(
1159 builder, loc, resultTypes, input);
1161 addr = memref::ExtractAlignedPointerAsIndexOp::create(
1162 builder, loc, meta.getBaseBuffer());
1163 offset = meta.getOffset();
1167 arith::IndexCastUIOp::create(builder, loc, type, addr);
1169 arith::IndexCastUIOp::create(builder, loc, type, offset);
1172 auto byteSize = arith::ConstantOp::create(
1175 memrefTy.getElementTypeBitWidth() / 8));
1177 arith::MulIOp::create(builder, loc, offsetCasted, byteSize);
1178 auto addrWithOffset =
1179 arith::AddIOp::create(builder, loc, addrCasted, byteOffset);
1181 return addrWithOffset.getResult();
1190 auto ui64ToI64MaterializationCast = [](OpBuilder &builder, Type type,
1192 Location loc) -> Value {
1193 if (inputs.size() != 1)
1195 auto input = inputs.front();
1198 index::CastUOp::create(builder, loc, builder.
getIndexType(), input)
1200 return arith::IndexCastUIOp::create(builder, loc, type, cast)
1210 auto ui32ToI32MaterializationCast = [](OpBuilder &builder, Type type,
1212 Location loc) -> Value {
1213 if (inputs.size() != 1)
1215 auto input = inputs.front();
1218 index::CastUOp::create(builder, loc, builder.
getIndexType(), input)
1220 return arith::IndexCastUIOp::create(builder, loc, type, cast)
1230 auto vectorToVectorMaterializationCast = [](OpBuilder &builder, Type type,
1232 Location loc) -> Value {
1233 if (inputs.size() != 1)
1235 auto input = inputs.front();
1236 if (
auto vecTy = dyn_cast<VectorType>(input.getType())) {
1237 if (
auto targetVecTy = dyn_cast<VectorType>(type)) {
1240 if (targetVecTy.getRank() == vecTy.getRank())
1241 return vector::BitCastOp::create(builder, loc, targetVecTy, input)
1243 else if (targetVecTy.getElementType() == vecTy.getElementType()) {
1246 return vector::ShapeCastOp::create(builder, loc, targetVecTy, input)
1257 auto vectorToSingleElementMaterializationCast =
1258 [](OpBuilder &builder, Type type,
ValueRange inputs,
1259 Location loc) -> Value {
1260 if (inputs.size() != 1)
1262 auto input = inputs.front();
1263 if (
auto vecTy = dyn_cast<VectorType>(input.getType())) {
1264 if (type == vecTy.getElementType() ||
1266 type.isInteger())) {
1269 auto rank = vecTy.getRank();
1273 vector::ExtractOp::create(builder, loc, input, {}).getResult();
1275 cast = vector::ExtractOp::create(builder, loc, input,
1276 SmallVector<int64_t>(rank, 0))
1279 if (type != vecTy.getElementType())
1280 cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
1294 auto singleElementToVectorMaterializationCast =
1295 [](OpBuilder &builder, Type type,
ValueRange inputs,
1296 Location loc) -> Value {
1297 if (inputs.size() != 1)
1299 auto input = inputs.front();
1302 if (
auto vecTy = dyn_cast<VectorType>(type)) {
1303 if (vecTy.getRank() == 0 || vecTy.getNumElements() == 1) {
1304 if (input.getType() == vecTy.getElementType()) {
1305 return vector::BroadcastOp::create(builder, loc, vecTy, input)
1307 }
else if (vecTy.getElementType() == builder.
getIndexType()) {
1308 Value cast = arith::IndexCastUIOp::create(
1311 return vector::BroadcastOp::create(builder, loc, vecTy, cast)
1318 typeConverter.addSourceMaterialization(
1319 singleElementToVectorMaterializationCast);
1320 typeConverter.addSourceMaterialization(vectorToVectorMaterializationCast);
1321 typeConverter.addTargetMaterialization(memrefToIntMaterializationCast);
1322 typeConverter.addTargetMaterialization(ui32ToI32MaterializationCast);
1323 typeConverter.addTargetMaterialization(ui64ToI64MaterializationCast);
1324 typeConverter.addTargetMaterialization(
1325 vectorToSingleElementMaterializationCast);
1326 typeConverter.addTargetMaterialization(vectorToVectorMaterializationCast);
1328 target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
1329 vector::VectorDialect, arith::ArithDialect,
1330 memref::MemRefDialect, gpu::GPUDialect,
1331 index::IndexDialect>();
1332 target.addIllegalDialect<xegpu::XeGPUDialect>();
1338 if (
failed(applyPartialConversion(getOperation(),
target,
1339 std::move(patterns))))
1340 signalPassFailure();
1350 patterns.
add<CreateNdDescToXeVMPattern,
1351 LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
1352 LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
1353 LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>(
1355 patterns.
add<AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
1356 LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
1357 LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
1359 patterns.
add<LoadStoreMatrixToXeVMPattern<xegpu::LoadMatrixOp>,
1360 LoadStoreMatrixToXeVMPattern<xegpu::StoreMatrixOp>,
1361 CreateMemDescOpPattern>(typeConverter, patterns.
getContext());
1362 patterns.
add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
Attributes are known-constant values of operations.
IntegerAttr getIndexAttr(int64_t value)
IntegerAttr getIntegerAttr(Type type, int64_t value)
IntegerType getIntegerType(unsigned width)
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
@ SubgroupMatrixMultiplyAcc
const uArch * getUArch(llvm::StringRef archName)
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Include the generated interface declarations.
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
void populateXeGPUToXeVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)