18 #include "llvm/Support/FormatVariadic.h"
23 #include "llvm/ADT/TypeSwitch.h"
26 #define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
27 #include "mlir/Conversion/Passes.h.inc"
35 struct LLVMFuncAttributeOptions {
36 bool isConvergent =
false;
37 bool isNoUnwind =
false;
38 bool isWillReturn =
false;
39 LLVM::MemoryEffectsAttr memEffectsAttr{};
41 static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
42 false,
true,
false, {}};
43 static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
44 false,
true,
true, {}};
45 static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
46 true,
true,
true, {}};
48 std::string getTypeMangling(
Type ty,
bool isUnsigned =
false) {
50 .Case([isUnsigned](VectorType ty) -> std::string {
51 return "Dv" + std::to_string(ty.getNumElements()) +
"_" +
52 getTypeMangling(ty.getElementType(), isUnsigned);
54 .Case([](Float16Type) -> std::string {
return "Dh"; })
55 .Case([](Float32Type) -> std::string {
return "f"; })
56 .Case([](Float64Type) -> std::string {
return "d"; })
57 .Case([isUnsigned](IntegerType ty) -> std::string {
58 switch (ty.getWidth()) {
60 return isUnsigned ?
"h" :
"c";
62 return isUnsigned ?
"t" :
"s";
64 return isUnsigned ?
"j" :
"i";
66 return isUnsigned ?
"m" :
"l";
68 llvm_unreachable(
"unhandled integer type");
71 .DefaultUnreachable(
"unhandled type for mangling");
76 assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
77 "Signedness info doesn't match");
79 llvm::raw_string_ostream os(s);
80 llvm::SmallDenseMap<Type, unsigned> substitutions;
81 os <<
"_Z" << baseName.size() << baseName;
83 auto it = substitutions.find(type);
84 if (it != substitutions.end()) {
87 if (
unsigned firstIdx = it->getSecond(); firstIdx > 0)
91 if (!type.isIntOrFloat())
92 substitutions[type] = substitutions.size();
93 os << getTypeMangling(type, isUnsigned.empty() ?
false : isUnsigned[idx]);
99 static int32_t getL1CacheControl(LoadCacheControl cc) {
102 case LoadCacheControl::L1UC_L2UC_L3UC:
103 case LoadCacheControl::L1UC_L2UC_L3C:
104 case LoadCacheControl::L1UC_L2C_L3UC:
105 case LoadCacheControl::L1UC_L2C_L3C:
108 case LoadCacheControl::L1C_L2UC_L3UC:
109 case LoadCacheControl::L1C_L2UC_L3C:
110 case LoadCacheControl::L1C_L2C_L3UC:
111 case LoadCacheControl::L1C_L2C_L3C:
114 case LoadCacheControl::L1S_L2UC_L3UC:
115 case LoadCacheControl::L1S_L2UC_L3C:
116 case LoadCacheControl::L1S_L2C_L3UC:
117 case LoadCacheControl::L1S_L2C_L3C:
120 case LoadCacheControl::INVALIDATE_READ:
127 static int32_t getL1CacheControl(StoreCacheControl cc) {
130 case StoreCacheControl::L1UC_L2UC_L3UC:
131 case StoreCacheControl::L1UC_L2UC_L3WB:
132 case StoreCacheControl::L1UC_L2WB_L3UC:
133 case StoreCacheControl::L1UC_L2WB_L3WB:
136 case StoreCacheControl::L1WT_L2UC_L3UC:
137 case StoreCacheControl::L1WT_L2UC_L3WB:
138 case StoreCacheControl::L1WT_L2WB_L3UC:
139 case StoreCacheControl::L1WT_L2WB_L3WB:
142 case StoreCacheControl::L1S_L2UC_L3UC:
143 case StoreCacheControl::L1S_L2UC_L3WB:
144 case StoreCacheControl::L1S_L2WB_L3UC:
145 case StoreCacheControl::L1S_L2WB_L3WB:
148 case StoreCacheControl::L1WB_L2UC_L3UC:
149 case StoreCacheControl::L1WB_L2WB_L3UC:
150 case StoreCacheControl::L1WB_L2UC_L3WB:
157 static int32_t getL3CacheControl(LoadCacheControl cc) {
160 case LoadCacheControl::L1UC_L2UC_L3UC:
161 case LoadCacheControl::L1UC_L2C_L3UC:
162 case LoadCacheControl::L1C_L2UC_L3UC:
163 case LoadCacheControl::L1C_L2C_L3UC:
164 case LoadCacheControl::L1S_L2UC_L3UC:
165 case LoadCacheControl::L1S_L2C_L3UC:
168 case LoadCacheControl::L1UC_L2UC_L3C:
169 case LoadCacheControl::L1UC_L2C_L3C:
170 case LoadCacheControl::L1C_L2UC_L3C:
171 case LoadCacheControl::L1C_L2C_L3C:
172 case LoadCacheControl::L1S_L2UC_L3C:
173 case LoadCacheControl::L1S_L2C_L3C:
176 case LoadCacheControl::INVALIDATE_READ:
183 static int32_t getL3CacheControl(StoreCacheControl cc) {
186 case StoreCacheControl::L1UC_L2UC_L3UC:
187 case StoreCacheControl::L1UC_L2WB_L3UC:
188 case StoreCacheControl::L1WT_L2UC_L3UC:
189 case StoreCacheControl::L1WT_L2WB_L3UC:
190 case StoreCacheControl::L1S_L2UC_L3UC:
191 case StoreCacheControl::L1S_L2WB_L3UC:
192 case StoreCacheControl::L1WB_L2UC_L3UC:
193 case StoreCacheControl::L1WB_L2WB_L3UC:
196 case StoreCacheControl::L1UC_L2UC_L3WB:
197 case StoreCacheControl::L1UC_L2WB_L3WB:
198 case StoreCacheControl::L1WT_L2UC_L3WB:
199 case StoreCacheControl::L1WT_L2WB_L3WB:
200 case StoreCacheControl::L1S_L2UC_L3WB:
201 case StoreCacheControl::L1S_L2WB_L3WB:
202 case StoreCacheControl::L1WB_L2UC_L3WB:
209 static std::optional<LoadCacheControl> getCacheControl(PrefetchOp op) {
210 return op.getCacheControl();
213 static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) {
214 return op.getCacheControl();
217 static std::optional<LoadCacheControl> getCacheControl(BlockLoadOp op) {
218 return op.getCacheControl();
221 static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) {
222 return op.getCacheControl();
225 static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) {
226 return op.getCacheControl();
229 static std::optional<StoreCacheControl> getCacheControl(BlockStoreOp op) {
230 return op.getCacheControl();
233 static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) {
234 if (op->hasAttr(
"cache_control")) {
235 auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>(
"cache_control");
238 return std::optional<LoadCacheControl>(attr.getValue());
243 static std::optional<StoreCacheControl> getCacheControl(LLVM::StoreOp op) {
244 if (op->hasAttr(
"cache_control")) {
245 auto attr = op->getAttrOfType<xevm::StoreCacheControlAttr>(
"cache_control");
248 return std::optional<StoreCacheControl>(attr.getValue());
253 template <
typename OpType>
254 int32_t getL1CacheControl(OpType op) {
255 return getL1CacheControl(*getCacheControl(op));
258 template <
typename OpType>
259 int32_t getL3CacheControl(OpType op) {
260 return getL3CacheControl(*getCacheControl(op));
263 template <
typename OpType>
264 static std::optional<ArrayAttr>
266 if (!getCacheControl(op))
268 constexpr int32_t decorationCacheControlArity{4};
269 constexpr int32_t loadCacheControlKey{6442};
270 constexpr int32_t storeCacheControlKey{6443};
271 constexpr
bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
272 std::is_same_v<OpType, BlockPrefetch2dOp> ||
273 std::is_same_v<OpType, LLVM::LoadOp> ||
274 std::is_same_v<OpType, BlockLoadOp> ||
275 std::is_same_v<OpType, PrefetchOp>;
276 const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
278 controlKey, 0, getL1CacheControl<OpType>(op), 0};
280 controlKey, 1, getL3CacheControl<OpType>(op), 0};
288 static LLVM::CallOp createDeviceFunctionCall(
292 LLVMFuncAttributeOptions funcAttributeOptions,
Operation *op) {
294 assert(moduleOp &&
"Expecting module");
299 assert(!
failed(funcOpRes));
300 LLVM::LLVMFuncOp funcOp = funcOpRes.value();
301 funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
302 funcOp.setConvergent(funcAttributeOptions.isConvergent);
303 funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
304 funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
306 if (funcAttributeOptions.memEffectsAttr)
307 funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
309 for (
auto [idx, attrName] : paramAttrs)
310 funcOp.setArgAttr(idx, attrName, rewriter.
getUnitAttr());
312 auto callOp = LLVM::CallOp::create(rewriter, loc, funcOp, args);
313 callOp->setAttrs(funcOp->getAttrs());
321 matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor,
326 auto precisionA = op.getTypes().getA();
327 auto precisionB = op.getTypes().getB();
328 auto precisionC = op.getTypes().getC();
329 auto precisionD = op.getTypes().getD();
330 if (precisionC != precisionD) {
333 if (precisionC != xevm::ElemType::S32 &&
334 precisionC != xevm::ElemType::F32 &&
335 precisionC != xevm::ElemType::F16 &&
336 precisionC != xevm::ElemType::BF16) {
338 op,
"type of C and D must be S32, F32, F16 or BF16");
340 if (precisionA == xevm::ElemType::S32 ||
341 precisionA == xevm::ElemType::F32) {
344 if (precisionB == xevm::ElemType::S32 ||
345 precisionB == xevm::ElemType::F32) {
348 constexpr uint32_t bitWidthPackedA{16};
349 constexpr uint32_t bitWidthPackedB{32};
350 auto loc = op.getLoc();
353 VectorType origTy = cast<VectorType>(val.
getType());
354 const uint32_t vecBitSize =
355 origTy.getNumElements() *
356 origTy.getElementType().getIntOrFloatBitWidth();
358 vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
360 val = LLVM::BitcastOp::create(rewriter, loc, newTy, val);
365 Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
368 a = castIfNeeded(a, packedAType);
371 Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
374 b = castIfNeeded(b, packedBType);
377 VectorType cOrigTy = cast<VectorType>(c.
getType());
378 VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
379 assert(cOrigTy == resOrigTy &&
"Accumulator and result type mismatch");
382 cOrigTy.getElementType().isBF16()
385 VectorType resTy = cTy;
387 c = LLVM::BitcastOp::create(rewriter, loc, cTy, c);
389 constexpr int32_t systolicDepth{8};
391 llvm::formatv(
"intel_sub_group_{0}_{1}_matrix_mad_k{2}",
392 stringifyElemType(op.getTypes().getA()).str(),
393 stringifyElemType(op.getTypes().getB()).str(),
395 getNumOperandsPerDword(op.getTypes().getA()))
398 fnName = mangle(fnName, argTypes);
401 auto memAttr = rewriter.
getAttr<LLVM::MemoryEffectsAttr>(
402 LLVM::ModRefInfo::NoModRef,
403 LLVM::ModRefInfo::NoModRef,
404 LLVM::ModRefInfo::NoModRef);
405 auto funcAttrs = convergentNoUnwindWillReturnAttrs;
406 funcAttrs.memEffectsAttr = memAttr;
408 createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
409 funcAttrs, op.getOperation())
412 if (resOrigTy != resTy)
413 result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy, result);
420 static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
422 case xevm::ElemType::TF32:
424 case xevm::ElemType::BF16:
425 case xevm::ElemType::F16:
427 case xevm::ElemType::U8:
428 case xevm::ElemType::S8:
431 llvm_unreachable(
"unsupported xevm::ElemType");
439 matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor,
441 auto loc = op.getLoc();
442 const std::string fnName{
"_Z8prefetchPU3AS1Kcm"};
444 LLVM::ConstantOp::create(rewriter, loc, rewriter.
getI64Type(), 1);
447 for (
auto arg : args)
448 argTypes.push_back(arg.getType());
449 auto funcAttr = noUnwindAttrs;
450 auto memAttr = rewriter.
getAttr<LLVM::MemoryEffectsAttr>(
451 LLVM::ModRefInfo::NoModRef,
452 LLVM::ModRefInfo::Ref,
453 LLVM::ModRefInfo::NoModRef);
454 funcAttr.memEffectsAttr = memAttr;
456 LLVM::CallOp call = createDeviceFunctionCall(
458 argTypes, args, {}, funcAttr, op.getOperation());
459 if (std::optional<ArrayAttr> optCacheControls =
460 getCacheControlMetadata(rewriter, op))
461 call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
470 matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor,
472 auto loc = op.getLoc();
473 const std::string fnName{
"atomic_work_item_fence"};
474 int memScope, addrSpace;
475 switch (op.getAddrspace()) {
476 case xevm::AddrSpace::SHARED:
479 case xevm::AddrSpace::GLOBAL:
485 op,
"Fence only supports global and shared address spaces.");
487 switch (op.getScope()) {
488 case xevm::MemScope::WORKGROUP:
491 case xevm::MemScope::DEVICE:
497 op,
"Fence only supports workgroup and device memory scopes.");
500 Value acqRel = LLVM::ConstantOp::create(rewriter, loc, i32Type, 4);
501 Value memScopeConst =
502 LLVM::ConstantOp::create(rewriter, loc, i32Type, memScope);
503 Value addrSpaceConst =
504 LLVM::ConstantOp::create(rewriter, loc, i32Type, addrSpace);
507 createDeviceFunctionCall(rewriter, mangle(fnName, argTypes),
509 argTypes, args, {}, noUnwindAttrs,
515 template <
typename OpType>
519 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
521 constexpr
bool isLoad = std::is_same_v<OpType, BlockLoad2dOp>;
522 constexpr
bool isPrefetch = std::is_same_v<OpType, BlockPrefetch2dOp>;
524 auto loc = op.getLoc();
526 bool packReg =
false;
527 bool transpose =
false;
528 if constexpr (isLoad) {
529 vecType = op.getRes().getType();
530 packReg = op.getPackRegister();
531 transpose = op.getTranspose();
532 }
else if constexpr (!isPrefetch) {
533 vecType = op.getStoredVal().getType();
539 Value zero = LLVM::ConstantOp::create(rewriter, loc, i32Type, 0);
540 Value one = LLVM::ConstantOp::create(rewriter, loc, i32Type, 1);
541 byteCoord = LLVM::InsertElementOp::create(
542 rewriter, loc,
VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
543 byteCoord = LLVM::InsertElementOp::create(
544 rewriter, loc,
VectorType::get(2, i32Type), byteCoord, op.getY(), one);
546 op.getBasePitch(), byteCoord};
549 std::string funcName{
"intel_sub_group_2d_block_"};
550 std::string bitWidthId;
551 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
553 if constexpr (isPrefetch) {
554 funcName +=
"prefetch";
555 paramAttrs = {std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName())};
556 auto memAttr = rewriter.
getAttr<LLVM::MemoryEffectsAttr>(
557 LLVM::ModRefInfo::NoModRef,
558 LLVM::ModRefInfo::Ref,
559 LLVM::ModRefInfo::NoModRef);
560 funcAttr = noUnwindAttrs;
561 funcAttr.memEffectsAttr = memAttr;
563 auto vecElemType = vecType.getElementType();
564 auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth();
565 Value numElems = LLVM::ConstantOp::create(rewriter, loc, i32Type,
566 vecType.getNumElements());
567 auto dstOrSrcPtr = LLVM::AllocaOp::create(
569 vecElemType, numElems);
570 args.push_back(dstOrSrcPtr);
571 if constexpr (isLoad) {
573 bitWidthId = getTypeMangling(vecElemType,
true);
575 funcName +=
"_transform";
577 funcName +=
"_transpose";
578 spvLoadDstPtr = dstOrSrcPtr;
579 retTypes.push_back(vecType);
581 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
582 std::make_pair(0, LLVM::LLVMDialect::getReadonlyAttrName()),
583 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
584 std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()),
588 bitWidthId = (vecElemBitWidth == 32)
590 : ((vecElemBitWidth == 16) ?
"t" :
"h");
591 LLVM::StoreOp::create(rewriter, loc, op.getStoredVal(), dstOrSrcPtr);
593 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
594 std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()),
595 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
596 std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()),
602 llvm::formatv(
"{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(),
603 op.getTileHeight(), op.getTileWidth(), op.getVBlocks())
605 std::string prefetchCode(
"");
608 funcName = llvm::formatv(
"_Z{0}{1}PU3AS1viiiDv2_i{2}{3}", funcName.size(),
609 funcName, prefetchCode, bitWidthId)
612 for (
auto arg : args) {
613 argTypes.push_back(arg.getType());
615 LLVM::CallOp call = createDeviceFunctionCall(
617 argTypes, args, paramAttrs, funcAttr, op.getOperation());
618 if (std::optional<ArrayAttr> optCacheControls =
619 getCacheControlMetadata(rewriter, op)) {
620 call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
622 if constexpr (isLoad)
624 op, LLVM::LoadOp::create(rewriter, loc, vecType, spvLoadDstPtr));
631 template <
typename OpType>
635 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
637 constexpr
bool isStore = std::is_same_v<OpType, xevm::BlockStoreOp>;
641 std::string funcName{
"intel_sub_group_block_"};
644 if constexpr (isStore) {
645 funcName +=
"write_u";
646 valOrResTy = op.getVal().getType();
648 funcName +=
"read_u";
649 valOrResTy = op.getType();
652 VectorType vecTy = dyn_cast<VectorType>(valOrResTy);
653 Type elemType = vecTy ? vecTy.getElementType() : valOrResTy;
654 funcName += getTypeMangling(elemType);
656 funcName += std::to_string(vecTy.getNumElements());
666 args.push_back(op.getPtr());
667 argTypes.push_back(op.getPtr().getType());
668 isUnsigned.push_back(
true);
670 if constexpr (isStore) {
671 args.push_back(op.getVal());
672 argTypes.push_back(op.getVal().getType());
673 isUnsigned.push_back(
true);
676 retType = valOrResTy;
678 funcName = std::string(
"_Z") + std::to_string(funcName.size()) + funcName +
680 std::to_string(op.getPtr().getType().getAddressSpace());
681 funcName += getTypeMangling(elemType,
true);
682 if constexpr (isStore)
683 funcName += getTypeMangling(valOrResTy,
true);
684 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
687 createDeviceFunctionCall(rewriter, funcName, retType, argTypes, args,
688 {}, funcAttr, op.getOperation());
689 if (std::optional<ArrayAttr> optCacheControls =
690 getCacheControlMetadata(rewriter, op)) {
691 call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
693 if constexpr (isStore)
696 rewriter.
replaceOp(op, call->getResult(0));
701 template <
typename OpType>
705 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
707 if (!op->hasAttr(
"cache_control"))
709 std::optional<ArrayAttr> optCacheControls =
710 getCacheControlMetadata(rewriter, op);
711 op->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
712 op->removeAttr(
"cache_control");
744 static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdXOp) {
745 return {
"get_local_id", 0};
747 static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdYOp) {
748 return {
"get_local_id", 1};
750 static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdZOp) {
751 return {
"get_local_id", 2};
753 static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimXOp) {
754 return {
"get_local_size", 0};
756 static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimYOp) {
757 return {
"get_local_size", 1};
759 static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimZOp) {
760 return {
"get_local_size", 2};
762 static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdXOp) {
763 return {
"get_group_id", 0};
765 static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdYOp) {
766 return {
"get_group_id", 1};
768 static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdZOp) {
769 return {
"get_group_id", 2};
771 static std::pair<StringRef, int64_t> getConfig(xevm::GridDimXOp) {
772 return {
"get_num_groups", 0};
774 static std::pair<StringRef, int64_t> getConfig(xevm::GridDimYOp) {
775 return {
"get_num_groups", 1};
777 static std::pair<StringRef, int64_t> getConfig(xevm::GridDimZOp) {
778 return {
"get_num_groups", 2};
782 template <
typename OpType>
786 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
789 auto [baseName, dim] = getConfig(op);
791 Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
792 static_cast<int64_t
>(dim));
793 std::string func = mangle(baseName, {dimTy}, {
true});
794 Type resTy = op.getType();
796 createDeviceFunctionCall(rewriter, func, resTy, {dimTy}, {dimVal}, {},
797 noUnwindWillReturnAttrs, op.getOperation());
798 constexpr
auto noModRef = LLVM::ModRefInfo::NoModRef;
799 auto memAttr = rewriter.
getAttr<LLVM::MemoryEffectsAttr>(
802 call.setMemoryEffectsAttr(memAttr);
820 static StringRef getConfig(xevm::LaneIdOp) {
return "get_sub_group_local_id"; }
821 static StringRef getConfig(xevm::SubgroupIdOp) {
return "get_sub_group_id"; }
822 static StringRef getConfig(xevm::SubgroupSizeOp) {
823 return "get_sub_group_size";
825 template <
typename OpType>
829 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
831 std::string func = mangle(getConfig(op).str(), {});
832 Type resTy = op.getType();
834 createDeviceFunctionCall(rewriter, func, resTy, {}, {}, {},
835 noUnwindWillReturnAttrs, op.getOperation());
836 constexpr
auto noModRef = LLVM::ModRefInfo::NoModRef;
837 auto memAttr = rewriter.
getAttr<LLVM::MemoryEffectsAttr>(
840 call.setMemoryEffectsAttr(memAttr);
850 struct ConvertXeVMToLLVMPass
851 :
public impl::ConvertXeVMToLLVMPassBase<ConvertXeVMToLLVMPass> {
855 registry.
insert<LLVM::LLVMDialect, XeVMDialect>();
858 void runOnOperation()
override {
877 void loadDependentDialects(
MLIRContext *context)
const final {
878 context->loadDialect<LLVM::LLVMDialect>();
883 void populateConvertToLLVMConversionPatterns(
898 [](
Operation *op) {
return !op->hasAttr(
"cache_control"); });
900 patterns.add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
901 LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
902 LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
903 MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern,
904 LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
905 LLVMLoadStoreToOCLPattern<LLVM::StoreOp>,
906 BlockLoadStore1DToOCLPattern<BlockLoadOp>,
907 BlockLoadStore1DToOCLPattern<BlockStoreOp>,
908 LaunchConfigOpToOCLPattern<WorkitemIdXOp>,
909 LaunchConfigOpToOCLPattern<WorkitemIdYOp>,
910 LaunchConfigOpToOCLPattern<WorkitemIdZOp>,
911 LaunchConfigOpToOCLPattern<WorkgroupDimXOp>,
912 LaunchConfigOpToOCLPattern<WorkgroupDimYOp>,
913 LaunchConfigOpToOCLPattern<WorkgroupDimZOp>,
914 LaunchConfigOpToOCLPattern<WorkgroupIdXOp>,
915 LaunchConfigOpToOCLPattern<WorkgroupIdYOp>,
916 LaunchConfigOpToOCLPattern<WorkgroupIdZOp>,
917 LaunchConfigOpToOCLPattern<GridDimXOp>,
918 LaunchConfigOpToOCLPattern<GridDimYOp>,
919 LaunchConfigOpToOCLPattern<GridDimZOp>,
920 SubgroupOpWorkitemOpToOCLPattern<LaneIdOp>,
921 SubgroupOpWorkitemOpToOCLPattern<SubgroupIdOp>,
922 SubgroupOpWorkitemOpToOCLPattern<SubgroupSizeOp>>(
928 dialect->addInterfaces<XeVMToLLVMDialectInterface>();
static MLIRContext * getContext(OpFoldResult val)
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addDynamicallyLegalDialect(const DynamicLegalityCallbackFn &callback, StringRef name, Names... names)
Register the operations of the given dialects as dynamically legal, i.e.
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes={}, Type resultType={}, bool isVarArg=false, bool isReserved=false, SymbolTableCollection *symbolTables=nullptr)
Create a FuncOp with signature resultType(paramTypes)and namename`.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
void populateXeVMToLLVMConversionPatterns(ConversionTarget &target, RewritePatternSet &patterns)
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void registerConvertXeVMToLLVMInterface(DialectRegistry ®istry)