17#include "llvm/Support/FormatVariadic.h"
23#include "llvm/ADT/TypeSwitch.h"
26#define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
27#include "mlir/Conversion/Passes.h.inc"
35struct LLVMFuncAttributeOptions {
36 bool isConvergent =
false;
37 bool isNoUnwind =
false;
38 bool isWillReturn =
false;
39 LLVM::MemoryEffectsAttr memEffectsAttr{};
41static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
42 false,
true,
false, {}};
43static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
44 false,
true,
true, {}};
45static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
46 true,
true,
true, {}};
48std::string getTypeMangling(
Type ty,
bool isUnsigned =
false) {
50 .Case([isUnsigned](VectorType ty) -> std::string {
51 return "Dv" + std::to_string(ty.getNumElements()) +
"_" +
52 getTypeMangling(ty.getElementType(), isUnsigned);
54 .Case([](Float16Type) -> std::string {
return "Dh"; })
55 .Case([](Float32Type) -> std::string {
return "f"; })
56 .Case([](Float64Type) -> std::string {
return "d"; })
57 .Case([isUnsigned](IntegerType ty) -> std::string {
58 switch (ty.getWidth()) {
60 return isUnsigned ?
"h" :
"c";
62 return isUnsigned ?
"t" :
"s";
64 return isUnsigned ?
"j" :
"i";
66 return isUnsigned ?
"m" :
"l";
68 llvm_unreachable(
"unhandled integer type");
71 .DefaultUnreachable(
"unhandled type for mangling");
76 assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
77 "Signedness info doesn't match");
79 llvm::raw_string_ostream os(s);
80 llvm::SmallDenseMap<Type, unsigned> substitutions;
81 os <<
"_Z" << baseName.size() << baseName;
82 for (
auto [idx, type] : llvm::enumerate(types)) {
83 auto it = substitutions.find(type);
84 if (it != substitutions.end()) {
87 if (
unsigned firstIdx = it->getSecond(); firstIdx > 0)
91 if (!type.isIntOrFloat())
92 substitutions[type] = substitutions.size();
93 os << getTypeMangling(type, isUnsigned.empty() ?
false : isUnsigned[idx]);
99static int32_t getL1CacheControl(LoadCacheControl cc) {
102 case LoadCacheControl::L1C_L2UC_L3UC:
103 case LoadCacheControl::L1C_L2UC_L3C:
104 case LoadCacheControl::L1C_L2C_L3UC:
105 case LoadCacheControl::L1C_L2C_L3C:
108 case LoadCacheControl::L1S_L2UC_L3UC:
109 case LoadCacheControl::L1S_L2UC_L3C:
110 case LoadCacheControl::L1S_L2C_L3UC:
111 case LoadCacheControl::L1S_L2C_L3C:
114 case LoadCacheControl::INVALIDATE_READ:
123static int32_t getL1CacheControl(StoreCacheControl cc) {
126 case StoreCacheControl::L1WT_L2UC_L3UC:
127 case StoreCacheControl::L1WT_L2UC_L3WB:
128 case StoreCacheControl::L1WT_L2WB_L3UC:
129 case StoreCacheControl::L1WT_L2WB_L3WB:
132 case StoreCacheControl::L1WB_L2UC_L3UC:
133 case StoreCacheControl::L1WB_L2WB_L3UC:
134 case StoreCacheControl::L1WB_L2UC_L3WB:
137 case StoreCacheControl::L1S_L2UC_L3UC:
138 case StoreCacheControl::L1S_L2UC_L3WB:
139 case StoreCacheControl::L1S_L2WB_L3UC:
140 case StoreCacheControl::L1S_L2WB_L3WB:
149static int32_t getL3CacheControl(LoadCacheControl cc) {
152 case LoadCacheControl::L1UC_L2UC_L3C:
153 case LoadCacheControl::L1UC_L2C_L3C:
154 case LoadCacheControl::L1C_L2UC_L3C:
155 case LoadCacheControl::L1C_L2C_L3C:
156 case LoadCacheControl::L1S_L2UC_L3C:
157 case LoadCacheControl::L1S_L2C_L3C:
160 case LoadCacheControl::INVALIDATE_READ:
169static int32_t getL3CacheControl(StoreCacheControl cc) {
172 case StoreCacheControl::L1UC_L2UC_L3WB:
173 case StoreCacheControl::L1UC_L2WB_L3WB:
174 case StoreCacheControl::L1WT_L2UC_L3WB:
175 case StoreCacheControl::L1WT_L2WB_L3WB:
176 case StoreCacheControl::L1S_L2UC_L3WB:
177 case StoreCacheControl::L1S_L2WB_L3WB:
178 case StoreCacheControl::L1WB_L2UC_L3WB:
187static std::optional<LoadCacheControl> getCacheControl(PrefetchOp op) {
188 return op.getCacheControl();
191static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) {
192 return op.getCacheControl();
195static std::optional<LoadCacheControl> getCacheControl(BlockLoadOp op) {
196 return op.getCacheControl();
199static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) {
200 return op.getCacheControl();
203static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) {
204 return op.getCacheControl();
207static std::optional<StoreCacheControl> getCacheControl(BlockStoreOp op) {
208 return op.getCacheControl();
211static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) {
212 if (op->hasAttr(
"cache_control")) {
213 auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>(
"cache_control");
216 return std::optional<LoadCacheControl>(attr.getValue());
221static std::optional<StoreCacheControl> getCacheControl(LLVM::StoreOp op) {
222 if (op->hasAttr(
"cache_control")) {
223 auto attr = op->getAttrOfType<xevm::StoreCacheControlAttr>(
"cache_control");
226 return std::optional<StoreCacheControl>(attr.getValue());
231template <
typename OpType>
232int32_t getL1CacheControl(OpType op) {
233 return getL1CacheControl(*getCacheControl(op));
236template <
typename OpType>
237int32_t getL3CacheControl(OpType op) {
238 return getL3CacheControl(*getCacheControl(op));
241template <
typename OpType>
242static std::optional<ArrayAttr>
243getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
244 if (!getCacheControl(op))
246 constexpr int32_t decorationCacheControlArity{3};
247 constexpr int32_t loadCacheControlKey{6442};
248 constexpr int32_t storeCacheControlKey{6443};
249 constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
250 std::is_same_v<OpType, BlockPrefetch2dOp> ||
251 std::is_same_v<OpType, LLVM::LoadOp> ||
252 std::is_same_v<OpType, BlockLoadOp> ||
253 std::is_same_v<OpType, PrefetchOp>;
254 const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
256 controlKey, 0, getL1CacheControl<OpType>(op)};
258 controlKey, 1, getL3CacheControl<OpType>(op)};
259 auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
260 auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
263 return rewriter.getArrayAttr(combinedAttrs);
266static LLVM::CallOp createDeviceFunctionCall(
267 ConversionPatternRewriter &rewriter, StringRef funcName,
Type retType,
270 LLVMFuncAttributeOptions funcAttributeOptions,
Operation *op) {
272 assert(moduleOp &&
"Expecting module");
277 assert(!
failed(funcOpRes));
278 LLVM::LLVMFuncOp funcOp = funcOpRes.value();
279 funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
280 funcOp.setConvergent(funcAttributeOptions.isConvergent);
281 funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
282 funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
284 if (funcAttributeOptions.memEffectsAttr)
285 funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
287 for (
auto [idx, attrName] : paramAttrs)
288 funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
290 auto callOp = LLVM::CallOp::create(rewriter, loc, funcOp, args);
291 callOp->setAttrs(funcOp->getAttrs());
296class MMAToOCLPattern :
public OpConversionPattern<xevm::MMAOp> {
297 using OpConversionPattern::OpConversionPattern;
299 matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor,
300 ConversionPatternRewriter &rewriter)
const override {
302 return rewriter.notifyMatchFailure(op,
"OCL requires C operand");
304 auto precisionA = op.getTypes().getA();
305 auto precisionB = op.getTypes().getB();
306 auto precisionC = op.getTypes().getC();
307 auto precisionD = op.getTypes().getD();
308 if (precisionC != precisionD) {
309 return rewriter.notifyMatchFailure(op,
"type of C and D need to match");
311 if (precisionC != xevm::ElemType::S32 &&
312 precisionC != xevm::ElemType::F32 &&
313 precisionC != xevm::ElemType::F16 &&
314 precisionC != xevm::ElemType::BF16) {
315 return rewriter.notifyMatchFailure(
316 op,
"type of C and D must be S32, F32, F16 or BF16");
318 if (precisionA == xevm::ElemType::S32 ||
319 precisionA == xevm::ElemType::F32) {
320 return rewriter.notifyMatchFailure(op,
"type of A cannot be S32 or F32");
322 if (precisionB == xevm::ElemType::S32 ||
323 precisionB == xevm::ElemType::F32) {
324 return rewriter.notifyMatchFailure(op,
"type of B cannot be S32 or F32");
326 constexpr uint32_t bitWidthPackedA{16};
327 constexpr uint32_t bitWidthPackedB{32};
328 auto loc = op.getLoc();
330 auto castIfNeeded = [&](Value val, Type packedType) -> Value {
331 VectorType origTy = cast<VectorType>(val.
getType());
332 const uint32_t vecBitSize =
333 origTy.getNumElements() *
334 origTy.getElementType().getIntOrFloatBitWidth();
335 VectorType newTy = VectorType::get(
336 vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
338 val = LLVM::BitcastOp::create(rewriter, loc, newTy, val);
343 Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
344 ? cast<Type>(rewriter.getF32Type())
345 : rewriter.getIntegerType(bitWidthPackedA);
346 a = castIfNeeded(a, packedAType);
349 Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
350 ? cast<Type>(rewriter.getF32Type())
351 : rewriter.getIntegerType(bitWidthPackedB);
352 b = castIfNeeded(
b, packedBType);
355 VectorType cOrigTy = cast<VectorType>(c.
getType());
356 VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
357 assert(cOrigTy == resOrigTy &&
"Accumulator and result type mismatch");
360 cOrigTy.getElementType().isBF16()
361 ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
363 VectorType resTy = cTy;
365 c = LLVM::BitcastOp::create(rewriter, loc, cTy, c);
367 constexpr int32_t systolicDepth{8};
369 llvm::formatv(
"intel_sub_group_{0}_{1}_matrix_mad_k{2}",
370 stringifyElemType(op.getTypes().getA()).str(),
371 stringifyElemType(op.getTypes().getB()).str(),
373 getNumOperandsPerDword(op.getTypes().getA()))
375 SmallVector<Type> argTypes{a.
getType(),
b.getType(), cTy};
376 fnName = mangle(fnName, argTypes);
377 SmallVector<Value> args{a,
b, c};
379 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
380 LLVM::ModRefInfo::NoModRef,
381 LLVM::ModRefInfo::NoModRef,
382 LLVM::ModRefInfo::NoModRef,
383 LLVM::ModRefInfo::NoModRef,
384 LLVM::ModRefInfo::NoModRef,
385 LLVM::ModRefInfo::NoModRef);
386 auto funcAttrs = convergentNoUnwindWillReturnAttrs;
387 funcAttrs.memEffectsAttr = memAttr;
389 createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
390 funcAttrs, op.getOperation())
393 if (resOrigTy != resTy)
394 result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy,
result);
396 rewriter.replaceOp(op,
result);
401 static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
403 case xevm::ElemType::TF32:
405 case xevm::ElemType::BF16:
406 case xevm::ElemType::F16:
408 case xevm::ElemType::U8:
409 case xevm::ElemType::S8:
412 llvm_unreachable(
"unsupported xevm::ElemType");
417class PrefetchToOCLPattern :
public OpConversionPattern<PrefetchOp> {
418 using OpConversionPattern::OpConversionPattern;
420 matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor,
421 ConversionPatternRewriter &rewriter)
const override {
422 auto loc = op.getLoc();
423 const std::string fnName{
"_Z8prefetchPU3AS1Kcm"};
425 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), 1);
426 SmallVector<Value> args{op.getPtr(), one};
427 SmallVector<Type> argTypes;
428 for (
auto arg : args)
429 argTypes.push_back(arg.getType());
430 auto funcAttr = noUnwindAttrs;
431 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
432 LLVM::ModRefInfo::NoModRef,
433 LLVM::ModRefInfo::Ref,
434 LLVM::ModRefInfo::NoModRef,
435 LLVM::ModRefInfo::NoModRef,
436 LLVM::ModRefInfo::NoModRef,
437 LLVM::ModRefInfo::NoModRef);
438 funcAttr.memEffectsAttr = memAttr;
440 LLVM::CallOp call = createDeviceFunctionCall(
441 rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()),
442 argTypes, args, {}, funcAttr, op.getOperation());
443 if (std::optional<ArrayAttr> optCacheControls =
444 getCacheControlMetadata(rewriter, op))
445 call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
446 rewriter.eraseOp(op);
451class MemfenceToOCLPattern :
public OpConversionPattern<MemfenceOp> {
452 using OpConversionPattern::OpConversionPattern;
454 matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor,
455 ConversionPatternRewriter &rewriter)
const override {
456 auto loc = op.getLoc();
457 const std::string fnName{
"atomic_work_item_fence"};
458 int memScope, addrSpace;
459 switch (op.getAddrspace()) {
460 case xevm::AddrSpace::SHARED:
463 case xevm::AddrSpace::GLOBAL:
468 return rewriter.notifyMatchFailure(
469 op,
"Fence only supports global and shared address spaces.");
471 switch (op.getScope()) {
472 case xevm::MemScope::WORKGROUP:
475 case xevm::MemScope::DEVICE:
480 return rewriter.notifyMatchFailure(
481 op,
"Fence only supports workgroup and device memory scopes.");
483 Type i32Type = rewriter.getI32Type();
484 Value acqRel = LLVM::ConstantOp::create(rewriter, loc, i32Type, 4);
485 Value memScopeConst =
486 LLVM::ConstantOp::create(rewriter, loc, i32Type, memScope);
487 Value addrSpaceConst =
488 LLVM::ConstantOp::create(rewriter, loc, i32Type, addrSpace);
489 SmallVector<Value> args{addrSpaceConst, acqRel, memScopeConst};
490 SmallVector<Type> argTypes{3, i32Type};
491 createDeviceFunctionCall(rewriter, mangle(fnName, argTypes),
492 LLVM::LLVMVoidType::get(rewriter.getContext()),
493 argTypes, args, {}, noUnwindAttrs,
495 rewriter.eraseOp(op);
499template <
typename OpType>
500class LoadStorePrefetchToOCLPattern :
public OpConversionPattern<OpType> {
501 using OpConversionPattern<OpType>::OpConversionPattern;
503 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
504 ConversionPatternRewriter &rewriter)
const override {
505 constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp>;
506 constexpr bool isPrefetch = std::is_same_v<OpType, BlockPrefetch2dOp>;
508 auto loc = op.getLoc();
510 bool packReg =
false;
511 bool transpose =
false;
512 if constexpr (isLoad) {
513 vecType = op.getRes().getType();
514 packReg = op.getPackRegister();
515 transpose = op.getTranspose();
516 }
else if constexpr (!isPrefetch) {
517 vecType = op.getStoredVal().getType();
520 auto i32Type = rewriter.getI32Type();
522 LLVM::UndefOp::create(rewriter, loc, VectorType::get(2, i32Type));
523 Value zero = LLVM::ConstantOp::create(rewriter, loc, i32Type, 0);
524 Value one = LLVM::ConstantOp::create(rewriter, loc, i32Type, 1);
525 byteCoord = LLVM::InsertElementOp::create(
526 rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
527 byteCoord = LLVM::InsertElementOp::create(
528 rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
529 SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
530 op.getBasePitch(), byteCoord};
531 SmallVector<Type> retTypes;
533 std::string funcName{
"intel_sub_group_2d_block_"};
534 std::string bitWidthId;
535 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
536 SmallVector<std::pair<unsigned, StringRef>, 4> paramAttrs;
537 if constexpr (isPrefetch) {
538 funcName +=
"prefetch";
539 paramAttrs = {std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName())};
540 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
541 LLVM::ModRefInfo::NoModRef,
542 LLVM::ModRefInfo::Ref,
543 LLVM::ModRefInfo::NoModRef,
544 LLVM::ModRefInfo::NoModRef,
545 LLVM::ModRefInfo::NoModRef,
546 LLVM::ModRefInfo::NoModRef);
547 funcAttr = noUnwindAttrs;
548 funcAttr.memEffectsAttr = memAttr;
550 auto vecElemType = vecType.getElementType();
551 auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth();
552 Value numElems = LLVM::ConstantOp::create(rewriter, loc, i32Type,
553 vecType.getNumElements());
554 auto dstOrSrcPtr = LLVM::AllocaOp::create(
555 rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()),
556 vecElemType, numElems);
557 args.push_back(dstOrSrcPtr);
558 if constexpr (isLoad) {
560 bitWidthId = getTypeMangling(vecElemType,
true);
562 funcName +=
"_transform";
564 funcName +=
"_transpose";
565 spvLoadDstPtr = dstOrSrcPtr;
566 retTypes.push_back(vecType);
568 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
569 std::make_pair(0, LLVM::LLVMDialect::getReadonlyAttrName()),
570 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
571 std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()),
575 bitWidthId = (vecElemBitWidth == 32)
577 : ((vecElemBitWidth == 16) ?
"t" :
"h");
578 LLVM::StoreOp::create(rewriter, loc, op.getStoredVal(), dstOrSrcPtr);
580 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
581 std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()),
582 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
583 std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()),
589 llvm::formatv(
"{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(),
590 op.getTileHeight(), op.getTileWidth(), op.getVBlocks())
592 std::string prefetchCode(
"");
595 funcName = llvm::formatv(
"_Z{0}{1}PU3AS1viiiDv2_i{2}{3}", funcName.size(),
596 funcName, prefetchCode, bitWidthId)
598 SmallVector<Type> argTypes;
599 for (
auto arg : args) {
600 argTypes.push_back(arg.getType());
602 LLVM::CallOp call = createDeviceFunctionCall(
603 rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
604 argTypes, args, paramAttrs, funcAttr, op.getOperation());
605 if (std::optional<ArrayAttr> optCacheControls =
606 getCacheControlMetadata(rewriter, op)) {
607 call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
609 if constexpr (isLoad)
611 op, LLVM::LoadOp::create(rewriter, loc, vecType, spvLoadDstPtr));
613 rewriter.eraseOp(op);
618template <
typename OpType>
619class BlockLoadStore1DToOCLPattern :
public OpConversionPattern<OpType> {
620 using OpConversionPattern<OpType>::OpConversionPattern;
622 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
623 ConversionPatternRewriter &rewriter)
const override {
624 constexpr bool isStore = std::is_same_v<OpType, xevm::BlockStoreOp>;
628 std::string funcName{
"intel_sub_group_block_"};
631 if constexpr (isStore) {
632 funcName +=
"write_u";
633 valOrResTy = op.getVal().getType();
635 funcName +=
"read_u";
636 valOrResTy = op.getType();
639 VectorType vecTy = dyn_cast<VectorType>(valOrResTy);
640 Type elemType = vecTy ? vecTy.getElementType() : valOrResTy;
641 funcName += getTypeMangling(elemType);
643 funcName += std::to_string(vecTy.getNumElements());
644 SmallVector<Type, 2> argTypes{};
648 SmallVector<bool, 2> isUnsigned{};
652 SmallVector<Value, 2> args{};
653 args.push_back(op.getPtr());
654 argTypes.push_back(op.getPtr().getType());
655 isUnsigned.push_back(
true);
657 if constexpr (isStore) {
658 args.push_back(op.getVal());
659 argTypes.push_back(op.getVal().getType());
660 isUnsigned.push_back(
true);
661 retType = LLVM::LLVMVoidType::get(rewriter.getContext());
663 retType = valOrResTy;
665 funcName = std::string(
"_Z") + std::to_string(funcName.size()) + funcName +
667 std::to_string(op.getPtr().getType().getAddressSpace());
668 funcName += getTypeMangling(elemType,
true);
669 if constexpr (isStore)
670 funcName += getTypeMangling(valOrResTy,
true);
671 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
674 createDeviceFunctionCall(rewriter, funcName, retType, argTypes, args,
675 {}, funcAttr, op.getOperation());
676 if (std::optional<ArrayAttr> optCacheControls =
677 getCacheControlMetadata(rewriter, op)) {
678 call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
680 if constexpr (isStore)
681 rewriter.eraseOp(op);
683 rewriter.replaceOp(op, call->getResult(0));
688template <
typename OpType>
689class LLVMLoadStoreToOCLPattern :
public OpConversionPattern<OpType> {
690 using OpConversionPattern<OpType>::OpConversionPattern;
692 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
693 ConversionPatternRewriter &rewriter)
const override {
694 if (!op->hasAttr(
"cache_control"))
696 std::optional<ArrayAttr> optCacheControls =
697 getCacheControlMetadata(rewriter, op);
698 op->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
699 op->removeAttr(
"cache_control");
731static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdXOp) {
732 return {
"get_local_id", 0};
734static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdYOp) {
735 return {
"get_local_id", 1};
737static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdZOp) {
738 return {
"get_local_id", 2};
740static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimXOp) {
741 return {
"get_local_size", 0};
743static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimYOp) {
744 return {
"get_local_size", 1};
746static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimZOp) {
747 return {
"get_local_size", 2};
749static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdXOp) {
750 return {
"get_group_id", 0};
752static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdYOp) {
753 return {
"get_group_id", 1};
755static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdZOp) {
756 return {
"get_group_id", 2};
758static std::pair<StringRef, int64_t> getConfig(xevm::GridDimXOp) {
759 return {
"get_num_groups", 0};
761static std::pair<StringRef, int64_t> getConfig(xevm::GridDimYOp) {
762 return {
"get_num_groups", 1};
764static std::pair<StringRef, int64_t> getConfig(xevm::GridDimZOp) {
765 return {
"get_num_groups", 2};
769template <
typename OpType>
770class LaunchConfigOpToOCLPattern :
public OpConversionPattern<OpType> {
771 using OpConversionPattern<OpType>::OpConversionPattern;
773 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
774 ConversionPatternRewriter &rewriter)
const override {
775 Location loc = op->getLoc();
776 auto [baseName, dim] = getConfig(op);
777 Type dimTy = rewriter.getI32Type();
778 Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
779 static_cast<int64_t
>(dim));
780 std::string func = mangle(baseName, {dimTy}, {
true});
781 Type resTy = op.getType();
783 createDeviceFunctionCall(rewriter, func, resTy, {dimTy}, {dimVal}, {},
784 noUnwindWillReturnAttrs, op.getOperation());
785 constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
786 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
792 call.setMemoryEffectsAttr(memAttr);
793 rewriter.replaceOp(op, call);
810static StringRef getConfig(xevm::LaneIdOp) {
return "get_sub_group_local_id"; }
811static StringRef getConfig(xevm::SubgroupIdOp) {
return "get_sub_group_id"; }
812static StringRef getConfig(xevm::SubgroupSizeOp) {
813 return "get_sub_group_size";
815template <
typename OpType>
816class SubgroupOpWorkitemOpToOCLPattern :
public OpConversionPattern<OpType> {
817 using OpConversionPattern<OpType>::OpConversionPattern;
819 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
820 ConversionPatternRewriter &rewriter)
const override {
821 std::string func = mangle(getConfig(op).str(), {});
822 Type resTy = op.getType();
824 createDeviceFunctionCall(rewriter, func, resTy, {}, {}, {},
825 noUnwindWillReturnAttrs, op.getOperation());
826 constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
827 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
833 call.setMemoryEffectsAttr(memAttr);
834 rewriter.replaceOp(op, call);
839static bool isExtractingContiguousSlice(LLVM::ShuffleVectorOp op) {
840 if (op.getV1() != op.getV2())
842 auto maskAttr = op.getMask();
843 int64_t firstIndex = maskAttr[0];
844 for (
int64_t i = 1; i < static_cast<int64_t>(maskAttr.size()); ++i) {
846 if (
index != firstIndex + i)
860class HandleVectorExtractPattern
862 using OpRewritePattern<LLVM::ShuffleVectorOp>::OpRewritePattern;
864 void initialize() { setHasBoundedRewriteRecursion(); }
866 LogicalResult matchAndRewrite(LLVM::ShuffleVectorOp op,
867 PatternRewriter &rewriter)
const override {
869 if (!isExtractingContiguousSlice(op))
872 auto mask = op.getMask();
873 auto loc = op.getLoc();
874 auto ty = op.getType();
876 auto src = op.getV1();
878 if (
auto srcOp = src.getDefiningOp()) {
879 if (isa<LLVM::FPExtOp>(srcOp) || isa<LLVM::FPTruncOp>(srcOp)) {
880 Value srcInput = srcOp->getOperand(0);
882 auto srcVecTy = dyn_cast<VectorType>(srcInput.
getType());
883 auto newShuffleVecTy =
884 VectorType::get(mask.size(), srcVecTy.getElementType());
885 auto newShuffle = LLVM::ShuffleVectorOp::create(
886 rewriter, loc, newShuffleVecTy, srcInput, srcInput, mask);
889 if (isa<LLVM::FPExtOp>(srcOp)) {
890 newUnaryOp = LLVM::FPExtOp::create(rewriter, loc, ty, newShuffle);
892 newUnaryOp = LLVM::FPTruncOp::create(rewriter, loc, ty, newShuffle);
895 }
else if (isa<LLVM::BitcastOp>(srcOp)) {
896 Value srcInput = srcOp->getOperand(0);
898 auto srcInputVecTy = dyn_cast<VectorType>(srcInput.
getType());
899 auto srcInputSize = srcInputVecTy.getNumElements();
900 auto srcResVecTy = dyn_cast<VectorType>(srcOp->getResult(0).getType());
901 auto srcResSize = srcResVecTy.getNumElements();
902 auto maskSize =
static_cast<int32_t
>(mask.size());
903 if (srcInputSize > srcResSize) {
906 if (srcResSize % srcInputSize != 0) {
909 auto maskScale = srcResSize / srcInputSize;
910 if (maskScale != 1) {
911 if (mask[0] % maskScale != 0) {
915 SmallVector<int32_t> newMask;
916 int32_t newMaskSize = maskSize / maskScale;
917 int32_t maskStart = mask[0] / maskScale;
918 for (int32_t i = 0; i < newMaskSize; ++i) {
919 newMask.push_back(maskStart + i);
923 auto newShuffleVecTy =
924 VectorType::get(srcInputSize, srcInputVecTy.getElementType());
925 auto newShuffle = LLVM::ShuffleVectorOp::create(
926 rewriter, loc, newShuffleVecTy, srcInput, srcInput, mask);
929 LLVM::BitcastOp::create(rewriter, loc, ty, newShuffle);
931 }
else if (isa<LLVM::ShuffleVectorOp>(srcOp)) {
933 auto srcShuffle = cast<LLVM::ShuffleVectorOp>(srcOp);
934 auto srcMask = srcShuffle.getMask();
935 SmallVector<int32_t> combinedMask;
936 for (
auto index : mask) {
937 combinedMask.push_back(srcMask[index]);
939 auto newShuffle = LLVM::ShuffleVectorOp::create(
940 rewriter, loc, ty, srcShuffle.getV1(), srcShuffle.getV1(),
943 }
else if (isa<LLVM::LoadOp>(srcOp)) {
945 auto loadOp = cast<LLVM::LoadOp>(srcOp);
946 auto loadPtr = loadOp.getAddr();
947 auto loadTy = dyn_cast<VectorType>(loadOp.getType());
948 auto elemTy = loadTy.getElementType();
949 auto firstIndex = mask[0];
950 auto newVecTy = VectorType::get(mask.size(), elemTy);
953 auto newPtr = LLVM::GEPOp::create(
955 LLVM::LLVMPointerType::get(rewriter.
getContext(),
956 loadPtr.getType().getAddressSpace()),
957 elemTy, loadPtr, ArrayRef<LLVM::GEPArg>{firstIndex});
958 auto newLoad = LLVM::LoadOp::create(rewriter, loc, newVecTy, newPtr);
961 auto newLoad = LLVM::LoadOp::create(rewriter, loc, newVecTy, loadPtr);
976struct ConvertXeVMToLLVMPass
980 void getDependentDialects(DialectRegistry ®istry)
const override {
981 registry.
insert<LLVM::LLVMDialect, XeVMDialect>();
984 void runOnOperation()
override {
988 if (
failed(applyPartialConversion(getOperation(),
target,
994 RewritePatternSet vectorPatterns(&
getContext());
995 vectorPatterns.add<HandleVectorExtractPattern>(&
getContext());
996 GreedyRewriteConfig
config{};
1001 config.enableFolding(
false);
1017 target.addDynamicallyLegalDialect<LLVM::LLVMDialect>(
1018 [](
Operation *op) {
return !op->hasAttr(
"cache_control"); });
1019 target.addIllegalDialect<XeVMDialect>();
1020 patterns.add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
1021 LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
1022 LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
1023 MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern,
1024 LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
1025 LLVMLoadStoreToOCLPattern<LLVM::StoreOp>,
1026 BlockLoadStore1DToOCLPattern<BlockLoadOp>,
1027 BlockLoadStore1DToOCLPattern<BlockStoreOp>,
1028 LaunchConfigOpToOCLPattern<WorkitemIdXOp>,
1029 LaunchConfigOpToOCLPattern<WorkitemIdYOp>,
1030 LaunchConfigOpToOCLPattern<WorkitemIdZOp>,
1031 LaunchConfigOpToOCLPattern<WorkgroupDimXOp>,
1032 LaunchConfigOpToOCLPattern<WorkgroupDimYOp>,
1033 LaunchConfigOpToOCLPattern<WorkgroupDimZOp>,
1034 LaunchConfigOpToOCLPattern<WorkgroupIdXOp>,
1035 LaunchConfigOpToOCLPattern<WorkgroupIdYOp>,
1036 LaunchConfigOpToOCLPattern<WorkgroupIdZOp>,
1037 LaunchConfigOpToOCLPattern<GridDimXOp>,
1038 LaunchConfigOpToOCLPattern<GridDimYOp>,
1039 LaunchConfigOpToOCLPattern<GridDimZOp>,
1040 SubgroupOpWorkitemOpToOCLPattern<LaneIdOp>,
1041 SubgroupOpWorkitemOpToOCLPattern<SubgroupIdOp>,
1042 SubgroupOpWorkitemOpToOCLPattern<SubgroupSizeOp>>(
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Location getLoc()
The source location the operation was defined or derived from.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Type getType() const
Return the type of this value.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes={}, Type resultType={}, bool isVarArg=false, bool isReserved=false, SymbolTableCollection *symbolTables=nullptr)
Create a FuncOp with signature resultType(paramTypes) and name name`.
Include the generated interface declarations.
const FrozenRewritePatternSet GreedyRewriteConfig config
LogicalResult applyPatternsGreedily(Region ®ion, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateXeVMToLLVMConversionPatterns(ConversionTarget &target, RewritePatternSet &patterns)
const FrozenRewritePatternSet & patterns
llvm::TypeSwitch< T, ResultT > TypeSwitch
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...