18#include "llvm/ADT/ArrayRef.h"
19#include "llvm/Support/FormatVariadic.h"
26#include "llvm/ADT/TypeSwitch.h"
29#define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
30#include "mlir/Conversion/Passes.h.inc"
38struct LLVMFuncAttributeOptions {
39 bool isConvergent =
false;
40 bool isNoUnwind =
false;
41 bool isWillReturn =
false;
42 LLVM::MemoryEffectsAttr memEffectsAttr{};
44static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
45 false,
true,
false, {}};
46static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
47 false,
true,
true, {}};
48static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
49 true,
true,
true, {}};
51std::string getTypeMangling(
Type ty,
bool isUnsigned =
false) {
53 .Case([isUnsigned](VectorType ty) -> std::string {
54 return "Dv" + std::to_string(ty.getNumElements()) +
"_" +
55 getTypeMangling(ty.getElementType(), isUnsigned);
57 .Case([](Float16Type) -> std::string {
return "Dh"; })
58 .Case([](Float32Type) -> std::string {
return "f"; })
59 .Case([](Float64Type) -> std::string {
return "d"; })
60 .Case([isUnsigned](IntegerType ty) -> std::string {
61 switch (ty.getWidth()) {
63 return isUnsigned ?
"h" :
"c";
65 return isUnsigned ?
"t" :
"s";
67 return isUnsigned ?
"j" :
"i";
69 return isUnsigned ?
"m" :
"l";
71 llvm_unreachable(
"unhandled integer type");
74 .DefaultUnreachable(
"unhandled type for mangling");
79 assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
80 "Signedness info doesn't match");
82 llvm::raw_string_ostream os(s);
83 llvm::SmallDenseMap<Type, unsigned> substitutions;
84 os <<
"_Z" << baseName.size() << baseName;
85 for (
auto [idx, type] : llvm::enumerate(types)) {
86 auto it = substitutions.find(type);
87 if (it != substitutions.end()) {
90 if (
unsigned firstIdx = it->getSecond(); firstIdx > 0)
94 if (!type.isIntOrFloat())
95 substitutions[type] = substitutions.size();
96 os << getTypeMangling(type, isUnsigned.empty() ?
false : isUnsigned[idx]);
102std::string builtinElemType(ElemType elemType) {
115 return stringifyElemType(elemType).str();
119static int32_t getL1CacheControl(LoadCacheControl cc) {
122 case LoadCacheControl::USE_DEFAULT:
125 case LoadCacheControl::L1C_L2UC_L3UC:
126 case LoadCacheControl::L1C_L2UC_L3C:
127 case LoadCacheControl::L1C_L2C_L3UC:
128 case LoadCacheControl::L1C_L2C_L3C:
131 case LoadCacheControl::L1S_L2UC_L3UC:
132 case LoadCacheControl::L1S_L2UC_L3C:
133 case LoadCacheControl::L1S_L2C_L3UC:
134 case LoadCacheControl::L1S_L2C_L3C:
137 case LoadCacheControl::INVALIDATE_READ:
146static int32_t getL1CacheControl(StoreCacheControl cc) {
149 case StoreCacheControl::USE_DEFAULT:
152 case StoreCacheControl::L1WT_L2UC_L3UC:
153 case StoreCacheControl::L1WT_L2UC_L3WB:
154 case StoreCacheControl::L1WT_L2WB_L3UC:
155 case StoreCacheControl::L1WT_L2WB_L3WB:
158 case StoreCacheControl::L1WB_L2UC_L3UC:
159 case StoreCacheControl::L1WB_L2WB_L3UC:
160 case StoreCacheControl::L1WB_L2UC_L3WB:
163 case StoreCacheControl::L1S_L2UC_L3UC:
164 case StoreCacheControl::L1S_L2UC_L3WB:
165 case StoreCacheControl::L1S_L2WB_L3UC:
166 case StoreCacheControl::L1S_L2WB_L3WB:
175static int32_t getL3CacheControl(LoadCacheControl cc) {
178 case LoadCacheControl::USE_DEFAULT:
181 case LoadCacheControl::L1UC_L2UC_L3C:
182 case LoadCacheControl::L1UC_L2C_L3C:
183 case LoadCacheControl::L1C_L2UC_L3C:
184 case LoadCacheControl::L1C_L2C_L3C:
185 case LoadCacheControl::L1S_L2UC_L3C:
186 case LoadCacheControl::L1S_L2C_L3C:
189 case LoadCacheControl::INVALIDATE_READ:
198static int32_t getL3CacheControl(StoreCacheControl cc) {
201 case StoreCacheControl::USE_DEFAULT:
204 case StoreCacheControl::L1UC_L2UC_L3WB:
205 case StoreCacheControl::L1UC_L2WB_L3WB:
206 case StoreCacheControl::L1WT_L2UC_L3WB:
207 case StoreCacheControl::L1WT_L2WB_L3WB:
208 case StoreCacheControl::L1S_L2UC_L3WB:
209 case StoreCacheControl::L1S_L2WB_L3WB:
210 case StoreCacheControl::L1WB_L2UC_L3WB:
219static std::optional<LoadCacheControl> getCacheControl(PrefetchOp op) {
220 return op.getCacheControl();
223static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) {
224 return op.getCacheControl();
227static std::optional<LoadCacheControl> getCacheControl(BlockLoadOp op) {
228 return op.getCacheControl();
231static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) {
232 return op.getCacheControl();
235static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) {
236 return op.getCacheControl();
239static std::optional<StoreCacheControl> getCacheControl(BlockStoreOp op) {
240 return op.getCacheControl();
243static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) {
244 if (op->hasAttr(
"cache_control")) {
245 auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>(
"cache_control");
248 return std::optional<LoadCacheControl>(attr.getValue());
253static std::optional<StoreCacheControl> getCacheControl(LLVM::StoreOp op) {
254 if (op->hasAttr(
"cache_control")) {
255 auto attr = op->getAttrOfType<xevm::StoreCacheControlAttr>(
"cache_control");
258 return std::optional<StoreCacheControl>(attr.getValue());
263template <
typename OpType>
264int32_t getL1CacheControl(OpType op) {
265 return getL1CacheControl(*getCacheControl(op));
268template <
typename OpType>
269int32_t getL3CacheControl(OpType op) {
270 return getL3CacheControl(*getCacheControl(op));
273template <
typename OpType>
274static std::optional<ArrayAttr>
275getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
276 if (!getCacheControl(op))
279 constexpr int32_t decorationCacheControlArity{3};
280 constexpr int32_t loadCacheControlKey{6442};
281 constexpr int32_t storeCacheControlKey{6443};
282 constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
283 std::is_same_v<OpType, BlockPrefetch2dOp> ||
284 std::is_same_v<OpType, LLVM::LoadOp> ||
285 std::is_same_v<OpType, BlockLoadOp> ||
286 std::is_same_v<OpType, PrefetchOp>;
292 assert(((getL1CacheControl<OpType>(op) == -1) ==
293 (getL3CacheControl<OpType>(op) == -1)) &&
294 "If one of L1 or L3 cache control is USE_DEFAULT, both must be "
297 if (getL1CacheControl<OpType>(op) == -1 &&
298 getL3CacheControl<OpType>(op) == -1)
300 const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
302 controlKey, 0, getL1CacheControl<OpType>(op)};
304 controlKey, 1, getL3CacheControl<OpType>(op)};
305 auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
306 auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
309 return rewriter.getArrayAttr(combinedAttrs);
331 llvm::StringMap<bool> seen;
334 auto arr = dyn_cast<ArrayAttr>(a);
338 auto vals = arr.getValue();
339 assert(vals.size() == 3 &&
340 "Expected exactly 3 integer values (Token, CacheLevel, "
341 "ControlValue) in cache control attribute.");
343 auto tokenAttr = dyn_cast<IntegerAttr>(vals[0]);
344 auto secondAttr = dyn_cast<IntegerAttr>(vals[1]);
345 auto thirdAttr = dyn_cast<IntegerAttr>(vals[2]);
347 if (!tokenAttr || !secondAttr || !thirdAttr)
353 llvm::formatv(
"{{{0}:\"{1},{2}\"}", tokenAttr.getValue().getZExtValue(),
354 secondAttr.getValue().getZExtValue(),
355 thirdAttr.getValue().getZExtValue());
358 if (!seen.insert({entry, true}).second)
361 payloads.push_back(std::move(entry));
366static std::atomic<uint64_t> globalNameCounter{0};
371static Value createMetadataStringPtr(ConversionPatternRewriter &rewriter,
373 StringRef value, StringRef nameHint) {
375 std::string strWithNull = value.str();
376 strWithNull.push_back(
'\0');
377 StringRef strRef(strWithNull.data(), strWithNull.size());
379 auto as1PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
383 if (
auto existingGlobal = dyn_cast<LLVM::GlobalOp>(&op)) {
384 if (!existingGlobal.getSection() ||
385 *existingGlobal.getSection() !=
"llvm.metadata")
388 dyn_cast_or_null<StringAttr>(existingGlobal.getValueOrNull())) {
389 if (strAttr.getValue() == strRef) {
390 return LLVM::AddressOfOp::create(rewriter, loc, as1PtrTy,
391 existingGlobal.getSymName());
398 auto i8Type = rewriter.getI8Type();
399 auto arrayType = LLVM::LLVMArrayType::get(i8Type, strWithNull.size());
400 std::string globalName =
401 llvm::formatv(
"{0}.{1}", nameHint,
402 globalNameCounter.fetch_add(1, std::memory_order_relaxed))
407 rewriter.setInsertionPointToStart(&moduleOp->
getRegion(0).
front());
410 LLVM::GlobalOp::create(rewriter, loc, arrayType,
411 true, LLVM::Linkage::Private,
412 globalName, rewriter.getStringAttr(strRef));
413 globalOp.setSection(StringRef(
"llvm.metadata"));
414 globalOp.setUnnamedAddr(LLVM::UnnamedAddr::Global);
415 globalOp.setAlignment(1);
416 globalOp.setAddrSpace(1);
420 return LLVM::AddressOfOp::create(rewriter, loc, as1PtrTy, globalName);
443static Value annotatePtrWithCacheControl(ConversionPatternRewriter &rewriter,
448 buildCacheControlPayloads(cacheControls.getValue());
449 if (payloads.empty())
452 auto ptrType = cast<LLVM::LLVMPointerType>(
ptr.getType());
453 auto as1PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
454 auto i32Ty = rewriter.getI32Type();
458 createMetadataStringPtr(rewriter, moduleOp, loc,
"",
".str.file");
459 Value lineVal = LLVM::ConstantOp::create(rewriter, loc, i32Ty, 0);
460 Value nullAS1 = LLVM::ZeroOp::create(rewriter, loc, as1PtrTy);
465 for (
const std::string &payload : payloads) {
466 Value annStr = createMetadataStringPtr(rewriter, moduleOp, loc, payload,
467 ".str.cachecontrol");
468 auto annOp = LLVM::PtrAnnotation::create(rewriter, loc, ptrType, curPtr,
469 annStr, fileStr, lineVal, nullAS1);
470 curPtr = annOp.getResult();
492template <
typename OpType>
494applyCacheControlAnnotation(ConversionPatternRewriter &rewriter,
Location loc,
496 Operation *moduleOp,
unsigned ptrIdx = 0) {
497 std::optional<ArrayAttr> optCacheControls =
498 getCacheControlMetadata(rewriter, op);
499 if (!optCacheControls)
502 Value annotatedPtr = annotatePtrWithCacheControl(rewriter, loc, args[ptrIdx],
503 *optCacheControls, moduleOp);
504 args[ptrIdx] = annotatedPtr;
511static LLVM::CallOp createDeviceFunctionCall(
512 ConversionPatternRewriter &rewriter, StringRef funcName,
Type retType,
515 LLVMFuncAttributeOptions funcAttributeOptions,
Operation *op) {
517 assert(moduleOp &&
"Expecting module");
522 assert(!
failed(funcOpRes));
523 LLVM::LLVMFuncOp funcOp = funcOpRes.value();
524 funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
525 funcOp.setConvergent(funcAttributeOptions.isConvergent);
526 funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
527 funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
529 if (funcAttributeOptions.memEffectsAttr)
530 funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
532 for (
auto [idx, attrName] : paramAttrs)
533 funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
535 auto callOp = LLVM::CallOp::create(rewriter, loc, funcOp, args);
536 callOp->setAttrs(funcOp->getAttrs());
541static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
543 case xevm::ElemType::F32:
544 case xevm::ElemType::TF32:
546 case xevm::ElemType::BF16:
547 case xevm::ElemType::F16:
549 case xevm::ElemType::U8:
550 case xevm::ElemType::S8:
551 case xevm::ElemType::BF8:
552 case xevm::ElemType::F8:
554 case xevm::ElemType::E2M1:
555 case xevm::ElemType::U4:
556 case xevm::ElemType::S4:
559 llvm_unreachable(
"unsupported xevm::ElemType");
563class MMAToOCLPattern :
public OpConversionPattern<xevm::MMAOp> {
564 using OpConversionPattern::OpConversionPattern;
566 matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor,
567 ConversionPatternRewriter &rewriter)
const override {
569 return rewriter.notifyMatchFailure(op,
"OCL requires C operand");
571 auto precisionA = op.getTypes().getA();
572 auto precisionB = op.getTypes().getB();
573 auto precisionC = op.getTypes().getC();
574 auto precisionD = op.getTypes().getD();
575 if (precisionC != precisionD) {
576 return rewriter.notifyMatchFailure(op,
"type of C and D need to match");
578 if (precisionC != xevm::ElemType::S32 &&
579 precisionC != xevm::ElemType::F32 &&
580 precisionC != xevm::ElemType::F16 &&
581 precisionC != xevm::ElemType::BF16) {
582 return rewriter.notifyMatchFailure(
583 op,
"type of C and D must be S32, F32, F16 or BF16");
585 if (precisionA == xevm::ElemType::S32 ||
586 precisionA == xevm::ElemType::F32) {
587 return rewriter.notifyMatchFailure(op,
"type of A cannot be S32 or F32");
589 if (precisionB == xevm::ElemType::S32 ||
590 precisionB == xevm::ElemType::F32) {
591 return rewriter.notifyMatchFailure(op,
"type of B cannot be S32 or F32");
593 constexpr uint32_t bitWidthPackedA{16};
594 constexpr uint32_t bitWidthPackedB{32};
595 auto loc = op.getLoc();
597 auto castIfNeeded = [&](Value val, Type packedType) -> Value {
598 VectorType origTy = cast<VectorType>(val.
getType());
599 const uint32_t vecBitSize =
600 origTy.getNumElements() *
601 origTy.getElementType().getIntOrFloatBitWidth();
602 VectorType newTy = VectorType::get(
603 vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
605 val = LLVM::BitcastOp::create(rewriter, loc, newTy, val);
610 Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
611 ? cast<Type>(rewriter.getF32Type())
612 : rewriter.getIntegerType(bitWidthPackedA);
613 a = castIfNeeded(a, packedAType);
616 Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
617 ? cast<Type>(rewriter.getF32Type())
618 : rewriter.getIntegerType(bitWidthPackedB);
619 b = castIfNeeded(
b, packedBType);
622 VectorType cOrigTy = cast<VectorType>(c.
getType());
623 VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
624 assert(cOrigTy == resOrigTy &&
"Accumulator and result type mismatch");
627 cOrigTy.getElementType().isBF16()
628 ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
630 VectorType resTy = cTy;
632 c = LLVM::BitcastOp::create(rewriter, loc, cTy, c);
634 constexpr int32_t systolicDepth{8};
636 llvm::formatv(
"intel_sub_group_{0}_{1}_matrix_mad_k{2}",
637 stringifyElemType(op.getTypes().getA()).str(),
638 stringifyElemType(op.getTypes().getB()).str(),
640 getNumOperandsPerDword(op.getTypes().getA()))
642 SmallVector<Type> argTypes{a.
getType(),
b.getType(), cTy};
643 fnName = mangle(fnName, argTypes);
644 SmallVector<Value> args{a,
b, c};
646 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
647 LLVM::ModRefInfo::NoModRef,
648 LLVM::ModRefInfo::NoModRef,
649 LLVM::ModRefInfo::NoModRef,
650 LLVM::ModRefInfo::NoModRef,
651 LLVM::ModRefInfo::NoModRef,
652 LLVM::ModRefInfo::NoModRef);
653 auto funcAttrs = convergentNoUnwindWillReturnAttrs;
654 funcAttrs.memEffectsAttr = memAttr;
656 createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
657 funcAttrs, op.getOperation())
660 if (resOrigTy != resTy)
661 result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy,
result);
663 rewriter.replaceOp(op,
result);
668class PrefetchToOCLPattern :
public OpConversionPattern<PrefetchOp> {
669 using OpConversionPattern::OpConversionPattern;
671 matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor,
672 ConversionPatternRewriter &rewriter)
const override {
673 auto loc = op.getLoc();
676 const std::string fnName{
"_Z8prefetchPU3AS1Kcm"};
678 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), 1);
679 SmallVector<Value> args{op.getPtr(), one};
682 applyCacheControlAnnotation(rewriter, loc, op, args, moduleOp,
685 SmallVector<Type> argTypes;
686 for (
auto arg : args)
687 argTypes.push_back(arg.getType());
688 auto funcAttr = noUnwindAttrs;
689 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
690 LLVM::ModRefInfo::NoModRef,
691 LLVM::ModRefInfo::Ref,
692 LLVM::ModRefInfo::NoModRef,
693 LLVM::ModRefInfo::NoModRef,
694 LLVM::ModRefInfo::NoModRef,
695 LLVM::ModRefInfo::NoModRef);
696 funcAttr.memEffectsAttr = memAttr;
698 createDeviceFunctionCall(rewriter, fnName,
699 LLVM::LLVMVoidType::get(rewriter.getContext()),
700 argTypes, args, {}, funcAttr, op.getOperation());
701 rewriter.eraseOp(op);
706class MemfenceToOCLPattern :
public OpConversionPattern<MemfenceOp> {
707 using OpConversionPattern::OpConversionPattern;
709 matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor,
710 ConversionPatternRewriter &rewriter)
const override {
711 auto loc = op.getLoc();
712 const std::string fnName{
"atomic_work_item_fence"};
713 int memScope, addrSpace;
714 switch (op.getAddrspace()) {
715 case xevm::AddrSpace::SHARED:
718 case xevm::AddrSpace::GLOBAL:
723 return rewriter.notifyMatchFailure(
724 op,
"Fence only supports global and shared address spaces.");
726 switch (op.getScope()) {
727 case xevm::MemScope::WORKGROUP:
730 case xevm::MemScope::DEVICE:
735 return rewriter.notifyMatchFailure(
736 op,
"Fence only supports workgroup and device memory scopes.");
738 Type i32Type = rewriter.getI32Type();
739 Value acqRel = LLVM::ConstantOp::create(rewriter, loc, i32Type, 4);
740 Value memScopeConst =
741 LLVM::ConstantOp::create(rewriter, loc, i32Type, memScope);
742 Value addrSpaceConst =
743 LLVM::ConstantOp::create(rewriter, loc, i32Type, addrSpace);
744 SmallVector<Value> args{addrSpaceConst, acqRel, memScopeConst};
745 SmallVector<Type> argTypes{3, i32Type};
746 createDeviceFunctionCall(rewriter, mangle(fnName, argTypes),
747 LLVM::LLVMVoidType::get(rewriter.getContext()),
748 argTypes, args, {}, noUnwindAttrs,
750 rewriter.eraseOp(op);
754template <
typename OpType>
755class LoadStorePrefetchToOCLPattern :
public OpConversionPattern<OpType> {
756 using OpConversionPattern<OpType>::OpConversionPattern;
758 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
759 ConversionPatternRewriter &rewriter)
const override {
760 constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp>;
761 constexpr bool isPrefetch = std::is_same_v<OpType, BlockPrefetch2dOp>;
763 auto loc = op.getLoc();
764 auto *moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
766 bool packReg =
false;
767 bool transpose =
false;
768 if constexpr (isLoad) {
769 vecType = op.getRes().getType();
770 packReg = op.getPackRegister();
771 transpose = op.getTranspose();
772 }
else if constexpr (!isPrefetch) {
773 vecType = op.getStoredVal().getType();
776 auto i32Type = rewriter.getI32Type();
778 LLVM::UndefOp::create(rewriter, loc, VectorType::get(2, i32Type));
779 Value zero = LLVM::ConstantOp::create(rewriter, loc, i32Type, 0);
780 Value one = LLVM::ConstantOp::create(rewriter, loc, i32Type, 1);
781 byteCoord = LLVM::InsertElementOp::create(
782 rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
783 byteCoord = LLVM::InsertElementOp::create(
784 rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
785 SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
786 op.getBasePitch(), byteCoord};
789 applyCacheControlAnnotation(rewriter, loc, op, args, moduleOp,
792 SmallVector<Type> retTypes;
794 std::string funcName{
"intel_sub_group_2d_block_"};
795 std::string bitWidthId;
796 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
797 SmallVector<std::pair<unsigned, StringRef>, 4> paramAttrs;
798 if constexpr (isPrefetch) {
799 funcName +=
"prefetch";
800 paramAttrs = {std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName())};
801 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
802 LLVM::ModRefInfo::NoModRef,
803 LLVM::ModRefInfo::Ref,
804 LLVM::ModRefInfo::NoModRef,
805 LLVM::ModRefInfo::NoModRef,
806 LLVM::ModRefInfo::NoModRef,
807 LLVM::ModRefInfo::NoModRef);
808 funcAttr = noUnwindAttrs;
809 funcAttr.memEffectsAttr = memAttr;
811 auto vecElemType = vecType.getElementType();
812 auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth();
813 auto vecNumElems = vecType.getNumElements();
819 if (op.getElemSizeInBits() == 8 && op.getTileWidth() == 32) {
820 vecElemBitWidth = 16;
821 vecElemType = rewriter.getI16Type();
822 vecNumElems = vecNumElems / 2;
825 LLVM::ConstantOp::create(rewriter, loc, i32Type, vecNumElems);
826 auto dstOrSrcPtr = LLVM::AllocaOp::create(
827 rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()),
828 vecElemType, numElems);
829 args.push_back(dstOrSrcPtr);
830 if constexpr (isLoad) {
832 bitWidthId = getTypeMangling(vecElemType,
true);
834 funcName +=
"_transform";
836 funcName +=
"_transpose";
837 spvLoadDstPtr = dstOrSrcPtr;
838 retTypes.push_back(vecType);
840 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
841 std::make_pair(0, LLVM::LLVMDialect::getReadonlyAttrName()),
842 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
843 std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()),
847 bitWidthId = (vecElemBitWidth == 32)
849 : ((vecElemBitWidth == 16) ?
"t" :
"h");
850 LLVM::StoreOp::create(rewriter, loc, op.getStoredVal(), dstOrSrcPtr);
852 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
853 std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()),
854 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
855 std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()),
861 llvm::formatv(
"{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(),
862 op.getTileHeight(), op.getTileWidth(), op.getVBlocks())
864 std::string prefetchCode(
"");
867 funcName = llvm::formatv(
"_Z{0}{1}PU3AS1viiiDv2_i{2}{3}", funcName.size(),
868 funcName, prefetchCode, bitWidthId)
870 SmallVector<Type> argTypes;
871 for (
auto arg : args) {
872 argTypes.push_back(arg.getType());
874 createDeviceFunctionCall(
875 rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
876 argTypes, args, paramAttrs, funcAttr, op.getOperation());
878 if constexpr (isLoad)
880 op, LLVM::LoadOp::create(rewriter, loc, vecType, spvLoadDstPtr));
882 rewriter.eraseOp(op);
887template <
typename OpType>
888class BlockLoadStore1DToOCLPattern :
public OpConversionPattern<OpType> {
889 using OpConversionPattern<OpType>::OpConversionPattern;
891 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
892 ConversionPatternRewriter &rewriter)
const override {
893 constexpr bool isStore = std::is_same_v<OpType, xevm::BlockStoreOp>;
894 auto loc = op.getLoc();
895 auto *moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
900 std::string funcName{
"intel_sub_group_block_"};
903 if constexpr (isStore) {
904 funcName +=
"write_u";
905 valOrResTy = op.getVal().getType();
907 funcName +=
"read_u";
908 valOrResTy = op.getType();
911 VectorType vecTy = dyn_cast<VectorType>(valOrResTy);
912 Type elemType = vecTy ? vecTy.getElementType() : valOrResTy;
913 funcName += getTypeMangling(elemType);
915 funcName += std::to_string(vecTy.getNumElements());
916 SmallVector<Type, 2> argTypes{};
920 SmallVector<bool, 2> isUnsigned{};
924 SmallVector<Value, 2> args{};
925 args.push_back(op.getPtr());
926 argTypes.push_back(op.getPtr().getType());
927 isUnsigned.push_back(
true);
930 applyCacheControlAnnotation(rewriter, loc, op, args, moduleOp,
934 argTypes[0] = args[0].getType();
937 if constexpr (isStore) {
938 args.push_back(op.getVal());
939 argTypes.push_back(op.getVal().getType());
940 isUnsigned.push_back(
true);
941 retType = LLVM::LLVMVoidType::get(rewriter.getContext());
943 retType = valOrResTy;
945 funcName = std::string(
"_Z") + std::to_string(funcName.size()) + funcName +
947 std::to_string(op.getPtr().getType().getAddressSpace());
948 funcName += getTypeMangling(elemType,
true);
949 if constexpr (isStore)
950 funcName += getTypeMangling(valOrResTy,
true);
951 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
954 createDeviceFunctionCall(rewriter, funcName, retType, argTypes, args,
955 {}, funcAttr, op.getOperation());
957 if constexpr (isStore)
958 rewriter.eraseOp(op);
960 rewriter.replaceOp(op, call->getResult(0));
965template <
typename OpType>
966class LLVMLoadStoreToOCLPattern :
public OpConversionPattern<OpType> {
967 using OpConversionPattern<OpType>::OpConversionPattern;
969 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
970 ConversionPatternRewriter &rewriter)
const override {
971 if (!op->hasAttr(
"cache_control"))
974 auto *moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
975 std::optional<ArrayAttr> optCacheControls =
976 getCacheControlMetadata(rewriter, op);
977 if (!optCacheControls) {
978 rewriter.modifyOpInPlace(op, [&]() { op->removeAttr(
"cache_control"); });
983 constexpr bool isStore = std::is_same_v<OpType, LLVM::StoreOp>;
984 unsigned ptrIdx = isStore ? 1 : 0;
985 Value ptr = op->getOperand(ptrIdx);
988 Value annotatedPtr = annotatePtrWithCacheControl(
989 rewriter, op->getLoc(), ptr, *optCacheControls, moduleOp);
992 rewriter.modifyOpInPlace(op, [&]() {
993 op->setOperand(ptrIdx, annotatedPtr);
994 op->removeAttr(
"cache_control");
1027static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdXOp) {
1028 return {
"get_local_id", 0};
1030static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdYOp) {
1031 return {
"get_local_id", 1};
1033static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdZOp) {
1034 return {
"get_local_id", 2};
1036static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimXOp) {
1037 return {
"get_local_size", 0};
1039static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimYOp) {
1040 return {
"get_local_size", 1};
1042static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimZOp) {
1043 return {
"get_local_size", 2};
1045static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdXOp) {
1046 return {
"get_group_id", 0};
1048static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdYOp) {
1049 return {
"get_group_id", 1};
1051static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdZOp) {
1052 return {
"get_group_id", 2};
1054static std::pair<StringRef, int64_t> getConfig(xevm::GridDimXOp) {
1055 return {
"get_num_groups", 0};
1057static std::pair<StringRef, int64_t> getConfig(xevm::GridDimYOp) {
1058 return {
"get_num_groups", 1};
1060static std::pair<StringRef, int64_t> getConfig(xevm::GridDimZOp) {
1061 return {
"get_num_groups", 2};
1065template <
typename OpType>
1066class LaunchConfigOpToOCLPattern :
public OpConversionPattern<OpType> {
1067 using OpConversionPattern<OpType>::OpConversionPattern;
1069 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
1070 ConversionPatternRewriter &rewriter)
const override {
1071 Location loc = op->getLoc();
1072 auto [baseName, dim] = getConfig(op);
1073 Type dimTy = rewriter.getI32Type();
1074 Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
1075 static_cast<int64_t
>(dim));
1076 std::string func = mangle(baseName, {dimTy}, {
true});
1077 Type resTy = op.getType();
1079 createDeviceFunctionCall(rewriter, func, resTy, {dimTy}, {dimVal}, {},
1080 noUnwindWillReturnAttrs, op.getOperation());
1081 constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
1082 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
1088 call.setMemoryEffectsAttr(memAttr);
1089 rewriter.replaceOp(op, call);
1106static StringRef getConfig(xevm::LaneIdOp) {
return "get_sub_group_local_id"; }
1107static StringRef getConfig(xevm::SubgroupIdOp) {
return "get_sub_group_id"; }
1108static StringRef getConfig(xevm::SubgroupSizeOp) {
1109 return "get_sub_group_size";
1111template <
typename OpType>
1112class SubgroupOpWorkitemOpToOCLPattern :
public OpConversionPattern<OpType> {
1113 using OpConversionPattern<OpType>::OpConversionPattern;
1115 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
1116 ConversionPatternRewriter &rewriter)
const override {
1117 std::string func = mangle(getConfig(op).str(), {});
1118 Type resTy = op.getType();
1120 createDeviceFunctionCall(rewriter, func, resTy, {}, {}, {},
1121 noUnwindWillReturnAttrs, op.getOperation());
1122 constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
1123 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
1129 call.setMemoryEffectsAttr(memAttr);
1130 rewriter.replaceOp(op, call);
1135class TruncfToOCLPattern :
public OpConversionPattern<TruncfOp> {
1136 using OpConversionPattern::OpConversionPattern;
1138 matchAndRewrite(TruncfOp op, TruncfOp::Adaptor adaptor,
1139 ConversionPatternRewriter &rewriter)
const override {
1141 auto srcEtype = op.getSrcEtype().getEtype();
1142 auto dstEtype = op.getDstEtype().getEtype();
1161 auto vecSrcTy = dyn_cast<VectorType>(op.getSrc().getType());
1163 return rewriter.notifyMatchFailure(op,
"Scalar src is not supported.");
1165 if (vecSrcTy.getNumElements() != 16)
1166 return rewriter.notifyMatchFailure(
1167 op,
"Only vector src of 16 elements is supported");
1168 auto vecDstTy = dyn_cast<VectorType>(op.getDst().getType());
1170 return rewriter.notifyMatchFailure(op,
"Scalar dst is not supported.");
1171 Value src = op.getSrc();
1172 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
1173 LLVM::ModRefInfo::NoModRef,
1174 LLVM::ModRefInfo::NoModRef,
1175 LLVM::ModRefInfo::NoModRef,
1176 LLVM::ModRefInfo::NoModRef,
1177 LLVM::ModRefInfo::NoModRef,
1178 LLVM::ModRefInfo::NoModRef);
1179 auto funcAttrs = convergentNoUnwindWillReturnAttrs;
1180 funcAttrs.memEffectsAttr = memAttr;
1183 if (dstEtype == TruncfDstElemTypes::E2M1) {
1190 Value cast = LLVM::BitcastOp::create(
1191 rewriter, op.getLoc(), VectorType::get(8, rewriter.getI32Type()),
1194 std::string fnName =
"__builtin_IB_dnscl_";
1195 fnName += (srcEtype == TruncfSrcElemTypes::F16) ?
"hf16" :
"bf16";
1196 auto genDnscl = [&](Value input, Value idx0, Value idx1, Value dstTy,
1197 Value mode) -> Value {
1199 LLVM::ExtractElementOp::create(rewriter, op.getLoc(), input, idx0)
1202 LLVM::ExtractElementOp::create(rewriter, op.getLoc(), input, idx1)
1205 dstTy.getType(), mode.getType()};
1206 SmallVector<Value> args{arg1, arg2, dstTy, mode};
1207 Value dnscl = createDeviceFunctionCall(
1208 rewriter, fnName, rewriter.getI32Type(), argTypes,
1209 args, {}, funcAttrs, op.getOperation())
1214 Value zero = LLVM::ConstantOp::create(rewriter, op.getLoc(),
1215 rewriter.getI32Type(), 0);
1216 Value one = LLVM::ConstantOp::create(rewriter, op.getLoc(),
1217 rewriter.getI32Type(), 1);
1218 Value two = LLVM::ConstantOp::create(rewriter, op.getLoc(),
1219 rewriter.getI32Type(), 2);
1220 Value three = LLVM::ConstantOp::create(rewriter, op.getLoc(),
1221 rewriter.getI32Type(), 3);
1222 Value even = genDnscl(cast, zero, two, one, zero);
1223 Value odd = genDnscl(cast, one, three, one, two);
1224 Value firstHalf = LLVM::OrOp::create(rewriter, op.getLoc(), even, odd);
1225 Value four = LLVM::ConstantOp::create(rewriter, op.getLoc(),
1226 rewriter.getI32Type(), 4);
1227 Value five = LLVM::ConstantOp::create(rewriter, op.getLoc(),
1228 rewriter.getI32Type(), 5);
1229 Value six = LLVM::ConstantOp::create(rewriter, op.getLoc(),
1230 rewriter.getI32Type(), 6);
1231 Value seven = LLVM::ConstantOp::create(rewriter, op.getLoc(),
1232 rewriter.getI32Type(), 7);
1233 even = genDnscl(cast, four, six, one, zero);
1234 odd = genDnscl(cast, five, seven, one, two);
1235 Value secondHalf = LLVM::OrOp::create(rewriter, op.getLoc(), even, odd);
1238 Value combined = LLVM::UndefOp::create(
1239 rewriter, op.getLoc(), VectorType::get(2, rewriter.getI32Type()));
1240 combined = LLVM::InsertElementOp::create(rewriter, op.getLoc(), combined,
1243 combined = LLVM::InsertElementOp::create(rewriter, op.getLoc(), combined,
1247 LLVM::BitcastOp::create(rewriter, op.getLoc(), vecDstTy, combined);
1248 rewriter.replaceOp(op,
result);
1255 if (srcEtype == TruncfSrcElemTypes::BF16) {
1258 src = LLVM::BitcastOp::create(
1259 rewriter, op.getLoc(),
1260 VectorType::get(vecSrcTy.getShape(), rewriter.getI16Type()), src);
1261 std::string fnName =
"__builtin_IB_bftof_16";
1262 SmallVector<Type> argTypes{src.
getType()};
1263 SmallVector<Value> args{src};
1264 Type resTy = VectorType::get(vecSrcTy.getShape(), rewriter.getF32Type());
1265 src = createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args,
1266 {}, funcAttrs, op.getOperation())
1270 std::string truncFnName =
"convert_half16";
1271 SmallVector<Type> truncArgTypes{src.
getType()};
1272 SmallVector<Value> truncArgs{src};
1273 truncFnName = mangle(truncFnName, truncArgTypes);
1274 resTy = VectorType::get(vecSrcTy.getShape(), rewriter.getF16Type());
1276 createDeviceFunctionCall(rewriter, truncFnName, resTy, truncArgTypes,
1277 truncArgs, {}, funcAttrs, op.getOperation())
1280 if (dstEtype == TruncfDstElemTypes::BF8) {
1282 std::string fnName =
"__builtin_IB_hftobf8_16";
1283 SmallVector<Type> argTypes{src.
getType()};
1284 SmallVector<Value> args{src};
1286 createDeviceFunctionCall(rewriter, fnName, vecDstTy, argTypes, args,
1287 {}, funcAttrs, op.getOperation())
1290 rewriter.replaceOp(op,
result);
1291 }
else if (dstEtype == TruncfDstElemTypes::F8) {
1293 std::string fnName =
"__builtin_IB_hftohf8_16";
1294 SmallVector<Type> argTypes{src.
getType()};
1295 SmallVector<Value> args{src};
1297 createDeviceFunctionCall(rewriter, fnName, vecDstTy, argTypes, args,
1298 {}, funcAttrs, op.getOperation())
1301 rewriter.replaceOp(op,
result);
1303 return rewriter.notifyMatchFailure(
1304 op,
"Unsupported src, dst element type pair.");
1310class MMAMxToOCLPattern :
public OpConversionPattern<MMAMxOp> {
1311 using OpConversionPattern::OpConversionPattern;
1313 matchAndRewrite(MMAMxOp op, MMAMxOp::Adaptor adaptor,
1314 ConversionPatternRewriter &rewriter)
const override {
1316 return rewriter.notifyMatchFailure(op,
"OCL requires C operand");
1318 auto precisionC = op.getTypes().getC();
1319 auto precisionD = op.getTypes().getD();
1320 if (precisionC != precisionD) {
1321 return rewriter.notifyMatchFailure(op,
"type of C and D need to match");
1324 constexpr uint32_t bitWidthPackedA{16};
1325 constexpr uint32_t bitWidthPackedB{32};
1326 auto loc = op.getLoc();
1328 auto castIfNeeded = [&](Value val, Type packedType) -> Value {
1329 VectorType origTy = cast<VectorType>(val.
getType());
1330 const uint32_t vecBitSize =
1331 origTy.getNumElements() *
1332 origTy.getElementType().getIntOrFloatBitWidth();
1333 VectorType newTy = VectorType::get(
1334 vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
1335 if (origTy != newTy)
1336 val = LLVM::BitcastOp::create(rewriter, loc, newTy, val);
1340 Value a = op.getA();
1341 Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
1342 ? cast<Type>(rewriter.getF32Type())
1343 : rewriter.getIntegerType(bitWidthPackedA);
1344 a = castIfNeeded(a, packedAType);
1346 Value
b = op.getB();
1347 Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
1348 ? cast<Type>(rewriter.getF32Type())
1349 : rewriter.getIntegerType(bitWidthPackedB);
1350 b = castIfNeeded(
b, packedBType);
1352 Value c = op.getC();
1353 VectorType cOrigTy = cast<VectorType>(c.
getType());
1354 VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
1355 assert(cOrigTy == resOrigTy &&
"Accumulator and result type mismatch");
1358 cOrigTy.getElementType().isBF16()
1359 ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
1361 VectorType resTy = cTy;
1363 c = LLVM::BitcastOp::create(rewriter, loc, cTy, c);
1365 std::string fnName =
1366 llvm::formatv(
"__builtin_IB_sub_group16_bdpas_{0}_{1}_{2}_{3}_8_8",
1367 builtinElemType(op.getTypes().getD()),
1368 builtinElemType(op.getTypes().getC()),
1369 builtinElemType(op.getTypes().getA()),
1370 builtinElemType(op.getTypes().getB()))
1372 auto scaleA = op.getScaleA();
1373 auto scaleB = op.getScaleB();
1374 SmallVector<Type> argTypes{cTy, a.
getType(),
b.getType(), scaleA.getType(),
1376 SmallVector<Value> args{c, a,
b, scaleA, scaleB};
1378 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
1379 LLVM::ModRefInfo::NoModRef,
1380 LLVM::ModRefInfo::NoModRef,
1381 LLVM::ModRefInfo::NoModRef,
1382 LLVM::ModRefInfo::NoModRef,
1383 LLVM::ModRefInfo::NoModRef,
1384 LLVM::ModRefInfo::NoModRef);
1385 auto funcAttrs = convergentNoUnwindWillReturnAttrs;
1386 funcAttrs.memEffectsAttr = memAttr;
1388 createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
1389 funcAttrs, op.getOperation())
1392 if (resOrigTy != resTy)
1393 result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy,
result);
1395 rewriter.replaceOp(op,
result);
1400class AllocaToGlobalPattern :
public OpConversionPattern<LLVM::AllocaOp> {
1401 using OpConversionPattern::OpConversionPattern;
1403 matchAndRewrite(LLVM::AllocaOp op, LLVM::AllocaOp::Adaptor adaptor,
1404 ConversionPatternRewriter &rewriter)
const override {
1405 auto ptrType = cast<LLVM::LLVMPointerType>(op.getType());
1406 auto addrSpace = ptrType.getAddressSpace();
1409 auto symTable = op->getParentWithTrait<OpTrait::SymbolTable>();
1413 if (ModuleOp mod = dyn_cast<ModuleOp>(*symTable)) {
1414 moduleBody = mod.getBody();
1415 }
else if (gpu::GPUModuleOp gpuMod =
1416 dyn_cast<gpu::GPUModuleOp>(*symTable)) {
1417 moduleBody = gpuMod.getBody();
1421 auto val = op.getArraySize();
1425 auto loc = op.getLoc();
1426 auto globalType = LLVM::LLVMArrayType::get(
1427 rewriter.getContext(), op.getElemType(), cst.getZExtValue());
1428 LLVM::GlobalOp globalVar;
1430 OpBuilder::InsertionGuard guard(rewriter);
1431 rewriter.setInsertionPointToStart(moduleBody);
1432 auto alignment = op.getAlignment();
1433 globalVar = LLVM::GlobalOp::create(
1434 rewriter, loc, globalType,
false,
1435 LLVM::Linkage::Internal,
1436 std::string(
"__global_alloca_") +
1437 std::to_string(getNextGlobalIdx()),
1439 alignment ? *alignment : 0, addrSpace);
1441 rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, globalVar);
1446 static unsigned getNextGlobalIdx() {
1447 static unsigned globalIdx = 0;
1458static bool isExtractingContiguousSlice(LLVM::ShuffleVectorOp op) {
1459 if (op.getV1() != op.getV2())
1461 auto maskAttr = op.getMask();
1463 int64_t sourceSize = op.getV1().getType().getNumElements();
1464 if (maskSize > sourceSize)
1466 int64_t firstIndex = maskAttr[0];
1467 for (
int64_t i = 1; i < maskSize; ++i) {
1469 if (
index != firstIndex + i)
1471 if (
index >= sourceSize)
1485class HandleVectorExtractPattern
1487 using OpRewritePattern<LLVM::ShuffleVectorOp>::OpRewritePattern;
1489 void initialize() { setHasBoundedRewriteRecursion(); }
1491 LogicalResult matchAndRewrite(LLVM::ShuffleVectorOp op,
1492 PatternRewriter &rewriter)
const override {
1494 if (!isExtractingContiguousSlice(op))
1497 auto mask = op.getMask();
1498 auto loc = op.getLoc();
1499 auto ty = op.getType();
1501 auto src = op.getV1();
1504 if (isa<LLVM::FPExtOp>(srcOp) || isa<LLVM::FPTruncOp>(srcOp)) {
1505 Value srcInput = srcOp->getOperand(0);
1507 auto srcVecTy = dyn_cast<VectorType>(srcInput.
getType());
1508 auto newShuffleVecTy =
1509 VectorType::get(mask.size(), srcVecTy.getElementType());
1510 auto newShuffle = LLVM::ShuffleVectorOp::create(
1511 rewriter, loc, newShuffleVecTy, srcInput, srcInput, mask);
1514 if (isa<LLVM::FPExtOp>(srcOp)) {
1515 newUnaryOp = LLVM::FPExtOp::create(rewriter, loc, ty, newShuffle);
1517 newUnaryOp = LLVM::FPTruncOp::create(rewriter, loc, ty, newShuffle);
1520 }
else if (isa<LLVM::BitcastOp>(srcOp)) {
1521 Value srcInput = srcOp->getOperand(0);
1523 auto srcInputVecTy = dyn_cast<VectorType>(srcInput.
getType());
1524 auto srcInputSize = srcInputVecTy.getNumElements();
1525 auto srcResVecTy = dyn_cast<VectorType>(srcOp->getResult(0).getType());
1526 auto srcResSize = srcResVecTy.getNumElements();
1527 auto maskSize =
static_cast<int32_t
>(mask.size());
1528 if (srcInputSize > srcResSize) {
1531 if (srcResSize % srcInputSize != 0) {
1534 auto maskScale = srcResSize / srcInputSize;
1535 if (maskScale != 1) {
1536 if (mask[0] % maskScale != 0) {
1540 SmallVector<int32_t> newMask;
1541 int32_t newMaskSize = maskSize / maskScale;
1542 int32_t maskStart = mask[0] / maskScale;
1543 for (int32_t i = 0; i < newMaskSize; ++i) {
1544 newMask.push_back(maskStart + i);
1548 auto newShuffleVecTy =
1549 VectorType::get(srcInputSize, srcInputVecTy.getElementType());
1550 auto newShuffle = LLVM::ShuffleVectorOp::create(
1551 rewriter, loc, newShuffleVecTy, srcInput, srcInput, mask);
1554 LLVM::BitcastOp::create(rewriter, loc, ty, newShuffle);
1556 }
else if (isa<LLVM::ShuffleVectorOp>(srcOp)) {
1561 auto srcShuffle = cast<LLVM::ShuffleVectorOp>(srcOp);
1562 if (!isExtractingContiguousSlice(srcShuffle))
1564 auto srcMask = srcShuffle.getMask();
1565 SmallVector<int32_t> combinedMask;
1566 for (
auto index : mask) {
1567 combinedMask.push_back(srcMask[index]);
1569 auto newShuffle = LLVM::ShuffleVectorOp::create(
1570 rewriter, loc, ty, srcShuffle.getV1(), srcShuffle.getV1(),
1573 }
else if (isa<LLVM::LoadOp>(srcOp)) {
1575 auto loadOp = cast<LLVM::LoadOp>(srcOp);
1576 auto loadPtr = loadOp.getAddr();
1577 auto loadAddrSpace = loadPtr.getType().getAddressSpace();
1578 if (loadAddrSpace != 0)
1580 auto loadTy = dyn_cast<VectorType>(loadOp.getType());
1581 auto elemTy = loadTy.getElementType();
1582 auto firstIndex = mask[0];
1583 auto newVecTy = VectorType::get(mask.size(), elemTy);
1586 auto newPtr = LLVM::GEPOp::create(
1588 LLVM::LLVMPointerType::get(rewriter.
getContext(), loadAddrSpace),
1589 elemTy, loadPtr, ArrayRef<LLVM::GEPArg>{firstIndex});
1590 auto newLoad = LLVM::LoadOp::create(rewriter, loc, newVecTy, newPtr);
1593 auto newLoad = LLVM::LoadOp::create(rewriter, loc, newVecTy, loadPtr);
1611struct ConvertXeVMToLLVMPass
1615 void getDependentDialects(DialectRegistry ®istry)
const override {
1616 registry.
insert<LLVM::LLVMDialect, XeVMDialect>();
1619 void runOnOperation()
override {
1623 if (
failed(applyPartialConversion(getOperation(),
target,
1624 std::move(patterns))))
1625 signalPassFailure();
1629 RewritePatternSet vectorPatterns(&
getContext());
1630 vectorPatterns.add<HandleVectorExtractPattern>(&
getContext());
1631 GreedyRewriteConfig config{};
1636 config.enableFolding(
false);
1653 target.addDynamicallyLegalDialect<LLVM::LLVMDialect>([](
Operation *op) {
1657 if (isa<LLVM::AllocaOp>(op)) {
1658 LLVM::AllocaOp aOp = cast<LLVM::AllocaOp>(op);
1659 LLVM::LLVMPointerType pTy = cast<LLVM::LLVMPointerType>(aOp.getType());
1660 auto addrSpace = pTy.getAddressSpace();
1661 return addrSpace != 3;
1664 return !op->hasAttr(
"cache_control");
1666 target.addIllegalDialect<XeVMDialect>();
1667 patterns.
add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
1668 LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
1669 LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
1670 MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern,
1671 LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
1672 LLVMLoadStoreToOCLPattern<LLVM::StoreOp>,
1673 BlockLoadStore1DToOCLPattern<BlockLoadOp>,
1674 BlockLoadStore1DToOCLPattern<BlockStoreOp>,
1675 LaunchConfigOpToOCLPattern<WorkitemIdXOp>,
1676 LaunchConfigOpToOCLPattern<WorkitemIdYOp>,
1677 LaunchConfigOpToOCLPattern<WorkitemIdZOp>,
1678 LaunchConfigOpToOCLPattern<WorkgroupDimXOp>,
1679 LaunchConfigOpToOCLPattern<WorkgroupDimYOp>,
1680 LaunchConfigOpToOCLPattern<WorkgroupDimZOp>,
1681 LaunchConfigOpToOCLPattern<WorkgroupIdXOp>,
1682 LaunchConfigOpToOCLPattern<WorkgroupIdYOp>,
1683 LaunchConfigOpToOCLPattern<WorkgroupIdZOp>,
1684 LaunchConfigOpToOCLPattern<GridDimXOp>,
1685 LaunchConfigOpToOCLPattern<GridDimYOp>,
1686 LaunchConfigOpToOCLPattern<GridDimZOp>,
1687 SubgroupOpWorkitemOpToOCLPattern<LaneIdOp>,
1688 SubgroupOpWorkitemOpToOCLPattern<SubgroupIdOp>,
1689 SubgroupOpWorkitemOpToOCLPattern<SubgroupSizeOp>,
1690 TruncfToOCLPattern, MMAMxToOCLPattern, AllocaToGlobalPattern>(
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
Attributes are known-constant values of operations.
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
RAII guard to reset the insertion point of the builder when destroyed.
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes={}, Type resultType={}, bool isVarArg=false, bool isReserved=false, SymbolTableCollection *symbolTables=nullptr)
Create a FuncOp with signature resultType(paramTypes) and name name`.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateXeVMToLLVMConversionPatterns(ConversionTarget &target, RewritePatternSet &patterns)
llvm::TypeSwitch< T, ResultT > TypeSwitch
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...