MLIR 23.0.0git
XeVMToLLVM.cpp
Go to the documentation of this file.
1//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
16#include "mlir/Pass/Pass.h"
17#include "mlir/Support/LLVM.h"
18#include "llvm/ADT/ArrayRef.h"
19#include "llvm/Support/FormatVariadic.h"
20
22#include "mlir/IR/Matchers.h"
23#include "mlir/IR/Types.h"
25
26#include "llvm/ADT/TypeSwitch.h"
27
28namespace mlir {
29#define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
30#include "mlir/Conversion/Passes.h.inc"
31} // namespace mlir
32
33using namespace mlir;
34using namespace xevm;
35
36namespace {
37
38struct LLVMFuncAttributeOptions {
39 bool isConvergent = false;
40 bool isNoUnwind = false;
41 bool isWillReturn = false;
42 LLVM::MemoryEffectsAttr memEffectsAttr{};
43};
44static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
45 false, true, false, {}};
46static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
47 false, true, true, {}};
48static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
49 true, true, true, {}};
50
51std::string getTypeMangling(Type ty, bool isUnsigned = false) {
53 .Case([isUnsigned](VectorType ty) -> std::string {
54 return "Dv" + std::to_string(ty.getNumElements()) + "_" +
55 getTypeMangling(ty.getElementType(), isUnsigned);
56 })
57 .Case([](Float16Type) -> std::string { return "Dh"; })
58 .Case([](Float32Type) -> std::string { return "f"; })
59 .Case([](Float64Type) -> std::string { return "d"; })
60 .Case([isUnsigned](IntegerType ty) -> std::string {
61 switch (ty.getWidth()) {
62 case 8:
63 return isUnsigned ? "h" : "c";
64 case 16:
65 return isUnsigned ? "t" : "s";
66 case 32:
67 return isUnsigned ? "j" : "i";
68 case 64:
69 return isUnsigned ? "m" : "l";
70 default:
71 llvm_unreachable("unhandled integer type");
72 }
73 })
74 .DefaultUnreachable("unhandled type for mangling");
75}
76
77std::string mangle(StringRef baseName, ArrayRef<Type> types,
78 ArrayRef<bool> isUnsigned = {}) {
79 assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
80 "Signedness info doesn't match");
81 std::string s;
82 llvm::raw_string_ostream os(s);
83 llvm::SmallDenseMap<Type, unsigned> substitutions;
84 os << "_Z" << baseName.size() << baseName;
85 for (auto [idx, type] : llvm::enumerate(types)) {
86 auto it = substitutions.find(type);
87 if (it != substitutions.end()) {
88 os << "S";
89 // First substitution is `S_`, second is `S0_`, and so on.
90 if (unsigned firstIdx = it->getSecond(); firstIdx > 0)
91 os << firstIdx - 1;
92 os << "_";
93 } else {
94 if (!type.isIntOrFloat())
95 substitutions[type] = substitutions.size();
96 os << getTypeMangling(type, isUnsigned.empty() ? false : isUnsigned[idx]);
97 }
98 }
99 return os.str();
100}
101
102static int32_t getL1CacheControl(LoadCacheControl cc) {
103 int32_t control = 0;
104 switch (cc) {
105 case LoadCacheControl::L1C_L2UC_L3UC:
106 case LoadCacheControl::L1C_L2UC_L3C:
107 case LoadCacheControl::L1C_L2C_L3UC:
108 case LoadCacheControl::L1C_L2C_L3C:
109 control = 1;
110 break;
111 case LoadCacheControl::L1S_L2UC_L3UC:
112 case LoadCacheControl::L1S_L2UC_L3C:
113 case LoadCacheControl::L1S_L2C_L3UC:
114 case LoadCacheControl::L1S_L2C_L3C:
115 control = 2;
116 break;
117 case LoadCacheControl::INVALIDATE_READ:
118 control = 3;
119 break;
120 default:
121 break;
122 }
123 return control;
124}
125
126static int32_t getL1CacheControl(StoreCacheControl cc) {
127 int32_t control = 0;
128 switch (cc) {
129 case StoreCacheControl::L1WT_L2UC_L3UC:
130 case StoreCacheControl::L1WT_L2UC_L3WB:
131 case StoreCacheControl::L1WT_L2WB_L3UC:
132 case StoreCacheControl::L1WT_L2WB_L3WB:
133 control = 1;
134 break;
135 case StoreCacheControl::L1WB_L2UC_L3UC:
136 case StoreCacheControl::L1WB_L2WB_L3UC:
137 case StoreCacheControl::L1WB_L2UC_L3WB:
138 control = 2;
139 break;
140 case StoreCacheControl::L1S_L2UC_L3UC:
141 case StoreCacheControl::L1S_L2UC_L3WB:
142 case StoreCacheControl::L1S_L2WB_L3UC:
143 case StoreCacheControl::L1S_L2WB_L3WB:
144 control = 3;
145 break;
146 default:
147 break;
148 }
149 return control;
150}
151
152static int32_t getL3CacheControl(LoadCacheControl cc) {
153 int32_t control = 0;
154 switch (cc) {
155 case LoadCacheControl::L1UC_L2UC_L3C:
156 case LoadCacheControl::L1UC_L2C_L3C:
157 case LoadCacheControl::L1C_L2UC_L3C:
158 case LoadCacheControl::L1C_L2C_L3C:
159 case LoadCacheControl::L1S_L2UC_L3C:
160 case LoadCacheControl::L1S_L2C_L3C:
161 control = 1;
162 break;
163 case LoadCacheControl::INVALIDATE_READ:
164 control = 3;
165 break;
166 default:
167 break;
168 }
169 return control;
170}
171
172static int32_t getL3CacheControl(StoreCacheControl cc) {
173 int32_t control = 0;
174 switch (cc) {
175 case StoreCacheControl::L1UC_L2UC_L3WB:
176 case StoreCacheControl::L1UC_L2WB_L3WB:
177 case StoreCacheControl::L1WT_L2UC_L3WB:
178 case StoreCacheControl::L1WT_L2WB_L3WB:
179 case StoreCacheControl::L1S_L2UC_L3WB:
180 case StoreCacheControl::L1S_L2WB_L3WB:
181 case StoreCacheControl::L1WB_L2UC_L3WB:
182 control = 2;
183 break;
184 default:
185 break;
186 }
187 return control;
188}
189
190static std::optional<LoadCacheControl> getCacheControl(PrefetchOp op) {
191 return op.getCacheControl();
192}
193
194static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) {
195 return op.getCacheControl();
196}
197
198static std::optional<LoadCacheControl> getCacheControl(BlockLoadOp op) {
199 return op.getCacheControl();
200}
201
202static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) {
203 return op.getCacheControl();
204}
205
206static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) {
207 return op.getCacheControl();
208}
209
210static std::optional<StoreCacheControl> getCacheControl(BlockStoreOp op) {
211 return op.getCacheControl();
212}
213
214static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) {
215 if (op->hasAttr("cache_control")) {
216 auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>("cache_control");
217 if (!attr)
218 return std::nullopt;
219 return std::optional<LoadCacheControl>(attr.getValue());
220 }
221 return std::nullopt;
222}
223
224static std::optional<StoreCacheControl> getCacheControl(LLVM::StoreOp op) {
225 if (op->hasAttr("cache_control")) {
226 auto attr = op->getAttrOfType<xevm::StoreCacheControlAttr>("cache_control");
227 if (!attr)
228 return std::nullopt;
229 return std::optional<StoreCacheControl>(attr.getValue());
230 }
231 return std::nullopt;
232}
233
234template <typename OpType>
235int32_t getL1CacheControl(OpType op) {
236 return getL1CacheControl(*getCacheControl(op));
237}
238
239template <typename OpType>
240int32_t getL3CacheControl(OpType op) {
241 return getL3CacheControl(*getCacheControl(op));
242}
243
244template <typename OpType>
245static std::optional<ArrayAttr>
246getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
247 if (!getCacheControl(op))
248 return {};
249 constexpr int32_t decorationCacheControlArity{3};
250 constexpr int32_t loadCacheControlKey{6442};
251 constexpr int32_t storeCacheControlKey{6443};
252 constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
253 std::is_same_v<OpType, BlockPrefetch2dOp> ||
254 std::is_same_v<OpType, LLVM::LoadOp> ||
255 std::is_same_v<OpType, BlockLoadOp> ||
256 std::is_same_v<OpType, PrefetchOp>;
257 const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
259 controlKey, 0, getL1CacheControl<OpType>(op)};
261 controlKey, 1, getL3CacheControl<OpType>(op)};
262 auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
263 auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
264
265 SmallVector<Attribute, 2> combinedAttrs = {arrayAttrL1, arrayAttrL3};
266 return rewriter.getArrayAttr(combinedAttrs);
267}
268
269//===----------------------------------------------------------------------===//
270// Cache control annotation utilities
271//
272// Instead of attaching cache control as MLIR attributes and handling them
273// during LLVM translation, we directly emit llvm.intr.ptr.annotation op in
274// MLIR.
275//===----------------------------------------------------------------------===//
276
277/// Build one cache-control payload string per attribute.
278///
279/// Each Attribute is expected to be an ArrayAttr of 3 IntegerAttr values:
280/// [SPIR-V decoration token, cache level, cache control value]
281///
282/// A single entry produces a string like: {6442:"0,1"}
283/// where the quote characters (0x22) will appear as \22 in LLVM IR textual
284/// form.
286buildCacheControlPayloads(ArrayRef<Attribute> attrs) {
288 llvm::StringMap<bool> seen;
289
290 for (Attribute a : attrs) {
291 auto arr = dyn_cast<ArrayAttr>(a);
292 if (!arr)
293 continue;
294
295 auto vals = arr.getValue();
296 assert(vals.size() == 3 &&
297 "Expected exactly 3 integer values (Token, CacheLevel, "
298 "ControlValue) in cache control attribute.");
299
300 auto tokenAttr = dyn_cast<IntegerAttr>(vals[0]);
301 auto secondAttr = dyn_cast<IntegerAttr>(vals[1]);
302 auto thirdAttr = dyn_cast<IntegerAttr>(vals[2]);
303
304 if (!tokenAttr || !secondAttr || !thirdAttr)
305 continue;
306
307 // Produce: {SPIR-V decoration token:"L1 cache control,L3 cache control"}
308 // The quote char (0x22) is embedded literally; LLVM IR prints it as \22.
309 std::string entry = llvm::formatv("'{'{0}:\"{1},{2}\"'}'",
310 tokenAttr.getValue().getZExtValue(),
311 secondAttr.getValue().getZExtValue(),
312 thirdAttr.getValue().getZExtValue());
313
314 // Deduplicate identical annotations.
315 if (!seen.insert({entry, true}).second)
316 continue;
317
318 payloads.push_back(std::move(entry));
319 }
320 return payloads;
321}
322/// Counter for generating unique global variable names.
323static std::atomic<uint64_t> globalNameCounter{0};
324
325/// Get or create a global metadata string and return a !llvm.ptr<1> value
326/// pointing to it. The AddressOfOp is created at the current rewriter
327/// insertion point; the GlobalOp is created at the module start.
328static Value createMetadataStringPtr(ConversionPatternRewriter &rewriter,
329 Operation *moduleOp, Location loc,
330 StringRef value, StringRef nameHint) {
331 // Build null-terminated string.
332 std::string strWithNull = value.str();
333 strWithNull.push_back('\0');
334 StringRef strRef(strWithNull.data(), strWithNull.size());
335
336 auto as1PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
337
338 // Search for an existing global with the same content.
339 for (auto &op : moduleOp->getRegion(0).front()) {
340 if (auto existingGlobal = dyn_cast<LLVM::GlobalOp>(&op)) {
341 if (!existingGlobal.getSection() ||
342 *existingGlobal.getSection() != "llvm.metadata")
343 continue;
344 if (auto strAttr =
345 dyn_cast_or_null<StringAttr>(existingGlobal.getValueOrNull())) {
346 if (strAttr.getValue() == strRef) {
347 return LLVM::AddressOfOp::create(rewriter, loc, as1PtrTy,
348 existingGlobal.getSymName());
349 }
350 }
351 }
352 }
353
354 // Create new global at module start.
355 auto i8Type = rewriter.getI8Type();
356 auto arrayType = LLVM::LLVMArrayType::get(i8Type, strWithNull.size());
357 std::string globalName =
358 llvm::formatv("{0}.{1}", nameHint,
359 globalNameCounter.fetch_add(1, std::memory_order_relaxed))
360 .str();
361
362 {
363 OpBuilder::InsertionGuard guard(rewriter);
364 rewriter.setInsertionPointToStart(&moduleOp->getRegion(0).front());
365
366 auto globalOp =
367 LLVM::GlobalOp::create(rewriter, loc, arrayType,
368 /*isConstant=*/true, LLVM::Linkage::Private,
369 globalName, rewriter.getStringAttr(strRef));
370 globalOp.setSection(StringRef("llvm.metadata"));
371 globalOp.setUnnamedAddr(LLVM::UnnamedAddr::Global);
372 globalOp.setAlignment(1);
373 globalOp.setAddrSpace(1);
374 }
375 // InsertionGuard restores the original insertion point here.
376
377 return LLVM::AddressOfOp::create(rewriter, loc, as1PtrTy, globalName);
378}
379
380/// Annotate a pointer value with cache control metadata by emitting chained
381/// `llvm.intr.ptr.annotation` ops (LLVM::PtrAnnotation).
382///
383/// This is the MLIR-level equivalent of handleDecorationCacheControl() from
384/// the LLVM translation layer. For each cache control attribute, it emits:
385///
386/// %ann = llvm.intr.ptr.annotation %ptr, @".str.cachecontrol.N",
387/// @".str.file.N", 0, null : !llvm.ptr<AS>
388///
389/// Multiple annotations are chained: the result of each annotation op is
390/// fed as the pointer input to the next one.
391///
392/// \param rewriter The pattern rewriter.
393/// \param loc Source location for created ops.
394/// \param ptr The pointer value to annotate.
395/// \param cacheControls The cache control ArrayAttr (from
396/// getCacheControlMetadata).
397/// \param moduleOp The enclosing module (for creating globals).
398/// \returns The annotated pointer value (or the original ptr if no
399/// annotations).
400static Value annotatePtrWithCacheControl(ConversionPatternRewriter &rewriter,
401 Location loc, Value ptr,
402 ArrayAttr cacheControls,
403 Operation *moduleOp) {
404 SmallVector<std::string> payloads =
405 buildCacheControlPayloads(cacheControls.getValue());
406 if (payloads.empty())
407 return ptr;
408
409 auto ptrType = cast<LLVM::LLVMPointerType>(ptr.getType());
410 auto as1PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
411 auto i32Ty = rewriter.getI32Type();
412
413 // Create shared constants for all annotations on this pointer.
414 Value fileStr =
415 createMetadataStringPtr(rewriter, moduleOp, loc, "", ".str.file");
416 Value lineVal = LLVM::ConstantOp::create(rewriter, loc, i32Ty, 0);
417 Value nullAS1 = LLVM::ZeroOp::create(rewriter, loc, as1PtrTy);
418
419 // Chain: each annotation takes the result of the previous one as its
420 // pointer operand.
421 Value curPtr = ptr;
422 for (const std::string &payload : payloads) {
423 Value annStr = createMetadataStringPtr(rewriter, moduleOp, loc, payload,
424 ".str.cachecontrol");
425 auto annOp = LLVM::PtrAnnotation::create(rewriter, loc, ptrType, curPtr,
426 annStr, fileStr, lineVal, nullAS1);
427 curPtr = annOp.getResult();
428 }
429
430 return curPtr;
431}
432
433/// Helper to apply cache control annotation on a pointer operand of a call.
434/// Replaces the pointer argument of the call with an annotated version.
435///
436/// For operations that produce a call (like block load/store/prefetch), the
437/// pointer is typically the first argument. This function:
438/// 1. Builds the annotation chain on the pointer.
439/// 2. Replaces the pointer operand in the provided args list.
440///
441/// \param rewriter The pattern rewriter.
442/// \param loc Source location.
443/// \param ptr The original pointer value (first arg to the call).
444/// \param cacheControls The cache control metadata.
445/// \param moduleOp The enclosing module.
446/// \param args The argument list (modified in place: args[ptrIdx] is
447/// replaced).
448/// \param ptrIdx Index of the pointer in the args list (default 0).
449template <typename OpType>
450static void
451applyCacheControlAnnotation(ConversionPatternRewriter &rewriter, Location loc,
452 OpType op, SmallVectorImpl<Value> &args,
453 Operation *moduleOp, unsigned ptrIdx = 0) {
454 std::optional<ArrayAttr> optCacheControls =
455 getCacheControlMetadata(rewriter, op);
456 if (!optCacheControls)
457 return;
458
459 Value annotatedPtr = annotatePtrWithCacheControl(rewriter, loc, args[ptrIdx],
460 *optCacheControls, moduleOp);
461 args[ptrIdx] = annotatedPtr;
462}
463
464//===----------------------------------------------------------------------===//
465// End cache control annotation utilities
466//===----------------------------------------------------------------------===//
467
468static LLVM::CallOp createDeviceFunctionCall(
469 ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
470 ArrayRef<Type> argTypes, ArrayRef<Value> args,
471 mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs,
472 LLVMFuncAttributeOptions funcAttributeOptions, Operation *op) {
473 auto *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
474 assert(moduleOp && "Expecting module");
475 Location loc = op->getLoc();
476
477 auto funcOpRes =
478 LLVM::lookupOrCreateFn(rewriter, moduleOp, funcName, argTypes, retType);
479 assert(!failed(funcOpRes));
480 LLVM::LLVMFuncOp funcOp = funcOpRes.value();
481 funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
482 funcOp.setConvergent(funcAttributeOptions.isConvergent);
483 funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
484 funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
485
486 if (funcAttributeOptions.memEffectsAttr)
487 funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
488
489 for (auto [idx, attrName] : paramAttrs)
490 funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
491
492 auto callOp = LLVM::CallOp::create(rewriter, loc, funcOp, args);
493 callOp->setAttrs(funcOp->getAttrs());
494
495 return callOp;
496}
497
498class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
499 using OpConversionPattern::OpConversionPattern;
500 LogicalResult
501 matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor,
502 ConversionPatternRewriter &rewriter) const override {
503 if (!op.getC()) {
504 return rewriter.notifyMatchFailure(op, "OCL requires C operand");
505 }
506 auto precisionA = op.getTypes().getA();
507 auto precisionB = op.getTypes().getB();
508 auto precisionC = op.getTypes().getC();
509 auto precisionD = op.getTypes().getD();
510 if (precisionC != precisionD) {
511 return rewriter.notifyMatchFailure(op, "type of C and D need to match");
512 }
513 if (precisionC != xevm::ElemType::S32 &&
514 precisionC != xevm::ElemType::F32 &&
515 precisionC != xevm::ElemType::F16 &&
516 precisionC != xevm::ElemType::BF16) {
517 return rewriter.notifyMatchFailure(
518 op, "type of C and D must be S32, F32, F16 or BF16");
519 }
520 if (precisionA == xevm::ElemType::S32 ||
521 precisionA == xevm::ElemType::F32) {
522 return rewriter.notifyMatchFailure(op, "type of A cannot be S32 or F32");
523 }
524 if (precisionB == xevm::ElemType::S32 ||
525 precisionB == xevm::ElemType::F32) {
526 return rewriter.notifyMatchFailure(op, "type of B cannot be S32 or F32");
527 }
528 constexpr uint32_t bitWidthPackedA{16};
529 constexpr uint32_t bitWidthPackedB{32};
530 auto loc = op.getLoc();
531
532 auto castIfNeeded = [&](Value val, Type packedType) -> Value {
533 VectorType origTy = cast<VectorType>(val.getType());
534 const uint32_t vecBitSize =
535 origTy.getNumElements() *
536 origTy.getElementType().getIntOrFloatBitWidth();
537 VectorType newTy = VectorType::get(
538 vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
539 if (origTy != newTy)
540 val = LLVM::BitcastOp::create(rewriter, loc, newTy, val);
541 return val;
542 };
543
544 Value a = op.getA();
545 Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
546 ? cast<Type>(rewriter.getF32Type())
547 : rewriter.getIntegerType(bitWidthPackedA);
548 a = castIfNeeded(a, packedAType);
549
550 Value b = op.getB();
551 Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
552 ? cast<Type>(rewriter.getF32Type())
553 : rewriter.getIntegerType(bitWidthPackedB);
554 b = castIfNeeded(b, packedBType);
555
556 Value c = op.getC();
557 VectorType cOrigTy = cast<VectorType>(c.getType());
558 VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
559 assert(cOrigTy == resOrigTy && "Accumulator and result type mismatch");
560 // OCL builtins encode bfloat16 as int16
561 VectorType cTy =
562 cOrigTy.getElementType().isBF16()
563 ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
564 : cOrigTy;
565 VectorType resTy = cTy;
566 if (cOrigTy != cTy)
567 c = LLVM::BitcastOp::create(rewriter, loc, cTy, c);
568
569 constexpr int32_t systolicDepth{8};
570 std::string fnName =
571 llvm::formatv("intel_sub_group_{0}_{1}_matrix_mad_k{2}",
572 stringifyElemType(op.getTypes().getA()).str(),
573 stringifyElemType(op.getTypes().getB()).str(),
574 systolicDepth *
575 getNumOperandsPerDword(op.getTypes().getA()))
576 .str();
577 SmallVector<Type> argTypes{a.getType(), b.getType(), cTy};
578 fnName = mangle(fnName, argTypes);
579 SmallVector<Value> args{a, b, c};
580
581 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
582 /*other=*/LLVM::ModRefInfo::NoModRef,
583 /*argMem=*/LLVM::ModRefInfo::NoModRef,
584 /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
585 /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
586 /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
587 /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
588 auto funcAttrs = convergentNoUnwindWillReturnAttrs;
589 funcAttrs.memEffectsAttr = memAttr;
590 Value result =
591 createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
592 funcAttrs, op.getOperation())
593 ->getResult(0);
594
595 if (resOrigTy != resTy)
596 result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy, result);
597
598 rewriter.replaceOp(op, result);
599 return success();
600 }
601
602private:
603 static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
604 switch (pTy) {
605 case xevm::ElemType::TF32:
606 return 1;
607 case xevm::ElemType::BF16:
608 case xevm::ElemType::F16:
609 return 2;
610 case xevm::ElemType::U8:
611 case xevm::ElemType::S8:
612 return 4;
613 default:
614 llvm_unreachable("unsupported xevm::ElemType");
615 }
616 }
617};
618
619class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
620 using OpConversionPattern::OpConversionPattern;
621 LogicalResult
622 matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor,
623 ConversionPatternRewriter &rewriter) const override {
624 auto loc = op.getLoc();
625 auto *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
626
627 const std::string fnName{"_Z8prefetchPU3AS1Kcm"};
628 Value one =
629 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), 1);
630 SmallVector<Value> args{op.getPtr(), one};
631
632 // Annotate pointer with cache control before passing to the call.
633 applyCacheControlAnnotation(rewriter, loc, op, args, moduleOp,
634 /*ptrIdx=*/0);
635
636 SmallVector<Type> argTypes;
637 for (auto arg : args)
638 argTypes.push_back(arg.getType());
639 auto funcAttr = noUnwindAttrs;
640 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
641 /*other=*/LLVM::ModRefInfo::NoModRef,
642 /*argMem=*/LLVM::ModRefInfo::Ref,
643 /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
644 /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
645 /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
646 /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
647 funcAttr.memEffectsAttr = memAttr;
648
649 createDeviceFunctionCall(rewriter, fnName,
650 LLVM::LLVMVoidType::get(rewriter.getContext()),
651 argTypes, args, {}, funcAttr, op.getOperation());
652 rewriter.eraseOp(op);
653 return success();
654 }
655};
656
657class MemfenceToOCLPattern : public OpConversionPattern<MemfenceOp> {
658 using OpConversionPattern::OpConversionPattern;
659 LogicalResult
660 matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor,
661 ConversionPatternRewriter &rewriter) const override {
662 auto loc = op.getLoc();
663 const std::string fnName{"atomic_work_item_fence"};
664 int memScope, addrSpace;
665 switch (op.getAddrspace()) {
666 case xevm::AddrSpace::SHARED:
667 addrSpace = 1; // CLK_LOCAL_MEM_FENCE
668 break;
669 case xevm::AddrSpace::GLOBAL:
670 addrSpace = 2; // CLK_GLOBAL_MEM_FENCE
671 break;
672 default:
673 // GENERIC is not supported in OpenCL
674 return rewriter.notifyMatchFailure(
675 op, "Fence only supports global and shared address spaces.");
676 }
677 switch (op.getScope()) {
678 case xevm::MemScope::WORKGROUP:
679 memScope = 1;
680 break;
681 case xevm::MemScope::DEVICE:
682 memScope = 2;
683 break;
684 default:
685 // CLUSTER and SYSTEM are not supported in OpenCL
686 return rewriter.notifyMatchFailure(
687 op, "Fence only supports workgroup and device memory scopes.");
688 }
689 Type i32Type = rewriter.getI32Type();
690 Value acqRel = LLVM::ConstantOp::create(rewriter, loc, i32Type, 4);
691 Value memScopeConst =
692 LLVM::ConstantOp::create(rewriter, loc, i32Type, memScope);
693 Value addrSpaceConst =
694 LLVM::ConstantOp::create(rewriter, loc, i32Type, addrSpace);
695 SmallVector<Value> args{addrSpaceConst, acqRel, memScopeConst};
696 SmallVector<Type> argTypes{3, i32Type};
697 createDeviceFunctionCall(rewriter, mangle(fnName, argTypes),
698 LLVM::LLVMVoidType::get(rewriter.getContext()),
699 argTypes, args, {}, noUnwindAttrs,
700 op.getOperation());
701 rewriter.eraseOp(op);
702 return success();
703 }
704};
705template <typename OpType>
706class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
707 using OpConversionPattern<OpType>::OpConversionPattern;
708 LogicalResult
709 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
710 ConversionPatternRewriter &rewriter) const override {
711 constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp>;
712 constexpr bool isPrefetch = std::is_same_v<OpType, BlockPrefetch2dOp>;
713
714 auto loc = op.getLoc();
715 auto *moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
716 VectorType vecType;
717 bool packReg = false;
718 bool transpose = false;
719 if constexpr (isLoad) {
720 vecType = op.getRes().getType();
721 packReg = op.getPackRegister();
722 transpose = op.getTranspose();
723 } else if constexpr (!isPrefetch) {
724 vecType = op.getStoredVal().getType();
725 }
726
727 auto i32Type = rewriter.getI32Type();
728 Value byteCoord =
729 LLVM::UndefOp::create(rewriter, loc, VectorType::get(2, i32Type));
730 Value zero = LLVM::ConstantOp::create(rewriter, loc, i32Type, 0);
731 Value one = LLVM::ConstantOp::create(rewriter, loc, i32Type, 1);
732 byteCoord = LLVM::InsertElementOp::create(
733 rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
734 byteCoord = LLVM::InsertElementOp::create(
735 rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
736 SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
737 op.getBasePitch(), byteCoord};
738
739 // Annotate pointer (args[0]) with cache control before the call.
740 applyCacheControlAnnotation(rewriter, loc, op, args, moduleOp,
741 /*ptrIdx=*/0);
742
743 SmallVector<Type> retTypes;
744 Value spvLoadDstPtr;
745 std::string funcName{"intel_sub_group_2d_block_"};
746 std::string bitWidthId;
747 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
748 SmallVector<std::pair<unsigned, StringRef>, 4> paramAttrs;
749 if constexpr (isPrefetch) { // Prefetch
750 funcName += "prefetch";
751 paramAttrs = {std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName())};
752 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
753 /*other=*/LLVM::ModRefInfo::NoModRef,
754 /*argMem=*/LLVM::ModRefInfo::Ref,
755 /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
756 /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
757 /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
758 /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
759 funcAttr = noUnwindAttrs;
760 funcAttr.memEffectsAttr = memAttr;
761 } else {
762 auto vecElemType = vecType.getElementType();
763 auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth();
764 Value numElems = LLVM::ConstantOp::create(rewriter, loc, i32Type,
765 vecType.getNumElements());
766 auto dstOrSrcPtr = LLVM::AllocaOp::create(
767 rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()),
768 vecElemType, numElems);
769 args.push_back(dstOrSrcPtr);
770 if constexpr (isLoad) { // Load
771 funcName += "read";
772 bitWidthId = getTypeMangling(vecElemType, /*isUnsigned=*/true);
773 if (packReg)
774 funcName += "_transform";
775 else if (transpose)
776 funcName += "_transpose";
777 spvLoadDstPtr = dstOrSrcPtr;
778 retTypes.push_back(vecType);
779 paramAttrs = {
780 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
781 std::make_pair(0, LLVM::LLVMDialect::getReadonlyAttrName()),
782 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
783 std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()),
784 };
785 } else { // Store
786 funcName += "write";
787 bitWidthId = (vecElemBitWidth == 32)
788 ? "j"
789 : ((vecElemBitWidth == 16) ? "t" : "h");
790 LLVM::StoreOp::create(rewriter, loc, op.getStoredVal(), dstOrSrcPtr);
791 paramAttrs = {
792 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
793 std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()),
794 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
795 std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()),
796 };
797 }
798 }
799
800 funcName =
801 llvm::formatv("{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(),
802 op.getTileHeight(), op.getTileWidth(), op.getVBlocks())
803 .str();
804 std::string prefetchCode("");
805 if (!isPrefetch)
806 prefetchCode += "P";
807 funcName = llvm::formatv("_Z{0}{1}PU3AS1viiiDv2_i{2}{3}", funcName.size(),
808 funcName, prefetchCode, bitWidthId)
809 .str();
810 SmallVector<Type> argTypes;
811 for (auto arg : args) {
812 argTypes.push_back(arg.getType());
813 }
814 createDeviceFunctionCall(
815 rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
816 argTypes, args, paramAttrs, funcAttr, op.getOperation());
817
818 if constexpr (isLoad)
819 rewriter.replaceOp(
820 op, LLVM::LoadOp::create(rewriter, loc, vecType, spvLoadDstPtr));
821 else
822 rewriter.eraseOp(op);
823 return success();
824 }
825};
826
827template <typename OpType>
828class BlockLoadStore1DToOCLPattern : public OpConversionPattern<OpType> {
829 using OpConversionPattern<OpType>::OpConversionPattern;
830 LogicalResult
831 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
832 ConversionPatternRewriter &rewriter) const override {
833 constexpr bool isStore = std::is_same_v<OpType, xevm::BlockStoreOp>;
834 auto loc = op.getLoc();
835 auto *moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
836
837 // Get OpenCL function name
838 // https://registry.khronos.org/OpenCL/extensions/
839 // intel/cl_intel_subgroup_local_block_io.html
840 std::string funcName{"intel_sub_group_block_"};
841 // Value or Result type can be vector or scalar
842 Type valOrResTy;
843 if constexpr (isStore) {
844 funcName += "write_u";
845 valOrResTy = op.getVal().getType();
846 } else {
847 funcName += "read_u";
848 valOrResTy = op.getType();
849 }
850 // Get element type of the vector/scalar
851 VectorType vecTy = dyn_cast<VectorType>(valOrResTy);
852 Type elemType = vecTy ? vecTy.getElementType() : valOrResTy;
853 funcName += getTypeMangling(elemType);
854 if (vecTy)
855 funcName += std::to_string(vecTy.getNumElements());
856 SmallVector<Type, 2> argTypes{};
857 // XeVM BlockLoad/StoreOp always use signless integer types
858 // but OpenCL builtins expect unsigned types
859 // use unsigned types for mangling
860 SmallVector<bool, 2> isUnsigned{};
861 // arg0: pointer to the src/dst address
862 // arg1 - only if store : vector to store
863 // Prepare arguments
864 SmallVector<Value, 2> args{};
865 args.push_back(op.getPtr());
866 argTypes.push_back(op.getPtr().getType());
867 isUnsigned.push_back(true);
868
869 // Annotate pointer (args[0]) with cache control.
870 applyCacheControlAnnotation(rewriter, loc, op, args, moduleOp,
871 /*ptrIdx=*/0);
872 // Update argTypes[0] in case the pointer type changed (it shouldn't
873 // change type, but the value is now the annotated pointer).
874 argTypes[0] = args[0].getType();
875
876 Type retType;
877 if constexpr (isStore) {
878 args.push_back(op.getVal());
879 argTypes.push_back(op.getVal().getType());
880 isUnsigned.push_back(true);
881 retType = LLVM::LLVMVoidType::get(rewriter.getContext());
882 } else {
883 retType = valOrResTy;
884 }
885 funcName = std::string("_Z") + std::to_string(funcName.size()) + funcName +
886 "PU3AS" +
887 std::to_string(op.getPtr().getType().getAddressSpace());
888 funcName += getTypeMangling(elemType, /*isUnsigned=*/true);
889 if constexpr (isStore)
890 funcName += getTypeMangling(valOrResTy, /*isUnsigned=*/true);
891 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
892
893 LLVM::CallOp call =
894 createDeviceFunctionCall(rewriter, funcName, retType, argTypes, args,
895 {}, funcAttr, op.getOperation());
896
897 if constexpr (isStore)
898 rewriter.eraseOp(op);
899 else
900 rewriter.replaceOp(op, call->getResult(0));
901 return success();
902 }
903};
904
905template <typename OpType>
906class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> {
907 using OpConversionPattern<OpType>::OpConversionPattern;
908 LogicalResult
909 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
910 ConversionPatternRewriter &rewriter) const override {
911 if (!op->hasAttr("cache_control"))
912 return failure();
913
914 auto *moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
915 std::optional<ArrayAttr> optCacheControls =
916 getCacheControlMetadata(rewriter, op);
917 if (!optCacheControls) {
918 op->removeAttr("cache_control");
919 return success();
920 }
921
922 // Determine which operand is the pointer.
923 constexpr bool isStore = std::is_same_v<OpType, LLVM::StoreOp>;
924 unsigned ptrIdx = isStore ? 1 : 0;
925 Value ptr = op->getOperand(ptrIdx);
926
927 // Emit annotation intrinsic calls on the pointer.
928 Value annotatedPtr = annotatePtrWithCacheControl(
929 rewriter, op->getLoc(), ptr, *optCacheControls, moduleOp);
930
931 // Replace the pointer operand with the annotated one.
932 op->setOperand(ptrIdx, annotatedPtr);
933 op->removeAttr("cache_control");
934 return success();
935 }
936};
937
938//===----------------------------------------------------------------------===//
939// GPU index id operations
940//===----------------------------------------------------------------------===//
941/*
942// Launch Config ops
943// dimidx - x, y, z - is fixed to i32
944// return type is set by XeVM type converter
945// get_local_id
946xevm::WorkitemIdXOp;
947xevm::WorkitemIdYOp;
948xevm::WorkitemIdZOp;
949// get_local_size
950xevm::WorkgroupDimXOp;
951xevm::WorkgroupDimYOp;
952xevm::WorkgroupDimZOp;
953// get_group_id
954xevm::WorkgroupIdXOp;
955xevm::WorkgroupIdYOp;
956xevm::WorkgroupIdZOp;
957// get_num_groups
958xevm::GridDimXOp;
959xevm::GridDimYOp;
960xevm::GridDimZOp;
961// get_global_id : to be added if needed
962*/
963
964// Helpers to get the OpenCL function name and dimension argument for each op.
965static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdXOp) {
966 return {"get_local_id", 0};
967}
968static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdYOp) {
969 return {"get_local_id", 1};
970}
971static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdZOp) {
972 return {"get_local_id", 2};
973}
974static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimXOp) {
975 return {"get_local_size", 0};
976}
977static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimYOp) {
978 return {"get_local_size", 1};
979}
980static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimZOp) {
981 return {"get_local_size", 2};
982}
983static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdXOp) {
984 return {"get_group_id", 0};
985}
986static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdYOp) {
987 return {"get_group_id", 1};
988}
989static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdZOp) {
990 return {"get_group_id", 2};
991}
992static std::pair<StringRef, int64_t> getConfig(xevm::GridDimXOp) {
993 return {"get_num_groups", 0};
994}
995static std::pair<StringRef, int64_t> getConfig(xevm::GridDimYOp) {
996 return {"get_num_groups", 1};
997}
998static std::pair<StringRef, int64_t> getConfig(xevm::GridDimZOp) {
999 return {"get_num_groups", 2};
1000}
1001/// Replace `xevm.*` with an `llvm.call` to the corresponding OpenCL func with
1002/// a constant argument for the dimension - x, y or z.
1003template <typename OpType>
1004class LaunchConfigOpToOCLPattern : public OpConversionPattern<OpType> {
1005 using OpConversionPattern<OpType>::OpConversionPattern;
1006 LogicalResult
1007 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
1008 ConversionPatternRewriter &rewriter) const override {
1009 Location loc = op->getLoc();
1010 auto [baseName, dim] = getConfig(op);
1011 Type dimTy = rewriter.getI32Type();
1012 Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
1013 static_cast<int64_t>(dim));
1014 std::string func = mangle(baseName, {dimTy}, {true});
1015 Type resTy = op.getType();
1016 auto call =
1017 createDeviceFunctionCall(rewriter, func, resTy, {dimTy}, {dimVal}, {},
1018 noUnwindWillReturnAttrs, op.getOperation());
1019 constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
1020 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
1021 /*other=*/noModRef,
1022 /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef,
1023 /*errnoMem=*/noModRef,
1024 /*targetMem0=*/noModRef,
1025 /*targetMem1=*/noModRef);
1026 call.setMemoryEffectsAttr(memAttr);
1027 rewriter.replaceOp(op, call);
1028 return success();
1029 }
1030};
1031
1032/*
1033// Subgroup ops
1034// get_sub_group_local_id
1035xevm::LaneIdOp;
1036// get_sub_group_id
1037xevm::SubgroupIdOp;
1038// get_sub_group_size
1039xevm::SubgroupSizeOp;
1040// get_num_sub_groups : to be added if needed
1041*/
1042
1043// Helpers to get the OpenCL function name for each op.
1044static StringRef getConfig(xevm::LaneIdOp) { return "get_sub_group_local_id"; }
1045static StringRef getConfig(xevm::SubgroupIdOp) { return "get_sub_group_id"; }
1046static StringRef getConfig(xevm::SubgroupSizeOp) {
1047 return "get_sub_group_size";
1048}
1049template <typename OpType>
1050class SubgroupOpWorkitemOpToOCLPattern : public OpConversionPattern<OpType> {
1051 using OpConversionPattern<OpType>::OpConversionPattern;
1052 LogicalResult
1053 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
1054 ConversionPatternRewriter &rewriter) const override {
1055 std::string func = mangle(getConfig(op).str(), {});
1056 Type resTy = op.getType();
1057 auto call =
1058 createDeviceFunctionCall(rewriter, func, resTy, {}, {}, {},
1059 noUnwindWillReturnAttrs, op.getOperation());
1060 constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
1061 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
1062 /*other=*/noModRef,
1063 /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef,
1064 /*errnoMem=*/noModRef,
1065 /*targetMem0=*/noModRef,
1066 /*targetMem1=*/noModRef);
1067 call.setMemoryEffectsAttr(memAttr);
1068 rewriter.replaceOp(op, call);
1069 return success();
1070 }
1071};
1072
1073class AllocaToGlobalPattern : public OpConversionPattern<LLVM::AllocaOp> {
1074 using OpConversionPattern::OpConversionPattern;
1075 LogicalResult
1076 matchAndRewrite(LLVM::AllocaOp op, LLVM::AllocaOp::Adaptor adaptor,
1077 ConversionPatternRewriter &rewriter) const override {
1078 auto ptrType = cast<LLVM::LLVMPointerType>(op.getType());
1079 auto addrSpace = ptrType.getAddressSpace();
1080 if (addrSpace != 3)
1081 return failure();
1082 auto symTable = op->getParentWithTrait<OpTrait::SymbolTable>();
1083 if (!symTable)
1084 return failure();
1085 Block *moduleBody;
1086 if (ModuleOp mod = dyn_cast<ModuleOp>(*symTable)) {
1087 moduleBody = mod.getBody();
1088 } else if (gpu::GPUModuleOp gpuMod =
1089 dyn_cast<gpu::GPUModuleOp>(*symTable)) {
1090 moduleBody = gpuMod.getBody();
1091 } else {
1092 return failure();
1093 }
1094 auto val = op.getArraySize();
1095 APInt cst;
1096 if (!matchPattern(val, m_ConstantInt(&cst)))
1097 return failure();
1098 auto loc = op.getLoc();
1099 auto globalType = LLVM::LLVMArrayType::get(
1100 rewriter.getContext(), op.getElemType(), cst.getZExtValue());
1101 LLVM::GlobalOp globalVar;
1102 {
1103 OpBuilder::InsertionGuard guard(rewriter);
1104 rewriter.setInsertionPointToStart(moduleBody);
1105 auto alignment = op.getAlignment();
1106 globalVar = LLVM::GlobalOp::create(
1107 rewriter, loc, globalType, /*isConstant=*/false,
1108 /*linkage=*/LLVM::Linkage::Internal,
1109 /*name=*/std::string("__global_alloca_") +
1110 std::to_string(getNextGlobalIdx()),
1111 /*value=*/Attribute(),
1112 /*alignment=*/alignment ? *alignment : 0, /*addrSpace=*/addrSpace);
1113 }
1114 rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, globalVar);
1115 return success();
1116 }
1117
1118private:
1119 static unsigned getNextGlobalIdx() {
1120 static unsigned globalIdx = 0;
1121 return globalIdx++;
1122 }
1123};
1124
1125static bool isExtractingContiguousSlice(LLVM::ShuffleVectorOp op) {
1126 if (op.getV1() != op.getV2())
1127 return false;
1128 auto maskAttr = op.getMask();
1129 int64_t firstIndex = maskAttr[0];
1130 for (int64_t i = 1; i < static_cast<int64_t>(maskAttr.size()); ++i) {
1131 int64_t index = maskAttr[i];
1132 if (index != firstIndex + i)
1133 return false;
1134 }
1135 return true;
1136}
1137
1138// Input vector of a shuffle vector op extracting a contiguous slice is an
1139// illegal vector in SPIRV kernel if the vector size is > 16 elements.
1140// To legalize this case, keep applying the following transformations until no
1141// more match:
1142// 1. keep hoisting the shuffle vector op past unary element-wise operations
1143// start with fpext, fptrunc and bitcast for now.
1144// 2. merge with another shuffle vector op
1145// 3. merge with load as a smaller load
1146class HandleVectorExtractPattern
1147 : public OpRewritePattern<LLVM::ShuffleVectorOp> {
1148 using OpRewritePattern<LLVM::ShuffleVectorOp>::OpRewritePattern;
1149
1150 void initialize() { setHasBoundedRewriteRecursion(); }
1151
1152 LogicalResult matchAndRewrite(LLVM::ShuffleVectorOp op,
1153 PatternRewriter &rewriter) const override {
1154
1155 if (!isExtractingContiguousSlice(op))
1156 return failure();
1157
1158 auto mask = op.getMask();
1159 auto loc = op.getLoc();
1160 auto ty = op.getType();
1161 // Check source operand to determine rewrite pattern.
1162 auto src = op.getV1();
1163 // 1. Hoist past unary element-wise operations
1164 if (auto srcOp = src.getDefiningOp()) {
1165 if (isa<LLVM::FPExtOp>(srcOp) || isa<LLVM::FPTruncOp>(srcOp)) {
1166 Value srcInput = srcOp->getOperand(0);
1167 // Create new shuffle vector op with unary input as source.
1168 auto srcVecTy = dyn_cast<VectorType>(srcInput.getType());
1169 auto newShuffleVecTy =
1170 VectorType::get(mask.size(), srcVecTy.getElementType());
1171 auto newShuffle = LLVM::ShuffleVectorOp::create(
1172 rewriter, loc, newShuffleVecTy, srcInput, srcInput, mask);
1173 // Create new unary op with new shuffle as input.
1174 Value newUnaryOp;
1175 if (isa<LLVM::FPExtOp>(srcOp)) {
1176 newUnaryOp = LLVM::FPExtOp::create(rewriter, loc, ty, newShuffle);
1177 } else {
1178 newUnaryOp = LLVM::FPTruncOp::create(rewriter, loc, ty, newShuffle);
1179 }
1180 rewriter.replaceOp(op, newUnaryOp);
1181 } else if (isa<LLVM::BitcastOp>(srcOp)) {
1182 Value srcInput = srcOp->getOperand(0);
1183 // Create new shuffle vector op with unary input as source.
1184 auto srcInputVecTy = dyn_cast<VectorType>(srcInput.getType());
1185 auto srcInputSize = srcInputVecTy.getNumElements();
1186 auto srcResVecTy = dyn_cast<VectorType>(srcOp->getResult(0).getType());
1187 auto srcResSize = srcResVecTy.getNumElements();
1188 auto maskSize = static_cast<int32_t>(mask.size());
1189 if (srcInputSize > srcResSize) {
1190 return failure();
1191 }
1192 if (srcResSize % srcInputSize != 0) {
1193 return failure();
1194 }
1195 auto maskScale = srcResSize / srcInputSize;
1196 if (maskScale != 1) {
1197 if (mask[0] % maskScale != 0) {
1198 return failure();
1199 }
1200 // Create a new mask that maps to the source vector
1201 SmallVector<int32_t> newMask;
1202 int32_t newMaskSize = maskSize / maskScale;
1203 int32_t maskStart = mask[0] / maskScale;
1204 for (int32_t i = 0; i < newMaskSize; ++i) {
1205 newMask.push_back(maskStart + i);
1206 }
1207 mask = newMask;
1208 }
1209 auto newShuffleVecTy =
1210 VectorType::get(srcInputSize, srcInputVecTy.getElementType());
1211 auto newShuffle = LLVM::ShuffleVectorOp::create(
1212 rewriter, loc, newShuffleVecTy, srcInput, srcInput, mask);
1213 // Create new unary op with new shuffle as input.
1214 auto newBitcast =
1215 LLVM::BitcastOp::create(rewriter, loc, ty, newShuffle);
1216 rewriter.replaceOp(op, newBitcast);
1217 } else if (isa<LLVM::ShuffleVectorOp>(srcOp)) {
1218 // 2. Merge with another shuffle vector op
1219 auto srcShuffle = cast<LLVM::ShuffleVectorOp>(srcOp);
1220 auto srcMask = srcShuffle.getMask();
1221 SmallVector<int32_t> combinedMask;
1222 for (auto index : mask) {
1223 combinedMask.push_back(srcMask[index]);
1224 }
1225 auto newShuffle = LLVM::ShuffleVectorOp::create(
1226 rewriter, loc, ty, srcShuffle.getV1(), srcShuffle.getV1(),
1227 DenseI32ArrayAttr::get(rewriter.getContext(), combinedMask));
1228 rewriter.replaceOp(op, newShuffle);
1229 } else if (isa<LLVM::LoadOp>(srcOp)) {
1230 // 3. Merge with load as a smaller load
1231 auto loadOp = cast<LLVM::LoadOp>(srcOp);
1232 auto loadPtr = loadOp.getAddr();
1233 auto loadTy = dyn_cast<VectorType>(loadOp.getType());
1234 auto elemTy = loadTy.getElementType();
1235 auto firstIndex = mask[0];
1236 auto newVecTy = VectorType::get(mask.size(), elemTy);
1237 // GEPOp is needed if first index is not zero
1238 if (firstIndex) {
1239 auto newPtr = LLVM::GEPOp::create(
1240 rewriter, loc,
1241 LLVM::LLVMPointerType::get(rewriter.getContext(),
1242 loadPtr.getType().getAddressSpace()),
1243 elemTy, loadPtr, ArrayRef<LLVM::GEPArg>{firstIndex});
1244 auto newLoad = LLVM::LoadOp::create(rewriter, loc, newVecTy, newPtr);
1245 rewriter.replaceOp(op, newLoad);
1246 } else {
1247 auto newLoad = LLVM::LoadOp::create(rewriter, loc, newVecTy, loadPtr);
1248 rewriter.replaceOp(op, newLoad);
1249 }
1250 } else {
1251 return failure();
1252 }
1253 }
1254 return success();
1255 }
1256};
1257
1258//===----------------------------------------------------------------------===//
1259// Pass Definition
1260//===----------------------------------------------------------------------===//
1261
1262struct ConvertXeVMToLLVMPass
1263 : public impl::ConvertXeVMToLLVMPassBase<ConvertXeVMToLLVMPass> {
1264 using Base::Base;
1265
1266 void getDependentDialects(DialectRegistry &registry) const override {
1267 registry.insert<LLVM::LLVMDialect, XeVMDialect>();
1268 }
1269
1270 void runOnOperation() override {
1271 ConversionTarget target(getContext());
1272 RewritePatternSet patterns(&getContext());
1274 if (failed(applyPartialConversion(getOperation(), target,
1275 std::move(patterns))))
1276 signalPassFailure();
1277
1278 // Apply in-dialect lowerings to handle illegal vectors
1279 {
1280 RewritePatternSet vectorPatterns(&getContext());
1281 vectorPatterns.add<HandleVectorExtractPattern>(&getContext());
1282 GreedyRewriteConfig config{};
1283 // folding can remove ops with temporary attributes used to
1284 // represent LLVM metadata, so disable it here.
1285 // Effectively just this single pattern is applied without any
1286 // op folding patterns from dialects.
1287 config.enableFolding(false);
1288 // config.setMaxIterations(GreedyRewriteConfig::kNoLimit);
1289 // config.setMaxNumRewrites(GreedyRewriteConfig::kNoLimit);
1290 (void)applyPatternsGreedily(getOperation(), std::move(vectorPatterns),
1291 config);
1292 }
1293 }
1294};
1295} // namespace
1296
1297//===----------------------------------------------------------------------===//
1298// Pattern Population
1299//===----------------------------------------------------------------------===//
1300
1301void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target,
1302 RewritePatternSet &patterns) {
1303 // some LLVM operations need to be converted.
1304 target.addDynamicallyLegalDialect<LLVM::LLVMDialect>([](Operation *op) {
1305 // llvm alloca op with addrspace 3 for OpenCL (Workgroup) is not handled
1306 // properly by SPIRV backend. It needs to be rewritten as a sequence with
1307 // llvm global.
1308 if (isa<LLVM::AllocaOp>(op)) {
1309 LLVM::AllocaOp aOp = cast<LLVM::AllocaOp>(op);
1310 LLVM::LLVMPointerType pTy = cast<LLVM::LLVMPointerType>(aOp.getType());
1311 auto addrSpace = pTy.getAddressSpace();
1312 return addrSpace != 3;
1313 }
1314 // cache_control attribute should be converted.
1315 return !op->hasAttr("cache_control");
1316 });
1317 target.addIllegalDialect<XeVMDialect>();
1318 patterns.add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
1319 LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
1320 LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
1321 MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern,
1322 LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
1323 LLVMLoadStoreToOCLPattern<LLVM::StoreOp>,
1324 BlockLoadStore1DToOCLPattern<BlockLoadOp>,
1325 BlockLoadStore1DToOCLPattern<BlockStoreOp>,
1326 LaunchConfigOpToOCLPattern<WorkitemIdXOp>,
1327 LaunchConfigOpToOCLPattern<WorkitemIdYOp>,
1328 LaunchConfigOpToOCLPattern<WorkitemIdZOp>,
1329 LaunchConfigOpToOCLPattern<WorkgroupDimXOp>,
1330 LaunchConfigOpToOCLPattern<WorkgroupDimYOp>,
1331 LaunchConfigOpToOCLPattern<WorkgroupDimZOp>,
1332 LaunchConfigOpToOCLPattern<WorkgroupIdXOp>,
1333 LaunchConfigOpToOCLPattern<WorkgroupIdYOp>,
1334 LaunchConfigOpToOCLPattern<WorkgroupIdZOp>,
1335 LaunchConfigOpToOCLPattern<GridDimXOp>,
1336 LaunchConfigOpToOCLPattern<GridDimYOp>,
1337 LaunchConfigOpToOCLPattern<GridDimZOp>,
1338 SubgroupOpWorkitemOpToOCLPattern<LaneIdOp>,
1339 SubgroupOpWorkitemOpToOCLPattern<SubgroupIdOp>,
1340 SubgroupOpWorkitemOpToOCLPattern<SubgroupSizeOp>,
1341 AllocaToGlobalPattern>(patterns.getContext());
1342}
return success()
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
Attributes are known-constant values of operations.
Definition Attributes.h:25
MLIRContext * getContext() const
Definition Builders.h:56
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:715
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:277
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:244
Block & front()
Definition Region.h:65
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes={}, Type resultType={}, bool isVarArg=false, bool isReserved=false, SymbolTableCollection *symbolTables=nullptr)
Create a FuncOp with signature resultType(paramTypes) and name name`.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateXeVMToLLVMConversionPatterns(ConversionTarget &target, RewritePatternSet &patterns)
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:136
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...