MLIR 22.0.0git
XeVMToLLVM.cpp
Go to the documentation of this file.
1//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
16#include "mlir/Pass/Pass.h"
17#include "mlir/Support/LLVM.h"
18#include "llvm/Support/FormatVariadic.h"
19
21#include "mlir/IR/Types.h"
22
23#include "llvm/ADT/TypeSwitch.h"
24
25namespace mlir {
26#define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
27#include "mlir/Conversion/Passes.h.inc"
28} // namespace mlir
29
30using namespace mlir;
31using namespace xevm;
32
33namespace {
34
35struct LLVMFuncAttributeOptions {
36 bool isConvergent = false;
37 bool isNoUnwind = false;
38 bool isWillReturn = false;
39 LLVM::MemoryEffectsAttr memEffectsAttr{};
40};
41static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
42 false, true, false, {}};
43static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
44 false, true, true, {}};
45static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
46 true, true, true, {}};
47
48std::string getTypeMangling(Type ty, bool isUnsigned = false) {
50 .Case([isUnsigned](VectorType ty) -> std::string {
51 return "Dv" + std::to_string(ty.getNumElements()) + "_" +
52 getTypeMangling(ty.getElementType(), isUnsigned);
53 })
54 .Case([](Float16Type) -> std::string { return "Dh"; })
55 .Case([](Float32Type) -> std::string { return "f"; })
56 .Case([](Float64Type) -> std::string { return "d"; })
57 .Case([isUnsigned](IntegerType ty) -> std::string {
58 switch (ty.getWidth()) {
59 case 8:
60 return isUnsigned ? "h" : "c";
61 case 16:
62 return isUnsigned ? "t" : "s";
63 case 32:
64 return isUnsigned ? "j" : "i";
65 case 64:
66 return isUnsigned ? "m" : "l";
67 default:
68 llvm_unreachable("unhandled integer type");
69 }
70 })
71 .DefaultUnreachable("unhandled type for mangling");
72}
73
74std::string mangle(StringRef baseName, ArrayRef<Type> types,
75 ArrayRef<bool> isUnsigned = {}) {
76 assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
77 "Signedness info doesn't match");
78 std::string s;
79 llvm::raw_string_ostream os(s);
80 llvm::SmallDenseMap<Type, unsigned> substitutions;
81 os << "_Z" << baseName.size() << baseName;
82 for (auto [idx, type] : llvm::enumerate(types)) {
83 auto it = substitutions.find(type);
84 if (it != substitutions.end()) {
85 os << "S";
86 // First substitution is `S_`, second is `S0_`, and so on.
87 if (unsigned firstIdx = it->getSecond(); firstIdx > 0)
88 os << firstIdx - 1;
89 os << "_";
90 } else {
91 if (!type.isIntOrFloat())
92 substitutions[type] = substitutions.size();
93 os << getTypeMangling(type, isUnsigned.empty() ? false : isUnsigned[idx]);
94 }
95 }
96 return os.str();
97}
98
99static int32_t getL1CacheControl(LoadCacheControl cc) {
100 int32_t control = 0;
101 switch (cc) {
102 case LoadCacheControl::L1UC_L2UC_L3UC:
103 case LoadCacheControl::L1UC_L2UC_L3C:
104 case LoadCacheControl::L1UC_L2C_L3UC:
105 case LoadCacheControl::L1UC_L2C_L3C:
106 control = 1;
107 break;
108 case LoadCacheControl::L1C_L2UC_L3UC:
109 case LoadCacheControl::L1C_L2UC_L3C:
110 case LoadCacheControl::L1C_L2C_L3UC:
111 case LoadCacheControl::L1C_L2C_L3C:
112 control = 2;
113 break;
114 case LoadCacheControl::L1S_L2UC_L3UC:
115 case LoadCacheControl::L1S_L2UC_L3C:
116 case LoadCacheControl::L1S_L2C_L3UC:
117 case LoadCacheControl::L1S_L2C_L3C:
118 control = 3;
119 break;
120 case LoadCacheControl::INVALIDATE_READ:
121 control = 4;
122 break;
123 }
124 return control;
125}
126
127static int32_t getL1CacheControl(StoreCacheControl cc) {
128 int32_t control = 0;
129 switch (cc) {
130 case StoreCacheControl::L1UC_L2UC_L3UC:
131 case StoreCacheControl::L1UC_L2UC_L3WB:
132 case StoreCacheControl::L1UC_L2WB_L3UC:
133 case StoreCacheControl::L1UC_L2WB_L3WB:
134 control = 1;
135 break;
136 case StoreCacheControl::L1WT_L2UC_L3UC:
137 case StoreCacheControl::L1WT_L2UC_L3WB:
138 case StoreCacheControl::L1WT_L2WB_L3UC:
139 case StoreCacheControl::L1WT_L2WB_L3WB:
140 control = 2;
141 break;
142 case StoreCacheControl::L1S_L2UC_L3UC:
143 case StoreCacheControl::L1S_L2UC_L3WB:
144 case StoreCacheControl::L1S_L2WB_L3UC:
145 case StoreCacheControl::L1S_L2WB_L3WB:
146 control = 3;
147 break;
148 case StoreCacheControl::L1WB_L2UC_L3UC:
149 case StoreCacheControl::L1WB_L2WB_L3UC:
150 case StoreCacheControl::L1WB_L2UC_L3WB:
151 control = 4;
152 break;
153 }
154 return control;
155}
156
157static int32_t getL3CacheControl(LoadCacheControl cc) {
158 int32_t control = 0;
159 switch (cc) {
160 case LoadCacheControl::L1UC_L2UC_L3UC:
161 case LoadCacheControl::L1UC_L2C_L3UC:
162 case LoadCacheControl::L1C_L2UC_L3UC:
163 case LoadCacheControl::L1C_L2C_L3UC:
164 case LoadCacheControl::L1S_L2UC_L3UC:
165 case LoadCacheControl::L1S_L2C_L3UC:
166 control = 1;
167 break;
168 case LoadCacheControl::L1UC_L2UC_L3C:
169 case LoadCacheControl::L1UC_L2C_L3C:
170 case LoadCacheControl::L1C_L2UC_L3C:
171 case LoadCacheControl::L1C_L2C_L3C:
172 case LoadCacheControl::L1S_L2UC_L3C:
173 case LoadCacheControl::L1S_L2C_L3C:
174 control = 2;
175 break;
176 case LoadCacheControl::INVALIDATE_READ:
177 control = 4;
178 break;
179 }
180 return control;
181}
182
183static int32_t getL3CacheControl(StoreCacheControl cc) {
184 int32_t control = 0;
185 switch (cc) {
186 case StoreCacheControl::L1UC_L2UC_L3UC:
187 case StoreCacheControl::L1UC_L2WB_L3UC:
188 case StoreCacheControl::L1WT_L2UC_L3UC:
189 case StoreCacheControl::L1WT_L2WB_L3UC:
190 case StoreCacheControl::L1S_L2UC_L3UC:
191 case StoreCacheControl::L1S_L2WB_L3UC:
192 case StoreCacheControl::L1WB_L2UC_L3UC:
193 case StoreCacheControl::L1WB_L2WB_L3UC:
194 control = 1;
195 break;
196 case StoreCacheControl::L1UC_L2UC_L3WB:
197 case StoreCacheControl::L1UC_L2WB_L3WB:
198 case StoreCacheControl::L1WT_L2UC_L3WB:
199 case StoreCacheControl::L1WT_L2WB_L3WB:
200 case StoreCacheControl::L1S_L2UC_L3WB:
201 case StoreCacheControl::L1S_L2WB_L3WB:
202 case StoreCacheControl::L1WB_L2UC_L3WB:
203 control = 2;
204 break;
205 }
206 return control;
207}
208
209static std::optional<LoadCacheControl> getCacheControl(PrefetchOp op) {
210 return op.getCacheControl();
211}
212
213static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) {
214 return op.getCacheControl();
215}
216
217static std::optional<LoadCacheControl> getCacheControl(BlockLoadOp op) {
218 return op.getCacheControl();
219}
220
221static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) {
222 return op.getCacheControl();
223}
224
225static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) {
226 return op.getCacheControl();
227}
228
229static std::optional<StoreCacheControl> getCacheControl(BlockStoreOp op) {
230 return op.getCacheControl();
231}
232
233static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) {
234 if (op->hasAttr("cache_control")) {
235 auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>("cache_control");
236 if (!attr)
237 return std::nullopt;
238 return std::optional<LoadCacheControl>(attr.getValue());
239 }
240 return std::nullopt;
241}
242
243static std::optional<StoreCacheControl> getCacheControl(LLVM::StoreOp op) {
244 if (op->hasAttr("cache_control")) {
245 auto attr = op->getAttrOfType<xevm::StoreCacheControlAttr>("cache_control");
246 if (!attr)
247 return std::nullopt;
248 return std::optional<StoreCacheControl>(attr.getValue());
249 }
250 return std::nullopt;
251}
252
253template <typename OpType>
254int32_t getL1CacheControl(OpType op) {
255 return getL1CacheControl(*getCacheControl(op));
256}
257
258template <typename OpType>
259int32_t getL3CacheControl(OpType op) {
260 return getL3CacheControl(*getCacheControl(op));
261}
262
263template <typename OpType>
264static std::optional<ArrayAttr>
265getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
266 if (!getCacheControl(op))
267 return {};
268 constexpr int32_t decorationCacheControlArity{4};
269 constexpr int32_t loadCacheControlKey{6442};
270 constexpr int32_t storeCacheControlKey{6443};
271 constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
272 std::is_same_v<OpType, BlockPrefetch2dOp> ||
273 std::is_same_v<OpType, LLVM::LoadOp> ||
274 std::is_same_v<OpType, BlockLoadOp> ||
275 std::is_same_v<OpType, PrefetchOp>;
276 const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
278 controlKey, 0, getL1CacheControl<OpType>(op), 0};
280 controlKey, 1, getL3CacheControl<OpType>(op), 0};
281 auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
282 auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
283
284 SmallVector<Attribute, 2> combinedAttrs = {arrayAttrL1, arrayAttrL3};
285 return rewriter.getArrayAttr(combinedAttrs);
286}
287
288static LLVM::CallOp createDeviceFunctionCall(
289 ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
290 ArrayRef<Type> argTypes, ArrayRef<Value> args,
291 mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs,
292 LLVMFuncAttributeOptions funcAttributeOptions, Operation *op) {
293 auto moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
294 assert(moduleOp && "Expecting module");
295 Location loc = op->getLoc();
296
297 auto funcOpRes =
298 LLVM::lookupOrCreateFn(rewriter, moduleOp, funcName, argTypes, retType);
299 assert(!failed(funcOpRes));
300 LLVM::LLVMFuncOp funcOp = funcOpRes.value();
301 funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
302 funcOp.setConvergent(funcAttributeOptions.isConvergent);
303 funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
304 funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
305
306 if (funcAttributeOptions.memEffectsAttr)
307 funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
308
309 for (auto [idx, attrName] : paramAttrs)
310 funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
311
312 auto callOp = LLVM::CallOp::create(rewriter, loc, funcOp, args);
313 callOp->setAttrs(funcOp->getAttrs());
314
315 return callOp;
316}
317
318class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
319 using OpConversionPattern::OpConversionPattern;
320 LogicalResult
321 matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor,
322 ConversionPatternRewriter &rewriter) const override {
323 if (!op.getC()) {
324 return rewriter.notifyMatchFailure(op, "OCL requires C operand");
325 }
326 auto precisionA = op.getTypes().getA();
327 auto precisionB = op.getTypes().getB();
328 auto precisionC = op.getTypes().getC();
329 auto precisionD = op.getTypes().getD();
330 if (precisionC != precisionD) {
331 return rewriter.notifyMatchFailure(op, "type of C and D need to match");
332 }
333 if (precisionC != xevm::ElemType::S32 &&
334 precisionC != xevm::ElemType::F32 &&
335 precisionC != xevm::ElemType::F16 &&
336 precisionC != xevm::ElemType::BF16) {
337 return rewriter.notifyMatchFailure(
338 op, "type of C and D must be S32, F32, F16 or BF16");
339 }
340 if (precisionA == xevm::ElemType::S32 ||
341 precisionA == xevm::ElemType::F32) {
342 return rewriter.notifyMatchFailure(op, "type of A cannot be S32 or F32");
343 }
344 if (precisionB == xevm::ElemType::S32 ||
345 precisionB == xevm::ElemType::F32) {
346 return rewriter.notifyMatchFailure(op, "type of B cannot be S32 or F32");
347 }
348 constexpr uint32_t bitWidthPackedA{16};
349 constexpr uint32_t bitWidthPackedB{32};
350 auto loc = op.getLoc();
351
352 auto castIfNeeded = [&](Value val, Type packedType) -> Value {
353 VectorType origTy = cast<VectorType>(val.getType());
354 const uint32_t vecBitSize =
355 origTy.getNumElements() *
356 origTy.getElementType().getIntOrFloatBitWidth();
357 VectorType newTy = VectorType::get(
358 vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
359 if (origTy != newTy)
360 val = LLVM::BitcastOp::create(rewriter, loc, newTy, val);
361 return val;
362 };
363
364 Value a = op.getA();
365 Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
366 ? cast<Type>(rewriter.getF32Type())
367 : rewriter.getIntegerType(bitWidthPackedA);
368 a = castIfNeeded(a, packedAType);
369
370 Value b = op.getB();
371 Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
372 ? cast<Type>(rewriter.getF32Type())
373 : rewriter.getIntegerType(bitWidthPackedB);
374 b = castIfNeeded(b, packedBType);
375
376 Value c = op.getC();
377 VectorType cOrigTy = cast<VectorType>(c.getType());
378 VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
379 assert(cOrigTy == resOrigTy && "Accumulator and result type mismatch");
380 // OCL builtins encode bfloat16 as int16
381 VectorType cTy =
382 cOrigTy.getElementType().isBF16()
383 ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
384 : cOrigTy;
385 VectorType resTy = cTy;
386 if (cOrigTy != cTy)
387 c = LLVM::BitcastOp::create(rewriter, loc, cTy, c);
388
389 constexpr int32_t systolicDepth{8};
390 std::string fnName =
391 llvm::formatv("intel_sub_group_{0}_{1}_matrix_mad_k{2}",
392 stringifyElemType(op.getTypes().getA()).str(),
393 stringifyElemType(op.getTypes().getB()).str(),
394 systolicDepth *
395 getNumOperandsPerDword(op.getTypes().getA()))
396 .str();
397 SmallVector<Type> argTypes{a.getType(), b.getType(), cTy};
398 fnName = mangle(fnName, argTypes);
399 SmallVector<Value> args{a, b, c};
400
401 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
402 /*other=*/LLVM::ModRefInfo::NoModRef,
403 /*argMem=*/LLVM::ModRefInfo::NoModRef,
404 /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
405 /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
406 /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
407 /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
408 auto funcAttrs = convergentNoUnwindWillReturnAttrs;
409 funcAttrs.memEffectsAttr = memAttr;
410 Value result =
411 createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
412 funcAttrs, op.getOperation())
413 ->getResult(0);
414
415 if (resOrigTy != resTy)
416 result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy, result);
417
418 rewriter.replaceOp(op, result);
419 return success();
420 }
421
422private:
423 static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
424 switch (pTy) {
425 case xevm::ElemType::TF32:
426 return 1;
427 case xevm::ElemType::BF16:
428 case xevm::ElemType::F16:
429 return 2;
430 case xevm::ElemType::U8:
431 case xevm::ElemType::S8:
432 return 4;
433 default:
434 llvm_unreachable("unsupported xevm::ElemType");
435 }
436 }
437};
438
439class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
440 using OpConversionPattern::OpConversionPattern;
441 LogicalResult
442 matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor,
443 ConversionPatternRewriter &rewriter) const override {
444 auto loc = op.getLoc();
445 const std::string fnName{"_Z8prefetchPU3AS1Kcm"};
446 Value one =
447 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), 1);
448 SmallVector<Value> args{op.getPtr(), one};
449 SmallVector<Type> argTypes;
450 for (auto arg : args)
451 argTypes.push_back(arg.getType());
452 auto funcAttr = noUnwindAttrs;
453 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
454 /*other=*/LLVM::ModRefInfo::NoModRef,
455 /*argMem=*/LLVM::ModRefInfo::Ref,
456 /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
457 /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
458 /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
459 /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
460 funcAttr.memEffectsAttr = memAttr;
461
462 LLVM::CallOp call = createDeviceFunctionCall(
463 rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()),
464 argTypes, args, {}, funcAttr, op.getOperation());
465 if (std::optional<ArrayAttr> optCacheControls =
466 getCacheControlMetadata(rewriter, op))
467 call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
468 rewriter.eraseOp(op);
469 return success();
470 }
471};
472
473class MemfenceToOCLPattern : public OpConversionPattern<MemfenceOp> {
474 using OpConversionPattern::OpConversionPattern;
475 LogicalResult
476 matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor,
477 ConversionPatternRewriter &rewriter) const override {
478 auto loc = op.getLoc();
479 const std::string fnName{"atomic_work_item_fence"};
480 int memScope, addrSpace;
481 switch (op.getAddrspace()) {
482 case xevm::AddrSpace::SHARED:
483 addrSpace = 1; // CLK_LOCAL_MEM_FENCE
484 break;
485 case xevm::AddrSpace::GLOBAL:
486 addrSpace = 2; // CLK_GLOBAL_MEM_FENCE
487 break;
488 default:
489 // GENERIC is not supported in OpenCL
490 return rewriter.notifyMatchFailure(
491 op, "Fence only supports global and shared address spaces.");
492 }
493 switch (op.getScope()) {
494 case xevm::MemScope::WORKGROUP:
495 memScope = 1;
496 break;
497 case xevm::MemScope::DEVICE:
498 memScope = 2;
499 break;
500 default:
501 // CLUSTER and SYSTEM are not supported in OpenCL
502 return rewriter.notifyMatchFailure(
503 op, "Fence only supports workgroup and device memory scopes.");
504 }
505 Type i32Type = rewriter.getI32Type();
506 Value acqRel = LLVM::ConstantOp::create(rewriter, loc, i32Type, 4);
507 Value memScopeConst =
508 LLVM::ConstantOp::create(rewriter, loc, i32Type, memScope);
509 Value addrSpaceConst =
510 LLVM::ConstantOp::create(rewriter, loc, i32Type, addrSpace);
511 SmallVector<Value> args{addrSpaceConst, acqRel, memScopeConst};
512 SmallVector<Type> argTypes{3, i32Type};
513 createDeviceFunctionCall(rewriter, mangle(fnName, argTypes),
514 LLVM::LLVMVoidType::get(rewriter.getContext()),
515 argTypes, args, {}, noUnwindAttrs,
516 op.getOperation());
517 rewriter.eraseOp(op);
518 return success();
519 }
520};
521template <typename OpType>
522class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
523 using OpConversionPattern<OpType>::OpConversionPattern;
524 LogicalResult
525 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
526 ConversionPatternRewriter &rewriter) const override {
527 constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp>;
528 constexpr bool isPrefetch = std::is_same_v<OpType, BlockPrefetch2dOp>;
529
530 auto loc = op.getLoc();
531 VectorType vecType;
532 bool packReg = false;
533 bool transpose = false;
534 if constexpr (isLoad) {
535 vecType = op.getRes().getType();
536 packReg = op.getPackRegister();
537 transpose = op.getTranspose();
538 } else if constexpr (!isPrefetch) {
539 vecType = op.getStoredVal().getType();
540 }
541
542 auto i32Type = rewriter.getI32Type();
543 Value byteCoord =
544 LLVM::UndefOp::create(rewriter, loc, VectorType::get(2, i32Type));
545 Value zero = LLVM::ConstantOp::create(rewriter, loc, i32Type, 0);
546 Value one = LLVM::ConstantOp::create(rewriter, loc, i32Type, 1);
547 byteCoord = LLVM::InsertElementOp::create(
548 rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
549 byteCoord = LLVM::InsertElementOp::create(
550 rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
551 SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
552 op.getBasePitch(), byteCoord};
553 SmallVector<Type> retTypes;
554 Value spvLoadDstPtr;
555 std::string funcName{"intel_sub_group_2d_block_"};
556 std::string bitWidthId;
557 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
558 SmallVector<std::pair<unsigned, StringRef>, 4> paramAttrs;
559 if constexpr (isPrefetch) { // Prefetch
560 funcName += "prefetch";
561 paramAttrs = {std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName())};
562 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
563 /*other=*/LLVM::ModRefInfo::NoModRef,
564 /*argMem=*/LLVM::ModRefInfo::Ref,
565 /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
566 /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
567 /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
568 /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
569 funcAttr = noUnwindAttrs;
570 funcAttr.memEffectsAttr = memAttr;
571 } else {
572 auto vecElemType = vecType.getElementType();
573 auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth();
574 Value numElems = LLVM::ConstantOp::create(rewriter, loc, i32Type,
575 vecType.getNumElements());
576 auto dstOrSrcPtr = LLVM::AllocaOp::create(
577 rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()),
578 vecElemType, numElems);
579 args.push_back(dstOrSrcPtr);
580 if constexpr (isLoad) { // Load
581 funcName += "read";
582 bitWidthId = getTypeMangling(vecElemType, /*isUnsigned=*/true);
583 if (packReg)
584 funcName += "_transform";
585 else if (transpose)
586 funcName += "_transpose";
587 spvLoadDstPtr = dstOrSrcPtr;
588 retTypes.push_back(vecType);
589 paramAttrs = {
590 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
591 std::make_pair(0, LLVM::LLVMDialect::getReadonlyAttrName()),
592 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
593 std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()),
594 };
595 } else { // Store
596 funcName += "write";
597 bitWidthId = (vecElemBitWidth == 32)
598 ? "j"
599 : ((vecElemBitWidth == 16) ? "t" : "h");
600 LLVM::StoreOp::create(rewriter, loc, op.getStoredVal(), dstOrSrcPtr);
601 paramAttrs = {
602 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
603 std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()),
604 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
605 std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()),
606 };
607 }
608 }
609
610 funcName =
611 llvm::formatv("{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(),
612 op.getTileHeight(), op.getTileWidth(), op.getVBlocks())
613 .str();
614 std::string prefetchCode("");
615 if (!isPrefetch)
616 prefetchCode += "P";
617 funcName = llvm::formatv("_Z{0}{1}PU3AS1viiiDv2_i{2}{3}", funcName.size(),
618 funcName, prefetchCode, bitWidthId)
619 .str();
620 SmallVector<Type> argTypes;
621 for (auto arg : args) {
622 argTypes.push_back(arg.getType());
623 }
624 LLVM::CallOp call = createDeviceFunctionCall(
625 rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
626 argTypes, args, paramAttrs, funcAttr, op.getOperation());
627 if (std::optional<ArrayAttr> optCacheControls =
628 getCacheControlMetadata(rewriter, op)) {
629 call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
630 }
631 if constexpr (isLoad)
632 rewriter.replaceOp(
633 op, LLVM::LoadOp::create(rewriter, loc, vecType, spvLoadDstPtr));
634 else
635 rewriter.eraseOp(op);
636 return success();
637 }
638};
639
640template <typename OpType>
641class BlockLoadStore1DToOCLPattern : public OpConversionPattern<OpType> {
642 using OpConversionPattern<OpType>::OpConversionPattern;
643 LogicalResult
644 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
645 ConversionPatternRewriter &rewriter) const override {
646 constexpr bool isStore = std::is_same_v<OpType, xevm::BlockStoreOp>;
647 // Get OpenCL function name
648 // https://registry.khronos.org/OpenCL/extensions/
649 // intel/cl_intel_subgroup_local_block_io.html
650 std::string funcName{"intel_sub_group_block_"};
651 // Value or Result type can be vector or scalar
652 Type valOrResTy;
653 if constexpr (isStore) {
654 funcName += "write_u";
655 valOrResTy = op.getVal().getType();
656 } else {
657 funcName += "read_u";
658 valOrResTy = op.getType();
659 }
660 // Get element type of the vector/scalar
661 VectorType vecTy = dyn_cast<VectorType>(valOrResTy);
662 Type elemType = vecTy ? vecTy.getElementType() : valOrResTy;
663 funcName += getTypeMangling(elemType);
664 if (vecTy)
665 funcName += std::to_string(vecTy.getNumElements());
666 SmallVector<Type, 2> argTypes{};
667 // XeVM BlockLoad/StoreOp always use signless integer types
668 // but OpenCL builtins expect unsigned types
669 // use unsigned types for mangling
670 SmallVector<bool, 2> isUnsigned{};
671 // arg0: pointer to the src/dst address
672 // arg1 - only if store : vector to store
673 // Prepare arguments
674 SmallVector<Value, 2> args{};
675 args.push_back(op.getPtr());
676 argTypes.push_back(op.getPtr().getType());
677 isUnsigned.push_back(true);
678 Type retType;
679 if constexpr (isStore) {
680 args.push_back(op.getVal());
681 argTypes.push_back(op.getVal().getType());
682 isUnsigned.push_back(true);
683 retType = LLVM::LLVMVoidType::get(rewriter.getContext());
684 } else {
685 retType = valOrResTy;
686 }
687 funcName = std::string("_Z") + std::to_string(funcName.size()) + funcName +
688 "PU3AS" +
689 std::to_string(op.getPtr().getType().getAddressSpace());
690 funcName += getTypeMangling(elemType, /*isUnsigned=*/true);
691 if constexpr (isStore)
692 funcName += getTypeMangling(valOrResTy, /*isUnsigned=*/true);
693 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
694
695 LLVM::CallOp call =
696 createDeviceFunctionCall(rewriter, funcName, retType, argTypes, args,
697 {}, funcAttr, op.getOperation());
698 if (std::optional<ArrayAttr> optCacheControls =
699 getCacheControlMetadata(rewriter, op)) {
700 call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
701 }
702 if constexpr (isStore)
703 rewriter.eraseOp(op);
704 else
705 rewriter.replaceOp(op, call->getResult(0));
706 return success();
707 }
708};
709
710template <typename OpType>
711class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> {
712 using OpConversionPattern<OpType>::OpConversionPattern;
713 LogicalResult
714 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
715 ConversionPatternRewriter &rewriter) const override {
716 if (!op->hasAttr("cache_control"))
717 return failure();
718 std::optional<ArrayAttr> optCacheControls =
719 getCacheControlMetadata(rewriter, op);
720 op->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
721 op->removeAttr("cache_control");
722 return success();
723 }
724};
725
726//===----------------------------------------------------------------------===//
727// GPU index id operations
728//===----------------------------------------------------------------------===//
729/*
730// Launch Config ops
731// dimidx - x, y, z - is fixed to i32
732// return type is set by XeVM type converter
733// get_local_id
734xevm::WorkitemIdXOp;
735xevm::WorkitemIdYOp;
736xevm::WorkitemIdZOp;
737// get_local_size
738xevm::WorkgroupDimXOp;
739xevm::WorkgroupDimYOp;
740xevm::WorkgroupDimZOp;
741// get_group_id
742xevm::WorkgroupIdXOp;
743xevm::WorkgroupIdYOp;
744xevm::WorkgroupIdZOp;
745// get_num_groups
746xevm::GridDimXOp;
747xevm::GridDimYOp;
748xevm::GridDimZOp;
749// get_global_id : to be added if needed
750*/
751
752// Helpers to get the OpenCL function name and dimension argument for each op.
753static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdXOp) {
754 return {"get_local_id", 0};
755}
756static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdYOp) {
757 return {"get_local_id", 1};
758}
759static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdZOp) {
760 return {"get_local_id", 2};
761}
762static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimXOp) {
763 return {"get_local_size", 0};
764}
765static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimYOp) {
766 return {"get_local_size", 1};
767}
768static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimZOp) {
769 return {"get_local_size", 2};
770}
771static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdXOp) {
772 return {"get_group_id", 0};
773}
774static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdYOp) {
775 return {"get_group_id", 1};
776}
777static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdZOp) {
778 return {"get_group_id", 2};
779}
780static std::pair<StringRef, int64_t> getConfig(xevm::GridDimXOp) {
781 return {"get_num_groups", 0};
782}
783static std::pair<StringRef, int64_t> getConfig(xevm::GridDimYOp) {
784 return {"get_num_groups", 1};
785}
786static std::pair<StringRef, int64_t> getConfig(xevm::GridDimZOp) {
787 return {"get_num_groups", 2};
788}
789/// Replace `xevm.*` with an `llvm.call` to the corresponding OpenCL func with
790/// a constant argument for the dimension - x, y or z.
791template <typename OpType>
792class LaunchConfigOpToOCLPattern : public OpConversionPattern<OpType> {
793 using OpConversionPattern<OpType>::OpConversionPattern;
794 LogicalResult
795 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
796 ConversionPatternRewriter &rewriter) const override {
797 Location loc = op->getLoc();
798 auto [baseName, dim] = getConfig(op);
799 Type dimTy = rewriter.getI32Type();
800 Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
801 static_cast<int64_t>(dim));
802 std::string func = mangle(baseName, {dimTy}, {true});
803 Type resTy = op.getType();
804 auto call =
805 createDeviceFunctionCall(rewriter, func, resTy, {dimTy}, {dimVal}, {},
806 noUnwindWillReturnAttrs, op.getOperation());
807 constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
808 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
809 /*other=*/noModRef,
810 /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef,
811 /*errnoMem=*/noModRef,
812 /*targetMem0=*/noModRef,
813 /*targetMem1=*/noModRef);
814 call.setMemoryEffectsAttr(memAttr);
815 rewriter.replaceOp(op, call);
816 return success();
817 }
818};
819
820/*
821// Subgroup ops
822// get_sub_group_local_id
823xevm::LaneIdOp;
824// get_sub_group_id
825xevm::SubgroupIdOp;
826// get_sub_group_size
827xevm::SubgroupSizeOp;
828// get_num_sub_groups : to be added if needed
829*/
830
831// Helpers to get the OpenCL function name for each op.
832static StringRef getConfig(xevm::LaneIdOp) { return "get_sub_group_local_id"; }
833static StringRef getConfig(xevm::SubgroupIdOp) { return "get_sub_group_id"; }
834static StringRef getConfig(xevm::SubgroupSizeOp) {
835 return "get_sub_group_size";
836}
837template <typename OpType>
838class SubgroupOpWorkitemOpToOCLPattern : public OpConversionPattern<OpType> {
839 using OpConversionPattern<OpType>::OpConversionPattern;
840 LogicalResult
841 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
842 ConversionPatternRewriter &rewriter) const override {
843 std::string func = mangle(getConfig(op).str(), {});
844 Type resTy = op.getType();
845 auto call =
846 createDeviceFunctionCall(rewriter, func, resTy, {}, {}, {},
847 noUnwindWillReturnAttrs, op.getOperation());
848 constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
849 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
850 /*other=*/noModRef,
851 /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef,
852 /*errnoMem=*/noModRef,
853 /*targetMem0=*/noModRef,
854 /*targetMem1=*/noModRef);
855 call.setMemoryEffectsAttr(memAttr);
856 rewriter.replaceOp(op, call);
857 return success();
858 }
859};
860
861//===----------------------------------------------------------------------===//
862// Pass Definition
863//===----------------------------------------------------------------------===//
864
865struct ConvertXeVMToLLVMPass
866 : public impl::ConvertXeVMToLLVMPassBase<ConvertXeVMToLLVMPass> {
867 using Base::Base;
868
869 void getDependentDialects(DialectRegistry &registry) const override {
870 registry.insert<LLVM::LLVMDialect, XeVMDialect>();
871 }
872
873 void runOnOperation() override {
874 ConversionTarget target(getContext());
875 RewritePatternSet patterns(&getContext());
877 if (failed(applyPartialConversion(getOperation(), target,
878 std::move(patterns))))
879 signalPassFailure();
880 }
881};
882} // namespace
883
884//===----------------------------------------------------------------------===//
885// ConvertToLLVMPatternInterface implementation
886//===----------------------------------------------------------------------===//
887
888namespace {
889/// Implement the interface to convert XeVM to LLVM.
890struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
892 void loadDependentDialects(MLIRContext *context) const final {
893 context->loadDialect<LLVM::LLVMDialect>();
894 }
895
896 /// Hook for derived dialect interface to provide conversion patterns
897 /// and mark dialect legal for the conversion target.
898 void populateConvertToLLVMConversionPatterns(
899 ConversionTarget &target, LLVMTypeConverter &typeConverter,
900 RewritePatternSet &patterns) const final {
902 }
903};
904} // namespace
905
906//===----------------------------------------------------------------------===//
907// Pattern Population
908//===----------------------------------------------------------------------===//
909
910void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target,
912 target.addDynamicallyLegalDialect<LLVM::LLVMDialect>(
913 [](Operation *op) { return !op->hasAttr("cache_control"); });
914 target.addIllegalDialect<XeVMDialect>();
915 patterns.add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
916 LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
917 LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
918 MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern,
919 LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
920 LLVMLoadStoreToOCLPattern<LLVM::StoreOp>,
921 BlockLoadStore1DToOCLPattern<BlockLoadOp>,
922 BlockLoadStore1DToOCLPattern<BlockStoreOp>,
923 LaunchConfigOpToOCLPattern<WorkitemIdXOp>,
924 LaunchConfigOpToOCLPattern<WorkitemIdYOp>,
925 LaunchConfigOpToOCLPattern<WorkitemIdZOp>,
926 LaunchConfigOpToOCLPattern<WorkgroupDimXOp>,
927 LaunchConfigOpToOCLPattern<WorkgroupDimYOp>,
928 LaunchConfigOpToOCLPattern<WorkgroupDimZOp>,
929 LaunchConfigOpToOCLPattern<WorkgroupIdXOp>,
930 LaunchConfigOpToOCLPattern<WorkgroupIdYOp>,
931 LaunchConfigOpToOCLPattern<WorkgroupIdZOp>,
932 LaunchConfigOpToOCLPattern<GridDimXOp>,
933 LaunchConfigOpToOCLPattern<GridDimYOp>,
934 LaunchConfigOpToOCLPattern<GridDimZOp>,
935 SubgroupOpWorkitemOpToOCLPattern<LaneIdOp>,
936 SubgroupOpWorkitemOpToOCLPattern<SubgroupIdOp>,
937 SubgroupOpWorkitemOpToOCLPattern<SubgroupSizeOp>>(
938 patterns.getContext());
939}
940
941void ::mlir::registerConvertXeVMToLLVMInterface(DialectRegistry &registry) {
942 registry.addExtension(+[](MLIRContext *ctx, XeVMDialect *dialect) {
943 dialect->addInterfaces<XeVMToLLVMDialectInterface>();
944 });
945}
return success()
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:248
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Type getType() const
Return the type of this value.
Definition Value.h:105
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes={}, Type resultType={}, bool isVarArg=false, bool isReserved=false, SymbolTableCollection *symbolTables=nullptr)
Create a FuncOp with signature resultType(paramTypes) and name name`.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
void populateXeVMToLLVMConversionPatterns(ConversionTarget &target, RewritePatternSet &patterns)
const FrozenRewritePatternSet & patterns
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144