MLIR 23.0.0git
XeVMToLLVM.cpp
Go to the documentation of this file.
1//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
16#include "mlir/Pass/Pass.h"
17#include "mlir/Support/LLVM.h"
18#include "llvm/ADT/ArrayRef.h"
19#include "llvm/Support/FormatVariadic.h"
20
22#include "mlir/IR/Matchers.h"
23#include "mlir/IR/Types.h"
25
26#include "llvm/ADT/TypeSwitch.h"
27
28namespace mlir {
29#define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
30#include "mlir/Conversion/Passes.h.inc"
31} // namespace mlir
32
33using namespace mlir;
34using namespace xevm;
35
36namespace {
37
38struct LLVMFuncAttributeOptions {
39 bool isConvergent = false;
40 bool isNoUnwind = false;
41 bool isWillReturn = false;
42 LLVM::MemoryEffectsAttr memEffectsAttr{};
43};
44static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
45 false, true, false, {}};
46static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
47 false, true, true, {}};
48static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
49 true, true, true, {}};
50
51std::string getTypeMangling(Type ty, bool isUnsigned = false) {
53 .Case([isUnsigned](VectorType ty) -> std::string {
54 return "Dv" + std::to_string(ty.getNumElements()) + "_" +
55 getTypeMangling(ty.getElementType(), isUnsigned);
56 })
57 .Case([](Float16Type) -> std::string { return "Dh"; })
58 .Case([](Float32Type) -> std::string { return "f"; })
59 .Case([](Float64Type) -> std::string { return "d"; })
60 .Case([isUnsigned](IntegerType ty) -> std::string {
61 switch (ty.getWidth()) {
62 case 8:
63 return isUnsigned ? "h" : "c";
64 case 16:
65 return isUnsigned ? "t" : "s";
66 case 32:
67 return isUnsigned ? "j" : "i";
68 case 64:
69 return isUnsigned ? "m" : "l";
70 default:
71 llvm_unreachable("unhandled integer type");
72 }
73 })
74 .DefaultUnreachable("unhandled type for mangling");
75}
76
77std::string mangle(StringRef baseName, ArrayRef<Type> types,
78 ArrayRef<bool> isUnsigned = {}) {
79 assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
80 "Signedness info doesn't match");
81 std::string s;
82 llvm::raw_string_ostream os(s);
83 llvm::SmallDenseMap<Type, unsigned> substitutions;
84 os << "_Z" << baseName.size() << baseName;
85 for (auto [idx, type] : llvm::enumerate(types)) {
86 auto it = substitutions.find(type);
87 if (it != substitutions.end()) {
88 os << "S";
89 // First substitution is `S_`, second is `S0_`, and so on.
90 if (unsigned firstIdx = it->getSecond(); firstIdx > 0)
91 os << firstIdx - 1;
92 os << "_";
93 } else {
94 if (!type.isIntOrFloat())
95 substitutions[type] = substitutions.size();
96 os << getTypeMangling(type, isUnsigned.empty() ? false : isUnsigned[idx]);
97 }
98 }
99 return os.str();
100}
101
102static int32_t getL1CacheControl(LoadCacheControl cc) {
103 int32_t control = 0;
104 switch (cc) {
105 case LoadCacheControl::L1C_L2UC_L3UC:
106 case LoadCacheControl::L1C_L2UC_L3C:
107 case LoadCacheControl::L1C_L2C_L3UC:
108 case LoadCacheControl::L1C_L2C_L3C:
109 control = 1;
110 break;
111 case LoadCacheControl::L1S_L2UC_L3UC:
112 case LoadCacheControl::L1S_L2UC_L3C:
113 case LoadCacheControl::L1S_L2C_L3UC:
114 case LoadCacheControl::L1S_L2C_L3C:
115 control = 2;
116 break;
117 case LoadCacheControl::INVALIDATE_READ:
118 control = 3;
119 break;
120 default:
121 break;
122 }
123 return control;
124}
125
126static int32_t getL1CacheControl(StoreCacheControl cc) {
127 int32_t control = 0;
128 switch (cc) {
129 case StoreCacheControl::L1WT_L2UC_L3UC:
130 case StoreCacheControl::L1WT_L2UC_L3WB:
131 case StoreCacheControl::L1WT_L2WB_L3UC:
132 case StoreCacheControl::L1WT_L2WB_L3WB:
133 control = 1;
134 break;
135 case StoreCacheControl::L1WB_L2UC_L3UC:
136 case StoreCacheControl::L1WB_L2WB_L3UC:
137 case StoreCacheControl::L1WB_L2UC_L3WB:
138 control = 2;
139 break;
140 case StoreCacheControl::L1S_L2UC_L3UC:
141 case StoreCacheControl::L1S_L2UC_L3WB:
142 case StoreCacheControl::L1S_L2WB_L3UC:
143 case StoreCacheControl::L1S_L2WB_L3WB:
144 control = 3;
145 break;
146 default:
147 break;
148 }
149 return control;
150}
151
152static int32_t getL3CacheControl(LoadCacheControl cc) {
153 int32_t control = 0;
154 switch (cc) {
155 case LoadCacheControl::L1UC_L2UC_L3C:
156 case LoadCacheControl::L1UC_L2C_L3C:
157 case LoadCacheControl::L1C_L2UC_L3C:
158 case LoadCacheControl::L1C_L2C_L3C:
159 case LoadCacheControl::L1S_L2UC_L3C:
160 case LoadCacheControl::L1S_L2C_L3C:
161 control = 1;
162 break;
163 case LoadCacheControl::INVALIDATE_READ:
164 control = 3;
165 break;
166 default:
167 break;
168 }
169 return control;
170}
171
172static int32_t getL3CacheControl(StoreCacheControl cc) {
173 int32_t control = 0;
174 switch (cc) {
175 case StoreCacheControl::L1UC_L2UC_L3WB:
176 case StoreCacheControl::L1UC_L2WB_L3WB:
177 case StoreCacheControl::L1WT_L2UC_L3WB:
178 case StoreCacheControl::L1WT_L2WB_L3WB:
179 case StoreCacheControl::L1S_L2UC_L3WB:
180 case StoreCacheControl::L1S_L2WB_L3WB:
181 case StoreCacheControl::L1WB_L2UC_L3WB:
182 control = 2;
183 break;
184 default:
185 break;
186 }
187 return control;
188}
189
190static std::optional<LoadCacheControl> getCacheControl(PrefetchOp op) {
191 return op.getCacheControl();
192}
193
194static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) {
195 return op.getCacheControl();
196}
197
198static std::optional<LoadCacheControl> getCacheControl(BlockLoadOp op) {
199 return op.getCacheControl();
200}
201
202static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) {
203 return op.getCacheControl();
204}
205
206static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) {
207 return op.getCacheControl();
208}
209
210static std::optional<StoreCacheControl> getCacheControl(BlockStoreOp op) {
211 return op.getCacheControl();
212}
213
214static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) {
215 if (op->hasAttr("cache_control")) {
216 auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>("cache_control");
217 if (!attr)
218 return std::nullopt;
219 return std::optional<LoadCacheControl>(attr.getValue());
220 }
221 return std::nullopt;
222}
223
224static std::optional<StoreCacheControl> getCacheControl(LLVM::StoreOp op) {
225 if (op->hasAttr("cache_control")) {
226 auto attr = op->getAttrOfType<xevm::StoreCacheControlAttr>("cache_control");
227 if (!attr)
228 return std::nullopt;
229 return std::optional<StoreCacheControl>(attr.getValue());
230 }
231 return std::nullopt;
232}
233
234template <typename OpType>
235int32_t getL1CacheControl(OpType op) {
236 return getL1CacheControl(*getCacheControl(op));
237}
238
239template <typename OpType>
240int32_t getL3CacheControl(OpType op) {
241 return getL3CacheControl(*getCacheControl(op));
242}
243
244template <typename OpType>
245static std::optional<ArrayAttr>
246getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
247 if (!getCacheControl(op))
248 return {};
249 constexpr int32_t decorationCacheControlArity{3};
250 constexpr int32_t loadCacheControlKey{6442};
251 constexpr int32_t storeCacheControlKey{6443};
252 constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
253 std::is_same_v<OpType, BlockPrefetch2dOp> ||
254 std::is_same_v<OpType, LLVM::LoadOp> ||
255 std::is_same_v<OpType, BlockLoadOp> ||
256 std::is_same_v<OpType, PrefetchOp>;
257 const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
259 controlKey, 0, getL1CacheControl<OpType>(op)};
261 controlKey, 1, getL3CacheControl<OpType>(op)};
262 auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
263 auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
264
265 SmallVector<Attribute, 2> combinedAttrs = {arrayAttrL1, arrayAttrL3};
266 return rewriter.getArrayAttr(combinedAttrs);
267}
268
269static LLVM::CallOp createDeviceFunctionCall(
270 ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
271 ArrayRef<Type> argTypes, ArrayRef<Value> args,
272 mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs,
273 LLVMFuncAttributeOptions funcAttributeOptions, Operation *op) {
274 auto *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
275 assert(moduleOp && "Expecting module");
276 Location loc = op->getLoc();
277
278 auto funcOpRes =
279 LLVM::lookupOrCreateFn(rewriter, moduleOp, funcName, argTypes, retType);
280 assert(!failed(funcOpRes));
281 LLVM::LLVMFuncOp funcOp = funcOpRes.value();
282 funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
283 funcOp.setConvergent(funcAttributeOptions.isConvergent);
284 funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
285 funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
286
287 if (funcAttributeOptions.memEffectsAttr)
288 funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
289
290 for (auto [idx, attrName] : paramAttrs)
291 funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
292
293 auto callOp = LLVM::CallOp::create(rewriter, loc, funcOp, args);
294 callOp->setAttrs(funcOp->getAttrs());
295
296 return callOp;
297}
298
299class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
300 using OpConversionPattern::OpConversionPattern;
301 LogicalResult
302 matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor,
303 ConversionPatternRewriter &rewriter) const override {
304 if (!op.getC()) {
305 return rewriter.notifyMatchFailure(op, "OCL requires C operand");
306 }
307 auto precisionA = op.getTypes().getA();
308 auto precisionB = op.getTypes().getB();
309 auto precisionC = op.getTypes().getC();
310 auto precisionD = op.getTypes().getD();
311 if (precisionC != precisionD) {
312 return rewriter.notifyMatchFailure(op, "type of C and D need to match");
313 }
314 if (precisionC != xevm::ElemType::S32 &&
315 precisionC != xevm::ElemType::F32 &&
316 precisionC != xevm::ElemType::F16 &&
317 precisionC != xevm::ElemType::BF16) {
318 return rewriter.notifyMatchFailure(
319 op, "type of C and D must be S32, F32, F16 or BF16");
320 }
321 if (precisionA == xevm::ElemType::S32 ||
322 precisionA == xevm::ElemType::F32) {
323 return rewriter.notifyMatchFailure(op, "type of A cannot be S32 or F32");
324 }
325 if (precisionB == xevm::ElemType::S32 ||
326 precisionB == xevm::ElemType::F32) {
327 return rewriter.notifyMatchFailure(op, "type of B cannot be S32 or F32");
328 }
329 constexpr uint32_t bitWidthPackedA{16};
330 constexpr uint32_t bitWidthPackedB{32};
331 auto loc = op.getLoc();
332
333 auto castIfNeeded = [&](Value val, Type packedType) -> Value {
334 VectorType origTy = cast<VectorType>(val.getType());
335 const uint32_t vecBitSize =
336 origTy.getNumElements() *
337 origTy.getElementType().getIntOrFloatBitWidth();
338 VectorType newTy = VectorType::get(
339 vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
340 if (origTy != newTy)
341 val = LLVM::BitcastOp::create(rewriter, loc, newTy, val);
342 return val;
343 };
344
345 Value a = op.getA();
346 Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
347 ? cast<Type>(rewriter.getF32Type())
348 : rewriter.getIntegerType(bitWidthPackedA);
349 a = castIfNeeded(a, packedAType);
350
351 Value b = op.getB();
352 Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
353 ? cast<Type>(rewriter.getF32Type())
354 : rewriter.getIntegerType(bitWidthPackedB);
355 b = castIfNeeded(b, packedBType);
356
357 Value c = op.getC();
358 VectorType cOrigTy = cast<VectorType>(c.getType());
359 VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
360 assert(cOrigTy == resOrigTy && "Accumulator and result type mismatch");
361 // OCL builtins encode bfloat16 as int16
362 VectorType cTy =
363 cOrigTy.getElementType().isBF16()
364 ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
365 : cOrigTy;
366 VectorType resTy = cTy;
367 if (cOrigTy != cTy)
368 c = LLVM::BitcastOp::create(rewriter, loc, cTy, c);
369
370 constexpr int32_t systolicDepth{8};
371 std::string fnName =
372 llvm::formatv("intel_sub_group_{0}_{1}_matrix_mad_k{2}",
373 stringifyElemType(op.getTypes().getA()).str(),
374 stringifyElemType(op.getTypes().getB()).str(),
375 systolicDepth *
376 getNumOperandsPerDword(op.getTypes().getA()))
377 .str();
378 SmallVector<Type> argTypes{a.getType(), b.getType(), cTy};
379 fnName = mangle(fnName, argTypes);
380 SmallVector<Value> args{a, b, c};
381
382 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
383 /*other=*/LLVM::ModRefInfo::NoModRef,
384 /*argMem=*/LLVM::ModRefInfo::NoModRef,
385 /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
386 /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
387 /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
388 /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
389 auto funcAttrs = convergentNoUnwindWillReturnAttrs;
390 funcAttrs.memEffectsAttr = memAttr;
391 Value result =
392 createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
393 funcAttrs, op.getOperation())
394 ->getResult(0);
395
396 if (resOrigTy != resTy)
397 result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy, result);
398
399 rewriter.replaceOp(op, result);
400 return success();
401 }
402
403private:
404 static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
405 switch (pTy) {
406 case xevm::ElemType::TF32:
407 return 1;
408 case xevm::ElemType::BF16:
409 case xevm::ElemType::F16:
410 return 2;
411 case xevm::ElemType::U8:
412 case xevm::ElemType::S8:
413 return 4;
414 default:
415 llvm_unreachable("unsupported xevm::ElemType");
416 }
417 }
418};
419
420class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
421 using OpConversionPattern::OpConversionPattern;
422 LogicalResult
423 matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor,
424 ConversionPatternRewriter &rewriter) const override {
425 auto loc = op.getLoc();
426 const std::string fnName{"_Z8prefetchPU3AS1Kcm"};
427 Value one =
428 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), 1);
429 SmallVector<Value> args{op.getPtr(), one};
430 SmallVector<Type> argTypes;
431 for (auto arg : args)
432 argTypes.push_back(arg.getType());
433 auto funcAttr = noUnwindAttrs;
434 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
435 /*other=*/LLVM::ModRefInfo::NoModRef,
436 /*argMem=*/LLVM::ModRefInfo::Ref,
437 /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
438 /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
439 /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
440 /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
441 funcAttr.memEffectsAttr = memAttr;
442
443 LLVM::CallOp call = createDeviceFunctionCall(
444 rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()),
445 argTypes, args, {}, funcAttr, op.getOperation());
446 if (std::optional<ArrayAttr> optCacheControls =
447 getCacheControlMetadata(rewriter, op))
448 call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
449 rewriter.eraseOp(op);
450 return success();
451 }
452};
453
454class MemfenceToOCLPattern : public OpConversionPattern<MemfenceOp> {
455 using OpConversionPattern::OpConversionPattern;
456 LogicalResult
457 matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor,
458 ConversionPatternRewriter &rewriter) const override {
459 auto loc = op.getLoc();
460 const std::string fnName{"atomic_work_item_fence"};
461 int memScope, addrSpace;
462 switch (op.getAddrspace()) {
463 case xevm::AddrSpace::SHARED:
464 addrSpace = 1; // CLK_LOCAL_MEM_FENCE
465 break;
466 case xevm::AddrSpace::GLOBAL:
467 addrSpace = 2; // CLK_GLOBAL_MEM_FENCE
468 break;
469 default:
470 // GENERIC is not supported in OpenCL
471 return rewriter.notifyMatchFailure(
472 op, "Fence only supports global and shared address spaces.");
473 }
474 switch (op.getScope()) {
475 case xevm::MemScope::WORKGROUP:
476 memScope = 1;
477 break;
478 case xevm::MemScope::DEVICE:
479 memScope = 2;
480 break;
481 default:
482 // CLUSTER and SYSTEM are not supported in OpenCL
483 return rewriter.notifyMatchFailure(
484 op, "Fence only supports workgroup and device memory scopes.");
485 }
486 Type i32Type = rewriter.getI32Type();
487 Value acqRel = LLVM::ConstantOp::create(rewriter, loc, i32Type, 4);
488 Value memScopeConst =
489 LLVM::ConstantOp::create(rewriter, loc, i32Type, memScope);
490 Value addrSpaceConst =
491 LLVM::ConstantOp::create(rewriter, loc, i32Type, addrSpace);
492 SmallVector<Value> args{addrSpaceConst, acqRel, memScopeConst};
493 SmallVector<Type> argTypes{3, i32Type};
494 createDeviceFunctionCall(rewriter, mangle(fnName, argTypes),
495 LLVM::LLVMVoidType::get(rewriter.getContext()),
496 argTypes, args, {}, noUnwindAttrs,
497 op.getOperation());
498 rewriter.eraseOp(op);
499 return success();
500 }
501};
502template <typename OpType>
503class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
504 using OpConversionPattern<OpType>::OpConversionPattern;
505 LogicalResult
506 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
507 ConversionPatternRewriter &rewriter) const override {
508 constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp>;
509 constexpr bool isPrefetch = std::is_same_v<OpType, BlockPrefetch2dOp>;
510
511 auto loc = op.getLoc();
512 VectorType vecType;
513 bool packReg = false;
514 bool transpose = false;
515 if constexpr (isLoad) {
516 vecType = op.getRes().getType();
517 packReg = op.getPackRegister();
518 transpose = op.getTranspose();
519 } else if constexpr (!isPrefetch) {
520 vecType = op.getStoredVal().getType();
521 }
522
523 auto i32Type = rewriter.getI32Type();
524 Value byteCoord =
525 LLVM::UndefOp::create(rewriter, loc, VectorType::get(2, i32Type));
526 Value zero = LLVM::ConstantOp::create(rewriter, loc, i32Type, 0);
527 Value one = LLVM::ConstantOp::create(rewriter, loc, i32Type, 1);
528 byteCoord = LLVM::InsertElementOp::create(
529 rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
530 byteCoord = LLVM::InsertElementOp::create(
531 rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
532 SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
533 op.getBasePitch(), byteCoord};
534 SmallVector<Type> retTypes;
535 Value spvLoadDstPtr;
536 std::string funcName{"intel_sub_group_2d_block_"};
537 std::string bitWidthId;
538 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
539 SmallVector<std::pair<unsigned, StringRef>, 4> paramAttrs;
540 if constexpr (isPrefetch) { // Prefetch
541 funcName += "prefetch";
542 paramAttrs = {std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName())};
543 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
544 /*other=*/LLVM::ModRefInfo::NoModRef,
545 /*argMem=*/LLVM::ModRefInfo::Ref,
546 /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
547 /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
548 /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
549 /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
550 funcAttr = noUnwindAttrs;
551 funcAttr.memEffectsAttr = memAttr;
552 } else {
553 auto vecElemType = vecType.getElementType();
554 auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth();
555 Value numElems = LLVM::ConstantOp::create(rewriter, loc, i32Type,
556 vecType.getNumElements());
557 auto dstOrSrcPtr = LLVM::AllocaOp::create(
558 rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()),
559 vecElemType, numElems);
560 args.push_back(dstOrSrcPtr);
561 if constexpr (isLoad) { // Load
562 funcName += "read";
563 bitWidthId = getTypeMangling(vecElemType, /*isUnsigned=*/true);
564 if (packReg)
565 funcName += "_transform";
566 else if (transpose)
567 funcName += "_transpose";
568 spvLoadDstPtr = dstOrSrcPtr;
569 retTypes.push_back(vecType);
570 paramAttrs = {
571 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
572 std::make_pair(0, LLVM::LLVMDialect::getReadonlyAttrName()),
573 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
574 std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()),
575 };
576 } else { // Store
577 funcName += "write";
578 bitWidthId = (vecElemBitWidth == 32)
579 ? "j"
580 : ((vecElemBitWidth == 16) ? "t" : "h");
581 LLVM::StoreOp::create(rewriter, loc, op.getStoredVal(), dstOrSrcPtr);
582 paramAttrs = {
583 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
584 std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()),
585 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
586 std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()),
587 };
588 }
589 }
590
591 funcName =
592 llvm::formatv("{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(),
593 op.getTileHeight(), op.getTileWidth(), op.getVBlocks())
594 .str();
595 std::string prefetchCode("");
596 if (!isPrefetch)
597 prefetchCode += "P";
598 funcName = llvm::formatv("_Z{0}{1}PU3AS1viiiDv2_i{2}{3}", funcName.size(),
599 funcName, prefetchCode, bitWidthId)
600 .str();
601 SmallVector<Type> argTypes;
602 for (auto arg : args) {
603 argTypes.push_back(arg.getType());
604 }
605 LLVM::CallOp call = createDeviceFunctionCall(
606 rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
607 argTypes, args, paramAttrs, funcAttr, op.getOperation());
608 if (std::optional<ArrayAttr> optCacheControls =
609 getCacheControlMetadata(rewriter, op)) {
610 call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
611 }
612 if constexpr (isLoad)
613 rewriter.replaceOp(
614 op, LLVM::LoadOp::create(rewriter, loc, vecType, spvLoadDstPtr));
615 else
616 rewriter.eraseOp(op);
617 return success();
618 }
619};
620
621template <typename OpType>
622class BlockLoadStore1DToOCLPattern : public OpConversionPattern<OpType> {
623 using OpConversionPattern<OpType>::OpConversionPattern;
624 LogicalResult
625 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
626 ConversionPatternRewriter &rewriter) const override {
627 constexpr bool isStore = std::is_same_v<OpType, xevm::BlockStoreOp>;
628 // Get OpenCL function name
629 // https://registry.khronos.org/OpenCL/extensions/
630 // intel/cl_intel_subgroup_local_block_io.html
631 std::string funcName{"intel_sub_group_block_"};
632 // Value or Result type can be vector or scalar
633 Type valOrResTy;
634 if constexpr (isStore) {
635 funcName += "write_u";
636 valOrResTy = op.getVal().getType();
637 } else {
638 funcName += "read_u";
639 valOrResTy = op.getType();
640 }
641 // Get element type of the vector/scalar
642 VectorType vecTy = dyn_cast<VectorType>(valOrResTy);
643 Type elemType = vecTy ? vecTy.getElementType() : valOrResTy;
644 funcName += getTypeMangling(elemType);
645 if (vecTy)
646 funcName += std::to_string(vecTy.getNumElements());
647 SmallVector<Type, 2> argTypes{};
648 // XeVM BlockLoad/StoreOp always use signless integer types
649 // but OpenCL builtins expect unsigned types
650 // use unsigned types for mangling
651 SmallVector<bool, 2> isUnsigned{};
652 // arg0: pointer to the src/dst address
653 // arg1 - only if store : vector to store
654 // Prepare arguments
655 SmallVector<Value, 2> args{};
656 args.push_back(op.getPtr());
657 argTypes.push_back(op.getPtr().getType());
658 isUnsigned.push_back(true);
659 Type retType;
660 if constexpr (isStore) {
661 args.push_back(op.getVal());
662 argTypes.push_back(op.getVal().getType());
663 isUnsigned.push_back(true);
664 retType = LLVM::LLVMVoidType::get(rewriter.getContext());
665 } else {
666 retType = valOrResTy;
667 }
668 funcName = std::string("_Z") + std::to_string(funcName.size()) + funcName +
669 "PU3AS" +
670 std::to_string(op.getPtr().getType().getAddressSpace());
671 funcName += getTypeMangling(elemType, /*isUnsigned=*/true);
672 if constexpr (isStore)
673 funcName += getTypeMangling(valOrResTy, /*isUnsigned=*/true);
674 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
675
676 LLVM::CallOp call =
677 createDeviceFunctionCall(rewriter, funcName, retType, argTypes, args,
678 {}, funcAttr, op.getOperation());
679 if (std::optional<ArrayAttr> optCacheControls =
680 getCacheControlMetadata(rewriter, op)) {
681 call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
682 }
683 if constexpr (isStore)
684 rewriter.eraseOp(op);
685 else
686 rewriter.replaceOp(op, call->getResult(0));
687 return success();
688 }
689};
690
691template <typename OpType>
692class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> {
693 using OpConversionPattern<OpType>::OpConversionPattern;
694 LogicalResult
695 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
696 ConversionPatternRewriter &rewriter) const override {
697 if (!op->hasAttr("cache_control"))
698 return failure();
699 std::optional<ArrayAttr> optCacheControls =
700 getCacheControlMetadata(rewriter, op);
701 op->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
702 op->removeAttr("cache_control");
703 return success();
704 }
705};
706
707//===----------------------------------------------------------------------===//
708// GPU index id operations
709//===----------------------------------------------------------------------===//
710/*
711// Launch Config ops
712// dimidx - x, y, z - is fixed to i32
713// return type is set by XeVM type converter
714// get_local_id
715xevm::WorkitemIdXOp;
716xevm::WorkitemIdYOp;
717xevm::WorkitemIdZOp;
718// get_local_size
719xevm::WorkgroupDimXOp;
720xevm::WorkgroupDimYOp;
721xevm::WorkgroupDimZOp;
722// get_group_id
723xevm::WorkgroupIdXOp;
724xevm::WorkgroupIdYOp;
725xevm::WorkgroupIdZOp;
726// get_num_groups
727xevm::GridDimXOp;
728xevm::GridDimYOp;
729xevm::GridDimZOp;
730// get_global_id : to be added if needed
731*/
732
733// Helpers to get the OpenCL function name and dimension argument for each op.
734static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdXOp) {
735 return {"get_local_id", 0};
736}
737static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdYOp) {
738 return {"get_local_id", 1};
739}
740static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdZOp) {
741 return {"get_local_id", 2};
742}
743static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimXOp) {
744 return {"get_local_size", 0};
745}
746static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimYOp) {
747 return {"get_local_size", 1};
748}
749static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimZOp) {
750 return {"get_local_size", 2};
751}
752static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdXOp) {
753 return {"get_group_id", 0};
754}
755static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdYOp) {
756 return {"get_group_id", 1};
757}
758static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdZOp) {
759 return {"get_group_id", 2};
760}
761static std::pair<StringRef, int64_t> getConfig(xevm::GridDimXOp) {
762 return {"get_num_groups", 0};
763}
764static std::pair<StringRef, int64_t> getConfig(xevm::GridDimYOp) {
765 return {"get_num_groups", 1};
766}
767static std::pair<StringRef, int64_t> getConfig(xevm::GridDimZOp) {
768 return {"get_num_groups", 2};
769}
770/// Replace `xevm.*` with an `llvm.call` to the corresponding OpenCL func with
771/// a constant argument for the dimension - x, y or z.
772template <typename OpType>
773class LaunchConfigOpToOCLPattern : public OpConversionPattern<OpType> {
774 using OpConversionPattern<OpType>::OpConversionPattern;
775 LogicalResult
776 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
777 ConversionPatternRewriter &rewriter) const override {
778 Location loc = op->getLoc();
779 auto [baseName, dim] = getConfig(op);
780 Type dimTy = rewriter.getI32Type();
781 Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
782 static_cast<int64_t>(dim));
783 std::string func = mangle(baseName, {dimTy}, {true});
784 Type resTy = op.getType();
785 auto call =
786 createDeviceFunctionCall(rewriter, func, resTy, {dimTy}, {dimVal}, {},
787 noUnwindWillReturnAttrs, op.getOperation());
788 constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
789 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
790 /*other=*/noModRef,
791 /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef,
792 /*errnoMem=*/noModRef,
793 /*targetMem0=*/noModRef,
794 /*targetMem1=*/noModRef);
795 call.setMemoryEffectsAttr(memAttr);
796 rewriter.replaceOp(op, call);
797 return success();
798 }
799};
800
801/*
802// Subgroup ops
803// get_sub_group_local_id
804xevm::LaneIdOp;
805// get_sub_group_id
806xevm::SubgroupIdOp;
807// get_sub_group_size
808xevm::SubgroupSizeOp;
809// get_num_sub_groups : to be added if needed
810*/
811
812// Helpers to get the OpenCL function name for each op.
813static StringRef getConfig(xevm::LaneIdOp) { return "get_sub_group_local_id"; }
814static StringRef getConfig(xevm::SubgroupIdOp) { return "get_sub_group_id"; }
815static StringRef getConfig(xevm::SubgroupSizeOp) {
816 return "get_sub_group_size";
817}
818template <typename OpType>
819class SubgroupOpWorkitemOpToOCLPattern : public OpConversionPattern<OpType> {
820 using OpConversionPattern<OpType>::OpConversionPattern;
821 LogicalResult
822 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
823 ConversionPatternRewriter &rewriter) const override {
824 std::string func = mangle(getConfig(op).str(), {});
825 Type resTy = op.getType();
826 auto call =
827 createDeviceFunctionCall(rewriter, func, resTy, {}, {}, {},
828 noUnwindWillReturnAttrs, op.getOperation());
829 constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
830 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
831 /*other=*/noModRef,
832 /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef,
833 /*errnoMem=*/noModRef,
834 /*targetMem0=*/noModRef,
835 /*targetMem1=*/noModRef);
836 call.setMemoryEffectsAttr(memAttr);
837 rewriter.replaceOp(op, call);
838 return success();
839 }
840};
841
842class AllocaToGlobalPattern : public OpConversionPattern<LLVM::AllocaOp> {
843 using OpConversionPattern::OpConversionPattern;
844 LogicalResult
845 matchAndRewrite(LLVM::AllocaOp op, LLVM::AllocaOp::Adaptor adaptor,
846 ConversionPatternRewriter &rewriter) const override {
847 auto ptrType = cast<LLVM::LLVMPointerType>(op.getType());
848 auto addrSpace = ptrType.getAddressSpace();
849 if (addrSpace != 3)
850 return failure();
851 auto symTable = op->getParentWithTrait<OpTrait::SymbolTable>();
852 if (!symTable)
853 return failure();
854 Block *moduleBody;
855 if (ModuleOp mod = dyn_cast<ModuleOp>(*symTable)) {
856 moduleBody = mod.getBody();
857 } else if (gpu::GPUModuleOp gpuMod =
858 dyn_cast<gpu::GPUModuleOp>(*symTable)) {
859 moduleBody = gpuMod.getBody();
860 } else {
861 return failure();
862 }
863 auto val = op.getArraySize();
864 APInt cst;
865 if (!matchPattern(val, m_ConstantInt(&cst)))
866 return failure();
867 auto loc = op.getLoc();
868 auto globalType = LLVM::LLVMArrayType::get(
869 rewriter.getContext(), op.getElemType(), cst.getZExtValue());
870 LLVM::GlobalOp globalVar;
871 {
872 OpBuilder::InsertionGuard guard(rewriter);
873 rewriter.setInsertionPointToStart(moduleBody);
874 auto alignment = op.getAlignment();
875 globalVar = LLVM::GlobalOp::create(
876 rewriter, loc, globalType, /*isConstant=*/false,
877 /*linkage=*/LLVM::Linkage::Internal,
878 /*name=*/std::string("__global_alloca_") +
879 std::to_string(getNextGlobalIdx()),
880 /*value=*/Attribute(),
881 /*alignment=*/alignment ? *alignment : 0, /*addrSpace=*/addrSpace);
882 }
883 rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, globalVar);
884 return success();
885 }
886
887private:
888 static unsigned getNextGlobalIdx() {
889 static unsigned globalIdx = 0;
890 return globalIdx++;
891 }
892};
893
894static bool isExtractingContiguousSlice(LLVM::ShuffleVectorOp op) {
895 if (op.getV1() != op.getV2())
896 return false;
897 auto maskAttr = op.getMask();
898 int64_t firstIndex = maskAttr[0];
899 for (int64_t i = 1; i < static_cast<int64_t>(maskAttr.size()); ++i) {
900 int64_t index = maskAttr[i];
901 if (index != firstIndex + i)
902 return false;
903 }
904 return true;
905}
906
907// Input vector of a shuffle vector op extracting a contiguous slice is an
908// illegal vector in SPIRV kernel if the vector size is > 16 elements.
909// To legalize this case, keep applying the following transformations until no
910// more match:
911// 1. keep hoisting the shuffle vector op past unary element-wise operations
912// start with fpext, fptrunc and bitcast for now.
913// 2. merge with another shuffle vector op
914// 3. merge with load as a smaller load
915class HandleVectorExtractPattern
916 : public OpRewritePattern<LLVM::ShuffleVectorOp> {
917 using OpRewritePattern<LLVM::ShuffleVectorOp>::OpRewritePattern;
918
919 void initialize() { setHasBoundedRewriteRecursion(); }
920
921 LogicalResult matchAndRewrite(LLVM::ShuffleVectorOp op,
922 PatternRewriter &rewriter) const override {
923
924 if (!isExtractingContiguousSlice(op))
925 return failure();
926
927 auto mask = op.getMask();
928 auto loc = op.getLoc();
929 auto ty = op.getType();
930 // Check source operand to determine rewrite pattern.
931 auto src = op.getV1();
932 // 1. Hoist past unary element-wise operations
933 if (auto srcOp = src.getDefiningOp()) {
934 if (isa<LLVM::FPExtOp>(srcOp) || isa<LLVM::FPTruncOp>(srcOp)) {
935 Value srcInput = srcOp->getOperand(0);
936 // Create new shuffle vector op with unary input as source.
937 auto srcVecTy = dyn_cast<VectorType>(srcInput.getType());
938 auto newShuffleVecTy =
939 VectorType::get(mask.size(), srcVecTy.getElementType());
940 auto newShuffle = LLVM::ShuffleVectorOp::create(
941 rewriter, loc, newShuffleVecTy, srcInput, srcInput, mask);
942 // Create new unary op with new shuffle as input.
943 Value newUnaryOp;
944 if (isa<LLVM::FPExtOp>(srcOp)) {
945 newUnaryOp = LLVM::FPExtOp::create(rewriter, loc, ty, newShuffle);
946 } else {
947 newUnaryOp = LLVM::FPTruncOp::create(rewriter, loc, ty, newShuffle);
948 }
949 rewriter.replaceOp(op, newUnaryOp);
950 } else if (isa<LLVM::BitcastOp>(srcOp)) {
951 Value srcInput = srcOp->getOperand(0);
952 // Create new shuffle vector op with unary input as source.
953 auto srcInputVecTy = dyn_cast<VectorType>(srcInput.getType());
954 auto srcInputSize = srcInputVecTy.getNumElements();
955 auto srcResVecTy = dyn_cast<VectorType>(srcOp->getResult(0).getType());
956 auto srcResSize = srcResVecTy.getNumElements();
957 auto maskSize = static_cast<int32_t>(mask.size());
958 if (srcInputSize > srcResSize) {
959 return failure();
960 }
961 if (srcResSize % srcInputSize != 0) {
962 return failure();
963 }
964 auto maskScale = srcResSize / srcInputSize;
965 if (maskScale != 1) {
966 if (mask[0] % maskScale != 0) {
967 return failure();
968 }
969 // Create a new mask that maps to the source vector
970 SmallVector<int32_t> newMask;
971 int32_t newMaskSize = maskSize / maskScale;
972 int32_t maskStart = mask[0] / maskScale;
973 for (int32_t i = 0; i < newMaskSize; ++i) {
974 newMask.push_back(maskStart + i);
975 }
976 mask = newMask;
977 }
978 auto newShuffleVecTy =
979 VectorType::get(srcInputSize, srcInputVecTy.getElementType());
980 auto newShuffle = LLVM::ShuffleVectorOp::create(
981 rewriter, loc, newShuffleVecTy, srcInput, srcInput, mask);
982 // Create new unary op with new shuffle as input.
983 auto newBitcast =
984 LLVM::BitcastOp::create(rewriter, loc, ty, newShuffle);
985 rewriter.replaceOp(op, newBitcast);
986 } else if (isa<LLVM::ShuffleVectorOp>(srcOp)) {
987 // 2. Merge with another shuffle vector op
988 auto srcShuffle = cast<LLVM::ShuffleVectorOp>(srcOp);
989 auto srcMask = srcShuffle.getMask();
990 SmallVector<int32_t> combinedMask;
991 for (auto index : mask) {
992 combinedMask.push_back(srcMask[index]);
993 }
994 auto newShuffle = LLVM::ShuffleVectorOp::create(
995 rewriter, loc, ty, srcShuffle.getV1(), srcShuffle.getV1(),
996 DenseI32ArrayAttr::get(rewriter.getContext(), combinedMask));
997 rewriter.replaceOp(op, newShuffle);
998 } else if (isa<LLVM::LoadOp>(srcOp)) {
999 // 3. Merge with load as a smaller load
1000 auto loadOp = cast<LLVM::LoadOp>(srcOp);
1001 auto loadPtr = loadOp.getAddr();
1002 auto loadTy = dyn_cast<VectorType>(loadOp.getType());
1003 auto elemTy = loadTy.getElementType();
1004 auto firstIndex = mask[0];
1005 auto newVecTy = VectorType::get(mask.size(), elemTy);
1006 // GEPOp is needed if first index is not zero
1007 if (firstIndex) {
1008 auto newPtr = LLVM::GEPOp::create(
1009 rewriter, loc,
1010 LLVM::LLVMPointerType::get(rewriter.getContext(),
1011 loadPtr.getType().getAddressSpace()),
1012 elemTy, loadPtr, ArrayRef<LLVM::GEPArg>{firstIndex});
1013 auto newLoad = LLVM::LoadOp::create(rewriter, loc, newVecTy, newPtr);
1014 rewriter.replaceOp(op, newLoad);
1015 } else {
1016 auto newLoad = LLVM::LoadOp::create(rewriter, loc, newVecTy, loadPtr);
1017 rewriter.replaceOp(op, newLoad);
1018 }
1019 } else {
1020 return failure();
1021 }
1022 }
1023 return success();
1024 }
1025};
1026
1027//===----------------------------------------------------------------------===//
1028// Pass Definition
1029//===----------------------------------------------------------------------===//
1030
1031struct ConvertXeVMToLLVMPass
1032 : public impl::ConvertXeVMToLLVMPassBase<ConvertXeVMToLLVMPass> {
1033 using Base::Base;
1034
1035 void getDependentDialects(DialectRegistry &registry) const override {
1036 registry.insert<LLVM::LLVMDialect, XeVMDialect>();
1037 }
1038
1039 void runOnOperation() override {
1040 ConversionTarget target(getContext());
1041 RewritePatternSet patterns(&getContext());
1043 if (failed(applyPartialConversion(getOperation(), target,
1044 std::move(patterns))))
1045 signalPassFailure();
1046
1047 // Apply in-dialect lowerings to handle illegal vectors
1048 {
1049 RewritePatternSet vectorPatterns(&getContext());
1050 vectorPatterns.add<HandleVectorExtractPattern>(&getContext());
1051 GreedyRewriteConfig config{};
1052 // folding can remove ops with temporary attributes used to
1053 // represent LLVM metadata, so disable it here.
1054 // Effectively just this single pattern is applied without any
1055 // op folding patterns from dialects.
1056 config.enableFolding(false);
1057 // config.setMaxIterations(GreedyRewriteConfig::kNoLimit);
1058 // config.setMaxNumRewrites(GreedyRewriteConfig::kNoLimit);
1059 (void)applyPatternsGreedily(getOperation(), std::move(vectorPatterns),
1060 config);
1061 }
1062 }
1063};
1064} // namespace
1065
1066//===----------------------------------------------------------------------===//
1067// Pattern Population
1068//===----------------------------------------------------------------------===//
1069
1070void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target,
1071 RewritePatternSet &patterns) {
1072 // some LLVM operations need to be converted.
1073 target.addDynamicallyLegalDialect<LLVM::LLVMDialect>([](Operation *op) {
1074 // llvm alloca op with addrspace 3 for OpenCL (Workgroup) is not handled
1075 // properly by SPIRV backend. It needs to be rewritten as a sequence with
1076 // llvm global.
1077 if (isa<LLVM::AllocaOp>(op)) {
1078 LLVM::AllocaOp aOp = cast<LLVM::AllocaOp>(op);
1079 LLVM::LLVMPointerType pTy = cast<LLVM::LLVMPointerType>(aOp.getType());
1080 auto addrSpace = pTy.getAddressSpace();
1081 return addrSpace != 3;
1082 }
1083 // cache_control attribute should be converted.
1084 return !op->hasAttr("cache_control");
1085 });
1086 target.addIllegalDialect<XeVMDialect>();
1087 patterns.add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
1088 LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
1089 LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
1090 MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern,
1091 LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
1092 LLVMLoadStoreToOCLPattern<LLVM::StoreOp>,
1093 BlockLoadStore1DToOCLPattern<BlockLoadOp>,
1094 BlockLoadStore1DToOCLPattern<BlockStoreOp>,
1095 LaunchConfigOpToOCLPattern<WorkitemIdXOp>,
1096 LaunchConfigOpToOCLPattern<WorkitemIdYOp>,
1097 LaunchConfigOpToOCLPattern<WorkitemIdZOp>,
1098 LaunchConfigOpToOCLPattern<WorkgroupDimXOp>,
1099 LaunchConfigOpToOCLPattern<WorkgroupDimYOp>,
1100 LaunchConfigOpToOCLPattern<WorkgroupDimZOp>,
1101 LaunchConfigOpToOCLPattern<WorkgroupIdXOp>,
1102 LaunchConfigOpToOCLPattern<WorkgroupIdYOp>,
1103 LaunchConfigOpToOCLPattern<WorkgroupIdZOp>,
1104 LaunchConfigOpToOCLPattern<GridDimXOp>,
1105 LaunchConfigOpToOCLPattern<GridDimYOp>,
1106 LaunchConfigOpToOCLPattern<GridDimZOp>,
1107 SubgroupOpWorkitemOpToOCLPattern<LaneIdOp>,
1108 SubgroupOpWorkitemOpToOCLPattern<SubgroupIdOp>,
1109 SubgroupOpWorkitemOpToOCLPattern<SubgroupSizeOp>,
1110 AllocaToGlobalPattern>(patterns.getContext());
1111}
return success()
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
b getContext())
MLIRContext * getContext() const
Definition Builders.h:56
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:277
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
Type getType() const
Return the type of this value.
Definition Value.h:105
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes={}, Type resultType={}, bool isVarArg=false, bool isReserved=false, SymbolTableCollection *symbolTables=nullptr)
Create a FuncOp with signature resultType(paramTypes) and name name`.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateXeVMToLLVMConversionPatterns(ConversionTarget &target, RewritePatternSet &patterns)
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...