18#include "llvm/ADT/ArrayRef.h"
19#include "llvm/Support/FormatVariadic.h"
26#include "llvm/ADT/TypeSwitch.h"
29#define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
30#include "mlir/Conversion/Passes.h.inc"
38struct LLVMFuncAttributeOptions {
39 bool isConvergent =
false;
40 bool isNoUnwind =
false;
41 bool isWillReturn =
false;
42 LLVM::MemoryEffectsAttr memEffectsAttr{};
44static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
45 false,
true,
false, {}};
46static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
47 false,
true,
true, {}};
48static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
49 true,
true,
true, {}};
51std::string getTypeMangling(
Type ty,
bool isUnsigned =
false) {
53 .Case([isUnsigned](VectorType ty) -> std::string {
54 return "Dv" + std::to_string(ty.getNumElements()) +
"_" +
55 getTypeMangling(ty.getElementType(), isUnsigned);
57 .Case([](Float16Type) -> std::string {
return "Dh"; })
58 .Case([](Float32Type) -> std::string {
return "f"; })
59 .Case([](Float64Type) -> std::string {
return "d"; })
60 .Case([isUnsigned](IntegerType ty) -> std::string {
61 switch (ty.getWidth()) {
63 return isUnsigned ?
"h" :
"c";
65 return isUnsigned ?
"t" :
"s";
67 return isUnsigned ?
"j" :
"i";
69 return isUnsigned ?
"m" :
"l";
71 llvm_unreachable(
"unhandled integer type");
74 .DefaultUnreachable(
"unhandled type for mangling");
79 assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
80 "Signedness info doesn't match");
82 llvm::raw_string_ostream os(s);
83 llvm::SmallDenseMap<Type, unsigned> substitutions;
84 os <<
"_Z" << baseName.size() << baseName;
85 for (
auto [idx, type] : llvm::enumerate(types)) {
86 auto it = substitutions.find(type);
87 if (it != substitutions.end()) {
90 if (
unsigned firstIdx = it->getSecond(); firstIdx > 0)
94 if (!type.isIntOrFloat())
95 substitutions[type] = substitutions.size();
96 os << getTypeMangling(type, isUnsigned.empty() ?
false : isUnsigned[idx]);
102static int32_t getL1CacheControl(LoadCacheControl cc) {
105 case LoadCacheControl::L1C_L2UC_L3UC:
106 case LoadCacheControl::L1C_L2UC_L3C:
107 case LoadCacheControl::L1C_L2C_L3UC:
108 case LoadCacheControl::L1C_L2C_L3C:
111 case LoadCacheControl::L1S_L2UC_L3UC:
112 case LoadCacheControl::L1S_L2UC_L3C:
113 case LoadCacheControl::L1S_L2C_L3UC:
114 case LoadCacheControl::L1S_L2C_L3C:
117 case LoadCacheControl::INVALIDATE_READ:
126static int32_t getL1CacheControl(StoreCacheControl cc) {
129 case StoreCacheControl::L1WT_L2UC_L3UC:
130 case StoreCacheControl::L1WT_L2UC_L3WB:
131 case StoreCacheControl::L1WT_L2WB_L3UC:
132 case StoreCacheControl::L1WT_L2WB_L3WB:
135 case StoreCacheControl::L1WB_L2UC_L3UC:
136 case StoreCacheControl::L1WB_L2WB_L3UC:
137 case StoreCacheControl::L1WB_L2UC_L3WB:
140 case StoreCacheControl::L1S_L2UC_L3UC:
141 case StoreCacheControl::L1S_L2UC_L3WB:
142 case StoreCacheControl::L1S_L2WB_L3UC:
143 case StoreCacheControl::L1S_L2WB_L3WB:
152static int32_t getL3CacheControl(LoadCacheControl cc) {
155 case LoadCacheControl::L1UC_L2UC_L3C:
156 case LoadCacheControl::L1UC_L2C_L3C:
157 case LoadCacheControl::L1C_L2UC_L3C:
158 case LoadCacheControl::L1C_L2C_L3C:
159 case LoadCacheControl::L1S_L2UC_L3C:
160 case LoadCacheControl::L1S_L2C_L3C:
163 case LoadCacheControl::INVALIDATE_READ:
172static int32_t getL3CacheControl(StoreCacheControl cc) {
175 case StoreCacheControl::L1UC_L2UC_L3WB:
176 case StoreCacheControl::L1UC_L2WB_L3WB:
177 case StoreCacheControl::L1WT_L2UC_L3WB:
178 case StoreCacheControl::L1WT_L2WB_L3WB:
179 case StoreCacheControl::L1S_L2UC_L3WB:
180 case StoreCacheControl::L1S_L2WB_L3WB:
181 case StoreCacheControl::L1WB_L2UC_L3WB:
190static std::optional<LoadCacheControl> getCacheControl(PrefetchOp op) {
191 return op.getCacheControl();
194static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) {
195 return op.getCacheControl();
198static std::optional<LoadCacheControl> getCacheControl(BlockLoadOp op) {
199 return op.getCacheControl();
202static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) {
203 return op.getCacheControl();
206static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) {
207 return op.getCacheControl();
210static std::optional<StoreCacheControl> getCacheControl(BlockStoreOp op) {
211 return op.getCacheControl();
214static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) {
215 if (op->hasAttr(
"cache_control")) {
216 auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>(
"cache_control");
219 return std::optional<LoadCacheControl>(attr.getValue());
224static std::optional<StoreCacheControl> getCacheControl(LLVM::StoreOp op) {
225 if (op->hasAttr(
"cache_control")) {
226 auto attr = op->getAttrOfType<xevm::StoreCacheControlAttr>(
"cache_control");
229 return std::optional<StoreCacheControl>(attr.getValue());
234template <
typename OpType>
235int32_t getL1CacheControl(OpType op) {
236 return getL1CacheControl(*getCacheControl(op));
239template <
typename OpType>
240int32_t getL3CacheControl(OpType op) {
241 return getL3CacheControl(*getCacheControl(op));
244template <
typename OpType>
245static std::optional<ArrayAttr>
246getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
247 if (!getCacheControl(op))
249 constexpr int32_t decorationCacheControlArity{3};
250 constexpr int32_t loadCacheControlKey{6442};
251 constexpr int32_t storeCacheControlKey{6443};
252 constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
253 std::is_same_v<OpType, BlockPrefetch2dOp> ||
254 std::is_same_v<OpType, LLVM::LoadOp> ||
255 std::is_same_v<OpType, BlockLoadOp> ||
256 std::is_same_v<OpType, PrefetchOp>;
257 const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
259 controlKey, 0, getL1CacheControl<OpType>(op)};
261 controlKey, 1, getL3CacheControl<OpType>(op)};
262 auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
263 auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
266 return rewriter.getArrayAttr(combinedAttrs);
269static LLVM::CallOp createDeviceFunctionCall(
270 ConversionPatternRewriter &rewriter, StringRef funcName,
Type retType,
273 LLVMFuncAttributeOptions funcAttributeOptions,
Operation *op) {
275 assert(moduleOp &&
"Expecting module");
280 assert(!
failed(funcOpRes));
281 LLVM::LLVMFuncOp funcOp = funcOpRes.value();
282 funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
283 funcOp.setConvergent(funcAttributeOptions.isConvergent);
284 funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
285 funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
287 if (funcAttributeOptions.memEffectsAttr)
288 funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
290 for (
auto [idx, attrName] : paramAttrs)
291 funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
293 auto callOp = LLVM::CallOp::create(rewriter, loc, funcOp, args);
294 callOp->setAttrs(funcOp->getAttrs());
299class MMAToOCLPattern :
public OpConversionPattern<xevm::MMAOp> {
300 using OpConversionPattern::OpConversionPattern;
302 matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor,
303 ConversionPatternRewriter &rewriter)
const override {
305 return rewriter.notifyMatchFailure(op,
"OCL requires C operand");
307 auto precisionA = op.getTypes().getA();
308 auto precisionB = op.getTypes().getB();
309 auto precisionC = op.getTypes().getC();
310 auto precisionD = op.getTypes().getD();
311 if (precisionC != precisionD) {
312 return rewriter.notifyMatchFailure(op,
"type of C and D need to match");
314 if (precisionC != xevm::ElemType::S32 &&
315 precisionC != xevm::ElemType::F32 &&
316 precisionC != xevm::ElemType::F16 &&
317 precisionC != xevm::ElemType::BF16) {
318 return rewriter.notifyMatchFailure(
319 op,
"type of C and D must be S32, F32, F16 or BF16");
321 if (precisionA == xevm::ElemType::S32 ||
322 precisionA == xevm::ElemType::F32) {
323 return rewriter.notifyMatchFailure(op,
"type of A cannot be S32 or F32");
325 if (precisionB == xevm::ElemType::S32 ||
326 precisionB == xevm::ElemType::F32) {
327 return rewriter.notifyMatchFailure(op,
"type of B cannot be S32 or F32");
329 constexpr uint32_t bitWidthPackedA{16};
330 constexpr uint32_t bitWidthPackedB{32};
331 auto loc = op.getLoc();
333 auto castIfNeeded = [&](Value val, Type packedType) -> Value {
334 VectorType origTy = cast<VectorType>(val.
getType());
335 const uint32_t vecBitSize =
336 origTy.getNumElements() *
337 origTy.getElementType().getIntOrFloatBitWidth();
338 VectorType newTy = VectorType::get(
339 vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
341 val = LLVM::BitcastOp::create(rewriter, loc, newTy, val);
346 Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
347 ? cast<Type>(rewriter.getF32Type())
348 : rewriter.getIntegerType(bitWidthPackedA);
349 a = castIfNeeded(a, packedAType);
352 Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
353 ? cast<Type>(rewriter.getF32Type())
354 : rewriter.getIntegerType(bitWidthPackedB);
355 b = castIfNeeded(
b, packedBType);
358 VectorType cOrigTy = cast<VectorType>(c.
getType());
359 VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
360 assert(cOrigTy == resOrigTy &&
"Accumulator and result type mismatch");
363 cOrigTy.getElementType().isBF16()
364 ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
366 VectorType resTy = cTy;
368 c = LLVM::BitcastOp::create(rewriter, loc, cTy, c);
370 constexpr int32_t systolicDepth{8};
372 llvm::formatv(
"intel_sub_group_{0}_{1}_matrix_mad_k{2}",
373 stringifyElemType(op.getTypes().getA()).str(),
374 stringifyElemType(op.getTypes().getB()).str(),
376 getNumOperandsPerDword(op.getTypes().getA()))
378 SmallVector<Type> argTypes{a.
getType(),
b.getType(), cTy};
379 fnName = mangle(fnName, argTypes);
380 SmallVector<Value> args{a,
b, c};
382 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
383 LLVM::ModRefInfo::NoModRef,
384 LLVM::ModRefInfo::NoModRef,
385 LLVM::ModRefInfo::NoModRef,
386 LLVM::ModRefInfo::NoModRef,
387 LLVM::ModRefInfo::NoModRef,
388 LLVM::ModRefInfo::NoModRef);
389 auto funcAttrs = convergentNoUnwindWillReturnAttrs;
390 funcAttrs.memEffectsAttr = memAttr;
392 createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
393 funcAttrs, op.getOperation())
396 if (resOrigTy != resTy)
397 result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy,
result);
399 rewriter.replaceOp(op,
result);
404 static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
406 case xevm::ElemType::TF32:
408 case xevm::ElemType::BF16:
409 case xevm::ElemType::F16:
411 case xevm::ElemType::U8:
412 case xevm::ElemType::S8:
415 llvm_unreachable(
"unsupported xevm::ElemType");
420class PrefetchToOCLPattern :
public OpConversionPattern<PrefetchOp> {
421 using OpConversionPattern::OpConversionPattern;
423 matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor,
424 ConversionPatternRewriter &rewriter)
const override {
425 auto loc = op.getLoc();
426 const std::string fnName{
"_Z8prefetchPU3AS1Kcm"};
428 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), 1);
429 SmallVector<Value> args{op.getPtr(), one};
430 SmallVector<Type> argTypes;
431 for (
auto arg : args)
432 argTypes.push_back(arg.getType());
433 auto funcAttr = noUnwindAttrs;
434 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
435 LLVM::ModRefInfo::NoModRef,
436 LLVM::ModRefInfo::Ref,
437 LLVM::ModRefInfo::NoModRef,
438 LLVM::ModRefInfo::NoModRef,
439 LLVM::ModRefInfo::NoModRef,
440 LLVM::ModRefInfo::NoModRef);
441 funcAttr.memEffectsAttr = memAttr;
443 LLVM::CallOp call = createDeviceFunctionCall(
444 rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()),
445 argTypes, args, {}, funcAttr, op.getOperation());
446 if (std::optional<ArrayAttr> optCacheControls =
447 getCacheControlMetadata(rewriter, op))
448 call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
449 rewriter.eraseOp(op);
454class MemfenceToOCLPattern :
public OpConversionPattern<MemfenceOp> {
455 using OpConversionPattern::OpConversionPattern;
457 matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor,
458 ConversionPatternRewriter &rewriter)
const override {
459 auto loc = op.getLoc();
460 const std::string fnName{
"atomic_work_item_fence"};
461 int memScope, addrSpace;
462 switch (op.getAddrspace()) {
463 case xevm::AddrSpace::SHARED:
466 case xevm::AddrSpace::GLOBAL:
471 return rewriter.notifyMatchFailure(
472 op,
"Fence only supports global and shared address spaces.");
474 switch (op.getScope()) {
475 case xevm::MemScope::WORKGROUP:
478 case xevm::MemScope::DEVICE:
483 return rewriter.notifyMatchFailure(
484 op,
"Fence only supports workgroup and device memory scopes.");
486 Type i32Type = rewriter.getI32Type();
487 Value acqRel = LLVM::ConstantOp::create(rewriter, loc, i32Type, 4);
488 Value memScopeConst =
489 LLVM::ConstantOp::create(rewriter, loc, i32Type, memScope);
490 Value addrSpaceConst =
491 LLVM::ConstantOp::create(rewriter, loc, i32Type, addrSpace);
492 SmallVector<Value> args{addrSpaceConst, acqRel, memScopeConst};
493 SmallVector<Type> argTypes{3, i32Type};
494 createDeviceFunctionCall(rewriter, mangle(fnName, argTypes),
495 LLVM::LLVMVoidType::get(rewriter.getContext()),
496 argTypes, args, {}, noUnwindAttrs,
498 rewriter.eraseOp(op);
502template <
typename OpType>
503class LoadStorePrefetchToOCLPattern :
public OpConversionPattern<OpType> {
504 using OpConversionPattern<OpType>::OpConversionPattern;
506 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
507 ConversionPatternRewriter &rewriter)
const override {
508 constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp>;
509 constexpr bool isPrefetch = std::is_same_v<OpType, BlockPrefetch2dOp>;
511 auto loc = op.getLoc();
513 bool packReg =
false;
514 bool transpose =
false;
515 if constexpr (isLoad) {
516 vecType = op.getRes().getType();
517 packReg = op.getPackRegister();
518 transpose = op.getTranspose();
519 }
else if constexpr (!isPrefetch) {
520 vecType = op.getStoredVal().getType();
523 auto i32Type = rewriter.getI32Type();
525 LLVM::UndefOp::create(rewriter, loc, VectorType::get(2, i32Type));
526 Value zero = LLVM::ConstantOp::create(rewriter, loc, i32Type, 0);
527 Value one = LLVM::ConstantOp::create(rewriter, loc, i32Type, 1);
528 byteCoord = LLVM::InsertElementOp::create(
529 rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
530 byteCoord = LLVM::InsertElementOp::create(
531 rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
532 SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
533 op.getBasePitch(), byteCoord};
534 SmallVector<Type> retTypes;
536 std::string funcName{
"intel_sub_group_2d_block_"};
537 std::string bitWidthId;
538 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
539 SmallVector<std::pair<unsigned, StringRef>, 4> paramAttrs;
540 if constexpr (isPrefetch) {
541 funcName +=
"prefetch";
542 paramAttrs = {std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName())};
543 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
544 LLVM::ModRefInfo::NoModRef,
545 LLVM::ModRefInfo::Ref,
546 LLVM::ModRefInfo::NoModRef,
547 LLVM::ModRefInfo::NoModRef,
548 LLVM::ModRefInfo::NoModRef,
549 LLVM::ModRefInfo::NoModRef);
550 funcAttr = noUnwindAttrs;
551 funcAttr.memEffectsAttr = memAttr;
553 auto vecElemType = vecType.getElementType();
554 auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth();
555 Value numElems = LLVM::ConstantOp::create(rewriter, loc, i32Type,
556 vecType.getNumElements());
557 auto dstOrSrcPtr = LLVM::AllocaOp::create(
558 rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()),
559 vecElemType, numElems);
560 args.push_back(dstOrSrcPtr);
561 if constexpr (isLoad) {
563 bitWidthId = getTypeMangling(vecElemType,
true);
565 funcName +=
"_transform";
567 funcName +=
"_transpose";
568 spvLoadDstPtr = dstOrSrcPtr;
569 retTypes.push_back(vecType);
571 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
572 std::make_pair(0, LLVM::LLVMDialect::getReadonlyAttrName()),
573 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
574 std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()),
578 bitWidthId = (vecElemBitWidth == 32)
580 : ((vecElemBitWidth == 16) ?
"t" :
"h");
581 LLVM::StoreOp::create(rewriter, loc, op.getStoredVal(), dstOrSrcPtr);
583 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
584 std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()),
585 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
586 std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()),
592 llvm::formatv(
"{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(),
593 op.getTileHeight(), op.getTileWidth(), op.getVBlocks())
595 std::string prefetchCode(
"");
598 funcName = llvm::formatv(
"_Z{0}{1}PU3AS1viiiDv2_i{2}{3}", funcName.size(),
599 funcName, prefetchCode, bitWidthId)
601 SmallVector<Type> argTypes;
602 for (
auto arg : args) {
603 argTypes.push_back(arg.getType());
605 LLVM::CallOp call = createDeviceFunctionCall(
606 rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
607 argTypes, args, paramAttrs, funcAttr, op.getOperation());
608 if (std::optional<ArrayAttr> optCacheControls =
609 getCacheControlMetadata(rewriter, op)) {
610 call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
612 if constexpr (isLoad)
614 op, LLVM::LoadOp::create(rewriter, loc, vecType, spvLoadDstPtr));
616 rewriter.eraseOp(op);
621template <
typename OpType>
622class BlockLoadStore1DToOCLPattern :
public OpConversionPattern<OpType> {
623 using OpConversionPattern<OpType>::OpConversionPattern;
625 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
626 ConversionPatternRewriter &rewriter)
const override {
627 constexpr bool isStore = std::is_same_v<OpType, xevm::BlockStoreOp>;
631 std::string funcName{
"intel_sub_group_block_"};
634 if constexpr (isStore) {
635 funcName +=
"write_u";
636 valOrResTy = op.getVal().getType();
638 funcName +=
"read_u";
639 valOrResTy = op.getType();
642 VectorType vecTy = dyn_cast<VectorType>(valOrResTy);
643 Type elemType = vecTy ? vecTy.getElementType() : valOrResTy;
644 funcName += getTypeMangling(elemType);
646 funcName += std::to_string(vecTy.getNumElements());
647 SmallVector<Type, 2> argTypes{};
651 SmallVector<bool, 2> isUnsigned{};
655 SmallVector<Value, 2> args{};
656 args.push_back(op.getPtr());
657 argTypes.push_back(op.getPtr().getType());
658 isUnsigned.push_back(
true);
660 if constexpr (isStore) {
661 args.push_back(op.getVal());
662 argTypes.push_back(op.getVal().getType());
663 isUnsigned.push_back(
true);
664 retType = LLVM::LLVMVoidType::get(rewriter.getContext());
666 retType = valOrResTy;
668 funcName = std::string(
"_Z") + std::to_string(funcName.size()) + funcName +
670 std::to_string(op.getPtr().getType().getAddressSpace());
671 funcName += getTypeMangling(elemType,
true);
672 if constexpr (isStore)
673 funcName += getTypeMangling(valOrResTy,
true);
674 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
677 createDeviceFunctionCall(rewriter, funcName, retType, argTypes, args,
678 {}, funcAttr, op.getOperation());
679 if (std::optional<ArrayAttr> optCacheControls =
680 getCacheControlMetadata(rewriter, op)) {
681 call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
683 if constexpr (isStore)
684 rewriter.eraseOp(op);
686 rewriter.replaceOp(op, call->getResult(0));
691template <
typename OpType>
692class LLVMLoadStoreToOCLPattern :
public OpConversionPattern<OpType> {
693 using OpConversionPattern<OpType>::OpConversionPattern;
695 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
696 ConversionPatternRewriter &rewriter)
const override {
697 if (!op->hasAttr(
"cache_control"))
699 std::optional<ArrayAttr> optCacheControls =
700 getCacheControlMetadata(rewriter, op);
701 op->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
702 op->removeAttr(
"cache_control");
734static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdXOp) {
735 return {
"get_local_id", 0};
737static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdYOp) {
738 return {
"get_local_id", 1};
740static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdZOp) {
741 return {
"get_local_id", 2};
743static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimXOp) {
744 return {
"get_local_size", 0};
746static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimYOp) {
747 return {
"get_local_size", 1};
749static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimZOp) {
750 return {
"get_local_size", 2};
752static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdXOp) {
753 return {
"get_group_id", 0};
755static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdYOp) {
756 return {
"get_group_id", 1};
758static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdZOp) {
759 return {
"get_group_id", 2};
761static std::pair<StringRef, int64_t> getConfig(xevm::GridDimXOp) {
762 return {
"get_num_groups", 0};
764static std::pair<StringRef, int64_t> getConfig(xevm::GridDimYOp) {
765 return {
"get_num_groups", 1};
767static std::pair<StringRef, int64_t> getConfig(xevm::GridDimZOp) {
768 return {
"get_num_groups", 2};
772template <
typename OpType>
773class LaunchConfigOpToOCLPattern :
public OpConversionPattern<OpType> {
774 using OpConversionPattern<OpType>::OpConversionPattern;
776 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
777 ConversionPatternRewriter &rewriter)
const override {
778 Location loc = op->getLoc();
779 auto [baseName, dim] = getConfig(op);
780 Type dimTy = rewriter.getI32Type();
781 Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
782 static_cast<int64_t
>(dim));
783 std::string func = mangle(baseName, {dimTy}, {
true});
784 Type resTy = op.getType();
786 createDeviceFunctionCall(rewriter, func, resTy, {dimTy}, {dimVal}, {},
787 noUnwindWillReturnAttrs, op.getOperation());
788 constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
789 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
795 call.setMemoryEffectsAttr(memAttr);
796 rewriter.replaceOp(op, call);
813static StringRef getConfig(xevm::LaneIdOp) {
return "get_sub_group_local_id"; }
814static StringRef getConfig(xevm::SubgroupIdOp) {
return "get_sub_group_id"; }
815static StringRef getConfig(xevm::SubgroupSizeOp) {
816 return "get_sub_group_size";
818template <
typename OpType>
819class SubgroupOpWorkitemOpToOCLPattern :
public OpConversionPattern<OpType> {
820 using OpConversionPattern<OpType>::OpConversionPattern;
822 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
823 ConversionPatternRewriter &rewriter)
const override {
824 std::string func = mangle(getConfig(op).str(), {});
825 Type resTy = op.getType();
827 createDeviceFunctionCall(rewriter, func, resTy, {}, {}, {},
828 noUnwindWillReturnAttrs, op.getOperation());
829 constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
830 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
836 call.setMemoryEffectsAttr(memAttr);
837 rewriter.replaceOp(op, call);
842class AllocaToGlobalPattern :
public OpConversionPattern<LLVM::AllocaOp> {
843 using OpConversionPattern::OpConversionPattern;
845 matchAndRewrite(LLVM::AllocaOp op, LLVM::AllocaOp::Adaptor adaptor,
846 ConversionPatternRewriter &rewriter)
const override {
847 auto ptrType = cast<LLVM::LLVMPointerType>(op.getType());
848 auto addrSpace = ptrType.getAddressSpace();
851 auto symTable = op->getParentWithTrait<OpTrait::SymbolTable>();
855 if (ModuleOp mod = dyn_cast<ModuleOp>(*symTable)) {
856 moduleBody = mod.getBody();
857 }
else if (gpu::GPUModuleOp gpuMod =
858 dyn_cast<gpu::GPUModuleOp>(*symTable)) {
859 moduleBody = gpuMod.getBody();
863 auto val = op.getArraySize();
867 auto loc = op.getLoc();
868 auto globalType = LLVM::LLVMArrayType::get(
869 rewriter.getContext(), op.getElemType(), cst.getZExtValue());
870 LLVM::GlobalOp globalVar;
872 OpBuilder::InsertionGuard guard(rewriter);
873 rewriter.setInsertionPointToStart(moduleBody);
874 auto alignment = op.getAlignment();
875 globalVar = LLVM::GlobalOp::create(
876 rewriter, loc, globalType,
false,
877 LLVM::Linkage::Internal,
878 std::string(
"__global_alloca_") +
879 std::to_string(getNextGlobalIdx()),
881 alignment ? *alignment : 0, addrSpace);
883 rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, globalVar);
888 static unsigned getNextGlobalIdx() {
889 static unsigned globalIdx = 0;
894static bool isExtractingContiguousSlice(LLVM::ShuffleVectorOp op) {
895 if (op.getV1() != op.getV2())
897 auto maskAttr = op.getMask();
898 int64_t firstIndex = maskAttr[0];
899 for (
int64_t i = 1; i < static_cast<int64_t>(maskAttr.size()); ++i) {
901 if (
index != firstIndex + i)
915class HandleVectorExtractPattern
917 using OpRewritePattern<LLVM::ShuffleVectorOp>::OpRewritePattern;
919 void initialize() { setHasBoundedRewriteRecursion(); }
921 LogicalResult matchAndRewrite(LLVM::ShuffleVectorOp op,
922 PatternRewriter &rewriter)
const override {
924 if (!isExtractingContiguousSlice(op))
927 auto mask = op.getMask();
928 auto loc = op.getLoc();
929 auto ty = op.getType();
931 auto src = op.getV1();
933 if (
auto srcOp = src.getDefiningOp()) {
934 if (isa<LLVM::FPExtOp>(srcOp) || isa<LLVM::FPTruncOp>(srcOp)) {
935 Value srcInput = srcOp->getOperand(0);
937 auto srcVecTy = dyn_cast<VectorType>(srcInput.
getType());
938 auto newShuffleVecTy =
939 VectorType::get(mask.size(), srcVecTy.getElementType());
940 auto newShuffle = LLVM::ShuffleVectorOp::create(
941 rewriter, loc, newShuffleVecTy, srcInput, srcInput, mask);
944 if (isa<LLVM::FPExtOp>(srcOp)) {
945 newUnaryOp = LLVM::FPExtOp::create(rewriter, loc, ty, newShuffle);
947 newUnaryOp = LLVM::FPTruncOp::create(rewriter, loc, ty, newShuffle);
950 }
else if (isa<LLVM::BitcastOp>(srcOp)) {
951 Value srcInput = srcOp->getOperand(0);
953 auto srcInputVecTy = dyn_cast<VectorType>(srcInput.
getType());
954 auto srcInputSize = srcInputVecTy.getNumElements();
955 auto srcResVecTy = dyn_cast<VectorType>(srcOp->getResult(0).getType());
956 auto srcResSize = srcResVecTy.getNumElements();
957 auto maskSize =
static_cast<int32_t
>(mask.size());
958 if (srcInputSize > srcResSize) {
961 if (srcResSize % srcInputSize != 0) {
964 auto maskScale = srcResSize / srcInputSize;
965 if (maskScale != 1) {
966 if (mask[0] % maskScale != 0) {
970 SmallVector<int32_t> newMask;
971 int32_t newMaskSize = maskSize / maskScale;
972 int32_t maskStart = mask[0] / maskScale;
973 for (int32_t i = 0; i < newMaskSize; ++i) {
974 newMask.push_back(maskStart + i);
978 auto newShuffleVecTy =
979 VectorType::get(srcInputSize, srcInputVecTy.getElementType());
980 auto newShuffle = LLVM::ShuffleVectorOp::create(
981 rewriter, loc, newShuffleVecTy, srcInput, srcInput, mask);
984 LLVM::BitcastOp::create(rewriter, loc, ty, newShuffle);
986 }
else if (isa<LLVM::ShuffleVectorOp>(srcOp)) {
988 auto srcShuffle = cast<LLVM::ShuffleVectorOp>(srcOp);
989 auto srcMask = srcShuffle.getMask();
990 SmallVector<int32_t> combinedMask;
991 for (
auto index : mask) {
992 combinedMask.push_back(srcMask[index]);
994 auto newShuffle = LLVM::ShuffleVectorOp::create(
995 rewriter, loc, ty, srcShuffle.getV1(), srcShuffle.getV1(),
998 }
else if (isa<LLVM::LoadOp>(srcOp)) {
1000 auto loadOp = cast<LLVM::LoadOp>(srcOp);
1001 auto loadPtr = loadOp.getAddr();
1002 auto loadTy = dyn_cast<VectorType>(loadOp.getType());
1003 auto elemTy = loadTy.getElementType();
1004 auto firstIndex = mask[0];
1005 auto newVecTy = VectorType::get(mask.size(), elemTy);
1008 auto newPtr = LLVM::GEPOp::create(
1010 LLVM::LLVMPointerType::get(rewriter.
getContext(),
1011 loadPtr.getType().getAddressSpace()),
1012 elemTy, loadPtr, ArrayRef<LLVM::GEPArg>{firstIndex});
1013 auto newLoad = LLVM::LoadOp::create(rewriter, loc, newVecTy, newPtr);
1016 auto newLoad = LLVM::LoadOp::create(rewriter, loc, newVecTy, loadPtr);
1031struct ConvertXeVMToLLVMPass
1035 void getDependentDialects(DialectRegistry ®istry)
const override {
1036 registry.
insert<LLVM::LLVMDialect, XeVMDialect>();
1039 void runOnOperation()
override {
1043 if (
failed(applyPartialConversion(getOperation(),
target,
1044 std::move(patterns))))
1045 signalPassFailure();
1049 RewritePatternSet vectorPatterns(&
getContext());
1050 vectorPatterns.add<HandleVectorExtractPattern>(&
getContext());
1051 GreedyRewriteConfig config{};
1056 config.enableFolding(
false);
1073 target.addDynamicallyLegalDialect<LLVM::LLVMDialect>([](
Operation *op) {
1077 if (isa<LLVM::AllocaOp>(op)) {
1078 LLVM::AllocaOp aOp = cast<LLVM::AllocaOp>(op);
1079 LLVM::LLVMPointerType pTy = cast<LLVM::LLVMPointerType>(aOp.getType());
1080 auto addrSpace = pTy.getAddressSpace();
1081 return addrSpace != 3;
1084 return !op->hasAttr(
"cache_control");
1086 target.addIllegalDialect<XeVMDialect>();
1087 patterns.
add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
1088 LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
1089 LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
1090 MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern,
1091 LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
1092 LLVMLoadStoreToOCLPattern<LLVM::StoreOp>,
1093 BlockLoadStore1DToOCLPattern<BlockLoadOp>,
1094 BlockLoadStore1DToOCLPattern<BlockStoreOp>,
1095 LaunchConfigOpToOCLPattern<WorkitemIdXOp>,
1096 LaunchConfigOpToOCLPattern<WorkitemIdYOp>,
1097 LaunchConfigOpToOCLPattern<WorkitemIdZOp>,
1098 LaunchConfigOpToOCLPattern<WorkgroupDimXOp>,
1099 LaunchConfigOpToOCLPattern<WorkgroupDimYOp>,
1100 LaunchConfigOpToOCLPattern<WorkgroupDimZOp>,
1101 LaunchConfigOpToOCLPattern<WorkgroupIdXOp>,
1102 LaunchConfigOpToOCLPattern<WorkgroupIdYOp>,
1103 LaunchConfigOpToOCLPattern<WorkgroupIdZOp>,
1104 LaunchConfigOpToOCLPattern<GridDimXOp>,
1105 LaunchConfigOpToOCLPattern<GridDimYOp>,
1106 LaunchConfigOpToOCLPattern<GridDimZOp>,
1107 SubgroupOpWorkitemOpToOCLPattern<LaneIdOp>,
1108 SubgroupOpWorkitemOpToOCLPattern<SubgroupIdOp>,
1109 SubgroupOpWorkitemOpToOCLPattern<SubgroupSizeOp>,
1110 AllocaToGlobalPattern>(patterns.
getContext());
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Location getLoc()
The source location the operation was defined or derived from.
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Type getType() const
Return the type of this value.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes={}, Type resultType={}, bool isVarArg=false, bool isReserved=false, SymbolTableCollection *symbolTables=nullptr)
Create a FuncOp with signature resultType(paramTypes) and name name`.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateXeVMToLLVMConversionPatterns(ConversionTarget &target, RewritePatternSet &patterns)
llvm::TypeSwitch< T, ResultT > TypeSwitch
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...