18#include "llvm/Support/FormatVariadic.h"
23#include "llvm/ADT/TypeSwitch.h"
26#define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
27#include "mlir/Conversion/Passes.h.inc"
35struct LLVMFuncAttributeOptions {
36 bool isConvergent =
false;
37 bool isNoUnwind =
false;
38 bool isWillReturn =
false;
39 LLVM::MemoryEffectsAttr memEffectsAttr{};
41static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
42 false,
true,
false, {}};
43static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
44 false,
true,
true, {}};
45static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
46 true,
true,
true, {}};
48std::string getTypeMangling(
Type ty,
bool isUnsigned =
false) {
50 .Case([isUnsigned](VectorType ty) -> std::string {
51 return "Dv" + std::to_string(ty.getNumElements()) +
"_" +
52 getTypeMangling(ty.getElementType(), isUnsigned);
54 .Case([](Float16Type) -> std::string {
return "Dh"; })
55 .Case([](Float32Type) -> std::string {
return "f"; })
56 .Case([](Float64Type) -> std::string {
return "d"; })
57 .Case([isUnsigned](IntegerType ty) -> std::string {
58 switch (ty.getWidth()) {
60 return isUnsigned ?
"h" :
"c";
62 return isUnsigned ?
"t" :
"s";
64 return isUnsigned ?
"j" :
"i";
66 return isUnsigned ?
"m" :
"l";
68 llvm_unreachable(
"unhandled integer type");
71 .DefaultUnreachable(
"unhandled type for mangling");
76 assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
77 "Signedness info doesn't match");
79 llvm::raw_string_ostream os(s);
80 llvm::SmallDenseMap<Type, unsigned> substitutions;
81 os <<
"_Z" << baseName.size() << baseName;
82 for (
auto [idx, type] : llvm::enumerate(types)) {
83 auto it = substitutions.find(type);
84 if (it != substitutions.end()) {
87 if (
unsigned firstIdx = it->getSecond(); firstIdx > 0)
91 if (!type.isIntOrFloat())
92 substitutions[type] = substitutions.size();
93 os << getTypeMangling(type, isUnsigned.empty() ?
false : isUnsigned[idx]);
99static int32_t getL1CacheControl(LoadCacheControl cc) {
102 case LoadCacheControl::L1UC_L2UC_L3UC:
103 case LoadCacheControl::L1UC_L2UC_L3C:
104 case LoadCacheControl::L1UC_L2C_L3UC:
105 case LoadCacheControl::L1UC_L2C_L3C:
108 case LoadCacheControl::L1C_L2UC_L3UC:
109 case LoadCacheControl::L1C_L2UC_L3C:
110 case LoadCacheControl::L1C_L2C_L3UC:
111 case LoadCacheControl::L1C_L2C_L3C:
114 case LoadCacheControl::L1S_L2UC_L3UC:
115 case LoadCacheControl::L1S_L2UC_L3C:
116 case LoadCacheControl::L1S_L2C_L3UC:
117 case LoadCacheControl::L1S_L2C_L3C:
120 case LoadCacheControl::INVALIDATE_READ:
127static int32_t getL1CacheControl(StoreCacheControl cc) {
130 case StoreCacheControl::L1UC_L2UC_L3UC:
131 case StoreCacheControl::L1UC_L2UC_L3WB:
132 case StoreCacheControl::L1UC_L2WB_L3UC:
133 case StoreCacheControl::L1UC_L2WB_L3WB:
136 case StoreCacheControl::L1WT_L2UC_L3UC:
137 case StoreCacheControl::L1WT_L2UC_L3WB:
138 case StoreCacheControl::L1WT_L2WB_L3UC:
139 case StoreCacheControl::L1WT_L2WB_L3WB:
142 case StoreCacheControl::L1S_L2UC_L3UC:
143 case StoreCacheControl::L1S_L2UC_L3WB:
144 case StoreCacheControl::L1S_L2WB_L3UC:
145 case StoreCacheControl::L1S_L2WB_L3WB:
148 case StoreCacheControl::L1WB_L2UC_L3UC:
149 case StoreCacheControl::L1WB_L2WB_L3UC:
150 case StoreCacheControl::L1WB_L2UC_L3WB:
157static int32_t getL3CacheControl(LoadCacheControl cc) {
160 case LoadCacheControl::L1UC_L2UC_L3UC:
161 case LoadCacheControl::L1UC_L2C_L3UC:
162 case LoadCacheControl::L1C_L2UC_L3UC:
163 case LoadCacheControl::L1C_L2C_L3UC:
164 case LoadCacheControl::L1S_L2UC_L3UC:
165 case LoadCacheControl::L1S_L2C_L3UC:
168 case LoadCacheControl::L1UC_L2UC_L3C:
169 case LoadCacheControl::L1UC_L2C_L3C:
170 case LoadCacheControl::L1C_L2UC_L3C:
171 case LoadCacheControl::L1C_L2C_L3C:
172 case LoadCacheControl::L1S_L2UC_L3C:
173 case LoadCacheControl::L1S_L2C_L3C:
176 case LoadCacheControl::INVALIDATE_READ:
183static int32_t getL3CacheControl(StoreCacheControl cc) {
186 case StoreCacheControl::L1UC_L2UC_L3UC:
187 case StoreCacheControl::L1UC_L2WB_L3UC:
188 case StoreCacheControl::L1WT_L2UC_L3UC:
189 case StoreCacheControl::L1WT_L2WB_L3UC:
190 case StoreCacheControl::L1S_L2UC_L3UC:
191 case StoreCacheControl::L1S_L2WB_L3UC:
192 case StoreCacheControl::L1WB_L2UC_L3UC:
193 case StoreCacheControl::L1WB_L2WB_L3UC:
196 case StoreCacheControl::L1UC_L2UC_L3WB:
197 case StoreCacheControl::L1UC_L2WB_L3WB:
198 case StoreCacheControl::L1WT_L2UC_L3WB:
199 case StoreCacheControl::L1WT_L2WB_L3WB:
200 case StoreCacheControl::L1S_L2UC_L3WB:
201 case StoreCacheControl::L1S_L2WB_L3WB:
202 case StoreCacheControl::L1WB_L2UC_L3WB:
209static std::optional<LoadCacheControl> getCacheControl(PrefetchOp op) {
210 return op.getCacheControl();
213static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) {
214 return op.getCacheControl();
217static std::optional<LoadCacheControl> getCacheControl(BlockLoadOp op) {
218 return op.getCacheControl();
221static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) {
222 return op.getCacheControl();
225static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) {
226 return op.getCacheControl();
229static std::optional<StoreCacheControl> getCacheControl(BlockStoreOp op) {
230 return op.getCacheControl();
233static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) {
234 if (op->hasAttr(
"cache_control")) {
235 auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>(
"cache_control");
238 return std::optional<LoadCacheControl>(attr.getValue());
243static std::optional<StoreCacheControl> getCacheControl(LLVM::StoreOp op) {
244 if (op->hasAttr(
"cache_control")) {
245 auto attr = op->getAttrOfType<xevm::StoreCacheControlAttr>(
"cache_control");
248 return std::optional<StoreCacheControl>(attr.getValue());
253template <
typename OpType>
254int32_t getL1CacheControl(OpType op) {
255 return getL1CacheControl(*getCacheControl(op));
258template <
typename OpType>
259int32_t getL3CacheControl(OpType op) {
260 return getL3CacheControl(*getCacheControl(op));
263template <
typename OpType>
264static std::optional<ArrayAttr>
265getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
266 if (!getCacheControl(op))
268 constexpr int32_t decorationCacheControlArity{4};
269 constexpr int32_t loadCacheControlKey{6442};
270 constexpr int32_t storeCacheControlKey{6443};
271 constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
272 std::is_same_v<OpType, BlockPrefetch2dOp> ||
273 std::is_same_v<OpType, LLVM::LoadOp> ||
274 std::is_same_v<OpType, BlockLoadOp> ||
275 std::is_same_v<OpType, PrefetchOp>;
276 const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
278 controlKey, 0, getL1CacheControl<OpType>(op), 0};
280 controlKey, 1, getL3CacheControl<OpType>(op), 0};
281 auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
282 auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
285 return rewriter.getArrayAttr(combinedAttrs);
288static LLVM::CallOp createDeviceFunctionCall(
289 ConversionPatternRewriter &rewriter, StringRef funcName,
Type retType,
292 LLVMFuncAttributeOptions funcAttributeOptions,
Operation *op) {
294 assert(moduleOp &&
"Expecting module");
299 assert(!
failed(funcOpRes));
300 LLVM::LLVMFuncOp funcOp = funcOpRes.value();
301 funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
302 funcOp.setConvergent(funcAttributeOptions.isConvergent);
303 funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
304 funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
306 if (funcAttributeOptions.memEffectsAttr)
307 funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
309 for (
auto [idx, attrName] : paramAttrs)
310 funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
312 auto callOp = LLVM::CallOp::create(rewriter, loc, funcOp, args);
313 callOp->setAttrs(funcOp->getAttrs());
318class MMAToOCLPattern :
public OpConversionPattern<xevm::MMAOp> {
319 using OpConversionPattern::OpConversionPattern;
321 matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor,
322 ConversionPatternRewriter &rewriter)
const override {
324 return rewriter.notifyMatchFailure(op,
"OCL requires C operand");
326 auto precisionA = op.getTypes().getA();
327 auto precisionB = op.getTypes().getB();
328 auto precisionC = op.getTypes().getC();
329 auto precisionD = op.getTypes().getD();
330 if (precisionC != precisionD) {
331 return rewriter.notifyMatchFailure(op,
"type of C and D need to match");
333 if (precisionC != xevm::ElemType::S32 &&
334 precisionC != xevm::ElemType::F32 &&
335 precisionC != xevm::ElemType::F16 &&
336 precisionC != xevm::ElemType::BF16) {
337 return rewriter.notifyMatchFailure(
338 op,
"type of C and D must be S32, F32, F16 or BF16");
340 if (precisionA == xevm::ElemType::S32 ||
341 precisionA == xevm::ElemType::F32) {
342 return rewriter.notifyMatchFailure(op,
"type of A cannot be S32 or F32");
344 if (precisionB == xevm::ElemType::S32 ||
345 precisionB == xevm::ElemType::F32) {
346 return rewriter.notifyMatchFailure(op,
"type of B cannot be S32 or F32");
348 constexpr uint32_t bitWidthPackedA{16};
349 constexpr uint32_t bitWidthPackedB{32};
350 auto loc = op.getLoc();
352 auto castIfNeeded = [&](Value val, Type packedType) -> Value {
353 VectorType origTy = cast<VectorType>(val.
getType());
354 const uint32_t vecBitSize =
355 origTy.getNumElements() *
356 origTy.getElementType().getIntOrFloatBitWidth();
357 VectorType newTy = VectorType::get(
358 vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
360 val = LLVM::BitcastOp::create(rewriter, loc, newTy, val);
365 Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
366 ? cast<Type>(rewriter.getF32Type())
367 : rewriter.getIntegerType(bitWidthPackedA);
368 a = castIfNeeded(a, packedAType);
371 Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
372 ? cast<Type>(rewriter.getF32Type())
373 : rewriter.getIntegerType(bitWidthPackedB);
374 b = castIfNeeded(
b, packedBType);
377 VectorType cOrigTy = cast<VectorType>(c.
getType());
378 VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
379 assert(cOrigTy == resOrigTy &&
"Accumulator and result type mismatch");
382 cOrigTy.getElementType().isBF16()
383 ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
385 VectorType resTy = cTy;
387 c = LLVM::BitcastOp::create(rewriter, loc, cTy, c);
389 constexpr int32_t systolicDepth{8};
391 llvm::formatv(
"intel_sub_group_{0}_{1}_matrix_mad_k{2}",
392 stringifyElemType(op.getTypes().getA()).str(),
393 stringifyElemType(op.getTypes().getB()).str(),
395 getNumOperandsPerDword(op.getTypes().getA()))
397 SmallVector<Type> argTypes{a.
getType(),
b.getType(), cTy};
398 fnName = mangle(fnName, argTypes);
399 SmallVector<Value> args{a,
b, c};
401 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
402 LLVM::ModRefInfo::NoModRef,
403 LLVM::ModRefInfo::NoModRef,
404 LLVM::ModRefInfo::NoModRef,
405 LLVM::ModRefInfo::NoModRef,
406 LLVM::ModRefInfo::NoModRef,
407 LLVM::ModRefInfo::NoModRef);
408 auto funcAttrs = convergentNoUnwindWillReturnAttrs;
409 funcAttrs.memEffectsAttr = memAttr;
411 createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
412 funcAttrs, op.getOperation())
415 if (resOrigTy != resTy)
416 result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy,
result);
418 rewriter.replaceOp(op,
result);
423 static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
425 case xevm::ElemType::TF32:
427 case xevm::ElemType::BF16:
428 case xevm::ElemType::F16:
430 case xevm::ElemType::U8:
431 case xevm::ElemType::S8:
434 llvm_unreachable(
"unsupported xevm::ElemType");
439class PrefetchToOCLPattern :
public OpConversionPattern<PrefetchOp> {
440 using OpConversionPattern::OpConversionPattern;
442 matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor,
443 ConversionPatternRewriter &rewriter)
const override {
444 auto loc = op.getLoc();
445 const std::string fnName{
"_Z8prefetchPU3AS1Kcm"};
447 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), 1);
448 SmallVector<Value> args{op.getPtr(), one};
449 SmallVector<Type> argTypes;
450 for (
auto arg : args)
451 argTypes.push_back(arg.getType());
452 auto funcAttr = noUnwindAttrs;
453 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
454 LLVM::ModRefInfo::NoModRef,
455 LLVM::ModRefInfo::Ref,
456 LLVM::ModRefInfo::NoModRef,
457 LLVM::ModRefInfo::NoModRef,
458 LLVM::ModRefInfo::NoModRef,
459 LLVM::ModRefInfo::NoModRef);
460 funcAttr.memEffectsAttr = memAttr;
462 LLVM::CallOp call = createDeviceFunctionCall(
463 rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()),
464 argTypes, args, {}, funcAttr, op.getOperation());
465 if (std::optional<ArrayAttr> optCacheControls =
466 getCacheControlMetadata(rewriter, op))
467 call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
468 rewriter.eraseOp(op);
473class MemfenceToOCLPattern :
public OpConversionPattern<MemfenceOp> {
474 using OpConversionPattern::OpConversionPattern;
476 matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor,
477 ConversionPatternRewriter &rewriter)
const override {
478 auto loc = op.getLoc();
479 const std::string fnName{
"atomic_work_item_fence"};
480 int memScope, addrSpace;
481 switch (op.getAddrspace()) {
482 case xevm::AddrSpace::SHARED:
485 case xevm::AddrSpace::GLOBAL:
490 return rewriter.notifyMatchFailure(
491 op,
"Fence only supports global and shared address spaces.");
493 switch (op.getScope()) {
494 case xevm::MemScope::WORKGROUP:
497 case xevm::MemScope::DEVICE:
502 return rewriter.notifyMatchFailure(
503 op,
"Fence only supports workgroup and device memory scopes.");
505 Type i32Type = rewriter.getI32Type();
506 Value acqRel = LLVM::ConstantOp::create(rewriter, loc, i32Type, 4);
507 Value memScopeConst =
508 LLVM::ConstantOp::create(rewriter, loc, i32Type, memScope);
509 Value addrSpaceConst =
510 LLVM::ConstantOp::create(rewriter, loc, i32Type, addrSpace);
511 SmallVector<Value> args{addrSpaceConst, acqRel, memScopeConst};
512 SmallVector<Type> argTypes{3, i32Type};
513 createDeviceFunctionCall(rewriter, mangle(fnName, argTypes),
514 LLVM::LLVMVoidType::get(rewriter.getContext()),
515 argTypes, args, {}, noUnwindAttrs,
517 rewriter.eraseOp(op);
521template <
typename OpType>
522class LoadStorePrefetchToOCLPattern :
public OpConversionPattern<OpType> {
523 using OpConversionPattern<OpType>::OpConversionPattern;
525 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
526 ConversionPatternRewriter &rewriter)
const override {
527 constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp>;
528 constexpr bool isPrefetch = std::is_same_v<OpType, BlockPrefetch2dOp>;
530 auto loc = op.getLoc();
532 bool packReg =
false;
533 bool transpose =
false;
534 if constexpr (isLoad) {
535 vecType = op.getRes().getType();
536 packReg = op.getPackRegister();
537 transpose = op.getTranspose();
538 }
else if constexpr (!isPrefetch) {
539 vecType = op.getStoredVal().getType();
542 auto i32Type = rewriter.getI32Type();
544 LLVM::UndefOp::create(rewriter, loc, VectorType::get(2, i32Type));
545 Value zero = LLVM::ConstantOp::create(rewriter, loc, i32Type, 0);
546 Value one = LLVM::ConstantOp::create(rewriter, loc, i32Type, 1);
547 byteCoord = LLVM::InsertElementOp::create(
548 rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
549 byteCoord = LLVM::InsertElementOp::create(
550 rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
551 SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
552 op.getBasePitch(), byteCoord};
553 SmallVector<Type> retTypes;
555 std::string funcName{
"intel_sub_group_2d_block_"};
556 std::string bitWidthId;
557 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
558 SmallVector<std::pair<unsigned, StringRef>, 4> paramAttrs;
559 if constexpr (isPrefetch) {
560 funcName +=
"prefetch";
561 paramAttrs = {std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName())};
562 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
563 LLVM::ModRefInfo::NoModRef,
564 LLVM::ModRefInfo::Ref,
565 LLVM::ModRefInfo::NoModRef,
566 LLVM::ModRefInfo::NoModRef,
567 LLVM::ModRefInfo::NoModRef,
568 LLVM::ModRefInfo::NoModRef);
569 funcAttr = noUnwindAttrs;
570 funcAttr.memEffectsAttr = memAttr;
572 auto vecElemType = vecType.getElementType();
573 auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth();
574 Value numElems = LLVM::ConstantOp::create(rewriter, loc, i32Type,
575 vecType.getNumElements());
576 auto dstOrSrcPtr = LLVM::AllocaOp::create(
577 rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()),
578 vecElemType, numElems);
579 args.push_back(dstOrSrcPtr);
580 if constexpr (isLoad) {
582 bitWidthId = getTypeMangling(vecElemType,
true);
584 funcName +=
"_transform";
586 funcName +=
"_transpose";
587 spvLoadDstPtr = dstOrSrcPtr;
588 retTypes.push_back(vecType);
590 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
591 std::make_pair(0, LLVM::LLVMDialect::getReadonlyAttrName()),
592 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
593 std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()),
597 bitWidthId = (vecElemBitWidth == 32)
599 : ((vecElemBitWidth == 16) ?
"t" :
"h");
600 LLVM::StoreOp::create(rewriter, loc, op.getStoredVal(), dstOrSrcPtr);
602 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
603 std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()),
604 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
605 std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()),
611 llvm::formatv(
"{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(),
612 op.getTileHeight(), op.getTileWidth(), op.getVBlocks())
614 std::string prefetchCode(
"");
617 funcName = llvm::formatv(
"_Z{0}{1}PU3AS1viiiDv2_i{2}{3}", funcName.size(),
618 funcName, prefetchCode, bitWidthId)
620 SmallVector<Type> argTypes;
621 for (
auto arg : args) {
622 argTypes.push_back(arg.getType());
624 LLVM::CallOp call = createDeviceFunctionCall(
625 rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
626 argTypes, args, paramAttrs, funcAttr, op.getOperation());
627 if (std::optional<ArrayAttr> optCacheControls =
628 getCacheControlMetadata(rewriter, op)) {
629 call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
631 if constexpr (isLoad)
633 op, LLVM::LoadOp::create(rewriter, loc, vecType, spvLoadDstPtr));
635 rewriter.eraseOp(op);
640template <
typename OpType>
641class BlockLoadStore1DToOCLPattern :
public OpConversionPattern<OpType> {
642 using OpConversionPattern<OpType>::OpConversionPattern;
644 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
645 ConversionPatternRewriter &rewriter)
const override {
646 constexpr bool isStore = std::is_same_v<OpType, xevm::BlockStoreOp>;
650 std::string funcName{
"intel_sub_group_block_"};
653 if constexpr (isStore) {
654 funcName +=
"write_u";
655 valOrResTy = op.getVal().getType();
657 funcName +=
"read_u";
658 valOrResTy = op.getType();
661 VectorType vecTy = dyn_cast<VectorType>(valOrResTy);
662 Type elemType = vecTy ? vecTy.getElementType() : valOrResTy;
663 funcName += getTypeMangling(elemType);
665 funcName += std::to_string(vecTy.getNumElements());
666 SmallVector<Type, 2> argTypes{};
670 SmallVector<bool, 2> isUnsigned{};
674 SmallVector<Value, 2> args{};
675 args.push_back(op.getPtr());
676 argTypes.push_back(op.getPtr().getType());
677 isUnsigned.push_back(
true);
679 if constexpr (isStore) {
680 args.push_back(op.getVal());
681 argTypes.push_back(op.getVal().getType());
682 isUnsigned.push_back(
true);
683 retType = LLVM::LLVMVoidType::get(rewriter.getContext());
685 retType = valOrResTy;
687 funcName = std::string(
"_Z") + std::to_string(funcName.size()) + funcName +
689 std::to_string(op.getPtr().getType().getAddressSpace());
690 funcName += getTypeMangling(elemType,
true);
691 if constexpr (isStore)
692 funcName += getTypeMangling(valOrResTy,
true);
693 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
696 createDeviceFunctionCall(rewriter, funcName, retType, argTypes, args,
697 {}, funcAttr, op.getOperation());
698 if (std::optional<ArrayAttr> optCacheControls =
699 getCacheControlMetadata(rewriter, op)) {
700 call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
702 if constexpr (isStore)
703 rewriter.eraseOp(op);
705 rewriter.replaceOp(op, call->getResult(0));
710template <
typename OpType>
711class LLVMLoadStoreToOCLPattern :
public OpConversionPattern<OpType> {
712 using OpConversionPattern<OpType>::OpConversionPattern;
714 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
715 ConversionPatternRewriter &rewriter)
const override {
716 if (!op->hasAttr(
"cache_control"))
718 std::optional<ArrayAttr> optCacheControls =
719 getCacheControlMetadata(rewriter, op);
720 op->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
721 op->removeAttr(
"cache_control");
753static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdXOp) {
754 return {
"get_local_id", 0};
756static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdYOp) {
757 return {
"get_local_id", 1};
759static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdZOp) {
760 return {
"get_local_id", 2};
762static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimXOp) {
763 return {
"get_local_size", 0};
765static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimYOp) {
766 return {
"get_local_size", 1};
768static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimZOp) {
769 return {
"get_local_size", 2};
771static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdXOp) {
772 return {
"get_group_id", 0};
774static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdYOp) {
775 return {
"get_group_id", 1};
777static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdZOp) {
778 return {
"get_group_id", 2};
780static std::pair<StringRef, int64_t> getConfig(xevm::GridDimXOp) {
781 return {
"get_num_groups", 0};
783static std::pair<StringRef, int64_t> getConfig(xevm::GridDimYOp) {
784 return {
"get_num_groups", 1};
786static std::pair<StringRef, int64_t> getConfig(xevm::GridDimZOp) {
787 return {
"get_num_groups", 2};
791template <
typename OpType>
792class LaunchConfigOpToOCLPattern :
public OpConversionPattern<OpType> {
793 using OpConversionPattern<OpType>::OpConversionPattern;
795 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
796 ConversionPatternRewriter &rewriter)
const override {
797 Location loc = op->getLoc();
798 auto [baseName, dim] = getConfig(op);
799 Type dimTy = rewriter.getI32Type();
800 Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
801 static_cast<int64_t
>(dim));
802 std::string func = mangle(baseName, {dimTy}, {
true});
803 Type resTy = op.getType();
805 createDeviceFunctionCall(rewriter, func, resTy, {dimTy}, {dimVal}, {},
806 noUnwindWillReturnAttrs, op.getOperation());
807 constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
808 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
814 call.setMemoryEffectsAttr(memAttr);
815 rewriter.replaceOp(op, call);
832static StringRef getConfig(xevm::LaneIdOp) {
return "get_sub_group_local_id"; }
833static StringRef getConfig(xevm::SubgroupIdOp) {
return "get_sub_group_id"; }
834static StringRef getConfig(xevm::SubgroupSizeOp) {
835 return "get_sub_group_size";
837template <
typename OpType>
838class SubgroupOpWorkitemOpToOCLPattern :
public OpConversionPattern<OpType> {
839 using OpConversionPattern<OpType>::OpConversionPattern;
841 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
842 ConversionPatternRewriter &rewriter)
const override {
843 std::string func = mangle(getConfig(op).str(), {});
844 Type resTy = op.getType();
846 createDeviceFunctionCall(rewriter, func, resTy, {}, {}, {},
847 noUnwindWillReturnAttrs, op.getOperation());
848 constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
849 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
855 call.setMemoryEffectsAttr(memAttr);
856 rewriter.replaceOp(op, call);
865struct ConvertXeVMToLLVMPass
869 void getDependentDialects(DialectRegistry ®istry)
const override {
870 registry.
insert<LLVM::LLVMDialect, XeVMDialect>();
873 void runOnOperation()
override {
877 if (
failed(applyPartialConversion(getOperation(),
target,
892 void loadDependentDialects(MLIRContext *context)
const final {
893 context->loadDialect<LLVM::LLVMDialect>();
898 void populateConvertToLLVMConversionPatterns(
899 ConversionTarget &
target, LLVMTypeConverter &typeConverter,
900 RewritePatternSet &
patterns)
const final {
912 target.addDynamicallyLegalDialect<LLVM::LLVMDialect>(
913 [](
Operation *op) {
return !op->hasAttr(
"cache_control"); });
914 target.addIllegalDialect<XeVMDialect>();
915 patterns.add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
916 LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
917 LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
918 MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern,
919 LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
920 LLVMLoadStoreToOCLPattern<LLVM::StoreOp>,
921 BlockLoadStore1DToOCLPattern<BlockLoadOp>,
922 BlockLoadStore1DToOCLPattern<BlockStoreOp>,
923 LaunchConfigOpToOCLPattern<WorkitemIdXOp>,
924 LaunchConfigOpToOCLPattern<WorkitemIdYOp>,
925 LaunchConfigOpToOCLPattern<WorkitemIdZOp>,
926 LaunchConfigOpToOCLPattern<WorkgroupDimXOp>,
927 LaunchConfigOpToOCLPattern<WorkgroupDimYOp>,
928 LaunchConfigOpToOCLPattern<WorkgroupDimZOp>,
929 LaunchConfigOpToOCLPattern<WorkgroupIdXOp>,
930 LaunchConfigOpToOCLPattern<WorkgroupIdYOp>,
931 LaunchConfigOpToOCLPattern<WorkgroupIdZOp>,
932 LaunchConfigOpToOCLPattern<GridDimXOp>,
933 LaunchConfigOpToOCLPattern<GridDimYOp>,
934 LaunchConfigOpToOCLPattern<GridDimZOp>,
935 SubgroupOpWorkitemOpToOCLPattern<LaneIdOp>,
936 SubgroupOpWorkitemOpToOCLPattern<SubgroupIdOp>,
937 SubgroupOpWorkitemOpToOCLPattern<SubgroupSizeOp>>(
943 dialect->addInterfaces<XeVMToLLVMDialectInterface>();
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Location getLoc()
The source location the operation was defined or derived from.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Type getType() const
Return the type of this value.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes={}, Type resultType={}, bool isVarArg=false, bool isReserved=false, SymbolTableCollection *symbolTables=nullptr)
Create a FuncOp with signature resultType(paramTypes) and name name`.
Include the generated interface declarations.
void populateXeVMToLLVMConversionPatterns(ConversionTarget &target, RewritePatternSet &patterns)
const FrozenRewritePatternSet & patterns
llvm::TypeSwitch< T, ResultT > TypeSwitch