18 #include "llvm/Support/FormatVariadic.h"
23 #include "llvm/ADT/TypeSwitch.h"
26 #define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
27 #include "mlir/Conversion/Passes.h.inc"
35 struct LLVMFuncAttributeOptions {
36 bool isConvergent =
false;
37 bool isNoUnwind =
false;
38 bool isWillReturn =
false;
39 LLVM::MemoryEffectsAttr memEffectsAttr{};
41 static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
42 false,
true,
false, {}};
43 static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
44 false,
true,
true, {}};
45 static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
46 true,
true,
true, {}};
48 std::string getTypeMangling(
Type ty,
bool isUnsigned =
false) {
50 .Case([isUnsigned](VectorType ty) -> std::string {
51 return "Dv" + std::to_string(ty.getNumElements()) +
"_" +
52 getTypeMangling(ty.getElementType(), isUnsigned);
54 .Case([](Float16Type) -> std::string {
return "Dh"; })
55 .Case([](Float32Type) -> std::string {
return "f"; })
56 .Case([](Float64Type) -> std::string {
return "d"; })
57 .Case([isUnsigned](IntegerType ty) -> std::string {
58 switch (ty.getWidth()) {
60 return isUnsigned ?
"h" :
"c";
62 return isUnsigned ?
"t" :
"s";
64 return isUnsigned ?
"j" :
"i";
66 return isUnsigned ?
"m" :
"l";
68 llvm_unreachable(
"unhandled integer type");
71 .Default([](
Type) -> std::string {
72 llvm_unreachable(
"unhandled type for mangling");
78 assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
79 "Signedness info doesn't match");
81 llvm::raw_string_ostream os(s);
82 llvm::SmallDenseMap<Type, unsigned> substitutions;
83 os <<
"_Z" << baseName.size() << baseName;
85 auto it = substitutions.find(type);
86 if (it != substitutions.end()) {
89 if (
unsigned firstIdx = it->getSecond(); firstIdx > 0)
93 if (!type.isIntOrFloat())
94 substitutions[type] = substitutions.size();
95 os << getTypeMangling(type, isUnsigned.empty() ?
false : isUnsigned[idx]);
101 template <
bool isLoad,
typename OpType>
102 int32_t getL1CacheControl(OpType op) {
104 if constexpr (isLoad) {
105 switch (*op.getCacheControl()) {
106 case LoadCacheControl::L1UC_L2UC_L3UC:
107 case LoadCacheControl::L1UC_L2UC_L3C:
108 case LoadCacheControl::L1UC_L2C_L3UC:
109 case LoadCacheControl::L1UC_L2C_L3C:
112 case LoadCacheControl::L1C_L2UC_L3UC:
113 case LoadCacheControl::L1C_L2UC_L3C:
114 case LoadCacheControl::L1C_L2C_L3UC:
115 case LoadCacheControl::L1C_L2C_L3C:
118 case LoadCacheControl::L1S_L2UC_L3UC:
119 case LoadCacheControl::L1S_L2UC_L3C:
120 case LoadCacheControl::L1S_L2C_L3UC:
121 case LoadCacheControl::L1S_L2C_L3C:
124 case LoadCacheControl::INVALIDATE_READ:
129 switch (*op.getCacheControl()) {
130 case StoreCacheControl::L1UC_L2UC_L3UC:
131 case StoreCacheControl::L1UC_L2UC_L3WB:
132 case StoreCacheControl::L1UC_L2WB_L3UC:
133 case StoreCacheControl::L1UC_L2WB_L3WB:
136 case StoreCacheControl::L1WT_L2UC_L3UC:
137 case StoreCacheControl::L1WT_L2UC_L3WB:
138 case StoreCacheControl::L1WT_L2WB_L3UC:
139 case StoreCacheControl::L1WT_L2WB_L3WB:
142 case StoreCacheControl::L1S_L2UC_L3UC:
143 case StoreCacheControl::L1S_L2UC_L3WB:
144 case StoreCacheControl::L1S_L2WB_L3UC:
145 case StoreCacheControl::L1S_L2WB_L3WB:
148 case StoreCacheControl::L1WB_L2UC_L3UC:
149 case StoreCacheControl::L1WB_L2WB_L3UC:
150 case StoreCacheControl::L1WB_L2UC_L3WB:
158 template <
bool isLoad,
typename OpType>
159 int32_t getL3CacheControl(OpType op) {
161 if constexpr (isLoad) {
162 switch (*op.getCacheControl()) {
163 case LoadCacheControl::L1UC_L2UC_L3UC:
164 case LoadCacheControl::L1UC_L2C_L3UC:
165 case LoadCacheControl::L1C_L2UC_L3UC:
166 case LoadCacheControl::L1C_L2C_L3UC:
167 case LoadCacheControl::L1S_L2UC_L3UC:
168 case LoadCacheControl::L1S_L2C_L3UC:
171 case LoadCacheControl::L1UC_L2UC_L3C:
172 case LoadCacheControl::L1UC_L2C_L3C:
173 case LoadCacheControl::L1C_L2UC_L3C:
174 case LoadCacheControl::L1C_L2C_L3C:
175 case LoadCacheControl::L1S_L2UC_L3C:
176 case LoadCacheControl::L1S_L2C_L3C:
179 case LoadCacheControl::INVALIDATE_READ:
184 switch (*op.getCacheControl()) {
185 case StoreCacheControl::L1UC_L2UC_L3UC:
186 case StoreCacheControl::L1UC_L2WB_L3UC:
187 case StoreCacheControl::L1WT_L2UC_L3UC:
188 case StoreCacheControl::L1WT_L2WB_L3UC:
189 case StoreCacheControl::L1S_L2UC_L3UC:
190 case StoreCacheControl::L1S_L2WB_L3UC:
191 case StoreCacheControl::L1WB_L2UC_L3UC:
192 case StoreCacheControl::L1WB_L2WB_L3UC:
195 case StoreCacheControl::L1UC_L2UC_L3WB:
196 case StoreCacheControl::L1UC_L2WB_L3WB:
197 case StoreCacheControl::L1WT_L2UC_L3WB:
198 case StoreCacheControl::L1WT_L2WB_L3WB:
199 case StoreCacheControl::L1S_L2UC_L3WB:
200 case StoreCacheControl::L1S_L2WB_L3WB:
201 case StoreCacheControl::L1WB_L2UC_L3WB:
209 template <
bool isLoad,
typename OpType>
210 static std::optional<ArrayAttr>
212 if (!op.getCacheControl())
214 constexpr int32_t decorationCacheControlArity{4};
215 constexpr int32_t loadCacheControlKey{6442};
216 constexpr int32_t storeCacheControlKey{6443};
217 const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
219 controlKey, 0, getL1CacheControl<isLoad, OpType>(op), 0};
221 controlKey, 1, getL3CacheControl<isLoad, OpType>(op), 0};
229 static LLVM::CallOp createDeviceFunctionCall(
233 LLVMFuncAttributeOptions funcAttributeOptions,
Operation *op) {
235 assert(moduleOp &&
"Expecting module");
240 assert(!
failed(funcOpRes));
241 LLVM::LLVMFuncOp funcOp = funcOpRes.value();
242 funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
243 funcOp.setConvergent(funcAttributeOptions.isConvergent);
244 funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
245 funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
247 if (funcAttributeOptions.memEffectsAttr)
248 funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
250 for (
auto [idx, attrName] : paramAttrs)
251 funcOp.setArgAttr(idx, attrName, rewriter.
getUnitAttr());
253 auto callOp = LLVM::CallOp::create(rewriter, loc, funcOp, args);
254 callOp->setAttrs(funcOp->getAttrs());
262 matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor,
267 auto precisionA = op.getTypes().getA();
268 auto precisionB = op.getTypes().getB();
269 auto precisionC = op.getTypes().getC();
270 auto precisionD = op.getTypes().getD();
271 if (precisionC != precisionD) {
274 if (precisionC != xevm::ElemType::S32 &&
275 precisionC != xevm::ElemType::F32 &&
276 precisionC != xevm::ElemType::F16 &&
277 precisionC != xevm::ElemType::BF16) {
279 op,
"type of C and D must be S32, F32, F16 or BF16");
281 if (precisionA == xevm::ElemType::S32 ||
282 precisionA == xevm::ElemType::F32) {
285 if (precisionB == xevm::ElemType::S32 ||
286 precisionB == xevm::ElemType::F32) {
289 constexpr uint32_t bitWidthPackedA{16};
290 constexpr uint32_t bitWidthPackedB{32};
291 auto loc = op.getLoc();
294 VectorType origTy = cast<VectorType>(val.
getType());
295 const uint32_t vecBitSize =
296 origTy.getNumElements() *
297 origTy.getElementType().getIntOrFloatBitWidth();
299 vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
301 val = LLVM::BitcastOp::create(rewriter, loc, newTy, val);
306 Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
309 a = castIfNeeded(a, packedAType);
312 Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
315 b = castIfNeeded(b, packedBType);
318 VectorType cOrigTy = cast<VectorType>(c.
getType());
319 VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
320 assert(cOrigTy == resOrigTy &&
"Accumulator and result type mismatch");
323 cOrigTy.getElementType().isBF16()
326 VectorType resTy = cTy;
328 c = LLVM::BitcastOp::create(rewriter, loc, cTy, c);
330 constexpr int32_t systolicDepth{8};
332 llvm::formatv(
"intel_sub_group_{0}_{1}_matrix_mad_k{2}",
333 stringifyElemType(op.getTypes().getA()).str(),
334 stringifyElemType(op.getTypes().getB()).str(),
336 getNumOperandsPerDword(op.getTypes().getA()))
339 fnName = mangle(fnName, argTypes);
342 auto memAttr = rewriter.
getAttr<LLVM::MemoryEffectsAttr>(
343 LLVM::ModRefInfo::NoModRef,
344 LLVM::ModRefInfo::NoModRef,
345 LLVM::ModRefInfo::NoModRef);
346 auto funcAttrs = convergentNoUnwindWillReturnAttrs;
347 funcAttrs.memEffectsAttr = memAttr;
349 createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
350 funcAttrs, op.getOperation())
353 if (resOrigTy != resTy)
354 result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy, result);
361 static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
363 case xevm::ElemType::TF32:
365 case xevm::ElemType::BF16:
366 case xevm::ElemType::F16:
368 case xevm::ElemType::U8:
369 case xevm::ElemType::S8:
372 llvm_unreachable(
"unsupported xevm::ElemType");
380 matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor,
382 auto loc = op.getLoc();
383 const std::string fnName{
"_Z8prefetchPU3AS1Kcm"};
385 LLVM::ConstantOp::create(rewriter, loc, rewriter.
getI64Type(), 1);
388 for (
auto arg : args)
389 argTypes.push_back(arg.getType());
390 auto funcAttr = noUnwindAttrs;
391 auto memAttr = rewriter.
getAttr<LLVM::MemoryEffectsAttr>(
392 LLVM::ModRefInfo::NoModRef,
393 LLVM::ModRefInfo::Ref,
394 LLVM::ModRefInfo::NoModRef);
395 funcAttr.memEffectsAttr = memAttr;
397 LLVM::CallOp call = createDeviceFunctionCall(
399 argTypes, args, {}, funcAttr, op.getOperation());
400 if (std::optional<ArrayAttr> optCacheControls =
401 getCacheControlMetadata<true>(rewriter, op))
402 call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
411 matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor,
413 auto loc = op.getLoc();
414 const std::string fnName{
"atomic_work_item_fence"};
415 int memScope, addrSpace;
416 switch (op.getAddrspace()) {
417 case xevm::AddrSpace::SHARED:
420 case xevm::AddrSpace::GLOBAL:
426 op,
"Fence only supports global and shared address spaces.");
428 switch (op.getScope()) {
429 case xevm::MemScope::WORKGROUP:
432 case xevm::MemScope::DEVICE:
438 op,
"Fence only supports workgroup and device memory scopes.");
441 Value acqRel = LLVM::ConstantOp::create(rewriter, loc, i32Type, 4);
442 Value memScopeConst =
443 LLVM::ConstantOp::create(rewriter, loc, i32Type, memScope);
444 Value addrSpaceConst =
445 LLVM::ConstantOp::create(rewriter, loc, i32Type, addrSpace);
448 createDeviceFunctionCall(rewriter, mangle(fnName, argTypes),
450 argTypes, args, {}, noUnwindAttrs,
456 template <
typename OpType>
460 matchAndRewrite(OpType op,
typename OpType::Adaptor adaptor,
462 constexpr
bool isLoad = std::is_same_v<OpType, BlockLoad2dOp>;
463 constexpr
bool isPrefetch = std::is_same_v<OpType, BlockPrefetch2dOp>;
465 auto loc = op.getLoc();
467 bool packReg =
false;
468 bool transpose =
false;
469 if constexpr (isLoad) {
470 vecType = op.getRes().getType();
471 packReg = op.getPackRegister();
472 transpose = op.getTranspose();
473 }
else if constexpr (!isPrefetch) {
474 vecType = op.getStoredVal().getType();
480 Value zero = LLVM::ConstantOp::create(rewriter, loc, i32Type, 0);
481 Value one = LLVM::ConstantOp::create(rewriter, loc, i32Type, 1);
482 byteCoord = LLVM::InsertElementOp::create(
483 rewriter, loc,
VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
484 byteCoord = LLVM::InsertElementOp::create(
485 rewriter, loc,
VectorType::get(2, i32Type), byteCoord, op.getY(), one);
487 op.getBasePitch(), byteCoord};
490 std::string funcName{
"intel_sub_group_2d_block_"};
491 std::string bitWidthId;
492 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
494 if constexpr (isPrefetch) {
495 funcName +=
"prefetch";
496 paramAttrs = {std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName())};
497 auto memAttr = rewriter.
getAttr<LLVM::MemoryEffectsAttr>(
498 LLVM::ModRefInfo::NoModRef,
499 LLVM::ModRefInfo::Ref,
500 LLVM::ModRefInfo::NoModRef);
501 funcAttr = noUnwindAttrs;
502 funcAttr.memEffectsAttr = memAttr;
504 auto vecElemType = vecType.getElementType();
505 auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth();
506 Value numElems = LLVM::ConstantOp::create(rewriter, loc, i32Type,
507 vecType.getNumElements());
508 auto dstOrSrcPtr = LLVM::AllocaOp::create(
510 vecElemType, numElems);
511 args.push_back(dstOrSrcPtr);
512 if constexpr (isLoad) {
514 bitWidthId = getTypeMangling(vecElemType,
true);
516 funcName +=
"_transform";
518 funcName +=
"_transpose";
519 spvLoadDstPtr = dstOrSrcPtr;
520 retTypes.push_back(vecType);
522 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
523 std::make_pair(0, LLVM::LLVMDialect::getReadonlyAttrName()),
524 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
525 std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()),
529 bitWidthId = (vecElemBitWidth == 32)
531 : ((vecElemBitWidth == 16) ?
"t" :
"h");
532 LLVM::StoreOp::create(rewriter, loc, op.getStoredVal(), dstOrSrcPtr);
534 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
535 std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()),
536 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
537 std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()),
543 llvm::formatv(
"{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(),
544 op.getTileHeight(), op.getTileWidth(), op.getVBlocks())
546 std::string prefetchCode(
"");
549 funcName = llvm::formatv(
"_Z{0}{1}PU3AS1viiiDv2_i{2}{3}", funcName.size(),
550 funcName, prefetchCode, bitWidthId)
553 for (
auto arg : args) {
554 argTypes.push_back(arg.getType());
556 LLVM::CallOp call = createDeviceFunctionCall(
558 argTypes, args, paramAttrs, funcAttr, op.getOperation());
559 if (std::optional<ArrayAttr> optCacheControls =
560 getCacheControlMetadata < isLoad || isPrefetch > (rewriter, op)) {
561 call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
563 if constexpr (isLoad)
565 op, LLVM::LoadOp::create(rewriter, loc, vecType, spvLoadDstPtr));
576 struct ConvertXeVMToLLVMPass
577 :
public impl::ConvertXeVMToLLVMPassBase<ConvertXeVMToLLVMPass> {
581 registry.
insert<LLVM::LLVMDialect, XeVMDialect>();
584 void runOnOperation()
override {
586 target.addLegalDialect<LLVM::LLVMDialect>();
587 target.addIllegalDialect<XeVMDialect>();
605 void loadDependentDialects(
MLIRContext *context)
const final {
606 context->loadDialect<LLVM::LLVMDialect>();
611 void populateConvertToLLVMConversionPatterns(
624 patterns.add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
625 LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
626 LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
627 MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern>(
633 dialect->addInterfaces<XeVMToLLVMDialectInterface>();
static MLIRContext * getContext(OpFoldResult val)
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes={}, Type resultType={}, bool isVarArg=false, bool isReserved=false, SymbolTableCollection *symbolTables=nullptr)
Create a FuncOp with signature resultType(paramTypes)and namename`.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
void populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns)
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void registerConvertXeVMToLLVMInterface(DialectRegistry ®istry)