18#include "llvm/Support/FormatVariadic.h"
23#include "llvm/ADT/TypeSwitch.h"
26#define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
27#include "mlir/Conversion/Passes.h.inc"
35struct LLVMFuncAttributeOptions {
36 bool isConvergent =
false;
37 bool isNoUnwind =
false;
38 bool isWillReturn =
false;
39 LLVM::MemoryEffectsAttr memEffectsAttr{};
41static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
42 false,
true,
false, {}};
43static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
44 false,
true,
true, {}};
45static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
46 true,
true,
true, {}};
48std::string getTypeMangling(
Type ty,
bool isUnsigned =
false) {
50 .Case([isUnsigned](VectorType ty) -> std::string {
51 return "Dv" + std::to_string(ty.getNumElements()) +
"_" +
52 getTypeMangling(ty.getElementType(), isUnsigned);
54 .Case([](Float16Type) -> std::string {
return "Dh"; })
55 .Case([](Float32Type) -> std::string {
return "f"; })
56 .Case([](Float64Type) -> std::string {
return "d"; })
57 .Case([isUnsigned](IntegerType ty) -> std::string {
58 switch (ty.getWidth()) {
60 return isUnsigned ?
"h" :
"c";
62 return isUnsigned ?
"t" :
"s";
64 return isUnsigned ?
"j" :
"i";
66 return isUnsigned ?
"m" :
"l";
68 llvm_unreachable(
"unhandled integer type");
71 .DefaultUnreachable(
"unhandled type for mangling");
76 assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
77 "Signedness info doesn't match");
79 llvm::raw_string_ostream os(s);
80 llvm::SmallDenseMap<Type, unsigned> substitutions;
81 os <<
"_Z" << baseName.size() << baseName;
82 for (
auto [idx, type] : llvm::enumerate(types)) {
83 auto it = substitutions.find(type);
84 if (it != substitutions.end()) {
87 if (
unsigned firstIdx = it->getSecond(); firstIdx > 0)
91 if (!type.isIntOrFloat())
92 substitutions[type] = substitutions.size();
93 os << getTypeMangling(type, isUnsigned.empty() ?
false : isUnsigned[idx]);
99static int32_t getL1CacheControl(LoadCacheControl cc) {
102 case LoadCacheControl::L1UC_L2UC_L3UC:
103 case LoadCacheControl::L1UC_L2UC_L3C:
104 case LoadCacheControl::L1UC_L2C_L3UC:
105 case LoadCacheControl::L1UC_L2C_L3C:
108 case LoadCacheControl::L1C_L2UC_L3UC:
109 case LoadCacheControl::L1C_L2UC_L3C:
110 case LoadCacheControl::L1C_L2C_L3UC:
111 case LoadCacheControl::L1C_L2C_L3C:
114 case LoadCacheControl::L1S_L2UC_L3UC:
115 case LoadCacheControl::L1S_L2UC_L3C:
116 case LoadCacheControl::L1S_L2C_L3UC:
117 case LoadCacheControl::L1S_L2C_L3C:
120 case LoadCacheControl::INVALIDATE_READ:
127static int32_t getL1CacheControl(StoreCacheControl cc) {
130 case StoreCacheControl::L1UC_L2UC_L3UC:
131 case StoreCacheControl::L1UC_L2UC_L3WB:
132 case StoreCacheControl::L1UC_L2WB_L3UC:
133 case StoreCacheControl::L1UC_L2WB_L3WB:
136 case StoreCacheControl::L1WT_L2UC_L3UC:
137 case StoreCacheControl::L1WT_L2UC_L3WB:
138 case StoreCacheControl::L1WT_L2WB_L3UC:
139 case StoreCacheControl::L1WT_L2WB_L3WB:
142 case StoreCacheControl::L1S_L2UC_L3UC:
143 case StoreCacheControl::L1S_L2UC_L3WB:
144 case StoreCacheControl::L1S_L2WB_L3UC:
145 case StoreCacheControl::L1S_L2WB_L3WB:
148 case StoreCacheControl::L1WB_L2UC_L3UC:
149 case StoreCacheControl::L1WB_L2WB_L3UC:
150 case StoreCacheControl::L1WB_L2UC_L3WB:
157static int32_t getL3CacheControl(LoadCacheControl cc) {
160 case LoadCacheControl::L1UC_L2UC_L3UC:
161 case LoadCacheControl::L1UC_L2C_L3UC:
162 case LoadCacheControl::L1C_L2UC_L3UC:
163 case LoadCacheControl::L1C_L2C_L3UC:
164 case LoadCacheControl::L1S_L2UC_L3UC:
165 case LoadCacheControl::L1S_L2C_L3UC:
168 case LoadCacheControl::L1UC_L2UC_L3C:
169 case LoadCacheControl::L1UC_L2C_L3C:
170 case LoadCacheControl::L1C_L2UC_L3C:
171 case LoadCacheControl::L1C_L2C_L3C:
172 case LoadCacheControl::L1S_L2UC_L3C:
173 case LoadCacheControl::L1S_L2C_L3C:
176 case LoadCacheControl::INVALIDATE_READ:
183static int32_t getL3CacheControl(StoreCacheControl cc) {
186 case StoreCacheControl::L1UC_L2UC_L3UC:
187 case StoreCacheControl::L1UC_L2WB_L3UC:
188 case StoreCacheControl::L1WT_L2UC_L3UC:
189 case StoreCacheControl::L1WT_L2WB_L3UC:
190 case StoreCacheControl::L1S_L2UC_L3UC:
191 case StoreCacheControl::L1S_L2WB_L3UC:
192 case StoreCacheControl::L1WB_L2UC_L3UC:
193 case StoreCacheControl::L1WB_L2WB_L3UC:
196 case StoreCacheControl::L1UC_L2UC_L3WB:
197 case StoreCacheControl::L1UC_L2WB_L3WB:
198 case StoreCacheControl::L1WT_L2UC_L3WB:
199 case StoreCacheControl::L1WT_L2WB_L3WB:
200 case StoreCacheControl::L1S_L2UC_L3WB:
201 case StoreCacheControl::L1S_L2WB_L3WB:
202 case StoreCacheControl::L1WB_L2UC_L3WB:
209static std::optional<LoadCacheControl> getCacheControl(PrefetchOp op) {
210 return op.getCacheControl();
213static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) {
214 return op.getCacheControl();
217static std::optional<LoadCacheControl> getCacheControl(BlockLoadOp op) {
218 return op.getCacheControl();
221static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) {
222 return op.getCacheControl();
225static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) {
226 return op.getCacheControl();
229static std::optional<StoreCacheControl> getCacheControl(BlockStoreOp op) {
230 return op.getCacheControl();
233static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) {
234 if (op->hasAttr(
"cache_control")) {
235 auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>(
"cache_control");
238 return std::optional<LoadCacheControl>(attr.getValue());
243static std::optional<StoreCacheControl> getCacheControl(LLVM::StoreOp op) {
244 if (op->hasAttr(
"cache_control")) {
245 auto attr = op->getAttrOfType<xevm::StoreCacheControlAttr>(
"cache_control");
248 return std::optional<StoreCacheControl>(attr.getValue());
253template <
typename OpType>
254int32_t getL1CacheControl(OpType op) {
255 return getL1CacheControl(*getCacheControl(op));
258template <
typename OpType>
259int32_t getL3CacheControl(OpType op) {
260 return getL3CacheControl(*getCacheControl(op));
263template <
typename OpType>
264static std::optional<ArrayAttr>
265getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
266 if (!getCacheControl(op))
268 constexpr int32_t decorationCacheControlArity{4};
269 constexpr int32_t loadCacheControlKey{6442};
270 constexpr int32_t storeCacheControlKey{6443};
271 constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
272 std::is_same_v<OpType, BlockPrefetch2dOp> ||
273 std::is_same_v<OpType, LLVM::LoadOp> ||
274 std::is_same_v<OpType, BlockLoadOp> ||
275 std::is_same_v<OpType, PrefetchOp>;
276 const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
278 controlKey, 0, getL1CacheControl<OpType>(op), 0};
280 controlKey, 1, getL3CacheControl<OpType>(op), 0};
281 auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
282 auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
285 return rewriter.getArrayAttr(combinedAttrs);
288static LLVM::CallOp createDeviceFunctionCall(
289 ConversionPatternRewriter &rewriter, StringRef funcName,
Type retType,
292 LLVMFuncAttributeOptions funcAttributeOptions,
Operation *op) {
294 assert(moduleOp &&
"Expecting module");
299 assert(!
failed(funcOpRes));
300 LLVM::LLVMFuncOp funcOp = funcOpRes.value();
301 funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
302 funcOp.setConvergent(funcAttributeOptions.isConvergent);
303 funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
304 funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
306 if (funcAttributeOptions.memEffectsAttr)
307 funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
309 for (
auto [idx, attrName] : paramAttrs)
310 funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
312 auto callOp = LLVM::CallOp::create(rewriter, loc, funcOp, args);
313 callOp->setAttrs(funcOp->getAttrs());
318class MMAToOCLPattern :
public OpConversionPattern<xevm::MMAOp> {
319 using OpConversionPattern::OpConversionPattern;
321 matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor,
322 ConversionPatternRewriter &rewriter)
const override {
324 return rewriter.notifyMatchFailure(op,
"OCL requires C operand");
326 auto precisionA = op.getTypes().getA();
327 auto precisionB = op.getTypes().getB();
328 auto precisionC = op.getTypes().getC();
329 auto precisionD = op.getTypes().getD();
330 if (precisionC != precisionD) {
331 return rewriter.notifyMatchFailure(op,
"type of C and D need to match");
333 if (precisionC != xevm::ElemType::S32 &&
334 precisionC != xevm::ElemType::F32 &&
335 precisionC != xevm::ElemType::F16 &&
336 precisionC != xevm::ElemType::BF16) {
337 return rewriter.notifyMatchFailure(
338 op,
"type of C and D must be S32, F32, F16 or BF16");
340 if (precisionA == xevm::ElemType::S32 ||
341 precisionA == xevm::ElemType::F32) {
342 return rewriter.notifyMatchFailure(op,
"type of A cannot be S32 or F32");
344 if (precisionB == xevm::ElemType::S32 ||
345 precisionB == xevm::ElemType::F32) {
346 return rewriter.notifyMatchFailure(op,
"type of B cannot be S32 or F32");
348 constexpr uint32_t bitWidthPackedA{16};
349 constexpr uint32_t bitWidthPackedB{32};
350 auto loc = op.getLoc();
352 auto castIfNeeded = [&](Value val, Type packedType) -> Value {
353 VectorType origTy = cast<VectorType>(val.
getType());
354 const uint32_t vecBitSize =
355 origTy.getNumElements() *
356 origTy.getElementType().getIntOrFloatBitWidth();
357 VectorType newTy = VectorType::get(
358 vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
360 val = LLVM::BitcastOp::create(rewriter, loc, newTy, val);
365 Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
366 ? cast<Type>(rewriter.getF32Type())
367 : rewriter.getIntegerType(bitWidthPackedA);
368 a = castIfNeeded(a, packedAType);
371 Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
372 ? cast<Type>(rewriter.getF32Type())
373 : rewriter.getIntegerType(bitWidthPackedB);
374 b = castIfNeeded(
b, packedBType);
377 VectorType cOrigTy = cast<VectorType>(c.
getType());
378 VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
379 assert(cOrigTy == resOrigTy &&
"Accumulator and result type mismatch");
382 cOrigTy.getElementType().isBF16()
383 ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
385 VectorType resTy = cTy;
387 c = LLVM::BitcastOp::create(rewriter, loc, cTy, c);
389 constexpr int32_t systolicDepth{8};
391 llvm::formatv(
"intel_sub_group_{0}_{1}_matrix_mad_k{2}",
392 stringifyElemType(op.getTypes().getA()).str(),
393 stringifyElemType(op.getTypes().getB()).str(),
395 getNumOperandsPerDword(op.getTypes().getA()))
397 SmallVector<Type> argTypes{a.
getType(),
b.getType(), cTy};
398 fnName = mangle(fnName, argTypes);
399 SmallVector<Value> args{a,
b, c};
401 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
402 LLVM::ModRefInfo::NoModRef,
403 LLVM::ModRefInfo::NoModRef,
404 LLVM::ModRefInfo::NoModRef);
405 auto funcAttrs = convergentNoUnwindWillReturnAttrs;
406 funcAttrs.memEffectsAttr = memAttr;
408 createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
409 funcAttrs, op.getOperation())
412 if (resOrigTy != resTy)
413 result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy,
result);
415 rewriter.replaceOp(op,
result);
420 static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
422 case xevm::ElemType::TF32:
424 case xevm::ElemType::BF16:
425 case xevm::ElemType::F16:
427 case xevm::ElemType::U8:
428 case xevm::ElemType::S8:
431 llvm_unreachable(
"unsupported xevm::ElemType");
436class PrefetchToOCLPattern :
public OpConversionPattern<PrefetchOp> {
437 using OpConversionPattern::OpConversionPattern;
439 matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor,
440 ConversionPatternRewriter &rewriter)
const override {
441 auto loc = op.getLoc();
442 const std::string fnName{
"_Z8prefetchPU3AS1Kcm"};
444 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), 1);
445 SmallVector<Value> args{op.getPtr(), one};
446 SmallVector<Type> argTypes;
447 for (
auto arg : args)
448 argTypes.push_back(arg.getType());
449 auto funcAttr = noUnwindAttrs;
450 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
451 LLVM::ModRefInfo::NoModRef,
452 LLVM::ModRefInfo::Ref,
453 LLVM::ModRefInfo::NoModRef);
454 funcAttr.memEffectsAttr = memAttr;
456 LLVM::CallOp call = createDeviceFunctionCall(
457 rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()),
458 argTypes, args, {}, funcAttr, op.getOperation());
459 if (std::optional<ArrayAttr> optCacheControls =
460 getCacheControlMetadata(rewriter, op))
461 call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
462 rewriter.eraseOp(op);
467class MemfenceToOCLPattern :
public OpConversionPattern<MemfenceOp> {
468 using OpConversionPattern::OpConversionPattern;
470 matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor,
471 ConversionPatternRewriter &rewriter)
const override {
472 auto loc = op.getLoc();
473 const std::string fnName{
"atomic_work_item_fence"};
474 int memScope, addrSpace;
475 switch (op.getAddrspace()) {
476 case xevm::AddrSpace::SHARED:
479 case xevm::AddrSpace::GLOBAL:
484 return rewriter.notifyMatchFailure(
485 op,
"Fence only supports global and shared address spaces.");
487 switch (op.getScope()) {
488 case xevm::MemScope::WORKGROUP:
491 case xevm::MemScope::DEVICE:
496 return rewriter.notifyMatchFailure(
497 op,
"Fence only supports workgroup and device memory scopes.");
499 Type i32Type = rewriter.getI32Type();
500 Value acqRel = LLVM::ConstantOp::create(rewriter, loc, i32Type, 4);
501 Value memScopeConst =
502 LLVM::ConstantOp::create(rewriter, loc, i32Type, memScope);
503 Value addrSpaceConst =
504 LLVM::ConstantOp::create(rewriter, loc, i32Type, addrSpace);
505 SmallVector<Value> args{addrSpaceConst, acqRel, memScopeConst};
506 SmallVector<Type> argTypes{3, i32Type};
507 createDeviceFunctionCall(rewriter, mangle(fnName, argTypes),
508 LLVM::LLVMVoidType::get(rewriter.getContext()),
509 argTypes, args, {}, noUnwindAttrs,
511 rewriter.eraseOp(op);
515template <
typename OpType>
516class LoadStorePrefetchToOCLPattern :
public OpConversionPattern<OpType> {
517 using OpConversionPattern<OpType>::OpConversionPattern;
519 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
520 ConversionPatternRewriter &rewriter)
const override {
521 constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp>;
522 constexpr bool isPrefetch = std::is_same_v<OpType, BlockPrefetch2dOp>;
524 auto loc = op.getLoc();
526 bool packReg =
false;
527 bool transpose =
false;
528 if constexpr (isLoad) {
529 vecType = op.getRes().getType();
530 packReg = op.getPackRegister();
531 transpose = op.getTranspose();
532 }
else if constexpr (!isPrefetch) {
533 vecType = op.getStoredVal().getType();
536 auto i32Type = rewriter.getI32Type();
538 LLVM::UndefOp::create(rewriter, loc, VectorType::get(2, i32Type));
539 Value zero = LLVM::ConstantOp::create(rewriter, loc, i32Type, 0);
540 Value one = LLVM::ConstantOp::create(rewriter, loc, i32Type, 1);
541 byteCoord = LLVM::InsertElementOp::create(
542 rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
543 byteCoord = LLVM::InsertElementOp::create(
544 rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
545 SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
546 op.getBasePitch(), byteCoord};
547 SmallVector<Type> retTypes;
549 std::string funcName{
"intel_sub_group_2d_block_"};
550 std::string bitWidthId;
551 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
552 SmallVector<std::pair<unsigned, StringRef>, 4> paramAttrs;
553 if constexpr (isPrefetch) {
554 funcName +=
"prefetch";
555 paramAttrs = {std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName())};
556 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
557 LLVM::ModRefInfo::NoModRef,
558 LLVM::ModRefInfo::Ref,
559 LLVM::ModRefInfo::NoModRef);
560 funcAttr = noUnwindAttrs;
561 funcAttr.memEffectsAttr = memAttr;
563 auto vecElemType = vecType.getElementType();
564 auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth();
565 Value numElems = LLVM::ConstantOp::create(rewriter, loc, i32Type,
566 vecType.getNumElements());
567 auto dstOrSrcPtr = LLVM::AllocaOp::create(
568 rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()),
569 vecElemType, numElems);
570 args.push_back(dstOrSrcPtr);
571 if constexpr (isLoad) {
573 bitWidthId = getTypeMangling(vecElemType,
true);
575 funcName +=
"_transform";
577 funcName +=
"_transpose";
578 spvLoadDstPtr = dstOrSrcPtr;
579 retTypes.push_back(vecType);
581 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
582 std::make_pair(0, LLVM::LLVMDialect::getReadonlyAttrName()),
583 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
584 std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()),
588 bitWidthId = (vecElemBitWidth == 32)
590 : ((vecElemBitWidth == 16) ?
"t" :
"h");
591 LLVM::StoreOp::create(rewriter, loc, op.getStoredVal(), dstOrSrcPtr);
593 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
594 std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()),
595 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
596 std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()),
602 llvm::formatv(
"{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(),
603 op.getTileHeight(), op.getTileWidth(), op.getVBlocks())
605 std::string prefetchCode(
"");
608 funcName = llvm::formatv(
"_Z{0}{1}PU3AS1viiiDv2_i{2}{3}", funcName.size(),
609 funcName, prefetchCode, bitWidthId)
611 SmallVector<Type> argTypes;
612 for (
auto arg : args) {
613 argTypes.push_back(arg.getType());
615 LLVM::CallOp call = createDeviceFunctionCall(
616 rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
617 argTypes, args, paramAttrs, funcAttr, op.getOperation());
618 if (std::optional<ArrayAttr> optCacheControls =
619 getCacheControlMetadata(rewriter, op)) {
620 call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
622 if constexpr (isLoad)
624 op, LLVM::LoadOp::create(rewriter, loc, vecType, spvLoadDstPtr));
626 rewriter.eraseOp(op);
631template <
typename OpType>
632class BlockLoadStore1DToOCLPattern :
public OpConversionPattern<OpType> {
633 using OpConversionPattern<OpType>::OpConversionPattern;
635 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
636 ConversionPatternRewriter &rewriter)
const override {
637 constexpr bool isStore = std::is_same_v<OpType, xevm::BlockStoreOp>;
641 std::string funcName{
"intel_sub_group_block_"};
644 if constexpr (isStore) {
645 funcName +=
"write_u";
646 valOrResTy = op.getVal().getType();
648 funcName +=
"read_u";
649 valOrResTy = op.getType();
652 VectorType vecTy = dyn_cast<VectorType>(valOrResTy);
653 Type elemType = vecTy ? vecTy.getElementType() : valOrResTy;
654 funcName += getTypeMangling(elemType);
656 funcName += std::to_string(vecTy.getNumElements());
657 SmallVector<Type, 2> argTypes{};
661 SmallVector<bool, 2> isUnsigned{};
665 SmallVector<Value, 2> args{};
666 args.push_back(op.getPtr());
667 argTypes.push_back(op.getPtr().getType());
668 isUnsigned.push_back(
true);
670 if constexpr (isStore) {
671 args.push_back(op.getVal());
672 argTypes.push_back(op.getVal().getType());
673 isUnsigned.push_back(
true);
674 retType = LLVM::LLVMVoidType::get(rewriter.getContext());
676 retType = valOrResTy;
678 funcName = std::string(
"_Z") + std::to_string(funcName.size()) + funcName +
680 std::to_string(op.getPtr().getType().getAddressSpace());
681 funcName += getTypeMangling(elemType,
true);
682 if constexpr (isStore)
683 funcName += getTypeMangling(valOrResTy,
true);
684 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
687 createDeviceFunctionCall(rewriter, funcName, retType, argTypes, args,
688 {}, funcAttr, op.getOperation());
689 if (std::optional<ArrayAttr> optCacheControls =
690 getCacheControlMetadata(rewriter, op)) {
691 call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
693 if constexpr (isStore)
694 rewriter.eraseOp(op);
696 rewriter.replaceOp(op, call->getResult(0));
701template <
typename OpType>
702class LLVMLoadStoreToOCLPattern :
public OpConversionPattern<OpType> {
703 using OpConversionPattern<OpType>::OpConversionPattern;
705 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
706 ConversionPatternRewriter &rewriter)
const override {
707 if (!op->hasAttr(
"cache_control"))
709 std::optional<ArrayAttr> optCacheControls =
710 getCacheControlMetadata(rewriter, op);
711 op->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
712 op->removeAttr(
"cache_control");
744static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdXOp) {
745 return {
"get_local_id", 0};
747static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdYOp) {
748 return {
"get_local_id", 1};
750static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdZOp) {
751 return {
"get_local_id", 2};
753static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimXOp) {
754 return {
"get_local_size", 0};
756static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimYOp) {
757 return {
"get_local_size", 1};
759static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimZOp) {
760 return {
"get_local_size", 2};
762static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdXOp) {
763 return {
"get_group_id", 0};
765static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdYOp) {
766 return {
"get_group_id", 1};
768static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdZOp) {
769 return {
"get_group_id", 2};
771static std::pair<StringRef, int64_t> getConfig(xevm::GridDimXOp) {
772 return {
"get_num_groups", 0};
774static std::pair<StringRef, int64_t> getConfig(xevm::GridDimYOp) {
775 return {
"get_num_groups", 1};
777static std::pair<StringRef, int64_t> getConfig(xevm::GridDimZOp) {
778 return {
"get_num_groups", 2};
782template <
typename OpType>
783class LaunchConfigOpToOCLPattern :
public OpConversionPattern<OpType> {
784 using OpConversionPattern<OpType>::OpConversionPattern;
786 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
787 ConversionPatternRewriter &rewriter)
const override {
788 Location loc = op->getLoc();
789 auto [baseName, dim] = getConfig(op);
790 Type dimTy = rewriter.getI32Type();
791 Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
792 static_cast<int64_t
>(dim));
793 std::string func = mangle(baseName, {dimTy}, {
true});
794 Type resTy = op.getType();
796 createDeviceFunctionCall(rewriter, func, resTy, {dimTy}, {dimVal}, {},
797 noUnwindWillReturnAttrs, op.getOperation());
798 constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
799 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
802 call.setMemoryEffectsAttr(memAttr);
803 rewriter.replaceOp(op, call);
820static StringRef getConfig(xevm::LaneIdOp) {
return "get_sub_group_local_id"; }
821static StringRef getConfig(xevm::SubgroupIdOp) {
return "get_sub_group_id"; }
822static StringRef getConfig(xevm::SubgroupSizeOp) {
823 return "get_sub_group_size";
825template <
typename OpType>
826class SubgroupOpWorkitemOpToOCLPattern :
public OpConversionPattern<OpType> {
827 using OpConversionPattern<OpType>::OpConversionPattern;
829 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
830 ConversionPatternRewriter &rewriter)
const override {
831 std::string func = mangle(getConfig(op).str(), {});
832 Type resTy = op.getType();
834 createDeviceFunctionCall(rewriter, func, resTy, {}, {}, {},
835 noUnwindWillReturnAttrs, op.getOperation());
836 constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
837 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
840 call.setMemoryEffectsAttr(memAttr);
841 rewriter.replaceOp(op, call);
850struct ConvertXeVMToLLVMPass
854 void getDependentDialects(DialectRegistry ®istry)
const override {
855 registry.
insert<LLVM::LLVMDialect, XeVMDialect>();
858 void runOnOperation()
override {
862 if (
failed(applyPartialConversion(getOperation(),
target,
877 void loadDependentDialects(MLIRContext *context)
const final {
878 context->loadDialect<LLVM::LLVMDialect>();
883 void populateConvertToLLVMConversionPatterns(
884 ConversionTarget &
target, LLVMTypeConverter &typeConverter,
885 RewritePatternSet &
patterns)
const final {
897 target.addDynamicallyLegalDialect<LLVM::LLVMDialect>(
898 [](
Operation *op) {
return !op->hasAttr(
"cache_control"); });
899 target.addIllegalDialect<XeVMDialect>();
900 patterns.add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
901 LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
902 LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
903 MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern,
904 LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
905 LLVMLoadStoreToOCLPattern<LLVM::StoreOp>,
906 BlockLoadStore1DToOCLPattern<BlockLoadOp>,
907 BlockLoadStore1DToOCLPattern<BlockStoreOp>,
908 LaunchConfigOpToOCLPattern<WorkitemIdXOp>,
909 LaunchConfigOpToOCLPattern<WorkitemIdYOp>,
910 LaunchConfigOpToOCLPattern<WorkitemIdZOp>,
911 LaunchConfigOpToOCLPattern<WorkgroupDimXOp>,
912 LaunchConfigOpToOCLPattern<WorkgroupDimYOp>,
913 LaunchConfigOpToOCLPattern<WorkgroupDimZOp>,
914 LaunchConfigOpToOCLPattern<WorkgroupIdXOp>,
915 LaunchConfigOpToOCLPattern<WorkgroupIdYOp>,
916 LaunchConfigOpToOCLPattern<WorkgroupIdZOp>,
917 LaunchConfigOpToOCLPattern<GridDimXOp>,
918 LaunchConfigOpToOCLPattern<GridDimYOp>,
919 LaunchConfigOpToOCLPattern<GridDimZOp>,
920 SubgroupOpWorkitemOpToOCLPattern<LaneIdOp>,
921 SubgroupOpWorkitemOpToOCLPattern<SubgroupIdOp>,
922 SubgroupOpWorkitemOpToOCLPattern<SubgroupSizeOp>>(
928 dialect->addInterfaces<XeVMToLLVMDialectInterface>();
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Location getLoc()
The source location the operation was defined or derived from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Type getType() const
Return the type of this value.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes={}, Type resultType={}, bool isVarArg=false, bool isReserved=false, SymbolTableCollection *symbolTables=nullptr)
Create a FuncOp with signature resultType(paramTypes) and name name`.
Include the generated interface declarations.
void populateXeVMToLLVMConversionPatterns(ConversionTarget &target, RewritePatternSet &patterns)
const FrozenRewritePatternSet & patterns
llvm::TypeSwitch< T, ResultT > TypeSwitch