MLIR 22.0.0git
XeVMToLLVM.cpp
Go to the documentation of this file.
1//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
16#include "mlir/Pass/Pass.h"
17#include "mlir/Support/LLVM.h"
18#include "llvm/Support/FormatVariadic.h"
19
21#include "mlir/IR/Types.h"
22
23#include "llvm/ADT/TypeSwitch.h"
24
25namespace mlir {
26#define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
27#include "mlir/Conversion/Passes.h.inc"
28} // namespace mlir
29
30using namespace mlir;
31using namespace xevm;
32
33namespace {
34
35struct LLVMFuncAttributeOptions {
36 bool isConvergent = false;
37 bool isNoUnwind = false;
38 bool isWillReturn = false;
39 LLVM::MemoryEffectsAttr memEffectsAttr{};
40};
41static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
42 false, true, false, {}};
43static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
44 false, true, true, {}};
45static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
46 true, true, true, {}};
47
48std::string getTypeMangling(Type ty, bool isUnsigned = false) {
50 .Case([isUnsigned](VectorType ty) -> std::string {
51 return "Dv" + std::to_string(ty.getNumElements()) + "_" +
52 getTypeMangling(ty.getElementType(), isUnsigned);
53 })
54 .Case([](Float16Type) -> std::string { return "Dh"; })
55 .Case([](Float32Type) -> std::string { return "f"; })
56 .Case([](Float64Type) -> std::string { return "d"; })
57 .Case([isUnsigned](IntegerType ty) -> std::string {
58 switch (ty.getWidth()) {
59 case 8:
60 return isUnsigned ? "h" : "c";
61 case 16:
62 return isUnsigned ? "t" : "s";
63 case 32:
64 return isUnsigned ? "j" : "i";
65 case 64:
66 return isUnsigned ? "m" : "l";
67 default:
68 llvm_unreachable("unhandled integer type");
69 }
70 })
71 .DefaultUnreachable("unhandled type for mangling");
72}
73
74std::string mangle(StringRef baseName, ArrayRef<Type> types,
75 ArrayRef<bool> isUnsigned = {}) {
76 assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
77 "Signedness info doesn't match");
78 std::string s;
79 llvm::raw_string_ostream os(s);
80 llvm::SmallDenseMap<Type, unsigned> substitutions;
81 os << "_Z" << baseName.size() << baseName;
82 for (auto [idx, type] : llvm::enumerate(types)) {
83 auto it = substitutions.find(type);
84 if (it != substitutions.end()) {
85 os << "S";
86 // First substitution is `S_`, second is `S0_`, and so on.
87 if (unsigned firstIdx = it->getSecond(); firstIdx > 0)
88 os << firstIdx - 1;
89 os << "_";
90 } else {
91 if (!type.isIntOrFloat())
92 substitutions[type] = substitutions.size();
93 os << getTypeMangling(type, isUnsigned.empty() ? false : isUnsigned[idx]);
94 }
95 }
96 return os.str();
97}
98
99static int32_t getL1CacheControl(LoadCacheControl cc) {
100 int32_t control = 0;
101 switch (cc) {
102 case LoadCacheControl::L1UC_L2UC_L3UC:
103 case LoadCacheControl::L1UC_L2UC_L3C:
104 case LoadCacheControl::L1UC_L2C_L3UC:
105 case LoadCacheControl::L1UC_L2C_L3C:
106 control = 1;
107 break;
108 case LoadCacheControl::L1C_L2UC_L3UC:
109 case LoadCacheControl::L1C_L2UC_L3C:
110 case LoadCacheControl::L1C_L2C_L3UC:
111 case LoadCacheControl::L1C_L2C_L3C:
112 control = 2;
113 break;
114 case LoadCacheControl::L1S_L2UC_L3UC:
115 case LoadCacheControl::L1S_L2UC_L3C:
116 case LoadCacheControl::L1S_L2C_L3UC:
117 case LoadCacheControl::L1S_L2C_L3C:
118 control = 3;
119 break;
120 case LoadCacheControl::INVALIDATE_READ:
121 control = 4;
122 break;
123 }
124 return control;
125}
126
127static int32_t getL1CacheControl(StoreCacheControl cc) {
128 int32_t control = 0;
129 switch (cc) {
130 case StoreCacheControl::L1UC_L2UC_L3UC:
131 case StoreCacheControl::L1UC_L2UC_L3WB:
132 case StoreCacheControl::L1UC_L2WB_L3UC:
133 case StoreCacheControl::L1UC_L2WB_L3WB:
134 control = 1;
135 break;
136 case StoreCacheControl::L1WT_L2UC_L3UC:
137 case StoreCacheControl::L1WT_L2UC_L3WB:
138 case StoreCacheControl::L1WT_L2WB_L3UC:
139 case StoreCacheControl::L1WT_L2WB_L3WB:
140 control = 2;
141 break;
142 case StoreCacheControl::L1S_L2UC_L3UC:
143 case StoreCacheControl::L1S_L2UC_L3WB:
144 case StoreCacheControl::L1S_L2WB_L3UC:
145 case StoreCacheControl::L1S_L2WB_L3WB:
146 control = 3;
147 break;
148 case StoreCacheControl::L1WB_L2UC_L3UC:
149 case StoreCacheControl::L1WB_L2WB_L3UC:
150 case StoreCacheControl::L1WB_L2UC_L3WB:
151 control = 4;
152 break;
153 }
154 return control;
155}
156
157static int32_t getL3CacheControl(LoadCacheControl cc) {
158 int32_t control = 0;
159 switch (cc) {
160 case LoadCacheControl::L1UC_L2UC_L3UC:
161 case LoadCacheControl::L1UC_L2C_L3UC:
162 case LoadCacheControl::L1C_L2UC_L3UC:
163 case LoadCacheControl::L1C_L2C_L3UC:
164 case LoadCacheControl::L1S_L2UC_L3UC:
165 case LoadCacheControl::L1S_L2C_L3UC:
166 control = 1;
167 break;
168 case LoadCacheControl::L1UC_L2UC_L3C:
169 case LoadCacheControl::L1UC_L2C_L3C:
170 case LoadCacheControl::L1C_L2UC_L3C:
171 case LoadCacheControl::L1C_L2C_L3C:
172 case LoadCacheControl::L1S_L2UC_L3C:
173 case LoadCacheControl::L1S_L2C_L3C:
174 control = 2;
175 break;
176 case LoadCacheControl::INVALIDATE_READ:
177 control = 4;
178 break;
179 }
180 return control;
181}
182
183static int32_t getL3CacheControl(StoreCacheControl cc) {
184 int32_t control = 0;
185 switch (cc) {
186 case StoreCacheControl::L1UC_L2UC_L3UC:
187 case StoreCacheControl::L1UC_L2WB_L3UC:
188 case StoreCacheControl::L1WT_L2UC_L3UC:
189 case StoreCacheControl::L1WT_L2WB_L3UC:
190 case StoreCacheControl::L1S_L2UC_L3UC:
191 case StoreCacheControl::L1S_L2WB_L3UC:
192 case StoreCacheControl::L1WB_L2UC_L3UC:
193 case StoreCacheControl::L1WB_L2WB_L3UC:
194 control = 1;
195 break;
196 case StoreCacheControl::L1UC_L2UC_L3WB:
197 case StoreCacheControl::L1UC_L2WB_L3WB:
198 case StoreCacheControl::L1WT_L2UC_L3WB:
199 case StoreCacheControl::L1WT_L2WB_L3WB:
200 case StoreCacheControl::L1S_L2UC_L3WB:
201 case StoreCacheControl::L1S_L2WB_L3WB:
202 case StoreCacheControl::L1WB_L2UC_L3WB:
203 control = 2;
204 break;
205 }
206 return control;
207}
208
209static std::optional<LoadCacheControl> getCacheControl(PrefetchOp op) {
210 return op.getCacheControl();
211}
212
213static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) {
214 return op.getCacheControl();
215}
216
217static std::optional<LoadCacheControl> getCacheControl(BlockLoadOp op) {
218 return op.getCacheControl();
219}
220
221static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) {
222 return op.getCacheControl();
223}
224
225static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) {
226 return op.getCacheControl();
227}
228
229static std::optional<StoreCacheControl> getCacheControl(BlockStoreOp op) {
230 return op.getCacheControl();
231}
232
233static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) {
234 if (op->hasAttr("cache_control")) {
235 auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>("cache_control");
236 if (!attr)
237 return std::nullopt;
238 return std::optional<LoadCacheControl>(attr.getValue());
239 }
240 return std::nullopt;
241}
242
243static std::optional<StoreCacheControl> getCacheControl(LLVM::StoreOp op) {
244 if (op->hasAttr("cache_control")) {
245 auto attr = op->getAttrOfType<xevm::StoreCacheControlAttr>("cache_control");
246 if (!attr)
247 return std::nullopt;
248 return std::optional<StoreCacheControl>(attr.getValue());
249 }
250 return std::nullopt;
251}
252
253template <typename OpType>
254int32_t getL1CacheControl(OpType op) {
255 return getL1CacheControl(*getCacheControl(op));
256}
257
258template <typename OpType>
259int32_t getL3CacheControl(OpType op) {
260 return getL3CacheControl(*getCacheControl(op));
261}
262
263template <typename OpType>
264static std::optional<ArrayAttr>
265getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
266 if (!getCacheControl(op))
267 return {};
268 constexpr int32_t decorationCacheControlArity{4};
269 constexpr int32_t loadCacheControlKey{6442};
270 constexpr int32_t storeCacheControlKey{6443};
271 constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
272 std::is_same_v<OpType, BlockPrefetch2dOp> ||
273 std::is_same_v<OpType, LLVM::LoadOp> ||
274 std::is_same_v<OpType, BlockLoadOp> ||
275 std::is_same_v<OpType, PrefetchOp>;
276 const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
278 controlKey, 0, getL1CacheControl<OpType>(op), 0};
280 controlKey, 1, getL3CacheControl<OpType>(op), 0};
281 auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
282 auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
283
284 SmallVector<Attribute, 2> combinedAttrs = {arrayAttrL1, arrayAttrL3};
285 return rewriter.getArrayAttr(combinedAttrs);
286}
287
288static LLVM::CallOp createDeviceFunctionCall(
289 ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
290 ArrayRef<Type> argTypes, ArrayRef<Value> args,
291 mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs,
292 LLVMFuncAttributeOptions funcAttributeOptions, Operation *op) {
293 auto moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
294 assert(moduleOp && "Expecting module");
295 Location loc = op->getLoc();
296
297 auto funcOpRes =
298 LLVM::lookupOrCreateFn(rewriter, moduleOp, funcName, argTypes, retType);
299 assert(!failed(funcOpRes));
300 LLVM::LLVMFuncOp funcOp = funcOpRes.value();
301 funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
302 funcOp.setConvergent(funcAttributeOptions.isConvergent);
303 funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
304 funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
305
306 if (funcAttributeOptions.memEffectsAttr)
307 funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
308
309 for (auto [idx, attrName] : paramAttrs)
310 funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
311
312 auto callOp = LLVM::CallOp::create(rewriter, loc, funcOp, args);
313 callOp->setAttrs(funcOp->getAttrs());
314
315 return callOp;
316}
317
318class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
319 using OpConversionPattern::OpConversionPattern;
320 LogicalResult
321 matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor,
322 ConversionPatternRewriter &rewriter) const override {
323 if (!op.getC()) {
324 return rewriter.notifyMatchFailure(op, "OCL requires C operand");
325 }
326 auto precisionA = op.getTypes().getA();
327 auto precisionB = op.getTypes().getB();
328 auto precisionC = op.getTypes().getC();
329 auto precisionD = op.getTypes().getD();
330 if (precisionC != precisionD) {
331 return rewriter.notifyMatchFailure(op, "type of C and D need to match");
332 }
333 if (precisionC != xevm::ElemType::S32 &&
334 precisionC != xevm::ElemType::F32 &&
335 precisionC != xevm::ElemType::F16 &&
336 precisionC != xevm::ElemType::BF16) {
337 return rewriter.notifyMatchFailure(
338 op, "type of C and D must be S32, F32, F16 or BF16");
339 }
340 if (precisionA == xevm::ElemType::S32 ||
341 precisionA == xevm::ElemType::F32) {
342 return rewriter.notifyMatchFailure(op, "type of A cannot be S32 or F32");
343 }
344 if (precisionB == xevm::ElemType::S32 ||
345 precisionB == xevm::ElemType::F32) {
346 return rewriter.notifyMatchFailure(op, "type of B cannot be S32 or F32");
347 }
348 constexpr uint32_t bitWidthPackedA{16};
349 constexpr uint32_t bitWidthPackedB{32};
350 auto loc = op.getLoc();
351
352 auto castIfNeeded = [&](Value val, Type packedType) -> Value {
353 VectorType origTy = cast<VectorType>(val.getType());
354 const uint32_t vecBitSize =
355 origTy.getNumElements() *
356 origTy.getElementType().getIntOrFloatBitWidth();
357 VectorType newTy = VectorType::get(
358 vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
359 if (origTy != newTy)
360 val = LLVM::BitcastOp::create(rewriter, loc, newTy, val);
361 return val;
362 };
363
364 Value a = op.getA();
365 Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
366 ? cast<Type>(rewriter.getF32Type())
367 : rewriter.getIntegerType(bitWidthPackedA);
368 a = castIfNeeded(a, packedAType);
369
370 Value b = op.getB();
371 Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
372 ? cast<Type>(rewriter.getF32Type())
373 : rewriter.getIntegerType(bitWidthPackedB);
374 b = castIfNeeded(b, packedBType);
375
376 Value c = op.getC();
377 VectorType cOrigTy = cast<VectorType>(c.getType());
378 VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
379 assert(cOrigTy == resOrigTy && "Accumulator and result type mismatch");
380 // OCL builtins encode bfloat16 as int16
381 VectorType cTy =
382 cOrigTy.getElementType().isBF16()
383 ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
384 : cOrigTy;
385 VectorType resTy = cTy;
386 if (cOrigTy != cTy)
387 c = LLVM::BitcastOp::create(rewriter, loc, cTy, c);
388
389 constexpr int32_t systolicDepth{8};
390 std::string fnName =
391 llvm::formatv("intel_sub_group_{0}_{1}_matrix_mad_k{2}",
392 stringifyElemType(op.getTypes().getA()).str(),
393 stringifyElemType(op.getTypes().getB()).str(),
394 systolicDepth *
395 getNumOperandsPerDword(op.getTypes().getA()))
396 .str();
397 SmallVector<Type> argTypes{a.getType(), b.getType(), cTy};
398 fnName = mangle(fnName, argTypes);
399 SmallVector<Value> args{a, b, c};
400
401 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
402 /*other=*/LLVM::ModRefInfo::NoModRef,
403 /*argMem=*/LLVM::ModRefInfo::NoModRef,
404 /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
405 auto funcAttrs = convergentNoUnwindWillReturnAttrs;
406 funcAttrs.memEffectsAttr = memAttr;
407 Value result =
408 createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
409 funcAttrs, op.getOperation())
410 ->getResult(0);
411
412 if (resOrigTy != resTy)
413 result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy, result);
414
415 rewriter.replaceOp(op, result);
416 return success();
417 }
418
419private:
420 static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
421 switch (pTy) {
422 case xevm::ElemType::TF32:
423 return 1;
424 case xevm::ElemType::BF16:
425 case xevm::ElemType::F16:
426 return 2;
427 case xevm::ElemType::U8:
428 case xevm::ElemType::S8:
429 return 4;
430 default:
431 llvm_unreachable("unsupported xevm::ElemType");
432 }
433 }
434};
435
436class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
437 using OpConversionPattern::OpConversionPattern;
438 LogicalResult
439 matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor,
440 ConversionPatternRewriter &rewriter) const override {
441 auto loc = op.getLoc();
442 const std::string fnName{"_Z8prefetchPU3AS1Kcm"};
443 Value one =
444 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), 1);
445 SmallVector<Value> args{op.getPtr(), one};
446 SmallVector<Type> argTypes;
447 for (auto arg : args)
448 argTypes.push_back(arg.getType());
449 auto funcAttr = noUnwindAttrs;
450 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
451 /*other=*/LLVM::ModRefInfo::NoModRef,
452 /*argMem=*/LLVM::ModRefInfo::Ref,
453 /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
454 funcAttr.memEffectsAttr = memAttr;
455
456 LLVM::CallOp call = createDeviceFunctionCall(
457 rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()),
458 argTypes, args, {}, funcAttr, op.getOperation());
459 if (std::optional<ArrayAttr> optCacheControls =
460 getCacheControlMetadata(rewriter, op))
461 call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
462 rewriter.eraseOp(op);
463 return success();
464 }
465};
466
467class MemfenceToOCLPattern : public OpConversionPattern<MemfenceOp> {
468 using OpConversionPattern::OpConversionPattern;
469 LogicalResult
470 matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor,
471 ConversionPatternRewriter &rewriter) const override {
472 auto loc = op.getLoc();
473 const std::string fnName{"atomic_work_item_fence"};
474 int memScope, addrSpace;
475 switch (op.getAddrspace()) {
476 case xevm::AddrSpace::SHARED:
477 addrSpace = 1; // CLK_LOCAL_MEM_FENCE
478 break;
479 case xevm::AddrSpace::GLOBAL:
480 addrSpace = 2; // CLK_GLOBAL_MEM_FENCE
481 break;
482 default:
483 // GENERIC is not supported in OpenCL
484 return rewriter.notifyMatchFailure(
485 op, "Fence only supports global and shared address spaces.");
486 }
487 switch (op.getScope()) {
488 case xevm::MemScope::WORKGROUP:
489 memScope = 1;
490 break;
491 case xevm::MemScope::DEVICE:
492 memScope = 2;
493 break;
494 default:
495 // CLUSTER and SYSTEM are not supported in OpenCL
496 return rewriter.notifyMatchFailure(
497 op, "Fence only supports workgroup and device memory scopes.");
498 }
499 Type i32Type = rewriter.getI32Type();
500 Value acqRel = LLVM::ConstantOp::create(rewriter, loc, i32Type, 4);
501 Value memScopeConst =
502 LLVM::ConstantOp::create(rewriter, loc, i32Type, memScope);
503 Value addrSpaceConst =
504 LLVM::ConstantOp::create(rewriter, loc, i32Type, addrSpace);
505 SmallVector<Value> args{addrSpaceConst, acqRel, memScopeConst};
506 SmallVector<Type> argTypes{3, i32Type};
507 createDeviceFunctionCall(rewriter, mangle(fnName, argTypes),
508 LLVM::LLVMVoidType::get(rewriter.getContext()),
509 argTypes, args, {}, noUnwindAttrs,
510 op.getOperation());
511 rewriter.eraseOp(op);
512 return success();
513 }
514};
515template <typename OpType>
516class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
517 using OpConversionPattern<OpType>::OpConversionPattern;
518 LogicalResult
519 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
520 ConversionPatternRewriter &rewriter) const override {
521 constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp>;
522 constexpr bool isPrefetch = std::is_same_v<OpType, BlockPrefetch2dOp>;
523
524 auto loc = op.getLoc();
525 VectorType vecType;
526 bool packReg = false;
527 bool transpose = false;
528 if constexpr (isLoad) {
529 vecType = op.getRes().getType();
530 packReg = op.getPackRegister();
531 transpose = op.getTranspose();
532 } else if constexpr (!isPrefetch) {
533 vecType = op.getStoredVal().getType();
534 }
535
536 auto i32Type = rewriter.getI32Type();
537 Value byteCoord =
538 LLVM::UndefOp::create(rewriter, loc, VectorType::get(2, i32Type));
539 Value zero = LLVM::ConstantOp::create(rewriter, loc, i32Type, 0);
540 Value one = LLVM::ConstantOp::create(rewriter, loc, i32Type, 1);
541 byteCoord = LLVM::InsertElementOp::create(
542 rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
543 byteCoord = LLVM::InsertElementOp::create(
544 rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
545 SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
546 op.getBasePitch(), byteCoord};
547 SmallVector<Type> retTypes;
548 Value spvLoadDstPtr;
549 std::string funcName{"intel_sub_group_2d_block_"};
550 std::string bitWidthId;
551 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
552 SmallVector<std::pair<unsigned, StringRef>, 4> paramAttrs;
553 if constexpr (isPrefetch) { // Prefetch
554 funcName += "prefetch";
555 paramAttrs = {std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName())};
556 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
557 /*other=*/LLVM::ModRefInfo::NoModRef,
558 /*argMem=*/LLVM::ModRefInfo::Ref,
559 /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
560 funcAttr = noUnwindAttrs;
561 funcAttr.memEffectsAttr = memAttr;
562 } else {
563 auto vecElemType = vecType.getElementType();
564 auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth();
565 Value numElems = LLVM::ConstantOp::create(rewriter, loc, i32Type,
566 vecType.getNumElements());
567 auto dstOrSrcPtr = LLVM::AllocaOp::create(
568 rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()),
569 vecElemType, numElems);
570 args.push_back(dstOrSrcPtr);
571 if constexpr (isLoad) { // Load
572 funcName += "read";
573 bitWidthId = getTypeMangling(vecElemType, /*isUnsigned=*/true);
574 if (packReg)
575 funcName += "_transform";
576 else if (transpose)
577 funcName += "_transpose";
578 spvLoadDstPtr = dstOrSrcPtr;
579 retTypes.push_back(vecType);
580 paramAttrs = {
581 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
582 std::make_pair(0, LLVM::LLVMDialect::getReadonlyAttrName()),
583 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
584 std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()),
585 };
586 } else { // Store
587 funcName += "write";
588 bitWidthId = (vecElemBitWidth == 32)
589 ? "j"
590 : ((vecElemBitWidth == 16) ? "t" : "h");
591 LLVM::StoreOp::create(rewriter, loc, op.getStoredVal(), dstOrSrcPtr);
592 paramAttrs = {
593 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
594 std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()),
595 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
596 std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()),
597 };
598 }
599 }
600
601 funcName =
602 llvm::formatv("{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(),
603 op.getTileHeight(), op.getTileWidth(), op.getVBlocks())
604 .str();
605 std::string prefetchCode("");
606 if (!isPrefetch)
607 prefetchCode += "P";
608 funcName = llvm::formatv("_Z{0}{1}PU3AS1viiiDv2_i{2}{3}", funcName.size(),
609 funcName, prefetchCode, bitWidthId)
610 .str();
611 SmallVector<Type> argTypes;
612 for (auto arg : args) {
613 argTypes.push_back(arg.getType());
614 }
615 LLVM::CallOp call = createDeviceFunctionCall(
616 rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
617 argTypes, args, paramAttrs, funcAttr, op.getOperation());
618 if (std::optional<ArrayAttr> optCacheControls =
619 getCacheControlMetadata(rewriter, op)) {
620 call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
621 }
622 if constexpr (isLoad)
623 rewriter.replaceOp(
624 op, LLVM::LoadOp::create(rewriter, loc, vecType, spvLoadDstPtr));
625 else
626 rewriter.eraseOp(op);
627 return success();
628 }
629};
630
631template <typename OpType>
632class BlockLoadStore1DToOCLPattern : public OpConversionPattern<OpType> {
633 using OpConversionPattern<OpType>::OpConversionPattern;
634 LogicalResult
635 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
636 ConversionPatternRewriter &rewriter) const override {
637 constexpr bool isStore = std::is_same_v<OpType, xevm::BlockStoreOp>;
638 // Get OpenCL function name
639 // https://registry.khronos.org/OpenCL/extensions/
640 // intel/cl_intel_subgroup_local_block_io.html
641 std::string funcName{"intel_sub_group_block_"};
642 // Value or Result type can be vector or scalar
643 Type valOrResTy;
644 if constexpr (isStore) {
645 funcName += "write_u";
646 valOrResTy = op.getVal().getType();
647 } else {
648 funcName += "read_u";
649 valOrResTy = op.getType();
650 }
651 // Get element type of the vector/scalar
652 VectorType vecTy = dyn_cast<VectorType>(valOrResTy);
653 Type elemType = vecTy ? vecTy.getElementType() : valOrResTy;
654 funcName += getTypeMangling(elemType);
655 if (vecTy)
656 funcName += std::to_string(vecTy.getNumElements());
657 SmallVector<Type, 2> argTypes{};
658 // XeVM BlockLoad/StoreOp always use signless integer types
659 // but OpenCL builtins expect unsigned types
660 // use unsigned types for mangling
661 SmallVector<bool, 2> isUnsigned{};
662 // arg0: pointer to the src/dst address
663 // arg1 - only if store : vector to store
664 // Prepare arguments
665 SmallVector<Value, 2> args{};
666 args.push_back(op.getPtr());
667 argTypes.push_back(op.getPtr().getType());
668 isUnsigned.push_back(true);
669 Type retType;
670 if constexpr (isStore) {
671 args.push_back(op.getVal());
672 argTypes.push_back(op.getVal().getType());
673 isUnsigned.push_back(true);
674 retType = LLVM::LLVMVoidType::get(rewriter.getContext());
675 } else {
676 retType = valOrResTy;
677 }
678 funcName = std::string("_Z") + std::to_string(funcName.size()) + funcName +
679 "PU3AS" +
680 std::to_string(op.getPtr().getType().getAddressSpace());
681 funcName += getTypeMangling(elemType, /*isUnsigned=*/true);
682 if constexpr (isStore)
683 funcName += getTypeMangling(valOrResTy, /*isUnsigned=*/true);
684 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
685
686 LLVM::CallOp call =
687 createDeviceFunctionCall(rewriter, funcName, retType, argTypes, args,
688 {}, funcAttr, op.getOperation());
689 if (std::optional<ArrayAttr> optCacheControls =
690 getCacheControlMetadata(rewriter, op)) {
691 call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
692 }
693 if constexpr (isStore)
694 rewriter.eraseOp(op);
695 else
696 rewriter.replaceOp(op, call->getResult(0));
697 return success();
698 }
699};
700
701template <typename OpType>
702class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> {
703 using OpConversionPattern<OpType>::OpConversionPattern;
704 LogicalResult
705 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
706 ConversionPatternRewriter &rewriter) const override {
707 if (!op->hasAttr("cache_control"))
708 return failure();
709 std::optional<ArrayAttr> optCacheControls =
710 getCacheControlMetadata(rewriter, op);
711 op->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
712 op->removeAttr("cache_control");
713 return success();
714 }
715};
716
717//===----------------------------------------------------------------------===//
718// GPU index id operations
719//===----------------------------------------------------------------------===//
720/*
721// Launch Config ops
722// dimidx - x, y, z - is fixed to i32
723// return type is set by XeVM type converter
724// get_local_id
725xevm::WorkitemIdXOp;
726xevm::WorkitemIdYOp;
727xevm::WorkitemIdZOp;
728// get_local_size
729xevm::WorkgroupDimXOp;
730xevm::WorkgroupDimYOp;
731xevm::WorkgroupDimZOp;
732// get_group_id
733xevm::WorkgroupIdXOp;
734xevm::WorkgroupIdYOp;
735xevm::WorkgroupIdZOp;
736// get_num_groups
737xevm::GridDimXOp;
738xevm::GridDimYOp;
739xevm::GridDimZOp;
740// get_global_id : to be added if needed
741*/
742
743// Helpers to get the OpenCL function name and dimension argument for each op.
744static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdXOp) {
745 return {"get_local_id", 0};
746}
747static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdYOp) {
748 return {"get_local_id", 1};
749}
750static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdZOp) {
751 return {"get_local_id", 2};
752}
753static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimXOp) {
754 return {"get_local_size", 0};
755}
756static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimYOp) {
757 return {"get_local_size", 1};
758}
759static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimZOp) {
760 return {"get_local_size", 2};
761}
762static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdXOp) {
763 return {"get_group_id", 0};
764}
765static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdYOp) {
766 return {"get_group_id", 1};
767}
768static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdZOp) {
769 return {"get_group_id", 2};
770}
771static std::pair<StringRef, int64_t> getConfig(xevm::GridDimXOp) {
772 return {"get_num_groups", 0};
773}
774static std::pair<StringRef, int64_t> getConfig(xevm::GridDimYOp) {
775 return {"get_num_groups", 1};
776}
777static std::pair<StringRef, int64_t> getConfig(xevm::GridDimZOp) {
778 return {"get_num_groups", 2};
779}
780/// Replace `xevm.*` with an `llvm.call` to the corresponding OpenCL func with
781/// a constant argument for the dimension - x, y or z.
782template <typename OpType>
783class LaunchConfigOpToOCLPattern : public OpConversionPattern<OpType> {
784 using OpConversionPattern<OpType>::OpConversionPattern;
785 LogicalResult
786 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
787 ConversionPatternRewriter &rewriter) const override {
788 Location loc = op->getLoc();
789 auto [baseName, dim] = getConfig(op);
790 Type dimTy = rewriter.getI32Type();
791 Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
792 static_cast<int64_t>(dim));
793 std::string func = mangle(baseName, {dimTy}, {true});
794 Type resTy = op.getType();
795 auto call =
796 createDeviceFunctionCall(rewriter, func, resTy, {dimTy}, {dimVal}, {},
797 noUnwindWillReturnAttrs, op.getOperation());
798 constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
799 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
800 /*other=*/noModRef,
801 /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef);
802 call.setMemoryEffectsAttr(memAttr);
803 rewriter.replaceOp(op, call);
804 return success();
805 }
806};
807
808/*
809// Subgroup ops
810// get_sub_group_local_id
811xevm::LaneIdOp;
812// get_sub_group_id
813xevm::SubgroupIdOp;
814// get_sub_group_size
815xevm::SubgroupSizeOp;
816// get_num_sub_groups : to be added if needed
817*/
818
819// Helpers to get the OpenCL function name for each op.
820static StringRef getConfig(xevm::LaneIdOp) { return "get_sub_group_local_id"; }
821static StringRef getConfig(xevm::SubgroupIdOp) { return "get_sub_group_id"; }
822static StringRef getConfig(xevm::SubgroupSizeOp) {
823 return "get_sub_group_size";
824}
825template <typename OpType>
826class SubgroupOpWorkitemOpToOCLPattern : public OpConversionPattern<OpType> {
827 using OpConversionPattern<OpType>::OpConversionPattern;
828 LogicalResult
829 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
830 ConversionPatternRewriter &rewriter) const override {
831 std::string func = mangle(getConfig(op).str(), {});
832 Type resTy = op.getType();
833 auto call =
834 createDeviceFunctionCall(rewriter, func, resTy, {}, {}, {},
835 noUnwindWillReturnAttrs, op.getOperation());
836 constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
837 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
838 /*other=*/noModRef,
839 /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef);
840 call.setMemoryEffectsAttr(memAttr);
841 rewriter.replaceOp(op, call);
842 return success();
843 }
844};
845
846//===----------------------------------------------------------------------===//
847// Pass Definition
848//===----------------------------------------------------------------------===//
849
850struct ConvertXeVMToLLVMPass
851 : public impl::ConvertXeVMToLLVMPassBase<ConvertXeVMToLLVMPass> {
852 using Base::Base;
853
854 void getDependentDialects(DialectRegistry &registry) const override {
855 registry.insert<LLVM::LLVMDialect, XeVMDialect>();
856 }
857
858 void runOnOperation() override {
859 ConversionTarget target(getContext());
860 RewritePatternSet patterns(&getContext());
862 if (failed(applyPartialConversion(getOperation(), target,
863 std::move(patterns))))
864 signalPassFailure();
865 }
866};
867} // namespace
868
869//===----------------------------------------------------------------------===//
870// ConvertToLLVMPatternInterface implementation
871//===----------------------------------------------------------------------===//
872
873namespace {
874/// Implement the interface to convert XeVM to LLVM.
875struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
877 void loadDependentDialects(MLIRContext *context) const final {
878 context->loadDialect<LLVM::LLVMDialect>();
879 }
880
881 /// Hook for derived dialect interface to provide conversion patterns
882 /// and mark dialect legal for the conversion target.
883 void populateConvertToLLVMConversionPatterns(
884 ConversionTarget &target, LLVMTypeConverter &typeConverter,
885 RewritePatternSet &patterns) const final {
887 }
888};
889} // namespace
890
891//===----------------------------------------------------------------------===//
892// Pattern Population
893//===----------------------------------------------------------------------===//
894
895void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target,
897 target.addDynamicallyLegalDialect<LLVM::LLVMDialect>(
898 [](Operation *op) { return !op->hasAttr("cache_control"); });
899 target.addIllegalDialect<XeVMDialect>();
900 patterns.add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
901 LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
902 LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
903 MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern,
904 LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
905 LLVMLoadStoreToOCLPattern<LLVM::StoreOp>,
906 BlockLoadStore1DToOCLPattern<BlockLoadOp>,
907 BlockLoadStore1DToOCLPattern<BlockStoreOp>,
908 LaunchConfigOpToOCLPattern<WorkitemIdXOp>,
909 LaunchConfigOpToOCLPattern<WorkitemIdYOp>,
910 LaunchConfigOpToOCLPattern<WorkitemIdZOp>,
911 LaunchConfigOpToOCLPattern<WorkgroupDimXOp>,
912 LaunchConfigOpToOCLPattern<WorkgroupDimYOp>,
913 LaunchConfigOpToOCLPattern<WorkgroupDimZOp>,
914 LaunchConfigOpToOCLPattern<WorkgroupIdXOp>,
915 LaunchConfigOpToOCLPattern<WorkgroupIdYOp>,
916 LaunchConfigOpToOCLPattern<WorkgroupIdZOp>,
917 LaunchConfigOpToOCLPattern<GridDimXOp>,
918 LaunchConfigOpToOCLPattern<GridDimYOp>,
919 LaunchConfigOpToOCLPattern<GridDimZOp>,
920 SubgroupOpWorkitemOpToOCLPattern<LaneIdOp>,
921 SubgroupOpWorkitemOpToOCLPattern<SubgroupIdOp>,
922 SubgroupOpWorkitemOpToOCLPattern<SubgroupSizeOp>>(
923 patterns.getContext());
924}
925
926void ::mlir::registerConvertXeVMToLLVMInterface(DialectRegistry &registry) {
927 registry.addExtension(+[](MLIRContext *ctx, XeVMDialect *dialect) {
928 dialect->addInterfaces<XeVMToLLVMDialectInterface>();
929 });
930}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:248
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Type getType() const
Return the type of this value.
Definition Value.h:105
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes={}, Type resultType={}, bool isVarArg=false, bool isReserved=false, SymbolTableCollection *symbolTables=nullptr)
Create a FuncOp with signature resultType(paramTypes) and name name`.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
void populateXeVMToLLVMConversionPatterns(ConversionTarget &target, RewritePatternSet &patterns)
const FrozenRewritePatternSet & patterns
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144