18#include "llvm/ADT/ArrayRef.h"
19#include "llvm/Support/FormatVariadic.h"
26#include "llvm/ADT/TypeSwitch.h"
29#define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
30#include "mlir/Conversion/Passes.h.inc"
38struct LLVMFuncAttributeOptions {
39 bool isConvergent =
false;
40 bool isNoUnwind =
false;
41 bool isWillReturn =
false;
42 LLVM::MemoryEffectsAttr memEffectsAttr{};
44static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
45 false,
true,
false, {}};
46static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
47 false,
true,
true, {}};
48static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
49 true,
true,
true, {}};
51std::string getTypeMangling(
Type ty,
bool isUnsigned =
false) {
53 .Case([isUnsigned](VectorType ty) -> std::string {
54 return "Dv" + std::to_string(ty.getNumElements()) +
"_" +
55 getTypeMangling(ty.getElementType(), isUnsigned);
57 .Case([](Float16Type) -> std::string {
return "Dh"; })
58 .Case([](Float32Type) -> std::string {
return "f"; })
59 .Case([](Float64Type) -> std::string {
return "d"; })
60 .Case([isUnsigned](IntegerType ty) -> std::string {
61 switch (ty.getWidth()) {
63 return isUnsigned ?
"h" :
"c";
65 return isUnsigned ?
"t" :
"s";
67 return isUnsigned ?
"j" :
"i";
69 return isUnsigned ?
"m" :
"l";
71 llvm_unreachable(
"unhandled integer type");
74 .DefaultUnreachable(
"unhandled type for mangling");
79 assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
80 "Signedness info doesn't match");
82 llvm::raw_string_ostream os(s);
83 llvm::SmallDenseMap<Type, unsigned> substitutions;
84 os <<
"_Z" << baseName.size() << baseName;
85 for (
auto [idx, type] : llvm::enumerate(types)) {
86 auto it = substitutions.find(type);
87 if (it != substitutions.end()) {
90 if (
unsigned firstIdx = it->getSecond(); firstIdx > 0)
94 if (!type.isIntOrFloat())
95 substitutions[type] = substitutions.size();
96 os << getTypeMangling(type, isUnsigned.empty() ?
false : isUnsigned[idx]);
102std::string builtinElemType(ElemType elemType) {
115 return stringifyElemType(elemType).str();
119static int32_t getL1CacheControl(LoadCacheControl cc) {
122 case LoadCacheControl::USE_DEFAULT:
125 case LoadCacheControl::L1C_L2UC_L3UC:
126 case LoadCacheControl::L1C_L2UC_L3C:
127 case LoadCacheControl::L1C_L2C_L3UC:
128 case LoadCacheControl::L1C_L2C_L3C:
131 case LoadCacheControl::L1S_L2UC_L3UC:
132 case LoadCacheControl::L1S_L2UC_L3C:
133 case LoadCacheControl::L1S_L2C_L3UC:
134 case LoadCacheControl::L1S_L2C_L3C:
137 case LoadCacheControl::INVALIDATE_READ:
146static int32_t getL1CacheControl(StoreCacheControl cc) {
149 case StoreCacheControl::USE_DEFAULT:
152 case StoreCacheControl::L1WT_L2UC_L3UC:
153 case StoreCacheControl::L1WT_L2UC_L3WB:
154 case StoreCacheControl::L1WT_L2WB_L3UC:
155 case StoreCacheControl::L1WT_L2WB_L3WB:
158 case StoreCacheControl::L1WB_L2UC_L3UC:
159 case StoreCacheControl::L1WB_L2WB_L3UC:
160 case StoreCacheControl::L1WB_L2UC_L3WB:
163 case StoreCacheControl::L1S_L2UC_L3UC:
164 case StoreCacheControl::L1S_L2UC_L3WB:
165 case StoreCacheControl::L1S_L2WB_L3UC:
166 case StoreCacheControl::L1S_L2WB_L3WB:
175static int32_t getL3CacheControl(LoadCacheControl cc) {
178 case LoadCacheControl::USE_DEFAULT:
181 case LoadCacheControl::L1UC_L2UC_L3C:
182 case LoadCacheControl::L1UC_L2C_L3C:
183 case LoadCacheControl::L1C_L2UC_L3C:
184 case LoadCacheControl::L1C_L2C_L3C:
185 case LoadCacheControl::L1S_L2UC_L3C:
186 case LoadCacheControl::L1S_L2C_L3C:
189 case LoadCacheControl::INVALIDATE_READ:
198static int32_t getL3CacheControl(StoreCacheControl cc) {
201 case StoreCacheControl::USE_DEFAULT:
204 case StoreCacheControl::L1UC_L2UC_L3WB:
205 case StoreCacheControl::L1UC_L2WB_L3WB:
206 case StoreCacheControl::L1WT_L2UC_L3WB:
207 case StoreCacheControl::L1WT_L2WB_L3WB:
208 case StoreCacheControl::L1S_L2UC_L3WB:
209 case StoreCacheControl::L1S_L2WB_L3WB:
210 case StoreCacheControl::L1WB_L2UC_L3WB:
219static std::optional<LoadCacheControl> getCacheControl(PrefetchOp op) {
220 return op.getCacheControl();
223static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) {
224 return op.getCacheControl();
227static std::optional<LoadCacheControl> getCacheControl(BlockLoadOp op) {
228 return op.getCacheControl();
231static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) {
232 return op.getCacheControl();
235static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) {
236 return op.getCacheControl();
239static std::optional<StoreCacheControl> getCacheControl(BlockStoreOp op) {
240 return op.getCacheControl();
243static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) {
244 if (op->hasAttr(
"cache_control")) {
245 auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>(
"cache_control");
248 return std::optional<LoadCacheControl>(attr.getValue());
253static std::optional<StoreCacheControl> getCacheControl(LLVM::StoreOp op) {
254 if (op->hasAttr(
"cache_control")) {
255 auto attr = op->getAttrOfType<xevm::StoreCacheControlAttr>(
"cache_control");
258 return std::optional<StoreCacheControl>(attr.getValue());
263template <
typename OpType>
264int32_t getL1CacheControl(OpType op) {
265 return getL1CacheControl(*getCacheControl(op));
268template <
typename OpType>
269int32_t getL3CacheControl(OpType op) {
270 return getL3CacheControl(*getCacheControl(op));
273template <
typename OpType>
274static std::optional<ArrayAttr>
275getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
276 if (!getCacheControl(op))
279 constexpr int32_t decorationCacheControlArity{3};
280 constexpr int32_t loadCacheControlKey{6442};
281 constexpr int32_t storeCacheControlKey{6443};
282 constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
283 std::is_same_v<OpType, BlockPrefetch2dOp> ||
284 std::is_same_v<OpType, LLVM::LoadOp> ||
285 std::is_same_v<OpType, BlockLoadOp> ||
286 std::is_same_v<OpType, PrefetchOp>;
292 assert(((getL1CacheControl<OpType>(op) == -1) ==
293 (getL3CacheControl<OpType>(op) == -1)) &&
294 "If one of L1 or L3 cache control is USE_DEFAULT, both must be "
297 if (getL1CacheControl<OpType>(op) == -1 &&
298 getL3CacheControl<OpType>(op) == -1)
300 const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
302 controlKey, 0, getL1CacheControl<OpType>(op)};
304 controlKey, 1, getL3CacheControl<OpType>(op)};
305 auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
306 auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
309 return rewriter.getArrayAttr(combinedAttrs);
331 llvm::StringMap<bool> seen;
334 auto arr = dyn_cast<ArrayAttr>(a);
338 auto vals = arr.getValue();
339 assert(vals.size() == 3 &&
340 "Expected exactly 3 integer values (Token, CacheLevel, "
341 "ControlValue) in cache control attribute.");
343 auto tokenAttr = dyn_cast<IntegerAttr>(vals[0]);
344 auto secondAttr = dyn_cast<IntegerAttr>(vals[1]);
345 auto thirdAttr = dyn_cast<IntegerAttr>(vals[2]);
347 if (!tokenAttr || !secondAttr || !thirdAttr)
353 llvm::formatv(
"{{{0}:\"{1},{2}\"}", tokenAttr.getValue().getZExtValue(),
354 secondAttr.getValue().getZExtValue(),
355 thirdAttr.getValue().getZExtValue());
358 if (!seen.insert({entry, true}).second)
361 payloads.push_back(std::move(entry));
366static std::atomic<uint64_t> globalNameCounter{0};
371static Value createMetadataStringPtr(ConversionPatternRewriter &rewriter,
373 StringRef value, StringRef nameHint) {
375 std::string strWithNull = value.str();
376 strWithNull.push_back(
'\0');
377 StringRef strRef(strWithNull.data(), strWithNull.size());
379 auto as1PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
383 if (
auto existingGlobal = dyn_cast<LLVM::GlobalOp>(&op)) {
384 if (!existingGlobal.getSection() ||
385 *existingGlobal.getSection() !=
"llvm.metadata")
388 dyn_cast_or_null<StringAttr>(existingGlobal.getValueOrNull())) {
389 if (strAttr.getValue() == strRef) {
390 return LLVM::AddressOfOp::create(rewriter, loc, as1PtrTy,
391 existingGlobal.getSymName());
398 auto i8Type = rewriter.getI8Type();
399 auto arrayType = LLVM::LLVMArrayType::get(i8Type, strWithNull.size());
400 std::string globalName =
401 llvm::formatv(
"{0}.{1}", nameHint,
402 globalNameCounter.fetch_add(1, std::memory_order_relaxed))
407 rewriter.setInsertionPointToStart(&moduleOp->
getRegion(0).
front());
410 LLVM::GlobalOp::create(rewriter, loc, arrayType,
411 true, LLVM::Linkage::Private,
412 globalName, rewriter.getStringAttr(strRef));
413 globalOp.setSection(StringRef(
"llvm.metadata"));
414 globalOp.setUnnamedAddr(LLVM::UnnamedAddr::Global);
415 globalOp.setAlignment(1);
416 globalOp.setAddrSpace(1);
420 return LLVM::AddressOfOp::create(rewriter, loc, as1PtrTy, globalName);
443static Value annotatePtrWithCacheControl(ConversionPatternRewriter &rewriter,
448 buildCacheControlPayloads(cacheControls.getValue());
449 if (payloads.empty())
452 auto ptrType = cast<LLVM::LLVMPointerType>(
ptr.getType());
453 auto as1PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
454 auto i32Ty = rewriter.getI32Type();
458 createMetadataStringPtr(rewriter, moduleOp, loc,
"",
".str.file");
459 Value lineVal = LLVM::ConstantOp::create(rewriter, loc, i32Ty, 0);
460 Value nullAS1 = LLVM::ZeroOp::create(rewriter, loc, as1PtrTy);
465 for (
const std::string &payload : payloads) {
466 Value annStr = createMetadataStringPtr(rewriter, moduleOp, loc, payload,
467 ".str.cachecontrol");
468 auto annOp = LLVM::PtrAnnotation::create(rewriter, loc, ptrType, curPtr,
469 annStr, fileStr, lineVal, nullAS1);
470 curPtr = annOp.getResult();
492template <
typename OpType>
494applyCacheControlAnnotation(ConversionPatternRewriter &rewriter,
Location loc,
496 Operation *moduleOp,
unsigned ptrIdx = 0) {
497 std::optional<ArrayAttr> optCacheControls =
498 getCacheControlMetadata(rewriter, op);
499 if (!optCacheControls)
502 Value annotatedPtr = annotatePtrWithCacheControl(rewriter, loc, args[ptrIdx],
503 *optCacheControls, moduleOp);
504 args[ptrIdx] = annotatedPtr;
511static LLVM::CallOp createDeviceFunctionCall(
512 ConversionPatternRewriter &rewriter, StringRef funcName,
Type retType,
515 LLVMFuncAttributeOptions funcAttributeOptions,
Operation *op) {
517 assert(moduleOp &&
"Expecting module");
522 assert(!
failed(funcOpRes));
523 LLVM::LLVMFuncOp funcOp = funcOpRes.value();
524 funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
525 funcOp.setConvergent(funcAttributeOptions.isConvergent);
526 funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
527 funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
529 if (funcAttributeOptions.memEffectsAttr)
530 funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
532 for (
auto [idx, attrName] : paramAttrs)
533 funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
535 auto callOp = LLVM::CallOp::create(rewriter, loc, funcOp, args);
536 callOp->setAttrs(funcOp->getAttrs());
541static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
543 case xevm::ElemType::F32:
544 case xevm::ElemType::TF32:
546 case xevm::ElemType::BF16:
547 case xevm::ElemType::F16:
549 case xevm::ElemType::U8:
550 case xevm::ElemType::S8:
551 case xevm::ElemType::BF8:
552 case xevm::ElemType::F8:
554 case xevm::ElemType::E2M1:
555 case xevm::ElemType::U4:
556 case xevm::ElemType::S4:
559 llvm_unreachable(
"unsupported xevm::ElemType");
563class MMAToOCLPattern :
public OpConversionPattern<xevm::MMAOp> {
564 using OpConversionPattern::OpConversionPattern;
566 matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor,
567 ConversionPatternRewriter &rewriter)
const override {
569 return rewriter.notifyMatchFailure(op,
"OCL requires C operand");
571 auto precisionA = op.getTypes().getA();
572 auto precisionB = op.getTypes().getB();
573 auto precisionC = op.getTypes().getC();
574 auto precisionD = op.getTypes().getD();
575 if (precisionC != precisionD) {
576 return rewriter.notifyMatchFailure(op,
"type of C and D need to match");
578 if (precisionC != xevm::ElemType::S32 &&
579 precisionC != xevm::ElemType::F32 &&
580 precisionC != xevm::ElemType::F16 &&
581 precisionC != xevm::ElemType::BF16) {
582 return rewriter.notifyMatchFailure(
583 op,
"type of C and D must be S32, F32, F16 or BF16");
585 if (precisionA == xevm::ElemType::S32 ||
586 precisionA == xevm::ElemType::F32) {
587 return rewriter.notifyMatchFailure(op,
"type of A cannot be S32 or F32");
589 if (precisionB == xevm::ElemType::S32 ||
590 precisionB == xevm::ElemType::F32) {
591 return rewriter.notifyMatchFailure(op,
"type of B cannot be S32 or F32");
593 constexpr uint32_t bitWidthPackedA{16};
594 constexpr uint32_t bitWidthPackedB{32};
595 auto loc = op.getLoc();
597 auto castIfNeeded = [&](Value val, Type packedType) -> Value {
598 VectorType origTy = cast<VectorType>(val.
getType());
599 const uint32_t vecBitSize =
600 origTy.getNumElements() *
601 origTy.getElementType().getIntOrFloatBitWidth();
602 VectorType newTy = VectorType::get(
603 vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
605 val = LLVM::BitcastOp::create(rewriter, loc, newTy, val);
610 Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
611 ? cast<Type>(rewriter.getF32Type())
612 : rewriter.getIntegerType(bitWidthPackedA);
613 a = castIfNeeded(a, packedAType);
616 Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
617 ? cast<Type>(rewriter.getF32Type())
618 : rewriter.getIntegerType(bitWidthPackedB);
619 b = castIfNeeded(
b, packedBType);
622 VectorType cOrigTy = cast<VectorType>(c.
getType());
623 VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
624 assert(cOrigTy == resOrigTy &&
"Accumulator and result type mismatch");
627 cOrigTy.getElementType().isBF16()
628 ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
630 VectorType resTy = cTy;
632 c = LLVM::BitcastOp::create(rewriter, loc, cTy, c);
634 constexpr int32_t systolicDepth{8};
636 llvm::formatv(
"intel_sub_group_{0}_{1}_matrix_mad_k{2}",
637 stringifyElemType(op.getTypes().getA()).str(),
638 stringifyElemType(op.getTypes().getB()).str(),
640 getNumOperandsPerDword(op.getTypes().getA()))
642 SmallVector<Type> argTypes{a.
getType(),
b.getType(), cTy};
643 fnName = mangle(fnName, argTypes);
644 SmallVector<Value> args{a,
b, c};
646 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
647 LLVM::ModRefInfo::NoModRef,
648 LLVM::ModRefInfo::NoModRef,
649 LLVM::ModRefInfo::NoModRef,
650 LLVM::ModRefInfo::NoModRef,
651 LLVM::ModRefInfo::NoModRef,
652 LLVM::ModRefInfo::NoModRef);
653 auto funcAttrs = convergentNoUnwindWillReturnAttrs;
654 funcAttrs.memEffectsAttr = memAttr;
656 createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
657 funcAttrs, op.getOperation())
660 if (resOrigTy != resTy)
661 result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy,
result);
663 rewriter.replaceOp(op,
result);
668class PrefetchToOCLPattern :
public OpConversionPattern<PrefetchOp> {
669 using OpConversionPattern::OpConversionPattern;
671 matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor,
672 ConversionPatternRewriter &rewriter)
const override {
673 auto loc = op.getLoc();
676 const std::string fnName{
"_Z8prefetchPU3AS1Kcm"};
678 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), 1);
679 SmallVector<Value> args{op.getPtr(), one};
682 applyCacheControlAnnotation(rewriter, loc, op, args, moduleOp,
685 SmallVector<Type> argTypes;
686 for (
auto arg : args)
687 argTypes.push_back(arg.getType());
688 auto funcAttr = noUnwindAttrs;
689 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
690 LLVM::ModRefInfo::NoModRef,
691 LLVM::ModRefInfo::Ref,
692 LLVM::ModRefInfo::NoModRef,
693 LLVM::ModRefInfo::NoModRef,
694 LLVM::ModRefInfo::NoModRef,
695 LLVM::ModRefInfo::NoModRef);
696 funcAttr.memEffectsAttr = memAttr;
698 createDeviceFunctionCall(rewriter, fnName,
699 LLVM::LLVMVoidType::get(rewriter.getContext()),
700 argTypes, args, {}, funcAttr, op.getOperation());
701 rewriter.eraseOp(op);
706class MemfenceToOCLPattern :
public OpConversionPattern<MemfenceOp> {
707 using OpConversionPattern::OpConversionPattern;
709 matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor,
710 ConversionPatternRewriter &rewriter)
const override {
711 auto loc = op.getLoc();
712 const std::string fnName{
"atomic_work_item_fence"};
713 int memScope, addrSpace;
714 switch (op.getAddrspace()) {
715 case xevm::AddrSpace::SHARED:
718 case xevm::AddrSpace::GLOBAL:
723 return rewriter.notifyMatchFailure(
724 op,
"Fence only supports global and shared address spaces.");
726 switch (op.getScope()) {
727 case xevm::MemScope::WORKGROUP:
730 case xevm::MemScope::DEVICE:
735 return rewriter.notifyMatchFailure(
736 op,
"Fence only supports workgroup and device memory scopes.");
738 Type i32Type = rewriter.getI32Type();
739 Value acqRel = LLVM::ConstantOp::create(rewriter, loc, i32Type, 4);
740 Value memScopeConst =
741 LLVM::ConstantOp::create(rewriter, loc, i32Type, memScope);
742 Value addrSpaceConst =
743 LLVM::ConstantOp::create(rewriter, loc, i32Type, addrSpace);
744 SmallVector<Value> args{addrSpaceConst, acqRel, memScopeConst};
745 SmallVector<Type> argTypes{3, i32Type};
746 createDeviceFunctionCall(rewriter, mangle(fnName, argTypes),
747 LLVM::LLVMVoidType::get(rewriter.getContext()),
748 argTypes, args, {}, noUnwindAttrs,
750 rewriter.eraseOp(op);
754template <
typename OpType>
755class LoadStorePrefetchToOCLPattern :
public OpConversionPattern<OpType> {
756 using OpConversionPattern<OpType>::OpConversionPattern;
758 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
759 ConversionPatternRewriter &rewriter)
const override {
760 constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp>;
761 constexpr bool isPrefetch = std::is_same_v<OpType, BlockPrefetch2dOp>;
763 auto loc = op.getLoc();
764 auto *moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
766 bool packReg =
false;
767 bool transpose =
false;
768 if constexpr (isLoad) {
769 vecType = op.getRes().getType();
770 packReg = op.getPackRegister();
771 transpose = op.getTranspose();
772 }
else if constexpr (!isPrefetch) {
773 vecType = op.getStoredVal().getType();
776 auto i32Type = rewriter.getI32Type();
778 LLVM::UndefOp::create(rewriter, loc, VectorType::get(2, i32Type));
779 Value zero = LLVM::ConstantOp::create(rewriter, loc, i32Type, 0);
780 Value one = LLVM::ConstantOp::create(rewriter, loc, i32Type, 1);
781 byteCoord = LLVM::InsertElementOp::create(
782 rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
783 byteCoord = LLVM::InsertElementOp::create(
784 rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
785 SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
786 op.getBasePitch(), byteCoord};
789 applyCacheControlAnnotation(rewriter, loc, op, args, moduleOp,
792 SmallVector<Type> retTypes;
794 std::string funcName{
"intel_sub_group_2d_block_"};
795 std::string bitWidthId;
796 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
797 SmallVector<std::pair<unsigned, StringRef>, 4> paramAttrs;
798 if constexpr (isPrefetch) {
799 funcName +=
"prefetch";
800 paramAttrs = {std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName())};
801 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
802 LLVM::ModRefInfo::NoModRef,
803 LLVM::ModRefInfo::Ref,
804 LLVM::ModRefInfo::NoModRef,
805 LLVM::ModRefInfo::NoModRef,
806 LLVM::ModRefInfo::NoModRef,
807 LLVM::ModRefInfo::NoModRef);
808 funcAttr = noUnwindAttrs;
809 funcAttr.memEffectsAttr = memAttr;
811 auto vecElemType = vecType.getElementType();
812 auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth();
813 Value numElems = LLVM::ConstantOp::create(rewriter, loc, i32Type,
814 vecType.getNumElements());
815 auto dstOrSrcPtr = LLVM::AllocaOp::create(
816 rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()),
817 vecElemType, numElems);
818 args.push_back(dstOrSrcPtr);
819 if constexpr (isLoad) {
821 bitWidthId = getTypeMangling(vecElemType,
true);
823 funcName +=
"_transform";
825 funcName +=
"_transpose";
826 spvLoadDstPtr = dstOrSrcPtr;
827 retTypes.push_back(vecType);
829 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
830 std::make_pair(0, LLVM::LLVMDialect::getReadonlyAttrName()),
831 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
832 std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()),
836 bitWidthId = (vecElemBitWidth == 32)
838 : ((vecElemBitWidth == 16) ?
"t" :
"h");
839 LLVM::StoreOp::create(rewriter, loc, op.getStoredVal(), dstOrSrcPtr);
841 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
842 std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()),
843 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
844 std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()),
850 llvm::formatv(
"{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(),
851 op.getTileHeight(), op.getTileWidth(), op.getVBlocks())
853 std::string prefetchCode(
"");
856 funcName = llvm::formatv(
"_Z{0}{1}PU3AS1viiiDv2_i{2}{3}", funcName.size(),
857 funcName, prefetchCode, bitWidthId)
859 SmallVector<Type> argTypes;
860 for (
auto arg : args) {
861 argTypes.push_back(arg.getType());
863 createDeviceFunctionCall(
864 rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
865 argTypes, args, paramAttrs, funcAttr, op.getOperation());
867 if constexpr (isLoad)
869 op, LLVM::LoadOp::create(rewriter, loc, vecType, spvLoadDstPtr));
871 rewriter.eraseOp(op);
876template <
typename OpType>
877class BlockLoadStore1DToOCLPattern :
public OpConversionPattern<OpType> {
878 using OpConversionPattern<OpType>::OpConversionPattern;
880 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
881 ConversionPatternRewriter &rewriter)
const override {
882 constexpr bool isStore = std::is_same_v<OpType, xevm::BlockStoreOp>;
883 auto loc = op.getLoc();
884 auto *moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
889 std::string funcName{
"intel_sub_group_block_"};
892 if constexpr (isStore) {
893 funcName +=
"write_u";
894 valOrResTy = op.getVal().getType();
896 funcName +=
"read_u";
897 valOrResTy = op.getType();
900 VectorType vecTy = dyn_cast<VectorType>(valOrResTy);
901 Type elemType = vecTy ? vecTy.getElementType() : valOrResTy;
902 funcName += getTypeMangling(elemType);
904 funcName += std::to_string(vecTy.getNumElements());
905 SmallVector<Type, 2> argTypes{};
909 SmallVector<bool, 2> isUnsigned{};
913 SmallVector<Value, 2> args{};
914 args.push_back(op.getPtr());
915 argTypes.push_back(op.getPtr().getType());
916 isUnsigned.push_back(
true);
919 applyCacheControlAnnotation(rewriter, loc, op, args, moduleOp,
923 argTypes[0] = args[0].getType();
926 if constexpr (isStore) {
927 args.push_back(op.getVal());
928 argTypes.push_back(op.getVal().getType());
929 isUnsigned.push_back(
true);
930 retType = LLVM::LLVMVoidType::get(rewriter.getContext());
932 retType = valOrResTy;
934 funcName = std::string(
"_Z") + std::to_string(funcName.size()) + funcName +
936 std::to_string(op.getPtr().getType().getAddressSpace());
937 funcName += getTypeMangling(elemType,
true);
938 if constexpr (isStore)
939 funcName += getTypeMangling(valOrResTy,
true);
940 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
943 createDeviceFunctionCall(rewriter, funcName, retType, argTypes, args,
944 {}, funcAttr, op.getOperation());
946 if constexpr (isStore)
947 rewriter.eraseOp(op);
949 rewriter.replaceOp(op, call->getResult(0));
954template <
typename OpType>
955class LLVMLoadStoreToOCLPattern :
public OpConversionPattern<OpType> {
956 using OpConversionPattern<OpType>::OpConversionPattern;
958 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
959 ConversionPatternRewriter &rewriter)
const override {
960 if (!op->hasAttr(
"cache_control"))
963 auto *moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
964 std::optional<ArrayAttr> optCacheControls =
965 getCacheControlMetadata(rewriter, op);
966 if (!optCacheControls) {
967 rewriter.modifyOpInPlace(op, [&]() { op->removeAttr(
"cache_control"); });
972 constexpr bool isStore = std::is_same_v<OpType, LLVM::StoreOp>;
973 unsigned ptrIdx = isStore ? 1 : 0;
974 Value ptr = op->getOperand(ptrIdx);
977 Value annotatedPtr = annotatePtrWithCacheControl(
978 rewriter, op->getLoc(), ptr, *optCacheControls, moduleOp);
981 rewriter.modifyOpInPlace(op, [&]() {
982 op->setOperand(ptrIdx, annotatedPtr);
983 op->removeAttr(
"cache_control");
1016static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdXOp) {
1017 return {
"get_local_id", 0};
1019static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdYOp) {
1020 return {
"get_local_id", 1};
1022static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdZOp) {
1023 return {
"get_local_id", 2};
1025static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimXOp) {
1026 return {
"get_local_size", 0};
1028static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimYOp) {
1029 return {
"get_local_size", 1};
1031static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimZOp) {
1032 return {
"get_local_size", 2};
1034static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdXOp) {
1035 return {
"get_group_id", 0};
1037static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdYOp) {
1038 return {
"get_group_id", 1};
1040static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdZOp) {
1041 return {
"get_group_id", 2};
1043static std::pair<StringRef, int64_t> getConfig(xevm::GridDimXOp) {
1044 return {
"get_num_groups", 0};
1046static std::pair<StringRef, int64_t> getConfig(xevm::GridDimYOp) {
1047 return {
"get_num_groups", 1};
1049static std::pair<StringRef, int64_t> getConfig(xevm::GridDimZOp) {
1050 return {
"get_num_groups", 2};
1054template <
typename OpType>
1055class LaunchConfigOpToOCLPattern :
public OpConversionPattern<OpType> {
1056 using OpConversionPattern<OpType>::OpConversionPattern;
1058 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
1059 ConversionPatternRewriter &rewriter)
const override {
1060 Location loc = op->getLoc();
1061 auto [baseName, dim] = getConfig(op);
1062 Type dimTy = rewriter.getI32Type();
1063 Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
1064 static_cast<int64_t
>(dim));
1065 std::string func = mangle(baseName, {dimTy}, {
true});
1066 Type resTy = op.getType();
1068 createDeviceFunctionCall(rewriter, func, resTy, {dimTy}, {dimVal}, {},
1069 noUnwindWillReturnAttrs, op.getOperation());
1070 constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
1071 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
1077 call.setMemoryEffectsAttr(memAttr);
1078 rewriter.replaceOp(op, call);
1095static StringRef getConfig(xevm::LaneIdOp) {
return "get_sub_group_local_id"; }
1096static StringRef getConfig(xevm::SubgroupIdOp) {
return "get_sub_group_id"; }
1097static StringRef getConfig(xevm::SubgroupSizeOp) {
1098 return "get_sub_group_size";
1100template <
typename OpType>
1101class SubgroupOpWorkitemOpToOCLPattern :
public OpConversionPattern<OpType> {
1102 using OpConversionPattern<OpType>::OpConversionPattern;
1104 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
1105 ConversionPatternRewriter &rewriter)
const override {
1106 std::string func = mangle(getConfig(op).str(), {});
1107 Type resTy = op.getType();
1109 createDeviceFunctionCall(rewriter, func, resTy, {}, {}, {},
1110 noUnwindWillReturnAttrs, op.getOperation());
1111 constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
1112 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
1118 call.setMemoryEffectsAttr(memAttr);
1119 rewriter.replaceOp(op, call);
1124class TruncfToOCLPattern :
public OpConversionPattern<TruncfOp> {
1125 using OpConversionPattern::OpConversionPattern;
1127 matchAndRewrite(TruncfOp op, TruncfOp::Adaptor adaptor,
1128 ConversionPatternRewriter &rewriter)
const override {
1130 auto srcEtype = op.getSrcEtype().getEtype();
1131 auto dstEtype = op.getDstEtype().getEtype();
1150 auto vecSrcTy = dyn_cast<VectorType>(op.getSrc().getType());
1152 return rewriter.notifyMatchFailure(op,
"Scalar src is not supported.");
1154 if (vecSrcTy.getNumElements() != 16)
1155 return rewriter.notifyMatchFailure(
1156 op,
"Only vector src of 16 elements is supported");
1157 auto vecDstTy = dyn_cast<VectorType>(op.getDst().getType());
1159 return rewriter.notifyMatchFailure(op,
"Scalar dst is not supported.");
1160 Value src = op.getSrc();
1161 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
1162 LLVM::ModRefInfo::NoModRef,
1163 LLVM::ModRefInfo::NoModRef,
1164 LLVM::ModRefInfo::NoModRef,
1165 LLVM::ModRefInfo::NoModRef,
1166 LLVM::ModRefInfo::NoModRef,
1167 LLVM::ModRefInfo::NoModRef);
1168 auto funcAttrs = convergentNoUnwindWillReturnAttrs;
1169 funcAttrs.memEffectsAttr = memAttr;
1172 if (dstEtype == TruncfDstElemTypes::E2M1) {
1179 Value cast = LLVM::BitcastOp::create(
1180 rewriter, op.getLoc(), VectorType::get(8, rewriter.getI32Type()),
1183 std::string fnName =
"__builtin_IB_dnscl_";
1184 fnName += (srcEtype == TruncfSrcElemTypes::F16) ?
"hf16" :
"bf16";
1185 auto genDnscl = [&](Value input, Value idx0, Value idx1, Value dstTy,
1186 Value mode) -> Value {
1188 LLVM::ExtractElementOp::create(rewriter, op.getLoc(), input, idx0)
1191 LLVM::ExtractElementOp::create(rewriter, op.getLoc(), input, idx1)
1194 dstTy.getType(), mode.getType()};
1195 SmallVector<Value> args{arg1, arg2, dstTy, mode};
1196 Value dnscl = createDeviceFunctionCall(
1197 rewriter, fnName, rewriter.getI32Type(), argTypes,
1198 args, {}, funcAttrs, op.getOperation())
1203 Value zero = LLVM::ConstantOp::create(rewriter, op.getLoc(),
1204 rewriter.getI32Type(), 0);
1205 Value one = LLVM::ConstantOp::create(rewriter, op.getLoc(),
1206 rewriter.getI32Type(), 1);
1207 Value two = LLVM::ConstantOp::create(rewriter, op.getLoc(),
1208 rewriter.getI32Type(), 2);
1209 Value three = LLVM::ConstantOp::create(rewriter, op.getLoc(),
1210 rewriter.getI32Type(), 3);
1211 Value even = genDnscl(cast, zero, two, one, zero);
1212 Value odd = genDnscl(cast, one, three, one, two);
1213 Value firstHalf = LLVM::OrOp::create(rewriter, op.getLoc(), even, odd);
1214 Value four = LLVM::ConstantOp::create(rewriter, op.getLoc(),
1215 rewriter.getI32Type(), 4);
1216 Value five = LLVM::ConstantOp::create(rewriter, op.getLoc(),
1217 rewriter.getI32Type(), 5);
1218 Value six = LLVM::ConstantOp::create(rewriter, op.getLoc(),
1219 rewriter.getI32Type(), 6);
1220 Value seven = LLVM::ConstantOp::create(rewriter, op.getLoc(),
1221 rewriter.getI32Type(), 7);
1222 even = genDnscl(cast, four, six, one, zero);
1223 odd = genDnscl(cast, five, seven, one, two);
1224 Value secondHalf = LLVM::OrOp::create(rewriter, op.getLoc(), even, odd);
1227 Value combined = LLVM::UndefOp::create(
1228 rewriter, op.getLoc(), VectorType::get(2, rewriter.getI32Type()));
1229 combined = LLVM::InsertElementOp::create(rewriter, op.getLoc(), combined,
1232 combined = LLVM::InsertElementOp::create(rewriter, op.getLoc(), combined,
1236 LLVM::BitcastOp::create(rewriter, op.getLoc(), vecDstTy, combined);
1237 rewriter.replaceOp(op,
result);
1244 if (srcEtype == TruncfSrcElemTypes::BF16) {
1247 src = LLVM::BitcastOp::create(
1248 rewriter, op.getLoc(),
1249 VectorType::get(vecSrcTy.getShape(), rewriter.getI16Type()), src);
1250 std::string fnName =
"__builtin_IB_bftof_16";
1251 SmallVector<Type> argTypes{src.
getType()};
1252 SmallVector<Value> args{src};
1253 Type resTy = VectorType::get(vecSrcTy.getShape(), rewriter.getF32Type());
1254 src = createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args,
1255 {}, funcAttrs, op.getOperation())
1259 std::string truncFnName =
"convert_half16";
1260 SmallVector<Type> truncArgTypes{src.
getType()};
1261 SmallVector<Value> truncArgs{src};
1262 truncFnName = mangle(truncFnName, truncArgTypes);
1263 resTy = VectorType::get(vecSrcTy.getShape(), rewriter.getF16Type());
1265 createDeviceFunctionCall(rewriter, truncFnName, resTy, truncArgTypes,
1266 truncArgs, {}, funcAttrs, op.getOperation())
1269 if (dstEtype == TruncfDstElemTypes::BF8) {
1271 std::string fnName =
"__builtin_IB_hftobf8_16";
1272 SmallVector<Type> argTypes{src.
getType()};
1273 SmallVector<Value> args{src};
1275 createDeviceFunctionCall(rewriter, fnName, vecDstTy, argTypes, args,
1276 {}, funcAttrs, op.getOperation())
1279 rewriter.replaceOp(op,
result);
1280 }
else if (dstEtype == TruncfDstElemTypes::F8) {
1282 std::string fnName =
"__builtin_IB_hftohf8_16";
1283 SmallVector<Type> argTypes{src.
getType()};
1284 SmallVector<Value> args{src};
1286 createDeviceFunctionCall(rewriter, fnName, vecDstTy, argTypes, args,
1287 {}, funcAttrs, op.getOperation())
1290 rewriter.replaceOp(op,
result);
1292 return rewriter.notifyMatchFailure(
1293 op,
"Unsupported src, dst element type pair.");
1299class MMAMxToOCLPattern :
public OpConversionPattern<MMAMxOp> {
1300 using OpConversionPattern::OpConversionPattern;
1302 matchAndRewrite(MMAMxOp op, MMAMxOp::Adaptor adaptor,
1303 ConversionPatternRewriter &rewriter)
const override {
1305 return rewriter.notifyMatchFailure(op,
"OCL requires C operand");
1307 auto precisionC = op.getTypes().getC();
1308 auto precisionD = op.getTypes().getD();
1309 if (precisionC != precisionD) {
1310 return rewriter.notifyMatchFailure(op,
"type of C and D need to match");
1313 constexpr uint32_t bitWidthPackedA{16};
1314 constexpr uint32_t bitWidthPackedB{32};
1315 auto loc = op.getLoc();
1317 auto castIfNeeded = [&](Value val, Type packedType) -> Value {
1318 VectorType origTy = cast<VectorType>(val.
getType());
1319 const uint32_t vecBitSize =
1320 origTy.getNumElements() *
1321 origTy.getElementType().getIntOrFloatBitWidth();
1322 VectorType newTy = VectorType::get(
1323 vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
1324 if (origTy != newTy)
1325 val = LLVM::BitcastOp::create(rewriter, loc, newTy, val);
1329 Value a = op.getA();
1330 Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
1331 ? cast<Type>(rewriter.getF32Type())
1332 : rewriter.getIntegerType(bitWidthPackedA);
1333 a = castIfNeeded(a, packedAType);
1335 Value
b = op.getB();
1336 Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
1337 ? cast<Type>(rewriter.getF32Type())
1338 : rewriter.getIntegerType(bitWidthPackedB);
1339 b = castIfNeeded(
b, packedBType);
1341 Value c = op.getC();
1342 VectorType cOrigTy = cast<VectorType>(c.
getType());
1343 VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
1344 assert(cOrigTy == resOrigTy &&
"Accumulator and result type mismatch");
1347 cOrigTy.getElementType().isBF16()
1348 ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
1350 VectorType resTy = cTy;
1352 c = LLVM::BitcastOp::create(rewriter, loc, cTy, c);
1354 std::string fnName =
1355 llvm::formatv(
"__builtin_IB_sub_group16_bdpas_{0}_{1}_{2}_{3}_8_8",
1356 builtinElemType(op.getTypes().getD()),
1357 builtinElemType(op.getTypes().getC()),
1358 builtinElemType(op.getTypes().getA()),
1359 builtinElemType(op.getTypes().getB()))
1361 auto scaleA = op.getScaleA();
1362 auto scaleB = op.getScaleB();
1363 SmallVector<Type> argTypes{cTy, a.
getType(),
b.getType(), scaleA.getType(),
1365 SmallVector<Value> args{c, a,
b, scaleA, scaleB};
1367 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
1368 LLVM::ModRefInfo::NoModRef,
1369 LLVM::ModRefInfo::NoModRef,
1370 LLVM::ModRefInfo::NoModRef,
1371 LLVM::ModRefInfo::NoModRef,
1372 LLVM::ModRefInfo::NoModRef,
1373 LLVM::ModRefInfo::NoModRef);
1374 auto funcAttrs = convergentNoUnwindWillReturnAttrs;
1375 funcAttrs.memEffectsAttr = memAttr;
1377 createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
1378 funcAttrs, op.getOperation())
1381 if (resOrigTy != resTy)
1382 result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy,
result);
1384 rewriter.replaceOp(op,
result);
1389class AllocaToGlobalPattern :
public OpConversionPattern<LLVM::AllocaOp> {
1390 using OpConversionPattern::OpConversionPattern;
1392 matchAndRewrite(LLVM::AllocaOp op, LLVM::AllocaOp::Adaptor adaptor,
1393 ConversionPatternRewriter &rewriter)
const override {
1394 auto ptrType = cast<LLVM::LLVMPointerType>(op.getType());
1395 auto addrSpace = ptrType.getAddressSpace();
1398 auto symTable = op->getParentWithTrait<OpTrait::SymbolTable>();
1402 if (ModuleOp mod = dyn_cast<ModuleOp>(*symTable)) {
1403 moduleBody = mod.getBody();
1404 }
else if (gpu::GPUModuleOp gpuMod =
1405 dyn_cast<gpu::GPUModuleOp>(*symTable)) {
1406 moduleBody = gpuMod.getBody();
1410 auto val = op.getArraySize();
1414 auto loc = op.getLoc();
1415 auto globalType = LLVM::LLVMArrayType::get(
1416 rewriter.getContext(), op.getElemType(), cst.getZExtValue());
1417 LLVM::GlobalOp globalVar;
1419 OpBuilder::InsertionGuard guard(rewriter);
1420 rewriter.setInsertionPointToStart(moduleBody);
1421 auto alignment = op.getAlignment();
1422 globalVar = LLVM::GlobalOp::create(
1423 rewriter, loc, globalType,
false,
1424 LLVM::Linkage::Internal,
1425 std::string(
"__global_alloca_") +
1426 std::to_string(getNextGlobalIdx()),
1428 alignment ? *alignment : 0, addrSpace);
1430 rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, globalVar);
1435 static unsigned getNextGlobalIdx() {
1436 static unsigned globalIdx = 0;
1447static bool isExtractingContiguousSlice(LLVM::ShuffleVectorOp op) {
1448 if (op.getV1() != op.getV2())
1450 auto maskAttr = op.getMask();
1452 int64_t sourceSize = op.getV1().getType().getNumElements();
1453 if (maskSize > sourceSize)
1455 int64_t firstIndex = maskAttr[0];
1456 for (
int64_t i = 1; i < maskSize; ++i) {
1458 if (
index != firstIndex + i)
1460 if (
index >= sourceSize)
1474class HandleVectorExtractPattern
1476 using OpRewritePattern<LLVM::ShuffleVectorOp>::OpRewritePattern;
1478 void initialize() { setHasBoundedRewriteRecursion(); }
1480 LogicalResult matchAndRewrite(LLVM::ShuffleVectorOp op,
1481 PatternRewriter &rewriter)
const override {
1483 if (!isExtractingContiguousSlice(op))
1486 auto mask = op.getMask();
1487 auto loc = op.getLoc();
1488 auto ty = op.getType();
1490 auto src = op.getV1();
1493 if (isa<LLVM::FPExtOp>(srcOp) || isa<LLVM::FPTruncOp>(srcOp)) {
1494 Value srcInput = srcOp->getOperand(0);
1496 auto srcVecTy = dyn_cast<VectorType>(srcInput.
getType());
1497 auto newShuffleVecTy =
1498 VectorType::get(mask.size(), srcVecTy.getElementType());
1499 auto newShuffle = LLVM::ShuffleVectorOp::create(
1500 rewriter, loc, newShuffleVecTy, srcInput, srcInput, mask);
1503 if (isa<LLVM::FPExtOp>(srcOp)) {
1504 newUnaryOp = LLVM::FPExtOp::create(rewriter, loc, ty, newShuffle);
1506 newUnaryOp = LLVM::FPTruncOp::create(rewriter, loc, ty, newShuffle);
1509 }
else if (isa<LLVM::BitcastOp>(srcOp)) {
1510 Value srcInput = srcOp->getOperand(0);
1512 auto srcInputVecTy = dyn_cast<VectorType>(srcInput.
getType());
1513 auto srcInputSize = srcInputVecTy.getNumElements();
1514 auto srcResVecTy = dyn_cast<VectorType>(srcOp->getResult(0).getType());
1515 auto srcResSize = srcResVecTy.getNumElements();
1516 auto maskSize =
static_cast<int32_t
>(mask.size());
1517 if (srcInputSize > srcResSize) {
1520 if (srcResSize % srcInputSize != 0) {
1523 auto maskScale = srcResSize / srcInputSize;
1524 if (maskScale != 1) {
1525 if (mask[0] % maskScale != 0) {
1529 SmallVector<int32_t> newMask;
1530 int32_t newMaskSize = maskSize / maskScale;
1531 int32_t maskStart = mask[0] / maskScale;
1532 for (int32_t i = 0; i < newMaskSize; ++i) {
1533 newMask.push_back(maskStart + i);
1537 auto newShuffleVecTy =
1538 VectorType::get(srcInputSize, srcInputVecTy.getElementType());
1539 auto newShuffle = LLVM::ShuffleVectorOp::create(
1540 rewriter, loc, newShuffleVecTy, srcInput, srcInput, mask);
1543 LLVM::BitcastOp::create(rewriter, loc, ty, newShuffle);
1545 }
else if (isa<LLVM::ShuffleVectorOp>(srcOp)) {
1550 auto srcShuffle = cast<LLVM::ShuffleVectorOp>(srcOp);
1551 if (!isExtractingContiguousSlice(srcShuffle))
1553 auto srcMask = srcShuffle.getMask();
1554 SmallVector<int32_t> combinedMask;
1555 for (
auto index : mask) {
1556 combinedMask.push_back(srcMask[index]);
1558 auto newShuffle = LLVM::ShuffleVectorOp::create(
1559 rewriter, loc, ty, srcShuffle.getV1(), srcShuffle.getV1(),
1562 }
else if (isa<LLVM::LoadOp>(srcOp)) {
1564 auto loadOp = cast<LLVM::LoadOp>(srcOp);
1565 auto loadPtr = loadOp.getAddr();
1566 auto loadAddrSpace = loadPtr.getType().getAddressSpace();
1567 if (loadAddrSpace != 0)
1569 auto loadTy = dyn_cast<VectorType>(loadOp.getType());
1570 auto elemTy = loadTy.getElementType();
1571 auto firstIndex = mask[0];
1572 auto newVecTy = VectorType::get(mask.size(), elemTy);
1575 auto newPtr = LLVM::GEPOp::create(
1577 LLVM::LLVMPointerType::get(rewriter.
getContext(), loadAddrSpace),
1578 elemTy, loadPtr, ArrayRef<LLVM::GEPArg>{firstIndex});
1579 auto newLoad = LLVM::LoadOp::create(rewriter, loc, newVecTy, newPtr);
1582 auto newLoad = LLVM::LoadOp::create(rewriter, loc, newVecTy, loadPtr);
1600struct ConvertXeVMToLLVMPass
1604 void getDependentDialects(DialectRegistry ®istry)
const override {
1605 registry.
insert<LLVM::LLVMDialect, XeVMDialect>();
1608 void runOnOperation()
override {
1612 if (
failed(applyPartialConversion(getOperation(),
target,
1613 std::move(patterns))))
1614 signalPassFailure();
1618 RewritePatternSet vectorPatterns(&
getContext());
1619 vectorPatterns.add<HandleVectorExtractPattern>(&
getContext());
1620 GreedyRewriteConfig config{};
1625 config.enableFolding(
false);
1642 target.addDynamicallyLegalDialect<LLVM::LLVMDialect>([](
Operation *op) {
1646 if (isa<LLVM::AllocaOp>(op)) {
1647 LLVM::AllocaOp aOp = cast<LLVM::AllocaOp>(op);
1648 LLVM::LLVMPointerType pTy = cast<LLVM::LLVMPointerType>(aOp.getType());
1649 auto addrSpace = pTy.getAddressSpace();
1650 return addrSpace != 3;
1653 return !op->hasAttr(
"cache_control");
1655 target.addIllegalDialect<XeVMDialect>();
1656 patterns.
add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
1657 LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
1658 LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
1659 MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern,
1660 LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
1661 LLVMLoadStoreToOCLPattern<LLVM::StoreOp>,
1662 BlockLoadStore1DToOCLPattern<BlockLoadOp>,
1663 BlockLoadStore1DToOCLPattern<BlockStoreOp>,
1664 LaunchConfigOpToOCLPattern<WorkitemIdXOp>,
1665 LaunchConfigOpToOCLPattern<WorkitemIdYOp>,
1666 LaunchConfigOpToOCLPattern<WorkitemIdZOp>,
1667 LaunchConfigOpToOCLPattern<WorkgroupDimXOp>,
1668 LaunchConfigOpToOCLPattern<WorkgroupDimYOp>,
1669 LaunchConfigOpToOCLPattern<WorkgroupDimZOp>,
1670 LaunchConfigOpToOCLPattern<WorkgroupIdXOp>,
1671 LaunchConfigOpToOCLPattern<WorkgroupIdYOp>,
1672 LaunchConfigOpToOCLPattern<WorkgroupIdZOp>,
1673 LaunchConfigOpToOCLPattern<GridDimXOp>,
1674 LaunchConfigOpToOCLPattern<GridDimYOp>,
1675 LaunchConfigOpToOCLPattern<GridDimZOp>,
1676 SubgroupOpWorkitemOpToOCLPattern<LaneIdOp>,
1677 SubgroupOpWorkitemOpToOCLPattern<SubgroupIdOp>,
1678 SubgroupOpWorkitemOpToOCLPattern<SubgroupSizeOp>,
1679 TruncfToOCLPattern, MMAMxToOCLPattern, AllocaToGlobalPattern>(
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
Attributes are known-constant values of operations.
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes={}, Type resultType={}, bool isVarArg=false, bool isReserved=false, SymbolTableCollection *symbolTables=nullptr)
Create a FuncOp with signature resultType(paramTypes) and name name`.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateXeVMToLLVMConversionPatterns(ConversionTarget &target, RewritePatternSet &patterns)
llvm::TypeSwitch< T, ResultT > TypeSwitch
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...