18#include "llvm/ADT/ArrayRef.h"
19#include "llvm/Support/FormatVariadic.h"
26#include "llvm/ADT/TypeSwitch.h"
29#define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
30#include "mlir/Conversion/Passes.h.inc"
38struct LLVMFuncAttributeOptions {
39 bool isConvergent =
false;
40 bool isNoUnwind =
false;
41 bool isWillReturn =
false;
42 LLVM::MemoryEffectsAttr memEffectsAttr{};
44static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
45 false,
true,
false, {}};
46static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
47 false,
true,
true, {}};
48static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
49 true,
true,
true, {}};
51std::string getTypeMangling(
Type ty,
bool isUnsigned =
false) {
53 .Case([isUnsigned](VectorType ty) -> std::string {
54 return "Dv" + std::to_string(ty.getNumElements()) +
"_" +
55 getTypeMangling(ty.getElementType(), isUnsigned);
57 .Case([](Float16Type) -> std::string {
return "Dh"; })
58 .Case([](Float32Type) -> std::string {
return "f"; })
59 .Case([](Float64Type) -> std::string {
return "d"; })
60 .Case([isUnsigned](IntegerType ty) -> std::string {
61 switch (ty.getWidth()) {
63 return isUnsigned ?
"h" :
"c";
65 return isUnsigned ?
"t" :
"s";
67 return isUnsigned ?
"j" :
"i";
69 return isUnsigned ?
"m" :
"l";
71 llvm_unreachable(
"unhandled integer type");
74 .DefaultUnreachable(
"unhandled type for mangling");
79 assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
80 "Signedness info doesn't match");
82 llvm::raw_string_ostream os(s);
83 llvm::SmallDenseMap<Type, unsigned> substitutions;
84 os <<
"_Z" << baseName.size() << baseName;
85 for (
auto [idx, type] : llvm::enumerate(types)) {
86 auto it = substitutions.find(type);
87 if (it != substitutions.end()) {
90 if (
unsigned firstIdx = it->getSecond(); firstIdx > 0)
94 if (!type.isIntOrFloat())
95 substitutions[type] = substitutions.size();
96 os << getTypeMangling(type, isUnsigned.empty() ?
false : isUnsigned[idx]);
102static int32_t getL1CacheControl(LoadCacheControl cc) {
105 case LoadCacheControl::L1C_L2UC_L3UC:
106 case LoadCacheControl::L1C_L2UC_L3C:
107 case LoadCacheControl::L1C_L2C_L3UC:
108 case LoadCacheControl::L1C_L2C_L3C:
111 case LoadCacheControl::L1S_L2UC_L3UC:
112 case LoadCacheControl::L1S_L2UC_L3C:
113 case LoadCacheControl::L1S_L2C_L3UC:
114 case LoadCacheControl::L1S_L2C_L3C:
117 case LoadCacheControl::INVALIDATE_READ:
126static int32_t getL1CacheControl(StoreCacheControl cc) {
129 case StoreCacheControl::L1WT_L2UC_L3UC:
130 case StoreCacheControl::L1WT_L2UC_L3WB:
131 case StoreCacheControl::L1WT_L2WB_L3UC:
132 case StoreCacheControl::L1WT_L2WB_L3WB:
135 case StoreCacheControl::L1WB_L2UC_L3UC:
136 case StoreCacheControl::L1WB_L2WB_L3UC:
137 case StoreCacheControl::L1WB_L2UC_L3WB:
140 case StoreCacheControl::L1S_L2UC_L3UC:
141 case StoreCacheControl::L1S_L2UC_L3WB:
142 case StoreCacheControl::L1S_L2WB_L3UC:
143 case StoreCacheControl::L1S_L2WB_L3WB:
152static int32_t getL3CacheControl(LoadCacheControl cc) {
155 case LoadCacheControl::L1UC_L2UC_L3C:
156 case LoadCacheControl::L1UC_L2C_L3C:
157 case LoadCacheControl::L1C_L2UC_L3C:
158 case LoadCacheControl::L1C_L2C_L3C:
159 case LoadCacheControl::L1S_L2UC_L3C:
160 case LoadCacheControl::L1S_L2C_L3C:
163 case LoadCacheControl::INVALIDATE_READ:
172static int32_t getL3CacheControl(StoreCacheControl cc) {
175 case StoreCacheControl::L1UC_L2UC_L3WB:
176 case StoreCacheControl::L1UC_L2WB_L3WB:
177 case StoreCacheControl::L1WT_L2UC_L3WB:
178 case StoreCacheControl::L1WT_L2WB_L3WB:
179 case StoreCacheControl::L1S_L2UC_L3WB:
180 case StoreCacheControl::L1S_L2WB_L3WB:
181 case StoreCacheControl::L1WB_L2UC_L3WB:
190static std::optional<LoadCacheControl> getCacheControl(PrefetchOp op) {
191 return op.getCacheControl();
194static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) {
195 return op.getCacheControl();
198static std::optional<LoadCacheControl> getCacheControl(BlockLoadOp op) {
199 return op.getCacheControl();
202static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) {
203 return op.getCacheControl();
206static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) {
207 return op.getCacheControl();
210static std::optional<StoreCacheControl> getCacheControl(BlockStoreOp op) {
211 return op.getCacheControl();
214static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) {
215 if (op->hasAttr(
"cache_control")) {
216 auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>(
"cache_control");
219 return std::optional<LoadCacheControl>(attr.getValue());
224static std::optional<StoreCacheControl> getCacheControl(LLVM::StoreOp op) {
225 if (op->hasAttr(
"cache_control")) {
226 auto attr = op->getAttrOfType<xevm::StoreCacheControlAttr>(
"cache_control");
229 return std::optional<StoreCacheControl>(attr.getValue());
234template <
typename OpType>
235int32_t getL1CacheControl(OpType op) {
236 return getL1CacheControl(*getCacheControl(op));
239template <
typename OpType>
240int32_t getL3CacheControl(OpType op) {
241 return getL3CacheControl(*getCacheControl(op));
244template <
typename OpType>
245static std::optional<ArrayAttr>
246getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
247 if (!getCacheControl(op))
249 constexpr int32_t decorationCacheControlArity{3};
250 constexpr int32_t loadCacheControlKey{6442};
251 constexpr int32_t storeCacheControlKey{6443};
252 constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
253 std::is_same_v<OpType, BlockPrefetch2dOp> ||
254 std::is_same_v<OpType, LLVM::LoadOp> ||
255 std::is_same_v<OpType, BlockLoadOp> ||
256 std::is_same_v<OpType, PrefetchOp>;
257 const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
259 controlKey, 0, getL1CacheControl<OpType>(op)};
261 controlKey, 1, getL3CacheControl<OpType>(op)};
262 auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
263 auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
266 return rewriter.getArrayAttr(combinedAttrs);
288 llvm::StringMap<bool> seen;
291 auto arr = dyn_cast<ArrayAttr>(a);
295 auto vals = arr.getValue();
296 assert(vals.size() == 3 &&
297 "Expected exactly 3 integer values (Token, CacheLevel, "
298 "ControlValue) in cache control attribute.");
300 auto tokenAttr = dyn_cast<IntegerAttr>(vals[0]);
301 auto secondAttr = dyn_cast<IntegerAttr>(vals[1]);
302 auto thirdAttr = dyn_cast<IntegerAttr>(vals[2]);
304 if (!tokenAttr || !secondAttr || !thirdAttr)
310 llvm::formatv(
"{{{0}:\"{1},{2}\"}", tokenAttr.getValue().getZExtValue(),
311 secondAttr.getValue().getZExtValue(),
312 thirdAttr.getValue().getZExtValue());
315 if (!seen.insert({entry, true}).second)
318 payloads.push_back(std::move(entry));
323static std::atomic<uint64_t> globalNameCounter{0};
328static Value createMetadataStringPtr(ConversionPatternRewriter &rewriter,
330 StringRef value, StringRef nameHint) {
332 std::string strWithNull = value.str();
333 strWithNull.push_back(
'\0');
334 StringRef strRef(strWithNull.data(), strWithNull.size());
336 auto as1PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
340 if (
auto existingGlobal = dyn_cast<LLVM::GlobalOp>(&op)) {
341 if (!existingGlobal.getSection() ||
342 *existingGlobal.getSection() !=
"llvm.metadata")
345 dyn_cast_or_null<StringAttr>(existingGlobal.getValueOrNull())) {
346 if (strAttr.getValue() == strRef) {
347 return LLVM::AddressOfOp::create(rewriter, loc, as1PtrTy,
348 existingGlobal.getSymName());
355 auto i8Type = rewriter.getI8Type();
356 auto arrayType = LLVM::LLVMArrayType::get(i8Type, strWithNull.size());
357 std::string globalName =
358 llvm::formatv(
"{0}.{1}", nameHint,
359 globalNameCounter.fetch_add(1, std::memory_order_relaxed))
364 rewriter.setInsertionPointToStart(&moduleOp->
getRegion(0).
front());
367 LLVM::GlobalOp::create(rewriter, loc, arrayType,
368 true, LLVM::Linkage::Private,
369 globalName, rewriter.getStringAttr(strRef));
370 globalOp.setSection(StringRef(
"llvm.metadata"));
371 globalOp.setUnnamedAddr(LLVM::UnnamedAddr::Global);
372 globalOp.setAlignment(1);
373 globalOp.setAddrSpace(1);
377 return LLVM::AddressOfOp::create(rewriter, loc, as1PtrTy, globalName);
400static Value annotatePtrWithCacheControl(ConversionPatternRewriter &rewriter,
405 buildCacheControlPayloads(cacheControls.getValue());
406 if (payloads.empty())
409 auto ptrType = cast<LLVM::LLVMPointerType>(
ptr.getType());
410 auto as1PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
411 auto i32Ty = rewriter.getI32Type();
415 createMetadataStringPtr(rewriter, moduleOp, loc,
"",
".str.file");
416 Value lineVal = LLVM::ConstantOp::create(rewriter, loc, i32Ty, 0);
417 Value nullAS1 = LLVM::ZeroOp::create(rewriter, loc, as1PtrTy);
422 for (
const std::string &payload : payloads) {
423 Value annStr = createMetadataStringPtr(rewriter, moduleOp, loc, payload,
424 ".str.cachecontrol");
425 auto annOp = LLVM::PtrAnnotation::create(rewriter, loc, ptrType, curPtr,
426 annStr, fileStr, lineVal, nullAS1);
427 curPtr = annOp.getResult();
449template <
typename OpType>
451applyCacheControlAnnotation(ConversionPatternRewriter &rewriter,
Location loc,
453 Operation *moduleOp,
unsigned ptrIdx = 0) {
454 std::optional<ArrayAttr> optCacheControls =
455 getCacheControlMetadata(rewriter, op);
456 if (!optCacheControls)
459 Value annotatedPtr = annotatePtrWithCacheControl(rewriter, loc, args[ptrIdx],
460 *optCacheControls, moduleOp);
461 args[ptrIdx] = annotatedPtr;
468static LLVM::CallOp createDeviceFunctionCall(
469 ConversionPatternRewriter &rewriter, StringRef funcName,
Type retType,
472 LLVMFuncAttributeOptions funcAttributeOptions,
Operation *op) {
474 assert(moduleOp &&
"Expecting module");
479 assert(!
failed(funcOpRes));
480 LLVM::LLVMFuncOp funcOp = funcOpRes.value();
481 funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
482 funcOp.setConvergent(funcAttributeOptions.isConvergent);
483 funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
484 funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
486 if (funcAttributeOptions.memEffectsAttr)
487 funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
489 for (
auto [idx, attrName] : paramAttrs)
490 funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
492 auto callOp = LLVM::CallOp::create(rewriter, loc, funcOp, args);
493 callOp->setAttrs(funcOp->getAttrs());
498class MMAToOCLPattern :
public OpConversionPattern<xevm::MMAOp> {
499 using OpConversionPattern::OpConversionPattern;
501 matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor,
502 ConversionPatternRewriter &rewriter)
const override {
504 return rewriter.notifyMatchFailure(op,
"OCL requires C operand");
506 auto precisionA = op.getTypes().getA();
507 auto precisionB = op.getTypes().getB();
508 auto precisionC = op.getTypes().getC();
509 auto precisionD = op.getTypes().getD();
510 if (precisionC != precisionD) {
511 return rewriter.notifyMatchFailure(op,
"type of C and D need to match");
513 if (precisionC != xevm::ElemType::S32 &&
514 precisionC != xevm::ElemType::F32 &&
515 precisionC != xevm::ElemType::F16 &&
516 precisionC != xevm::ElemType::BF16) {
517 return rewriter.notifyMatchFailure(
518 op,
"type of C and D must be S32, F32, F16 or BF16");
520 if (precisionA == xevm::ElemType::S32 ||
521 precisionA == xevm::ElemType::F32) {
522 return rewriter.notifyMatchFailure(op,
"type of A cannot be S32 or F32");
524 if (precisionB == xevm::ElemType::S32 ||
525 precisionB == xevm::ElemType::F32) {
526 return rewriter.notifyMatchFailure(op,
"type of B cannot be S32 or F32");
528 constexpr uint32_t bitWidthPackedA{16};
529 constexpr uint32_t bitWidthPackedB{32};
530 auto loc = op.getLoc();
532 auto castIfNeeded = [&](Value val, Type packedType) -> Value {
533 VectorType origTy = cast<VectorType>(val.
getType());
534 const uint32_t vecBitSize =
535 origTy.getNumElements() *
536 origTy.getElementType().getIntOrFloatBitWidth();
537 VectorType newTy = VectorType::get(
538 vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
540 val = LLVM::BitcastOp::create(rewriter, loc, newTy, val);
545 Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
546 ? cast<Type>(rewriter.getF32Type())
547 : rewriter.getIntegerType(bitWidthPackedA);
548 a = castIfNeeded(a, packedAType);
551 Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
552 ? cast<Type>(rewriter.getF32Type())
553 : rewriter.getIntegerType(bitWidthPackedB);
554 b = castIfNeeded(
b, packedBType);
557 VectorType cOrigTy = cast<VectorType>(c.
getType());
558 VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
559 assert(cOrigTy == resOrigTy &&
"Accumulator and result type mismatch");
562 cOrigTy.getElementType().isBF16()
563 ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
565 VectorType resTy = cTy;
567 c = LLVM::BitcastOp::create(rewriter, loc, cTy, c);
569 constexpr int32_t systolicDepth{8};
571 llvm::formatv(
"intel_sub_group_{0}_{1}_matrix_mad_k{2}",
572 stringifyElemType(op.getTypes().getA()).str(),
573 stringifyElemType(op.getTypes().getB()).str(),
575 getNumOperandsPerDword(op.getTypes().getA()))
577 SmallVector<Type> argTypes{a.
getType(),
b.getType(), cTy};
578 fnName = mangle(fnName, argTypes);
579 SmallVector<Value> args{a,
b, c};
581 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
582 LLVM::ModRefInfo::NoModRef,
583 LLVM::ModRefInfo::NoModRef,
584 LLVM::ModRefInfo::NoModRef,
585 LLVM::ModRefInfo::NoModRef,
586 LLVM::ModRefInfo::NoModRef,
587 LLVM::ModRefInfo::NoModRef);
588 auto funcAttrs = convergentNoUnwindWillReturnAttrs;
589 funcAttrs.memEffectsAttr = memAttr;
591 createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
592 funcAttrs, op.getOperation())
595 if (resOrigTy != resTy)
596 result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy,
result);
598 rewriter.replaceOp(op,
result);
603 static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
605 case xevm::ElemType::TF32:
607 case xevm::ElemType::BF16:
608 case xevm::ElemType::F16:
610 case xevm::ElemType::U8:
611 case xevm::ElemType::S8:
614 llvm_unreachable(
"unsupported xevm::ElemType");
619class PrefetchToOCLPattern :
public OpConversionPattern<PrefetchOp> {
620 using OpConversionPattern::OpConversionPattern;
622 matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor,
623 ConversionPatternRewriter &rewriter)
const override {
624 auto loc = op.getLoc();
627 const std::string fnName{
"_Z8prefetchPU3AS1Kcm"};
629 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), 1);
630 SmallVector<Value> args{op.getPtr(), one};
633 applyCacheControlAnnotation(rewriter, loc, op, args, moduleOp,
636 SmallVector<Type> argTypes;
637 for (
auto arg : args)
638 argTypes.push_back(arg.getType());
639 auto funcAttr = noUnwindAttrs;
640 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
641 LLVM::ModRefInfo::NoModRef,
642 LLVM::ModRefInfo::Ref,
643 LLVM::ModRefInfo::NoModRef,
644 LLVM::ModRefInfo::NoModRef,
645 LLVM::ModRefInfo::NoModRef,
646 LLVM::ModRefInfo::NoModRef);
647 funcAttr.memEffectsAttr = memAttr;
649 createDeviceFunctionCall(rewriter, fnName,
650 LLVM::LLVMVoidType::get(rewriter.getContext()),
651 argTypes, args, {}, funcAttr, op.getOperation());
652 rewriter.eraseOp(op);
657class MemfenceToOCLPattern :
public OpConversionPattern<MemfenceOp> {
658 using OpConversionPattern::OpConversionPattern;
660 matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor,
661 ConversionPatternRewriter &rewriter)
const override {
662 auto loc = op.getLoc();
663 const std::string fnName{
"atomic_work_item_fence"};
664 int memScope, addrSpace;
665 switch (op.getAddrspace()) {
666 case xevm::AddrSpace::SHARED:
669 case xevm::AddrSpace::GLOBAL:
674 return rewriter.notifyMatchFailure(
675 op,
"Fence only supports global and shared address spaces.");
677 switch (op.getScope()) {
678 case xevm::MemScope::WORKGROUP:
681 case xevm::MemScope::DEVICE:
686 return rewriter.notifyMatchFailure(
687 op,
"Fence only supports workgroup and device memory scopes.");
689 Type i32Type = rewriter.getI32Type();
690 Value acqRel = LLVM::ConstantOp::create(rewriter, loc, i32Type, 4);
691 Value memScopeConst =
692 LLVM::ConstantOp::create(rewriter, loc, i32Type, memScope);
693 Value addrSpaceConst =
694 LLVM::ConstantOp::create(rewriter, loc, i32Type, addrSpace);
695 SmallVector<Value> args{addrSpaceConst, acqRel, memScopeConst};
696 SmallVector<Type> argTypes{3, i32Type};
697 createDeviceFunctionCall(rewriter, mangle(fnName, argTypes),
698 LLVM::LLVMVoidType::get(rewriter.getContext()),
699 argTypes, args, {}, noUnwindAttrs,
701 rewriter.eraseOp(op);
705template <
typename OpType>
706class LoadStorePrefetchToOCLPattern :
public OpConversionPattern<OpType> {
707 using OpConversionPattern<OpType>::OpConversionPattern;
709 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
710 ConversionPatternRewriter &rewriter)
const override {
711 constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp>;
712 constexpr bool isPrefetch = std::is_same_v<OpType, BlockPrefetch2dOp>;
714 auto loc = op.getLoc();
715 auto *moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
717 bool packReg =
false;
718 bool transpose =
false;
719 if constexpr (isLoad) {
720 vecType = op.getRes().getType();
721 packReg = op.getPackRegister();
722 transpose = op.getTranspose();
723 }
else if constexpr (!isPrefetch) {
724 vecType = op.getStoredVal().getType();
727 auto i32Type = rewriter.getI32Type();
729 LLVM::UndefOp::create(rewriter, loc, VectorType::get(2, i32Type));
730 Value zero = LLVM::ConstantOp::create(rewriter, loc, i32Type, 0);
731 Value one = LLVM::ConstantOp::create(rewriter, loc, i32Type, 1);
732 byteCoord = LLVM::InsertElementOp::create(
733 rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
734 byteCoord = LLVM::InsertElementOp::create(
735 rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
736 SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
737 op.getBasePitch(), byteCoord};
740 applyCacheControlAnnotation(rewriter, loc, op, args, moduleOp,
743 SmallVector<Type> retTypes;
745 std::string funcName{
"intel_sub_group_2d_block_"};
746 std::string bitWidthId;
747 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
748 SmallVector<std::pair<unsigned, StringRef>, 4> paramAttrs;
749 if constexpr (isPrefetch) {
750 funcName +=
"prefetch";
751 paramAttrs = {std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName())};
752 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
753 LLVM::ModRefInfo::NoModRef,
754 LLVM::ModRefInfo::Ref,
755 LLVM::ModRefInfo::NoModRef,
756 LLVM::ModRefInfo::NoModRef,
757 LLVM::ModRefInfo::NoModRef,
758 LLVM::ModRefInfo::NoModRef);
759 funcAttr = noUnwindAttrs;
760 funcAttr.memEffectsAttr = memAttr;
762 auto vecElemType = vecType.getElementType();
763 auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth();
764 Value numElems = LLVM::ConstantOp::create(rewriter, loc, i32Type,
765 vecType.getNumElements());
766 auto dstOrSrcPtr = LLVM::AllocaOp::create(
767 rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()),
768 vecElemType, numElems);
769 args.push_back(dstOrSrcPtr);
770 if constexpr (isLoad) {
772 bitWidthId = getTypeMangling(vecElemType,
true);
774 funcName +=
"_transform";
776 funcName +=
"_transpose";
777 spvLoadDstPtr = dstOrSrcPtr;
778 retTypes.push_back(vecType);
780 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
781 std::make_pair(0, LLVM::LLVMDialect::getReadonlyAttrName()),
782 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
783 std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()),
787 bitWidthId = (vecElemBitWidth == 32)
789 : ((vecElemBitWidth == 16) ?
"t" :
"h");
790 LLVM::StoreOp::create(rewriter, loc, op.getStoredVal(), dstOrSrcPtr);
792 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
793 std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()),
794 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
795 std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()),
801 llvm::formatv(
"{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(),
802 op.getTileHeight(), op.getTileWidth(), op.getVBlocks())
804 std::string prefetchCode(
"");
807 funcName = llvm::formatv(
"_Z{0}{1}PU3AS1viiiDv2_i{2}{3}", funcName.size(),
808 funcName, prefetchCode, bitWidthId)
810 SmallVector<Type> argTypes;
811 for (
auto arg : args) {
812 argTypes.push_back(arg.getType());
814 createDeviceFunctionCall(
815 rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
816 argTypes, args, paramAttrs, funcAttr, op.getOperation());
818 if constexpr (isLoad)
820 op, LLVM::LoadOp::create(rewriter, loc, vecType, spvLoadDstPtr));
822 rewriter.eraseOp(op);
827template <
typename OpType>
828class BlockLoadStore1DToOCLPattern :
public OpConversionPattern<OpType> {
829 using OpConversionPattern<OpType>::OpConversionPattern;
831 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
832 ConversionPatternRewriter &rewriter)
const override {
833 constexpr bool isStore = std::is_same_v<OpType, xevm::BlockStoreOp>;
834 auto loc = op.getLoc();
835 auto *moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
840 std::string funcName{
"intel_sub_group_block_"};
843 if constexpr (isStore) {
844 funcName +=
"write_u";
845 valOrResTy = op.getVal().getType();
847 funcName +=
"read_u";
848 valOrResTy = op.getType();
851 VectorType vecTy = dyn_cast<VectorType>(valOrResTy);
852 Type elemType = vecTy ? vecTy.getElementType() : valOrResTy;
853 funcName += getTypeMangling(elemType);
855 funcName += std::to_string(vecTy.getNumElements());
856 SmallVector<Type, 2> argTypes{};
860 SmallVector<bool, 2> isUnsigned{};
864 SmallVector<Value, 2> args{};
865 args.push_back(op.getPtr());
866 argTypes.push_back(op.getPtr().getType());
867 isUnsigned.push_back(
true);
870 applyCacheControlAnnotation(rewriter, loc, op, args, moduleOp,
874 argTypes[0] = args[0].getType();
877 if constexpr (isStore) {
878 args.push_back(op.getVal());
879 argTypes.push_back(op.getVal().getType());
880 isUnsigned.push_back(
true);
881 retType = LLVM::LLVMVoidType::get(rewriter.getContext());
883 retType = valOrResTy;
885 funcName = std::string(
"_Z") + std::to_string(funcName.size()) + funcName +
887 std::to_string(op.getPtr().getType().getAddressSpace());
888 funcName += getTypeMangling(elemType,
true);
889 if constexpr (isStore)
890 funcName += getTypeMangling(valOrResTy,
true);
891 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
894 createDeviceFunctionCall(rewriter, funcName, retType, argTypes, args,
895 {}, funcAttr, op.getOperation());
897 if constexpr (isStore)
898 rewriter.eraseOp(op);
900 rewriter.replaceOp(op, call->getResult(0));
905template <
typename OpType>
906class LLVMLoadStoreToOCLPattern :
public OpConversionPattern<OpType> {
907 using OpConversionPattern<OpType>::OpConversionPattern;
909 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
910 ConversionPatternRewriter &rewriter)
const override {
911 if (!op->hasAttr(
"cache_control"))
914 auto *moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
915 std::optional<ArrayAttr> optCacheControls =
916 getCacheControlMetadata(rewriter, op);
917 if (!optCacheControls) {
918 rewriter.modifyOpInPlace(op, [&]() { op->removeAttr(
"cache_control"); });
923 constexpr bool isStore = std::is_same_v<OpType, LLVM::StoreOp>;
924 unsigned ptrIdx = isStore ? 1 : 0;
925 Value ptr = op->getOperand(ptrIdx);
928 Value annotatedPtr = annotatePtrWithCacheControl(
929 rewriter, op->getLoc(), ptr, *optCacheControls, moduleOp);
932 rewriter.modifyOpInPlace(op, [&]() {
933 op->setOperand(ptrIdx, annotatedPtr);
934 op->removeAttr(
"cache_control");
967static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdXOp) {
968 return {
"get_local_id", 0};
970static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdYOp) {
971 return {
"get_local_id", 1};
973static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdZOp) {
974 return {
"get_local_id", 2};
976static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimXOp) {
977 return {
"get_local_size", 0};
979static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimYOp) {
980 return {
"get_local_size", 1};
982static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimZOp) {
983 return {
"get_local_size", 2};
985static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdXOp) {
986 return {
"get_group_id", 0};
988static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdYOp) {
989 return {
"get_group_id", 1};
991static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdZOp) {
992 return {
"get_group_id", 2};
994static std::pair<StringRef, int64_t> getConfig(xevm::GridDimXOp) {
995 return {
"get_num_groups", 0};
997static std::pair<StringRef, int64_t> getConfig(xevm::GridDimYOp) {
998 return {
"get_num_groups", 1};
1000static std::pair<StringRef, int64_t> getConfig(xevm::GridDimZOp) {
1001 return {
"get_num_groups", 2};
1005template <
typename OpType>
1006class LaunchConfigOpToOCLPattern :
public OpConversionPattern<OpType> {
1007 using OpConversionPattern<OpType>::OpConversionPattern;
1009 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
1010 ConversionPatternRewriter &rewriter)
const override {
1011 Location loc = op->getLoc();
1012 auto [baseName, dim] = getConfig(op);
1013 Type dimTy = rewriter.getI32Type();
1014 Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
1015 static_cast<int64_t
>(dim));
1016 std::string func = mangle(baseName, {dimTy}, {
true});
1017 Type resTy = op.getType();
1019 createDeviceFunctionCall(rewriter, func, resTy, {dimTy}, {dimVal}, {},
1020 noUnwindWillReturnAttrs, op.getOperation());
1021 constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
1022 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
1028 call.setMemoryEffectsAttr(memAttr);
1029 rewriter.replaceOp(op, call);
1046static StringRef getConfig(xevm::LaneIdOp) {
return "get_sub_group_local_id"; }
1047static StringRef getConfig(xevm::SubgroupIdOp) {
return "get_sub_group_id"; }
1048static StringRef getConfig(xevm::SubgroupSizeOp) {
1049 return "get_sub_group_size";
1051template <
typename OpType>
1052class SubgroupOpWorkitemOpToOCLPattern :
public OpConversionPattern<OpType> {
1053 using OpConversionPattern<OpType>::OpConversionPattern;
1055 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
1056 ConversionPatternRewriter &rewriter)
const override {
1057 std::string func = mangle(getConfig(op).str(), {});
1058 Type resTy = op.getType();
1060 createDeviceFunctionCall(rewriter, func, resTy, {}, {}, {},
1061 noUnwindWillReturnAttrs, op.getOperation());
1062 constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
1063 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
1069 call.setMemoryEffectsAttr(memAttr);
1070 rewriter.replaceOp(op, call);
1075class AllocaToGlobalPattern :
public OpConversionPattern<LLVM::AllocaOp> {
1076 using OpConversionPattern::OpConversionPattern;
1078 matchAndRewrite(LLVM::AllocaOp op, LLVM::AllocaOp::Adaptor adaptor,
1079 ConversionPatternRewriter &rewriter)
const override {
1080 auto ptrType = cast<LLVM::LLVMPointerType>(op.getType());
1081 auto addrSpace = ptrType.getAddressSpace();
1084 auto symTable = op->getParentWithTrait<OpTrait::SymbolTable>();
1088 if (ModuleOp mod = dyn_cast<ModuleOp>(*symTable)) {
1089 moduleBody = mod.getBody();
1090 }
else if (gpu::GPUModuleOp gpuMod =
1091 dyn_cast<gpu::GPUModuleOp>(*symTable)) {
1092 moduleBody = gpuMod.getBody();
1096 auto val = op.getArraySize();
1100 auto loc = op.getLoc();
1101 auto globalType = LLVM::LLVMArrayType::get(
1102 rewriter.getContext(), op.getElemType(), cst.getZExtValue());
1103 LLVM::GlobalOp globalVar;
1105 OpBuilder::InsertionGuard guard(rewriter);
1106 rewriter.setInsertionPointToStart(moduleBody);
1107 auto alignment = op.getAlignment();
1108 globalVar = LLVM::GlobalOp::create(
1109 rewriter, loc, globalType,
false,
1110 LLVM::Linkage::Internal,
1111 std::string(
"__global_alloca_") +
1112 std::to_string(getNextGlobalIdx()),
1114 alignment ? *alignment : 0, addrSpace);
1116 rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, globalVar);
1121 static unsigned getNextGlobalIdx() {
1122 static unsigned globalIdx = 0;
1133static bool isExtractingContiguousSlice(LLVM::ShuffleVectorOp op) {
1134 if (op.getV1() != op.getV2())
1136 auto maskAttr = op.getMask();
1138 int64_t sourceSize = op.getV1().getType().getNumElements();
1139 if (maskSize > sourceSize)
1141 int64_t firstIndex = maskAttr[0];
1142 for (
int64_t i = 1; i < maskSize; ++i) {
1144 if (
index != firstIndex + i)
1146 if (
index >= sourceSize)
1160class HandleVectorExtractPattern
1162 using OpRewritePattern<LLVM::ShuffleVectorOp>::OpRewritePattern;
1164 void initialize() { setHasBoundedRewriteRecursion(); }
1166 LogicalResult matchAndRewrite(LLVM::ShuffleVectorOp op,
1167 PatternRewriter &rewriter)
const override {
1169 if (!isExtractingContiguousSlice(op))
1172 auto mask = op.getMask();
1173 auto loc = op.getLoc();
1174 auto ty = op.getType();
1176 auto src = op.getV1();
1178 if (
auto srcOp = src.getDefiningOp()) {
1179 if (isa<LLVM::FPExtOp>(srcOp) || isa<LLVM::FPTruncOp>(srcOp)) {
1180 Value srcInput = srcOp->getOperand(0);
1182 auto srcVecTy = dyn_cast<VectorType>(srcInput.
getType());
1183 auto newShuffleVecTy =
1184 VectorType::get(mask.size(), srcVecTy.getElementType());
1185 auto newShuffle = LLVM::ShuffleVectorOp::create(
1186 rewriter, loc, newShuffleVecTy, srcInput, srcInput, mask);
1189 if (isa<LLVM::FPExtOp>(srcOp)) {
1190 newUnaryOp = LLVM::FPExtOp::create(rewriter, loc, ty, newShuffle);
1192 newUnaryOp = LLVM::FPTruncOp::create(rewriter, loc, ty, newShuffle);
1195 }
else if (isa<LLVM::BitcastOp>(srcOp)) {
1196 Value srcInput = srcOp->getOperand(0);
1198 auto srcInputVecTy = dyn_cast<VectorType>(srcInput.
getType());
1199 auto srcInputSize = srcInputVecTy.getNumElements();
1200 auto srcResVecTy = dyn_cast<VectorType>(srcOp->getResult(0).getType());
1201 auto srcResSize = srcResVecTy.getNumElements();
1202 auto maskSize =
static_cast<int32_t
>(mask.size());
1203 if (srcInputSize > srcResSize) {
1206 if (srcResSize % srcInputSize != 0) {
1209 auto maskScale = srcResSize / srcInputSize;
1210 if (maskScale != 1) {
1211 if (mask[0] % maskScale != 0) {
1215 SmallVector<int32_t> newMask;
1216 int32_t newMaskSize = maskSize / maskScale;
1217 int32_t maskStart = mask[0] / maskScale;
1218 for (int32_t i = 0; i < newMaskSize; ++i) {
1219 newMask.push_back(maskStart + i);
1223 auto newShuffleVecTy =
1224 VectorType::get(srcInputSize, srcInputVecTy.getElementType());
1225 auto newShuffle = LLVM::ShuffleVectorOp::create(
1226 rewriter, loc, newShuffleVecTy, srcInput, srcInput, mask);
1229 LLVM::BitcastOp::create(rewriter, loc, ty, newShuffle);
1231 }
else if (isa<LLVM::ShuffleVectorOp>(srcOp)) {
1236 auto srcShuffle = cast<LLVM::ShuffleVectorOp>(srcOp);
1237 if (!isExtractingContiguousSlice(srcShuffle))
1239 auto srcMask = srcShuffle.getMask();
1240 SmallVector<int32_t> combinedMask;
1241 for (
auto index : mask) {
1242 combinedMask.push_back(srcMask[index]);
1244 auto newShuffle = LLVM::ShuffleVectorOp::create(
1245 rewriter, loc, ty, srcShuffle.getV1(), srcShuffle.getV1(),
1248 }
else if (isa<LLVM::LoadOp>(srcOp)) {
1250 auto loadOp = cast<LLVM::LoadOp>(srcOp);
1251 auto loadPtr = loadOp.getAddr();
1252 auto loadTy = dyn_cast<VectorType>(loadOp.getType());
1253 auto elemTy = loadTy.getElementType();
1254 auto firstIndex = mask[0];
1255 auto newVecTy = VectorType::get(mask.size(), elemTy);
1258 auto newPtr = LLVM::GEPOp::create(
1260 LLVM::LLVMPointerType::get(rewriter.
getContext(),
1261 loadPtr.getType().getAddressSpace()),
1262 elemTy, loadPtr, ArrayRef<LLVM::GEPArg>{firstIndex});
1263 auto newLoad = LLVM::LoadOp::create(rewriter, loc, newVecTy, newPtr);
1266 auto newLoad = LLVM::LoadOp::create(rewriter, loc, newVecTy, loadPtr);
1281struct ConvertXeVMToLLVMPass
1282 :
public impl::ConvertXeVMToLLVMPassBase<ConvertXeVMToLLVMPass> {
1285 void getDependentDialects(DialectRegistry ®istry)
const override {
1286 registry.
insert<LLVM::LLVMDialect, XeVMDialect>();
1289 void runOnOperation()
override {
1293 if (
failed(applyPartialConversion(getOperation(),
target,
1294 std::move(patterns))))
1295 signalPassFailure();
1299 RewritePatternSet vectorPatterns(&
getContext());
1300 vectorPatterns.add<HandleVectorExtractPattern>(&
getContext());
1301 GreedyRewriteConfig config{};
1306 config.enableFolding(
false);
1323 target.addDynamicallyLegalDialect<LLVM::LLVMDialect>([](
Operation *op) {
1327 if (isa<LLVM::AllocaOp>(op)) {
1328 LLVM::AllocaOp aOp = cast<LLVM::AllocaOp>(op);
1329 LLVM::LLVMPointerType pTy = cast<LLVM::LLVMPointerType>(aOp.getType());
1330 auto addrSpace = pTy.getAddressSpace();
1331 return addrSpace != 3;
1334 return !op->hasAttr(
"cache_control");
1336 target.addIllegalDialect<XeVMDialect>();
1337 patterns.
add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
1338 LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
1339 LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
1340 MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern,
1341 LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
1342 LLVMLoadStoreToOCLPattern<LLVM::StoreOp>,
1343 BlockLoadStore1DToOCLPattern<BlockLoadOp>,
1344 BlockLoadStore1DToOCLPattern<BlockStoreOp>,
1345 LaunchConfigOpToOCLPattern<WorkitemIdXOp>,
1346 LaunchConfigOpToOCLPattern<WorkitemIdYOp>,
1347 LaunchConfigOpToOCLPattern<WorkitemIdZOp>,
1348 LaunchConfigOpToOCLPattern<WorkgroupDimXOp>,
1349 LaunchConfigOpToOCLPattern<WorkgroupDimYOp>,
1350 LaunchConfigOpToOCLPattern<WorkgroupDimZOp>,
1351 LaunchConfigOpToOCLPattern<WorkgroupIdXOp>,
1352 LaunchConfigOpToOCLPattern<WorkgroupIdYOp>,
1353 LaunchConfigOpToOCLPattern<WorkgroupIdZOp>,
1354 LaunchConfigOpToOCLPattern<GridDimXOp>,
1355 LaunchConfigOpToOCLPattern<GridDimYOp>,
1356 LaunchConfigOpToOCLPattern<GridDimZOp>,
1357 SubgroupOpWorkitemOpToOCLPattern<LaneIdOp>,
1358 SubgroupOpWorkitemOpToOCLPattern<SubgroupIdOp>,
1359 SubgroupOpWorkitemOpToOCLPattern<SubgroupSizeOp>,
1360 AllocaToGlobalPattern>(patterns.
getContext());
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
Attributes are known-constant values of operations.
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes={}, Type resultType={}, bool isVarArg=false, bool isReserved=false, SymbolTableCollection *symbolTables=nullptr)
Create a FuncOp with signature resultType(paramTypes) and name name`.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateXeVMToLLVMConversionPatterns(ConversionTarget &target, RewritePatternSet &patterns)
llvm::TypeSwitch< T, ResultT > TypeSwitch
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...