18#include "llvm/ADT/ArrayRef.h"
19#include "llvm/Support/FormatVariadic.h"
26#include "llvm/ADT/TypeSwitch.h"
29#define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
30#include "mlir/Conversion/Passes.h.inc"
38struct LLVMFuncAttributeOptions {
39 bool isConvergent =
false;
40 bool isNoUnwind =
false;
41 bool isWillReturn =
false;
42 LLVM::MemoryEffectsAttr memEffectsAttr{};
44static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
45 false,
true,
false, {}};
46static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
47 false,
true,
true, {}};
48static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
49 true,
true,
true, {}};
51std::string getTypeMangling(
Type ty,
bool isUnsigned =
false) {
53 .Case([isUnsigned](VectorType ty) -> std::string {
54 return "Dv" + std::to_string(ty.getNumElements()) +
"_" +
55 getTypeMangling(ty.getElementType(), isUnsigned);
57 .Case([](Float16Type) -> std::string {
return "Dh"; })
58 .Case([](Float32Type) -> std::string {
return "f"; })
59 .Case([](Float64Type) -> std::string {
return "d"; })
60 .Case([isUnsigned](IntegerType ty) -> std::string {
61 switch (ty.getWidth()) {
63 return isUnsigned ?
"h" :
"c";
65 return isUnsigned ?
"t" :
"s";
67 return isUnsigned ?
"j" :
"i";
69 return isUnsigned ?
"m" :
"l";
71 llvm_unreachable(
"unhandled integer type");
74 .DefaultUnreachable(
"unhandled type for mangling");
79 assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
80 "Signedness info doesn't match");
82 llvm::raw_string_ostream os(s);
83 llvm::SmallDenseMap<Type, unsigned> substitutions;
84 os <<
"_Z" << baseName.size() << baseName;
85 for (
auto [idx, type] : llvm::enumerate(types)) {
86 auto it = substitutions.find(type);
87 if (it != substitutions.end()) {
90 if (
unsigned firstIdx = it->getSecond(); firstIdx > 0)
94 if (!type.isIntOrFloat())
95 substitutions[type] = substitutions.size();
96 os << getTypeMangling(type, isUnsigned.empty() ?
false : isUnsigned[idx]);
102static int32_t getL1CacheControl(LoadCacheControl cc) {
105 case LoadCacheControl::USE_DEFAULT:
108 case LoadCacheControl::L1C_L2UC_L3UC:
109 case LoadCacheControl::L1C_L2UC_L3C:
110 case LoadCacheControl::L1C_L2C_L3UC:
111 case LoadCacheControl::L1C_L2C_L3C:
114 case LoadCacheControl::L1S_L2UC_L3UC:
115 case LoadCacheControl::L1S_L2UC_L3C:
116 case LoadCacheControl::L1S_L2C_L3UC:
117 case LoadCacheControl::L1S_L2C_L3C:
120 case LoadCacheControl::INVALIDATE_READ:
129static int32_t getL1CacheControl(StoreCacheControl cc) {
132 case StoreCacheControl::USE_DEFAULT:
135 case StoreCacheControl::L1WT_L2UC_L3UC:
136 case StoreCacheControl::L1WT_L2UC_L3WB:
137 case StoreCacheControl::L1WT_L2WB_L3UC:
138 case StoreCacheControl::L1WT_L2WB_L3WB:
141 case StoreCacheControl::L1WB_L2UC_L3UC:
142 case StoreCacheControl::L1WB_L2WB_L3UC:
143 case StoreCacheControl::L1WB_L2UC_L3WB:
146 case StoreCacheControl::L1S_L2UC_L3UC:
147 case StoreCacheControl::L1S_L2UC_L3WB:
148 case StoreCacheControl::L1S_L2WB_L3UC:
149 case StoreCacheControl::L1S_L2WB_L3WB:
158static int32_t getL3CacheControl(LoadCacheControl cc) {
161 case LoadCacheControl::USE_DEFAULT:
164 case LoadCacheControl::L1UC_L2UC_L3C:
165 case LoadCacheControl::L1UC_L2C_L3C:
166 case LoadCacheControl::L1C_L2UC_L3C:
167 case LoadCacheControl::L1C_L2C_L3C:
168 case LoadCacheControl::L1S_L2UC_L3C:
169 case LoadCacheControl::L1S_L2C_L3C:
172 case LoadCacheControl::INVALIDATE_READ:
181static int32_t getL3CacheControl(StoreCacheControl cc) {
184 case StoreCacheControl::USE_DEFAULT:
187 case StoreCacheControl::L1UC_L2UC_L3WB:
188 case StoreCacheControl::L1UC_L2WB_L3WB:
189 case StoreCacheControl::L1WT_L2UC_L3WB:
190 case StoreCacheControl::L1WT_L2WB_L3WB:
191 case StoreCacheControl::L1S_L2UC_L3WB:
192 case StoreCacheControl::L1S_L2WB_L3WB:
193 case StoreCacheControl::L1WB_L2UC_L3WB:
202static std::optional<LoadCacheControl> getCacheControl(PrefetchOp op) {
203 return op.getCacheControl();
206static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) {
207 return op.getCacheControl();
210static std::optional<LoadCacheControl> getCacheControl(BlockLoadOp op) {
211 return op.getCacheControl();
214static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) {
215 return op.getCacheControl();
218static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) {
219 return op.getCacheControl();
222static std::optional<StoreCacheControl> getCacheControl(BlockStoreOp op) {
223 return op.getCacheControl();
226static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) {
227 if (op->hasAttr(
"cache_control")) {
228 auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>(
"cache_control");
231 return std::optional<LoadCacheControl>(attr.getValue());
236static std::optional<StoreCacheControl> getCacheControl(LLVM::StoreOp op) {
237 if (op->hasAttr(
"cache_control")) {
238 auto attr = op->getAttrOfType<xevm::StoreCacheControlAttr>(
"cache_control");
241 return std::optional<StoreCacheControl>(attr.getValue());
246template <
typename OpType>
247int32_t getL1CacheControl(OpType op) {
248 return getL1CacheControl(*getCacheControl(op));
251template <
typename OpType>
252int32_t getL3CacheControl(OpType op) {
253 return getL3CacheControl(*getCacheControl(op));
256template <
typename OpType>
257static std::optional<ArrayAttr>
258getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
259 if (!getCacheControl(op))
262 constexpr int32_t decorationCacheControlArity{3};
263 constexpr int32_t loadCacheControlKey{6442};
264 constexpr int32_t storeCacheControlKey{6443};
265 constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
266 std::is_same_v<OpType, BlockPrefetch2dOp> ||
267 std::is_same_v<OpType, LLVM::LoadOp> ||
268 std::is_same_v<OpType, BlockLoadOp> ||
269 std::is_same_v<OpType, PrefetchOp>;
275 assert(((getL1CacheControl<OpType>(op) == -1) ==
276 (getL3CacheControl<OpType>(op) == -1)) &&
277 "If one of L1 or L3 cache control is USE_DEFAULT, both must be "
280 if (getL1CacheControl<OpType>(op) == -1 &&
281 getL3CacheControl<OpType>(op) == -1)
283 const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
285 controlKey, 0, getL1CacheControl<OpType>(op)};
287 controlKey, 1, getL3CacheControl<OpType>(op)};
288 auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
289 auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
292 return rewriter.getArrayAttr(combinedAttrs);
314 llvm::StringMap<bool> seen;
317 auto arr = dyn_cast<ArrayAttr>(a);
321 auto vals = arr.getValue();
322 assert(vals.size() == 3 &&
323 "Expected exactly 3 integer values (Token, CacheLevel, "
324 "ControlValue) in cache control attribute.");
326 auto tokenAttr = dyn_cast<IntegerAttr>(vals[0]);
327 auto secondAttr = dyn_cast<IntegerAttr>(vals[1]);
328 auto thirdAttr = dyn_cast<IntegerAttr>(vals[2]);
330 if (!tokenAttr || !secondAttr || !thirdAttr)
336 llvm::formatv(
"{{{0}:\"{1},{2}\"}", tokenAttr.getValue().getZExtValue(),
337 secondAttr.getValue().getZExtValue(),
338 thirdAttr.getValue().getZExtValue());
341 if (!seen.insert({entry, true}).second)
344 payloads.push_back(std::move(entry));
349static std::atomic<uint64_t> globalNameCounter{0};
354static Value createMetadataStringPtr(ConversionPatternRewriter &rewriter,
356 StringRef value, StringRef nameHint) {
358 std::string strWithNull = value.str();
359 strWithNull.push_back(
'\0');
360 StringRef strRef(strWithNull.data(), strWithNull.size());
362 auto as1PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
366 if (
auto existingGlobal = dyn_cast<LLVM::GlobalOp>(&op)) {
367 if (!existingGlobal.getSection() ||
368 *existingGlobal.getSection() !=
"llvm.metadata")
371 dyn_cast_or_null<StringAttr>(existingGlobal.getValueOrNull())) {
372 if (strAttr.getValue() == strRef) {
373 return LLVM::AddressOfOp::create(rewriter, loc, as1PtrTy,
374 existingGlobal.getSymName());
381 auto i8Type = rewriter.getI8Type();
382 auto arrayType = LLVM::LLVMArrayType::get(i8Type, strWithNull.size());
383 std::string globalName =
384 llvm::formatv(
"{0}.{1}", nameHint,
385 globalNameCounter.fetch_add(1, std::memory_order_relaxed))
390 rewriter.setInsertionPointToStart(&moduleOp->
getRegion(0).
front());
393 LLVM::GlobalOp::create(rewriter, loc, arrayType,
394 true, LLVM::Linkage::Private,
395 globalName, rewriter.getStringAttr(strRef));
396 globalOp.setSection(StringRef(
"llvm.metadata"));
397 globalOp.setUnnamedAddr(LLVM::UnnamedAddr::Global);
398 globalOp.setAlignment(1);
399 globalOp.setAddrSpace(1);
403 return LLVM::AddressOfOp::create(rewriter, loc, as1PtrTy, globalName);
426static Value annotatePtrWithCacheControl(ConversionPatternRewriter &rewriter,
431 buildCacheControlPayloads(cacheControls.getValue());
432 if (payloads.empty())
435 auto ptrType = cast<LLVM::LLVMPointerType>(
ptr.getType());
436 auto as1PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
437 auto i32Ty = rewriter.getI32Type();
441 createMetadataStringPtr(rewriter, moduleOp, loc,
"",
".str.file");
442 Value lineVal = LLVM::ConstantOp::create(rewriter, loc, i32Ty, 0);
443 Value nullAS1 = LLVM::ZeroOp::create(rewriter, loc, as1PtrTy);
448 for (
const std::string &payload : payloads) {
449 Value annStr = createMetadataStringPtr(rewriter, moduleOp, loc, payload,
450 ".str.cachecontrol");
451 auto annOp = LLVM::PtrAnnotation::create(rewriter, loc, ptrType, curPtr,
452 annStr, fileStr, lineVal, nullAS1);
453 curPtr = annOp.getResult();
475template <
typename OpType>
477applyCacheControlAnnotation(ConversionPatternRewriter &rewriter,
Location loc,
479 Operation *moduleOp,
unsigned ptrIdx = 0) {
480 std::optional<ArrayAttr> optCacheControls =
481 getCacheControlMetadata(rewriter, op);
482 if (!optCacheControls)
485 Value annotatedPtr = annotatePtrWithCacheControl(rewriter, loc, args[ptrIdx],
486 *optCacheControls, moduleOp);
487 args[ptrIdx] = annotatedPtr;
494static LLVM::CallOp createDeviceFunctionCall(
495 ConversionPatternRewriter &rewriter, StringRef funcName,
Type retType,
498 LLVMFuncAttributeOptions funcAttributeOptions,
Operation *op) {
500 assert(moduleOp &&
"Expecting module");
505 assert(!
failed(funcOpRes));
506 LLVM::LLVMFuncOp funcOp = funcOpRes.value();
507 funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
508 funcOp.setConvergent(funcAttributeOptions.isConvergent);
509 funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
510 funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
512 if (funcAttributeOptions.memEffectsAttr)
513 funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
515 for (
auto [idx, attrName] : paramAttrs)
516 funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
518 auto callOp = LLVM::CallOp::create(rewriter, loc, funcOp, args);
519 callOp->setAttrs(funcOp->getAttrs());
524static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
526 case xevm::ElemType::F32:
527 case xevm::ElemType::TF32:
529 case xevm::ElemType::BF16:
530 case xevm::ElemType::F16:
532 case xevm::ElemType::U8:
533 case xevm::ElemType::S8:
534 case xevm::ElemType::BF8:
535 case xevm::ElemType::F8:
537 case xevm::ElemType::E2M1:
538 case xevm::ElemType::U4:
539 case xevm::ElemType::S4:
542 llvm_unreachable(
"unsupported xevm::ElemType");
546class MMAToOCLPattern :
public OpConversionPattern<xevm::MMAOp> {
547 using OpConversionPattern::OpConversionPattern;
549 matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor,
550 ConversionPatternRewriter &rewriter)
const override {
552 return rewriter.notifyMatchFailure(op,
"OCL requires C operand");
554 auto precisionA = op.getTypes().getA();
555 auto precisionB = op.getTypes().getB();
556 auto precisionC = op.getTypes().getC();
557 auto precisionD = op.getTypes().getD();
558 if (precisionC != precisionD) {
559 return rewriter.notifyMatchFailure(op,
"type of C and D need to match");
561 if (precisionC != xevm::ElemType::S32 &&
562 precisionC != xevm::ElemType::F32 &&
563 precisionC != xevm::ElemType::F16 &&
564 precisionC != xevm::ElemType::BF16) {
565 return rewriter.notifyMatchFailure(
566 op,
"type of C and D must be S32, F32, F16 or BF16");
568 if (precisionA == xevm::ElemType::S32 ||
569 precisionA == xevm::ElemType::F32) {
570 return rewriter.notifyMatchFailure(op,
"type of A cannot be S32 or F32");
572 if (precisionB == xevm::ElemType::S32 ||
573 precisionB == xevm::ElemType::F32) {
574 return rewriter.notifyMatchFailure(op,
"type of B cannot be S32 or F32");
576 constexpr uint32_t bitWidthPackedA{16};
577 constexpr uint32_t bitWidthPackedB{32};
578 auto loc = op.getLoc();
580 auto castIfNeeded = [&](Value val, Type packedType) -> Value {
581 VectorType origTy = cast<VectorType>(val.
getType());
582 const uint32_t vecBitSize =
583 origTy.getNumElements() *
584 origTy.getElementType().getIntOrFloatBitWidth();
585 VectorType newTy = VectorType::get(
586 vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
588 val = LLVM::BitcastOp::create(rewriter, loc, newTy, val);
593 Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
594 ? cast<Type>(rewriter.getF32Type())
595 : rewriter.getIntegerType(bitWidthPackedA);
596 a = castIfNeeded(a, packedAType);
599 Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
600 ? cast<Type>(rewriter.getF32Type())
601 : rewriter.getIntegerType(bitWidthPackedB);
602 b = castIfNeeded(
b, packedBType);
605 VectorType cOrigTy = cast<VectorType>(c.
getType());
606 VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
607 assert(cOrigTy == resOrigTy &&
"Accumulator and result type mismatch");
610 cOrigTy.getElementType().isBF16()
611 ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
613 VectorType resTy = cTy;
615 c = LLVM::BitcastOp::create(rewriter, loc, cTy, c);
617 constexpr int32_t systolicDepth{8};
619 llvm::formatv(
"intel_sub_group_{0}_{1}_matrix_mad_k{2}",
620 stringifyElemType(op.getTypes().getA()).str(),
621 stringifyElemType(op.getTypes().getB()).str(),
623 getNumOperandsPerDword(op.getTypes().getA()))
625 SmallVector<Type> argTypes{a.
getType(),
b.getType(), cTy};
626 fnName = mangle(fnName, argTypes);
627 SmallVector<Value> args{a,
b, c};
629 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
630 LLVM::ModRefInfo::NoModRef,
631 LLVM::ModRefInfo::NoModRef,
632 LLVM::ModRefInfo::NoModRef,
633 LLVM::ModRefInfo::NoModRef,
634 LLVM::ModRefInfo::NoModRef,
635 LLVM::ModRefInfo::NoModRef);
636 auto funcAttrs = convergentNoUnwindWillReturnAttrs;
637 funcAttrs.memEffectsAttr = memAttr;
639 createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
640 funcAttrs, op.getOperation())
643 if (resOrigTy != resTy)
644 result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy,
result);
646 rewriter.replaceOp(op,
result);
651class PrefetchToOCLPattern :
public OpConversionPattern<PrefetchOp> {
652 using OpConversionPattern::OpConversionPattern;
654 matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor,
655 ConversionPatternRewriter &rewriter)
const override {
656 auto loc = op.getLoc();
659 const std::string fnName{
"_Z8prefetchPU3AS1Kcm"};
661 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), 1);
662 SmallVector<Value> args{op.getPtr(), one};
665 applyCacheControlAnnotation(rewriter, loc, op, args, moduleOp,
668 SmallVector<Type> argTypes;
669 for (
auto arg : args)
670 argTypes.push_back(arg.getType());
671 auto funcAttr = noUnwindAttrs;
672 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
673 LLVM::ModRefInfo::NoModRef,
674 LLVM::ModRefInfo::Ref,
675 LLVM::ModRefInfo::NoModRef,
676 LLVM::ModRefInfo::NoModRef,
677 LLVM::ModRefInfo::NoModRef,
678 LLVM::ModRefInfo::NoModRef);
679 funcAttr.memEffectsAttr = memAttr;
681 createDeviceFunctionCall(rewriter, fnName,
682 LLVM::LLVMVoidType::get(rewriter.getContext()),
683 argTypes, args, {}, funcAttr, op.getOperation());
684 rewriter.eraseOp(op);
689class MemfenceToOCLPattern :
public OpConversionPattern<MemfenceOp> {
690 using OpConversionPattern::OpConversionPattern;
692 matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor,
693 ConversionPatternRewriter &rewriter)
const override {
694 auto loc = op.getLoc();
695 const std::string fnName{
"atomic_work_item_fence"};
696 int memScope, addrSpace;
697 switch (op.getAddrspace()) {
698 case xevm::AddrSpace::SHARED:
701 case xevm::AddrSpace::GLOBAL:
706 return rewriter.notifyMatchFailure(
707 op,
"Fence only supports global and shared address spaces.");
709 switch (op.getScope()) {
710 case xevm::MemScope::WORKGROUP:
713 case xevm::MemScope::DEVICE:
718 return rewriter.notifyMatchFailure(
719 op,
"Fence only supports workgroup and device memory scopes.");
721 Type i32Type = rewriter.getI32Type();
722 Value acqRel = LLVM::ConstantOp::create(rewriter, loc, i32Type, 4);
723 Value memScopeConst =
724 LLVM::ConstantOp::create(rewriter, loc, i32Type, memScope);
725 Value addrSpaceConst =
726 LLVM::ConstantOp::create(rewriter, loc, i32Type, addrSpace);
727 SmallVector<Value> args{addrSpaceConst, acqRel, memScopeConst};
728 SmallVector<Type> argTypes{3, i32Type};
729 createDeviceFunctionCall(rewriter, mangle(fnName, argTypes),
730 LLVM::LLVMVoidType::get(rewriter.getContext()),
731 argTypes, args, {}, noUnwindAttrs,
733 rewriter.eraseOp(op);
737template <
typename OpType>
738class LoadStorePrefetchToOCLPattern :
public OpConversionPattern<OpType> {
739 using OpConversionPattern<OpType>::OpConversionPattern;
741 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
742 ConversionPatternRewriter &rewriter)
const override {
743 constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp>;
744 constexpr bool isPrefetch = std::is_same_v<OpType, BlockPrefetch2dOp>;
746 auto loc = op.getLoc();
747 auto *moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
749 bool packReg =
false;
750 bool transpose =
false;
751 if constexpr (isLoad) {
752 vecType = op.getRes().getType();
753 packReg = op.getPackRegister();
754 transpose = op.getTranspose();
755 }
else if constexpr (!isPrefetch) {
756 vecType = op.getStoredVal().getType();
759 auto i32Type = rewriter.getI32Type();
761 LLVM::UndefOp::create(rewriter, loc, VectorType::get(2, i32Type));
762 Value zero = LLVM::ConstantOp::create(rewriter, loc, i32Type, 0);
763 Value one = LLVM::ConstantOp::create(rewriter, loc, i32Type, 1);
764 byteCoord = LLVM::InsertElementOp::create(
765 rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
766 byteCoord = LLVM::InsertElementOp::create(
767 rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
768 SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
769 op.getBasePitch(), byteCoord};
772 applyCacheControlAnnotation(rewriter, loc, op, args, moduleOp,
775 SmallVector<Type> retTypes;
777 std::string funcName{
"intel_sub_group_2d_block_"};
778 std::string bitWidthId;
779 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
780 SmallVector<std::pair<unsigned, StringRef>, 4> paramAttrs;
781 if constexpr (isPrefetch) {
782 funcName +=
"prefetch";
783 paramAttrs = {std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName())};
784 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
785 LLVM::ModRefInfo::NoModRef,
786 LLVM::ModRefInfo::Ref,
787 LLVM::ModRefInfo::NoModRef,
788 LLVM::ModRefInfo::NoModRef,
789 LLVM::ModRefInfo::NoModRef,
790 LLVM::ModRefInfo::NoModRef);
791 funcAttr = noUnwindAttrs;
792 funcAttr.memEffectsAttr = memAttr;
794 auto vecElemType = vecType.getElementType();
795 auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth();
796 Value numElems = LLVM::ConstantOp::create(rewriter, loc, i32Type,
797 vecType.getNumElements());
798 auto dstOrSrcPtr = LLVM::AllocaOp::create(
799 rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()),
800 vecElemType, numElems);
801 args.push_back(dstOrSrcPtr);
802 if constexpr (isLoad) {
804 bitWidthId = getTypeMangling(vecElemType,
true);
806 funcName +=
"_transform";
808 funcName +=
"_transpose";
809 spvLoadDstPtr = dstOrSrcPtr;
810 retTypes.push_back(vecType);
812 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
813 std::make_pair(0, LLVM::LLVMDialect::getReadonlyAttrName()),
814 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
815 std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()),
819 bitWidthId = (vecElemBitWidth == 32)
821 : ((vecElemBitWidth == 16) ?
"t" :
"h");
822 LLVM::StoreOp::create(rewriter, loc, op.getStoredVal(), dstOrSrcPtr);
824 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
825 std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()),
826 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
827 std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()),
833 llvm::formatv(
"{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(),
834 op.getTileHeight(), op.getTileWidth(), op.getVBlocks())
836 std::string prefetchCode(
"");
839 funcName = llvm::formatv(
"_Z{0}{1}PU3AS1viiiDv2_i{2}{3}", funcName.size(),
840 funcName, prefetchCode, bitWidthId)
842 SmallVector<Type> argTypes;
843 for (
auto arg : args) {
844 argTypes.push_back(arg.getType());
846 createDeviceFunctionCall(
847 rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
848 argTypes, args, paramAttrs, funcAttr, op.getOperation());
850 if constexpr (isLoad)
852 op, LLVM::LoadOp::create(rewriter, loc, vecType, spvLoadDstPtr));
854 rewriter.eraseOp(op);
859template <
typename OpType>
860class BlockLoadStore1DToOCLPattern :
public OpConversionPattern<OpType> {
861 using OpConversionPattern<OpType>::OpConversionPattern;
863 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
864 ConversionPatternRewriter &rewriter)
const override {
865 constexpr bool isStore = std::is_same_v<OpType, xevm::BlockStoreOp>;
866 auto loc = op.getLoc();
867 auto *moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
872 std::string funcName{
"intel_sub_group_block_"};
875 if constexpr (isStore) {
876 funcName +=
"write_u";
877 valOrResTy = op.getVal().getType();
879 funcName +=
"read_u";
880 valOrResTy = op.getType();
883 VectorType vecTy = dyn_cast<VectorType>(valOrResTy);
884 Type elemType = vecTy ? vecTy.getElementType() : valOrResTy;
885 funcName += getTypeMangling(elemType);
887 funcName += std::to_string(vecTy.getNumElements());
888 SmallVector<Type, 2> argTypes{};
892 SmallVector<bool, 2> isUnsigned{};
896 SmallVector<Value, 2> args{};
897 args.push_back(op.getPtr());
898 argTypes.push_back(op.getPtr().getType());
899 isUnsigned.push_back(
true);
902 applyCacheControlAnnotation(rewriter, loc, op, args, moduleOp,
906 argTypes[0] = args[0].getType();
909 if constexpr (isStore) {
910 args.push_back(op.getVal());
911 argTypes.push_back(op.getVal().getType());
912 isUnsigned.push_back(
true);
913 retType = LLVM::LLVMVoidType::get(rewriter.getContext());
915 retType = valOrResTy;
917 funcName = std::string(
"_Z") + std::to_string(funcName.size()) + funcName +
919 std::to_string(op.getPtr().getType().getAddressSpace());
920 funcName += getTypeMangling(elemType,
true);
921 if constexpr (isStore)
922 funcName += getTypeMangling(valOrResTy,
true);
923 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
926 createDeviceFunctionCall(rewriter, funcName, retType, argTypes, args,
927 {}, funcAttr, op.getOperation());
929 if constexpr (isStore)
930 rewriter.eraseOp(op);
932 rewriter.replaceOp(op, call->getResult(0));
937template <
typename OpType>
938class LLVMLoadStoreToOCLPattern :
public OpConversionPattern<OpType> {
939 using OpConversionPattern<OpType>::OpConversionPattern;
941 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
942 ConversionPatternRewriter &rewriter)
const override {
943 if (!op->hasAttr(
"cache_control"))
946 auto *moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
947 std::optional<ArrayAttr> optCacheControls =
948 getCacheControlMetadata(rewriter, op);
949 if (!optCacheControls) {
950 rewriter.modifyOpInPlace(op, [&]() { op->removeAttr(
"cache_control"); });
955 constexpr bool isStore = std::is_same_v<OpType, LLVM::StoreOp>;
956 unsigned ptrIdx = isStore ? 1 : 0;
957 Value ptr = op->getOperand(ptrIdx);
960 Value annotatedPtr = annotatePtrWithCacheControl(
961 rewriter, op->getLoc(), ptr, *optCacheControls, moduleOp);
964 rewriter.modifyOpInPlace(op, [&]() {
965 op->setOperand(ptrIdx, annotatedPtr);
966 op->removeAttr(
"cache_control");
999static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdXOp) {
1000 return {
"get_local_id", 0};
1002static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdYOp) {
1003 return {
"get_local_id", 1};
1005static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdZOp) {
1006 return {
"get_local_id", 2};
1008static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimXOp) {
1009 return {
"get_local_size", 0};
1011static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimYOp) {
1012 return {
"get_local_size", 1};
1014static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimZOp) {
1015 return {
"get_local_size", 2};
1017static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdXOp) {
1018 return {
"get_group_id", 0};
1020static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdYOp) {
1021 return {
"get_group_id", 1};
1023static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdZOp) {
1024 return {
"get_group_id", 2};
1026static std::pair<StringRef, int64_t> getConfig(xevm::GridDimXOp) {
1027 return {
"get_num_groups", 0};
1029static std::pair<StringRef, int64_t> getConfig(xevm::GridDimYOp) {
1030 return {
"get_num_groups", 1};
1032static std::pair<StringRef, int64_t> getConfig(xevm::GridDimZOp) {
1033 return {
"get_num_groups", 2};
1037template <
typename OpType>
1038class LaunchConfigOpToOCLPattern :
public OpConversionPattern<OpType> {
1039 using OpConversionPattern<OpType>::OpConversionPattern;
1041 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
1042 ConversionPatternRewriter &rewriter)
const override {
1043 Location loc = op->getLoc();
1044 auto [baseName, dim] = getConfig(op);
1045 Type dimTy = rewriter.getI32Type();
1046 Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
1047 static_cast<int64_t
>(dim));
1048 std::string func = mangle(baseName, {dimTy}, {
true});
1049 Type resTy = op.getType();
1051 createDeviceFunctionCall(rewriter, func, resTy, {dimTy}, {dimVal}, {},
1052 noUnwindWillReturnAttrs, op.getOperation());
1053 constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
1054 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
1060 call.setMemoryEffectsAttr(memAttr);
1061 rewriter.replaceOp(op, call);
1078static StringRef getConfig(xevm::LaneIdOp) {
return "get_sub_group_local_id"; }
1079static StringRef getConfig(xevm::SubgroupIdOp) {
return "get_sub_group_id"; }
1080static StringRef getConfig(xevm::SubgroupSizeOp) {
1081 return "get_sub_group_size";
1083template <
typename OpType>
1084class SubgroupOpWorkitemOpToOCLPattern :
public OpConversionPattern<OpType> {
1085 using OpConversionPattern<OpType>::OpConversionPattern;
1087 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
1088 ConversionPatternRewriter &rewriter)
const override {
1089 std::string func = mangle(getConfig(op).str(), {});
1090 Type resTy = op.getType();
1092 createDeviceFunctionCall(rewriter, func, resTy, {}, {}, {},
1093 noUnwindWillReturnAttrs, op.getOperation());
1094 constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
1095 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
1101 call.setMemoryEffectsAttr(memAttr);
1102 rewriter.replaceOp(op, call);
1107class TruncfToOCLPattern :
public OpConversionPattern<TruncfOp> {
1108 using OpConversionPattern::OpConversionPattern;
1110 matchAndRewrite(TruncfOp op, TruncfOp::Adaptor adaptor,
1111 ConversionPatternRewriter &rewriter)
const override {
1113 auto srcEtype = op.getSrcEtype().getEtype();
1114 auto dstEtype = op.getDstEtype().getEtype();
1115 if (
auto vecSrcTy = dyn_cast<VectorType>(op.getSrc().getType())) {
1116 if (vecSrcTy.getNumElements() != 16)
1117 return rewriter.notifyMatchFailure(
1118 op,
"Only vector src of 16 elements is supported");
1120 return rewriter.notifyMatchFailure(op,
"Scalar src is not supported.");
1122 if (
auto vecDstTy = dyn_cast<VectorType>(op.getDst().getType())) {
1123 if (vecDstTy.getNumElements() != 16)
1124 return rewriter.notifyMatchFailure(
1125 op,
"Only vector dst of 16 elements is supported");
1127 return rewriter.notifyMatchFailure(op,
"Scalar dst is not supported.");
1129 if (srcEtype == TruncfSrcElemTypes::F16 &&
1130 dstEtype == TruncfDstElemTypes::BF8) {
1138 LLVM::ShuffleVectorOp::create(rewriter, op.getLoc(), op.getSrc(),
1139 op.getSrc(), {0, 1, 2, 3, 4, 5, 6, 7});
1140 auto secondHalf = LLVM::ShuffleVectorOp::create(
1141 rewriter, op.getLoc(), op.getSrc(), op.getSrc(),
1142 {8, 9, 10, 11, 12, 13, 14, 15});
1143 auto firstHalfCasted = LLVM::BitcastOp::create(
1144 rewriter, op.getLoc(), VectorType::get(16, rewriter.getI8Type()),
1146 auto secondHalfCasted = LLVM::BitcastOp::create(
1147 rewriter, op.getLoc(), VectorType::get(16, rewriter.getI8Type()),
1150 auto resFirstHalf = LLVM::ShuffleVectorOp::create(
1151 rewriter, op.getLoc(), firstHalfCasted, firstHalfCasted,
1152 {1, 3, 5, 7, 9, 11, 13, 15});
1153 auto resSecondHalf = LLVM::ShuffleVectorOp::create(
1154 rewriter, op.getLoc(), secondHalfCasted, secondHalfCasted,
1155 {1, 3, 5, 7, 9, 11, 13, 15});
1156 auto res = LLVM::ShuffleVectorOp::create(
1157 rewriter, op.getLoc(), resFirstHalf, resSecondHalf,
1158 {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
1159 rewriter.replaceOp(op, res);
1161 return rewriter.notifyMatchFailure(
1162 op,
"Unsupported src, dst element type pair.");
1168class MMAMxToOCLPattern :
public OpConversionPattern<MMAMxOp> {
1169 using OpConversionPattern::OpConversionPattern;
1171 matchAndRewrite(MMAMxOp op, MMAMxOp::Adaptor adaptor,
1172 ConversionPatternRewriter &rewriter)
const override {
1174 return rewriter.notifyMatchFailure(op,
"OCL requires C operand");
1176 auto precisionC = op.getTypes().getC();
1177 auto precisionD = op.getTypes().getD();
1178 if (precisionC != precisionD) {
1179 return rewriter.notifyMatchFailure(op,
"type of C and D need to match");
1182 constexpr uint32_t bitWidthPackedA{16};
1183 constexpr uint32_t bitWidthPackedB{32};
1184 auto loc = op.getLoc();
1186 auto castIfNeeded = [&](Value val, Type packedType) -> Value {
1187 VectorType origTy = cast<VectorType>(val.
getType());
1188 const uint32_t vecBitSize =
1189 origTy.getNumElements() *
1190 origTy.getElementType().getIntOrFloatBitWidth();
1191 VectorType newTy = VectorType::get(
1192 vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
1193 if (origTy != newTy)
1194 val = LLVM::BitcastOp::create(rewriter, loc, newTy, val);
1198 Value a = op.getA();
1199 Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
1200 ? cast<Type>(rewriter.getF32Type())
1201 : rewriter.getIntegerType(bitWidthPackedA);
1202 a = castIfNeeded(a, packedAType);
1204 Value
b = op.getB();
1205 Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
1206 ? cast<Type>(rewriter.getF32Type())
1207 : rewriter.getIntegerType(bitWidthPackedB);
1208 b = castIfNeeded(
b, packedBType);
1210 Value c = op.getC();
1211 VectorType cOrigTy = cast<VectorType>(c.
getType());
1212 VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
1213 assert(cOrigTy == resOrigTy &&
"Accumulator and result type mismatch");
1216 cOrigTy.getElementType().isBF16()
1217 ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
1219 VectorType resTy = cTy;
1221 c = LLVM::BitcastOp::create(rewriter, loc, cTy, c);
1223 constexpr int32_t systolicDepth{8};
1224 std::string fnName =
1225 llvm::formatv(
"intel_sub_group_{0}_{1}_scaled_matrix_mad_k{2}_{3}",
1226 stringifyElemType(op.getTypes().getA()).str(),
1227 stringifyElemType(op.getTypes().getB()).str(),
1229 getNumOperandsPerDword(op.getTypes().getA()),
1230 stringifyElemType(op.getTypes().getC()).str())
1232 auto scaleA = op.getScaleA();
1233 auto scaleB = op.getScaleB();
1234 SmallVector<Type> argTypes{a.
getType(),
b.getType(), cTy, scaleA.getType(),
1236 fnName = mangle(fnName, argTypes);
1237 SmallVector<Value> args{a,
b, c, scaleA, scaleB};
1239 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
1240 LLVM::ModRefInfo::NoModRef,
1241 LLVM::ModRefInfo::NoModRef,
1242 LLVM::ModRefInfo::NoModRef,
1243 LLVM::ModRefInfo::NoModRef,
1244 LLVM::ModRefInfo::NoModRef,
1245 LLVM::ModRefInfo::NoModRef);
1246 auto funcAttrs = convergentNoUnwindWillReturnAttrs;
1247 funcAttrs.memEffectsAttr = memAttr;
1249 createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
1250 funcAttrs, op.getOperation())
1253 if (resOrigTy != resTy)
1254 result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy,
result);
1256 rewriter.replaceOp(op,
result);
1261class AllocaToGlobalPattern :
public OpConversionPattern<LLVM::AllocaOp> {
1262 using OpConversionPattern::OpConversionPattern;
1264 matchAndRewrite(LLVM::AllocaOp op, LLVM::AllocaOp::Adaptor adaptor,
1265 ConversionPatternRewriter &rewriter)
const override {
1266 auto ptrType = cast<LLVM::LLVMPointerType>(op.getType());
1267 auto addrSpace = ptrType.getAddressSpace();
1270 auto symTable = op->getParentWithTrait<OpTrait::SymbolTable>();
1274 if (ModuleOp mod = dyn_cast<ModuleOp>(*symTable)) {
1275 moduleBody = mod.getBody();
1276 }
else if (gpu::GPUModuleOp gpuMod =
1277 dyn_cast<gpu::GPUModuleOp>(*symTable)) {
1278 moduleBody = gpuMod.getBody();
1282 auto val = op.getArraySize();
1286 auto loc = op.getLoc();
1287 auto globalType = LLVM::LLVMArrayType::get(
1288 rewriter.getContext(), op.getElemType(), cst.getZExtValue());
1289 LLVM::GlobalOp globalVar;
1291 OpBuilder::InsertionGuard guard(rewriter);
1292 rewriter.setInsertionPointToStart(moduleBody);
1293 auto alignment = op.getAlignment();
1294 globalVar = LLVM::GlobalOp::create(
1295 rewriter, loc, globalType,
false,
1296 LLVM::Linkage::Internal,
1297 std::string(
"__global_alloca_") +
1298 std::to_string(getNextGlobalIdx()),
1300 alignment ? *alignment : 0, addrSpace);
1302 rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, globalVar);
1307 static unsigned getNextGlobalIdx() {
1308 static unsigned globalIdx = 0;
1319static bool isExtractingContiguousSlice(LLVM::ShuffleVectorOp op) {
1320 if (op.getV1() != op.getV2())
1322 auto maskAttr = op.getMask();
1324 int64_t sourceSize = op.getV1().getType().getNumElements();
1325 if (maskSize > sourceSize)
1327 int64_t firstIndex = maskAttr[0];
1328 for (
int64_t i = 1; i < maskSize; ++i) {
1330 if (
index != firstIndex + i)
1332 if (
index >= sourceSize)
1346class HandleVectorExtractPattern
1348 using OpRewritePattern<LLVM::ShuffleVectorOp>::OpRewritePattern;
1350 void initialize() { setHasBoundedRewriteRecursion(); }
1352 LogicalResult matchAndRewrite(LLVM::ShuffleVectorOp op,
1353 PatternRewriter &rewriter)
const override {
1355 if (!isExtractingContiguousSlice(op))
1358 auto mask = op.getMask();
1359 auto loc = op.getLoc();
1360 auto ty = op.getType();
1362 auto src = op.getV1();
1364 if (
auto srcOp = src.getDefiningOp()) {
1365 if (isa<LLVM::FPExtOp>(srcOp) || isa<LLVM::FPTruncOp>(srcOp)) {
1366 Value srcInput = srcOp->getOperand(0);
1368 auto srcVecTy = dyn_cast<VectorType>(srcInput.
getType());
1369 auto newShuffleVecTy =
1370 VectorType::get(mask.size(), srcVecTy.getElementType());
1371 auto newShuffle = LLVM::ShuffleVectorOp::create(
1372 rewriter, loc, newShuffleVecTy, srcInput, srcInput, mask);
1375 if (isa<LLVM::FPExtOp>(srcOp)) {
1376 newUnaryOp = LLVM::FPExtOp::create(rewriter, loc, ty, newShuffle);
1378 newUnaryOp = LLVM::FPTruncOp::create(rewriter, loc, ty, newShuffle);
1381 }
else if (isa<LLVM::BitcastOp>(srcOp)) {
1382 Value srcInput = srcOp->getOperand(0);
1384 auto srcInputVecTy = dyn_cast<VectorType>(srcInput.
getType());
1385 auto srcInputSize = srcInputVecTy.getNumElements();
1386 auto srcResVecTy = dyn_cast<VectorType>(srcOp->getResult(0).getType());
1387 auto srcResSize = srcResVecTy.getNumElements();
1388 auto maskSize =
static_cast<int32_t
>(mask.size());
1389 if (srcInputSize > srcResSize) {
1392 if (srcResSize % srcInputSize != 0) {
1395 auto maskScale = srcResSize / srcInputSize;
1396 if (maskScale != 1) {
1397 if (mask[0] % maskScale != 0) {
1401 SmallVector<int32_t> newMask;
1402 int32_t newMaskSize = maskSize / maskScale;
1403 int32_t maskStart = mask[0] / maskScale;
1404 for (int32_t i = 0; i < newMaskSize; ++i) {
1405 newMask.push_back(maskStart + i);
1409 auto newShuffleVecTy =
1410 VectorType::get(srcInputSize, srcInputVecTy.getElementType());
1411 auto newShuffle = LLVM::ShuffleVectorOp::create(
1412 rewriter, loc, newShuffleVecTy, srcInput, srcInput, mask);
1415 LLVM::BitcastOp::create(rewriter, loc, ty, newShuffle);
1417 }
else if (isa<LLVM::ShuffleVectorOp>(srcOp)) {
1422 auto srcShuffle = cast<LLVM::ShuffleVectorOp>(srcOp);
1423 if (!isExtractingContiguousSlice(srcShuffle))
1425 auto srcMask = srcShuffle.getMask();
1426 SmallVector<int32_t> combinedMask;
1427 for (
auto index : mask) {
1428 combinedMask.push_back(srcMask[index]);
1430 auto newShuffle = LLVM::ShuffleVectorOp::create(
1431 rewriter, loc, ty, srcShuffle.getV1(), srcShuffle.getV1(),
1434 }
else if (isa<LLVM::LoadOp>(srcOp)) {
1436 auto loadOp = cast<LLVM::LoadOp>(srcOp);
1437 auto loadPtr = loadOp.getAddr();
1438 auto loadAddrSpace = loadPtr.getType().getAddressSpace();
1439 if (loadAddrSpace != 0)
1441 auto loadTy = dyn_cast<VectorType>(loadOp.getType());
1442 auto elemTy = loadTy.getElementType();
1443 auto firstIndex = mask[0];
1444 auto newVecTy = VectorType::get(mask.size(), elemTy);
1447 auto newPtr = LLVM::GEPOp::create(
1449 LLVM::LLVMPointerType::get(rewriter.
getContext(), loadAddrSpace),
1450 elemTy, loadPtr, ArrayRef<LLVM::GEPArg>{firstIndex});
1451 auto newLoad = LLVM::LoadOp::create(rewriter, loc, newVecTy, newPtr);
1454 auto newLoad = LLVM::LoadOp::create(rewriter, loc, newVecTy, loadPtr);
1469struct ConvertXeVMToLLVMPass
1470 :
public impl::ConvertXeVMToLLVMPassBase<ConvertXeVMToLLVMPass> {
1473 void getDependentDialects(DialectRegistry ®istry)
const override {
1474 registry.
insert<LLVM::LLVMDialect, XeVMDialect>();
1477 void runOnOperation()
override {
1481 if (
failed(applyPartialConversion(getOperation(),
target,
1482 std::move(patterns))))
1483 signalPassFailure();
1487 RewritePatternSet vectorPatterns(&
getContext());
1488 vectorPatterns.add<HandleVectorExtractPattern>(&
getContext());
1489 GreedyRewriteConfig config{};
1494 config.enableFolding(
false);
1511 target.addDynamicallyLegalDialect<LLVM::LLVMDialect>([](
Operation *op) {
1515 if (isa<LLVM::AllocaOp>(op)) {
1516 LLVM::AllocaOp aOp = cast<LLVM::AllocaOp>(op);
1517 LLVM::LLVMPointerType pTy = cast<LLVM::LLVMPointerType>(aOp.getType());
1518 auto addrSpace = pTy.getAddressSpace();
1519 return addrSpace != 3;
1522 return !op->hasAttr(
"cache_control");
1524 target.addIllegalDialect<XeVMDialect>();
1525 patterns.
add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
1526 LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
1527 LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
1528 MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern,
1529 LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
1530 LLVMLoadStoreToOCLPattern<LLVM::StoreOp>,
1531 BlockLoadStore1DToOCLPattern<BlockLoadOp>,
1532 BlockLoadStore1DToOCLPattern<BlockStoreOp>,
1533 LaunchConfigOpToOCLPattern<WorkitemIdXOp>,
1534 LaunchConfigOpToOCLPattern<WorkitemIdYOp>,
1535 LaunchConfigOpToOCLPattern<WorkitemIdZOp>,
1536 LaunchConfigOpToOCLPattern<WorkgroupDimXOp>,
1537 LaunchConfigOpToOCLPattern<WorkgroupDimYOp>,
1538 LaunchConfigOpToOCLPattern<WorkgroupDimZOp>,
1539 LaunchConfigOpToOCLPattern<WorkgroupIdXOp>,
1540 LaunchConfigOpToOCLPattern<WorkgroupIdYOp>,
1541 LaunchConfigOpToOCLPattern<WorkgroupIdZOp>,
1542 LaunchConfigOpToOCLPattern<GridDimXOp>,
1543 LaunchConfigOpToOCLPattern<GridDimYOp>,
1544 LaunchConfigOpToOCLPattern<GridDimZOp>,
1545 SubgroupOpWorkitemOpToOCLPattern<LaneIdOp>,
1546 SubgroupOpWorkitemOpToOCLPattern<SubgroupIdOp>,
1547 SubgroupOpWorkitemOpToOCLPattern<SubgroupSizeOp>,
1548 TruncfToOCLPattern, MMAMxToOCLPattern, AllocaToGlobalPattern>(
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
Attributes are known-constant values of operations.
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes={}, Type resultType={}, bool isVarArg=false, bool isReserved=false, SymbolTableCollection *symbolTables=nullptr)
Create a FuncOp with signature resultType(paramTypes) and name name`.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateXeVMToLLVMConversionPatterns(ConversionTarget &target, RewritePatternSet &patterns)
llvm::TypeSwitch< T, ResultT > TypeSwitch
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...