MLIR 23.0.0git
XeVMToLLVM.cpp
Go to the documentation of this file.
1//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
16#include "mlir/Pass/Pass.h"
17#include "mlir/Support/LLVM.h"
18#include "llvm/ADT/ArrayRef.h"
19#include "llvm/Support/FormatVariadic.h"
20
22#include "mlir/IR/Matchers.h"
23#include "mlir/IR/Types.h"
25
26#include "llvm/ADT/TypeSwitch.h"
27
28namespace mlir {
29#define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
30#include "mlir/Conversion/Passes.h.inc"
31} // namespace mlir
32
33using namespace mlir;
34using namespace xevm;
35
36namespace {
37
38struct LLVMFuncAttributeOptions {
39 bool isConvergent = false;
40 bool isNoUnwind = false;
41 bool isWillReturn = false;
42 LLVM::MemoryEffectsAttr memEffectsAttr{};
43};
44static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
45 false, true, false, {}};
46static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
47 false, true, true, {}};
48static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
49 true, true, true, {}};
50
51std::string getTypeMangling(Type ty, bool isUnsigned = false) {
53 .Case([isUnsigned](VectorType ty) -> std::string {
54 return "Dv" + std::to_string(ty.getNumElements()) + "_" +
55 getTypeMangling(ty.getElementType(), isUnsigned);
56 })
57 .Case([](Float16Type) -> std::string { return "Dh"; })
58 .Case([](Float32Type) -> std::string { return "f"; })
59 .Case([](Float64Type) -> std::string { return "d"; })
60 .Case([isUnsigned](IntegerType ty) -> std::string {
61 switch (ty.getWidth()) {
62 case 8:
63 return isUnsigned ? "h" : "c";
64 case 16:
65 return isUnsigned ? "t" : "s";
66 case 32:
67 return isUnsigned ? "j" : "i";
68 case 64:
69 return isUnsigned ? "m" : "l";
70 default:
71 llvm_unreachable("unhandled integer type");
72 }
73 })
74 .DefaultUnreachable("unhandled type for mangling");
75}
76
77std::string mangle(StringRef baseName, ArrayRef<Type> types,
78 ArrayRef<bool> isUnsigned = {}) {
79 assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
80 "Signedness info doesn't match");
81 std::string s;
82 llvm::raw_string_ostream os(s);
83 llvm::SmallDenseMap<Type, unsigned> substitutions;
84 os << "_Z" << baseName.size() << baseName;
85 for (auto [idx, type] : llvm::enumerate(types)) {
86 auto it = substitutions.find(type);
87 if (it != substitutions.end()) {
88 os << "S";
89 // First substitution is `S_`, second is `S0_`, and so on.
90 if (unsigned firstIdx = it->getSecond(); firstIdx > 0)
91 os << firstIdx - 1;
92 os << "_";
93 } else {
94 if (!type.isIntOrFloat())
95 substitutions[type] = substitutions.size();
96 os << getTypeMangling(type, isUnsigned.empty() ? false : isUnsigned[idx]);
97 }
98 }
99 return os.str();
100}
101
102static int32_t getL1CacheControl(LoadCacheControl cc) {
103 int32_t control = 0;
104 switch (cc) {
105 case LoadCacheControl::L1C_L2UC_L3UC:
106 case LoadCacheControl::L1C_L2UC_L3C:
107 case LoadCacheControl::L1C_L2C_L3UC:
108 case LoadCacheControl::L1C_L2C_L3C:
109 control = 1;
110 break;
111 case LoadCacheControl::L1S_L2UC_L3UC:
112 case LoadCacheControl::L1S_L2UC_L3C:
113 case LoadCacheControl::L1S_L2C_L3UC:
114 case LoadCacheControl::L1S_L2C_L3C:
115 control = 2;
116 break;
117 case LoadCacheControl::INVALIDATE_READ:
118 control = 3;
119 break;
120 default:
121 break;
122 }
123 return control;
124}
125
126static int32_t getL1CacheControl(StoreCacheControl cc) {
127 int32_t control = 0;
128 switch (cc) {
129 case StoreCacheControl::L1WT_L2UC_L3UC:
130 case StoreCacheControl::L1WT_L2UC_L3WB:
131 case StoreCacheControl::L1WT_L2WB_L3UC:
132 case StoreCacheControl::L1WT_L2WB_L3WB:
133 control = 1;
134 break;
135 case StoreCacheControl::L1WB_L2UC_L3UC:
136 case StoreCacheControl::L1WB_L2WB_L3UC:
137 case StoreCacheControl::L1WB_L2UC_L3WB:
138 control = 2;
139 break;
140 case StoreCacheControl::L1S_L2UC_L3UC:
141 case StoreCacheControl::L1S_L2UC_L3WB:
142 case StoreCacheControl::L1S_L2WB_L3UC:
143 case StoreCacheControl::L1S_L2WB_L3WB:
144 control = 3;
145 break;
146 default:
147 break;
148 }
149 return control;
150}
151
152static int32_t getL3CacheControl(LoadCacheControl cc) {
153 int32_t control = 0;
154 switch (cc) {
155 case LoadCacheControl::L1UC_L2UC_L3C:
156 case LoadCacheControl::L1UC_L2C_L3C:
157 case LoadCacheControl::L1C_L2UC_L3C:
158 case LoadCacheControl::L1C_L2C_L3C:
159 case LoadCacheControl::L1S_L2UC_L3C:
160 case LoadCacheControl::L1S_L2C_L3C:
161 control = 1;
162 break;
163 case LoadCacheControl::INVALIDATE_READ:
164 control = 3;
165 break;
166 default:
167 break;
168 }
169 return control;
170}
171
172static int32_t getL3CacheControl(StoreCacheControl cc) {
173 int32_t control = 0;
174 switch (cc) {
175 case StoreCacheControl::L1UC_L2UC_L3WB:
176 case StoreCacheControl::L1UC_L2WB_L3WB:
177 case StoreCacheControl::L1WT_L2UC_L3WB:
178 case StoreCacheControl::L1WT_L2WB_L3WB:
179 case StoreCacheControl::L1S_L2UC_L3WB:
180 case StoreCacheControl::L1S_L2WB_L3WB:
181 case StoreCacheControl::L1WB_L2UC_L3WB:
182 control = 2;
183 break;
184 default:
185 break;
186 }
187 return control;
188}
189
190static std::optional<LoadCacheControl> getCacheControl(PrefetchOp op) {
191 return op.getCacheControl();
192}
193
194static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) {
195 return op.getCacheControl();
196}
197
198static std::optional<LoadCacheControl> getCacheControl(BlockLoadOp op) {
199 return op.getCacheControl();
200}
201
202static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) {
203 return op.getCacheControl();
204}
205
206static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) {
207 return op.getCacheControl();
208}
209
210static std::optional<StoreCacheControl> getCacheControl(BlockStoreOp op) {
211 return op.getCacheControl();
212}
213
214static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) {
215 if (op->hasAttr("cache_control")) {
216 auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>("cache_control");
217 if (!attr)
218 return std::nullopt;
219 return std::optional<LoadCacheControl>(attr.getValue());
220 }
221 return std::nullopt;
222}
223
224static std::optional<StoreCacheControl> getCacheControl(LLVM::StoreOp op) {
225 if (op->hasAttr("cache_control")) {
226 auto attr = op->getAttrOfType<xevm::StoreCacheControlAttr>("cache_control");
227 if (!attr)
228 return std::nullopt;
229 return std::optional<StoreCacheControl>(attr.getValue());
230 }
231 return std::nullopt;
232}
233
234template <typename OpType>
235int32_t getL1CacheControl(OpType op) {
236 return getL1CacheControl(*getCacheControl(op));
237}
238
239template <typename OpType>
240int32_t getL3CacheControl(OpType op) {
241 return getL3CacheControl(*getCacheControl(op));
242}
243
244template <typename OpType>
245static std::optional<ArrayAttr>
246getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
247 if (!getCacheControl(op))
248 return {};
249 constexpr int32_t decorationCacheControlArity{3};
250 constexpr int32_t loadCacheControlKey{6442};
251 constexpr int32_t storeCacheControlKey{6443};
252 constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
253 std::is_same_v<OpType, BlockPrefetch2dOp> ||
254 std::is_same_v<OpType, LLVM::LoadOp> ||
255 std::is_same_v<OpType, BlockLoadOp> ||
256 std::is_same_v<OpType, PrefetchOp>;
257 const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
259 controlKey, 0, getL1CacheControl<OpType>(op)};
261 controlKey, 1, getL3CacheControl<OpType>(op)};
262 auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
263 auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
264
265 SmallVector<Attribute, 2> combinedAttrs = {arrayAttrL1, arrayAttrL3};
266 return rewriter.getArrayAttr(combinedAttrs);
267}
268
269//===----------------------------------------------------------------------===//
270// Cache control annotation utilities
271//
272// Instead of attaching cache control as MLIR attributes and handling them
273// during LLVM translation, we directly emit llvm.intr.ptr.annotation op in
274// MLIR.
275//===----------------------------------------------------------------------===//
276
277/// Build one cache-control payload string per attribute.
278///
279/// Each Attribute is expected to be an ArrayAttr of 3 IntegerAttr values:
280/// [SPIR-V decoration token, cache level, cache control value]
281///
282/// A single entry produces a string like: {6442:"0,1"}
283/// where the quote characters (0x22) will appear as \22 in LLVM IR textual
284/// form.
286buildCacheControlPayloads(ArrayRef<Attribute> attrs) {
288 llvm::StringMap<bool> seen;
289
290 for (Attribute a : attrs) {
291 auto arr = dyn_cast<ArrayAttr>(a);
292 if (!arr)
293 continue;
294
295 auto vals = arr.getValue();
296 assert(vals.size() == 3 &&
297 "Expected exactly 3 integer values (Token, CacheLevel, "
298 "ControlValue) in cache control attribute.");
299
300 auto tokenAttr = dyn_cast<IntegerAttr>(vals[0]);
301 auto secondAttr = dyn_cast<IntegerAttr>(vals[1]);
302 auto thirdAttr = dyn_cast<IntegerAttr>(vals[2]);
303
304 if (!tokenAttr || !secondAttr || !thirdAttr)
305 continue;
306
307 // Produce: {SPIR-V decoration token:"L1 cache control,L3 cache control"}
308 // The quote char (0x22) is embedded literally; LLVM IR prints it as \22.
309 std::string entry =
310 llvm::formatv("{{{0}:\"{1},{2}\"}", tokenAttr.getValue().getZExtValue(),
311 secondAttr.getValue().getZExtValue(),
312 thirdAttr.getValue().getZExtValue());
313
314 // Deduplicate identical annotations.
315 if (!seen.insert({entry, true}).second)
316 continue;
317
318 payloads.push_back(std::move(entry));
319 }
320 return payloads;
321}
322/// Counter for generating unique global variable names.
323static std::atomic<uint64_t> globalNameCounter{0};
324
325/// Get or create a global metadata string and return a !llvm.ptr<1> value
326/// pointing to it. The AddressOfOp is created at the current rewriter
327/// insertion point; the GlobalOp is created at the module start.
328static Value createMetadataStringPtr(ConversionPatternRewriter &rewriter,
329 Operation *moduleOp, Location loc,
330 StringRef value, StringRef nameHint) {
331 // Build null-terminated string.
332 std::string strWithNull = value.str();
333 strWithNull.push_back('\0');
334 StringRef strRef(strWithNull.data(), strWithNull.size());
335
336 auto as1PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
337
338 // Search for an existing global with the same content.
339 for (auto &op : moduleOp->getRegion(0).front()) {
340 if (auto existingGlobal = dyn_cast<LLVM::GlobalOp>(&op)) {
341 if (!existingGlobal.getSection() ||
342 *existingGlobal.getSection() != "llvm.metadata")
343 continue;
344 if (auto strAttr =
345 dyn_cast_or_null<StringAttr>(existingGlobal.getValueOrNull())) {
346 if (strAttr.getValue() == strRef) {
347 return LLVM::AddressOfOp::create(rewriter, loc, as1PtrTy,
348 existingGlobal.getSymName());
349 }
350 }
351 }
352 }
353
354 // Create new global at module start.
355 auto i8Type = rewriter.getI8Type();
356 auto arrayType = LLVM::LLVMArrayType::get(i8Type, strWithNull.size());
357 std::string globalName =
358 llvm::formatv("{0}.{1}", nameHint,
359 globalNameCounter.fetch_add(1, std::memory_order_relaxed))
360 .str();
361
362 {
363 OpBuilder::InsertionGuard guard(rewriter);
364 rewriter.setInsertionPointToStart(&moduleOp->getRegion(0).front());
365
366 auto globalOp =
367 LLVM::GlobalOp::create(rewriter, loc, arrayType,
368 /*isConstant=*/true, LLVM::Linkage::Private,
369 globalName, rewriter.getStringAttr(strRef));
370 globalOp.setSection(StringRef("llvm.metadata"));
371 globalOp.setUnnamedAddr(LLVM::UnnamedAddr::Global);
372 globalOp.setAlignment(1);
373 globalOp.setAddrSpace(1);
374 }
375 // InsertionGuard restores the original insertion point here.
376
377 return LLVM::AddressOfOp::create(rewriter, loc, as1PtrTy, globalName);
378}
379
380/// Annotate a pointer value with cache control metadata by emitting chained
381/// `llvm.intr.ptr.annotation` ops (LLVM::PtrAnnotation).
382///
383/// This is the MLIR-level equivalent of handleDecorationCacheControl() from
384/// the LLVM translation layer. For each cache control attribute, it emits:
385///
386/// %ann = llvm.intr.ptr.annotation %ptr, @".str.cachecontrol.N",
387/// @".str.file.N", 0, null : !llvm.ptr<AS>
388///
389/// Multiple annotations are chained: the result of each annotation op is
390/// fed as the pointer input to the next one.
391///
392/// \param rewriter The pattern rewriter.
393/// \param loc Source location for created ops.
394/// \param ptr The pointer value to annotate.
395/// \param cacheControls The cache control ArrayAttr (from
396/// getCacheControlMetadata).
397/// \param moduleOp The enclosing module (for creating globals).
398/// \returns The annotated pointer value (or the original ptr if no
399/// annotations).
400static Value annotatePtrWithCacheControl(ConversionPatternRewriter &rewriter,
401 Location loc, Value ptr,
402 ArrayAttr cacheControls,
403 Operation *moduleOp) {
404 SmallVector<std::string> payloads =
405 buildCacheControlPayloads(cacheControls.getValue());
406 if (payloads.empty())
407 return ptr;
408
409 auto ptrType = cast<LLVM::LLVMPointerType>(ptr.getType());
410 auto as1PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
411 auto i32Ty = rewriter.getI32Type();
412
413 // Create shared constants for all annotations on this pointer.
414 Value fileStr =
415 createMetadataStringPtr(rewriter, moduleOp, loc, "", ".str.file");
416 Value lineVal = LLVM::ConstantOp::create(rewriter, loc, i32Ty, 0);
417 Value nullAS1 = LLVM::ZeroOp::create(rewriter, loc, as1PtrTy);
418
419 // Chain: each annotation takes the result of the previous one as its
420 // pointer operand.
421 Value curPtr = ptr;
422 for (const std::string &payload : payloads) {
423 Value annStr = createMetadataStringPtr(rewriter, moduleOp, loc, payload,
424 ".str.cachecontrol");
425 auto annOp = LLVM::PtrAnnotation::create(rewriter, loc, ptrType, curPtr,
426 annStr, fileStr, lineVal, nullAS1);
427 curPtr = annOp.getResult();
428 }
429
430 return curPtr;
431}
432
433/// Helper to apply cache control annotation on a pointer operand of a call.
434/// Replaces the pointer argument of the call with an annotated version.
435///
436/// For operations that produce a call (like block load/store/prefetch), the
437/// pointer is typically the first argument. This function:
438/// 1. Builds the annotation chain on the pointer.
439/// 2. Replaces the pointer operand in the provided args list.
440///
441/// \param rewriter The pattern rewriter.
442/// \param loc Source location.
443/// \param ptr The original pointer value (first arg to the call).
444/// \param cacheControls The cache control metadata.
445/// \param moduleOp The enclosing module.
446/// \param args The argument list (modified in place: args[ptrIdx] is
447/// replaced).
448/// \param ptrIdx Index of the pointer in the args list (default 0).
449template <typename OpType>
450static void
451applyCacheControlAnnotation(ConversionPatternRewriter &rewriter, Location loc,
452 OpType op, SmallVectorImpl<Value> &args,
453 Operation *moduleOp, unsigned ptrIdx = 0) {
454 std::optional<ArrayAttr> optCacheControls =
455 getCacheControlMetadata(rewriter, op);
456 if (!optCacheControls)
457 return;
458
459 Value annotatedPtr = annotatePtrWithCacheControl(rewriter, loc, args[ptrIdx],
460 *optCacheControls, moduleOp);
461 args[ptrIdx] = annotatedPtr;
462}
463
464//===----------------------------------------------------------------------===//
465// End cache control annotation utilities
466//===----------------------------------------------------------------------===//
467
468static LLVM::CallOp createDeviceFunctionCall(
469 ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
470 ArrayRef<Type> argTypes, ArrayRef<Value> args,
471 mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs,
472 LLVMFuncAttributeOptions funcAttributeOptions, Operation *op) {
473 auto *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
474 assert(moduleOp && "Expecting module");
475 Location loc = op->getLoc();
476
477 auto funcOpRes =
478 LLVM::lookupOrCreateFn(rewriter, moduleOp, funcName, argTypes, retType);
479 assert(!failed(funcOpRes));
480 LLVM::LLVMFuncOp funcOp = funcOpRes.value();
481 funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
482 funcOp.setConvergent(funcAttributeOptions.isConvergent);
483 funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
484 funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
485
486 if (funcAttributeOptions.memEffectsAttr)
487 funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
488
489 for (auto [idx, attrName] : paramAttrs)
490 funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
491
492 auto callOp = LLVM::CallOp::create(rewriter, loc, funcOp, args);
493 callOp->setAttrs(funcOp->getAttrs());
494
495 return callOp;
496}
497
498class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
499 using OpConversionPattern::OpConversionPattern;
500 LogicalResult
501 matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor,
502 ConversionPatternRewriter &rewriter) const override {
503 if (!op.getC()) {
504 return rewriter.notifyMatchFailure(op, "OCL requires C operand");
505 }
506 auto precisionA = op.getTypes().getA();
507 auto precisionB = op.getTypes().getB();
508 auto precisionC = op.getTypes().getC();
509 auto precisionD = op.getTypes().getD();
510 if (precisionC != precisionD) {
511 return rewriter.notifyMatchFailure(op, "type of C and D need to match");
512 }
513 if (precisionC != xevm::ElemType::S32 &&
514 precisionC != xevm::ElemType::F32 &&
515 precisionC != xevm::ElemType::F16 &&
516 precisionC != xevm::ElemType::BF16) {
517 return rewriter.notifyMatchFailure(
518 op, "type of C and D must be S32, F32, F16 or BF16");
519 }
520 if (precisionA == xevm::ElemType::S32 ||
521 precisionA == xevm::ElemType::F32) {
522 return rewriter.notifyMatchFailure(op, "type of A cannot be S32 or F32");
523 }
524 if (precisionB == xevm::ElemType::S32 ||
525 precisionB == xevm::ElemType::F32) {
526 return rewriter.notifyMatchFailure(op, "type of B cannot be S32 or F32");
527 }
528 constexpr uint32_t bitWidthPackedA{16};
529 constexpr uint32_t bitWidthPackedB{32};
530 auto loc = op.getLoc();
531
532 auto castIfNeeded = [&](Value val, Type packedType) -> Value {
533 VectorType origTy = cast<VectorType>(val.getType());
534 const uint32_t vecBitSize =
535 origTy.getNumElements() *
536 origTy.getElementType().getIntOrFloatBitWidth();
537 VectorType newTy = VectorType::get(
538 vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
539 if (origTy != newTy)
540 val = LLVM::BitcastOp::create(rewriter, loc, newTy, val);
541 return val;
542 };
543
544 Value a = op.getA();
545 Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
546 ? cast<Type>(rewriter.getF32Type())
547 : rewriter.getIntegerType(bitWidthPackedA);
548 a = castIfNeeded(a, packedAType);
549
550 Value b = op.getB();
551 Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
552 ? cast<Type>(rewriter.getF32Type())
553 : rewriter.getIntegerType(bitWidthPackedB);
554 b = castIfNeeded(b, packedBType);
555
556 Value c = op.getC();
557 VectorType cOrigTy = cast<VectorType>(c.getType());
558 VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
559 assert(cOrigTy == resOrigTy && "Accumulator and result type mismatch");
560 // OCL builtins encode bfloat16 as int16
561 VectorType cTy =
562 cOrigTy.getElementType().isBF16()
563 ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
564 : cOrigTy;
565 VectorType resTy = cTy;
566 if (cOrigTy != cTy)
567 c = LLVM::BitcastOp::create(rewriter, loc, cTy, c);
568
569 constexpr int32_t systolicDepth{8};
570 std::string fnName =
571 llvm::formatv("intel_sub_group_{0}_{1}_matrix_mad_k{2}",
572 stringifyElemType(op.getTypes().getA()).str(),
573 stringifyElemType(op.getTypes().getB()).str(),
574 systolicDepth *
575 getNumOperandsPerDword(op.getTypes().getA()))
576 .str();
577 SmallVector<Type> argTypes{a.getType(), b.getType(), cTy};
578 fnName = mangle(fnName, argTypes);
579 SmallVector<Value> args{a, b, c};
580
581 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
582 /*other=*/LLVM::ModRefInfo::NoModRef,
583 /*argMem=*/LLVM::ModRefInfo::NoModRef,
584 /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
585 /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
586 /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
587 /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
588 auto funcAttrs = convergentNoUnwindWillReturnAttrs;
589 funcAttrs.memEffectsAttr = memAttr;
590 Value result =
591 createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
592 funcAttrs, op.getOperation())
593 ->getResult(0);
594
595 if (resOrigTy != resTy)
596 result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy, result);
597
598 rewriter.replaceOp(op, result);
599 return success();
600 }
601
602private:
603 static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
604 switch (pTy) {
605 case xevm::ElemType::TF32:
606 return 1;
607 case xevm::ElemType::BF16:
608 case xevm::ElemType::F16:
609 return 2;
610 case xevm::ElemType::U8:
611 case xevm::ElemType::S8:
612 return 4;
613 default:
614 llvm_unreachable("unsupported xevm::ElemType");
615 }
616 }
617};
618
619class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
620 using OpConversionPattern::OpConversionPattern;
621 LogicalResult
622 matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor,
623 ConversionPatternRewriter &rewriter) const override {
624 auto loc = op.getLoc();
625 auto *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
626
627 const std::string fnName{"_Z8prefetchPU3AS1Kcm"};
628 Value one =
629 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), 1);
630 SmallVector<Value> args{op.getPtr(), one};
631
632 // Annotate pointer with cache control before passing to the call.
633 applyCacheControlAnnotation(rewriter, loc, op, args, moduleOp,
634 /*ptrIdx=*/0);
635
636 SmallVector<Type> argTypes;
637 for (auto arg : args)
638 argTypes.push_back(arg.getType());
639 auto funcAttr = noUnwindAttrs;
640 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
641 /*other=*/LLVM::ModRefInfo::NoModRef,
642 /*argMem=*/LLVM::ModRefInfo::Ref,
643 /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
644 /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
645 /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
646 /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
647 funcAttr.memEffectsAttr = memAttr;
648
649 createDeviceFunctionCall(rewriter, fnName,
650 LLVM::LLVMVoidType::get(rewriter.getContext()),
651 argTypes, args, {}, funcAttr, op.getOperation());
652 rewriter.eraseOp(op);
653 return success();
654 }
655};
656
657class MemfenceToOCLPattern : public OpConversionPattern<MemfenceOp> {
658 using OpConversionPattern::OpConversionPattern;
659 LogicalResult
660 matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor,
661 ConversionPatternRewriter &rewriter) const override {
662 auto loc = op.getLoc();
663 const std::string fnName{"atomic_work_item_fence"};
664 int memScope, addrSpace;
665 switch (op.getAddrspace()) {
666 case xevm::AddrSpace::SHARED:
667 addrSpace = 1; // CLK_LOCAL_MEM_FENCE
668 break;
669 case xevm::AddrSpace::GLOBAL:
670 addrSpace = 2; // CLK_GLOBAL_MEM_FENCE
671 break;
672 default:
673 // GENERIC is not supported in OpenCL
674 return rewriter.notifyMatchFailure(
675 op, "Fence only supports global and shared address spaces.");
676 }
677 switch (op.getScope()) {
678 case xevm::MemScope::WORKGROUP:
679 memScope = 1;
680 break;
681 case xevm::MemScope::DEVICE:
682 memScope = 2;
683 break;
684 default:
685 // CLUSTER and SYSTEM are not supported in OpenCL
686 return rewriter.notifyMatchFailure(
687 op, "Fence only supports workgroup and device memory scopes.");
688 }
689 Type i32Type = rewriter.getI32Type();
690 Value acqRel = LLVM::ConstantOp::create(rewriter, loc, i32Type, 4);
691 Value memScopeConst =
692 LLVM::ConstantOp::create(rewriter, loc, i32Type, memScope);
693 Value addrSpaceConst =
694 LLVM::ConstantOp::create(rewriter, loc, i32Type, addrSpace);
695 SmallVector<Value> args{addrSpaceConst, acqRel, memScopeConst};
696 SmallVector<Type> argTypes{3, i32Type};
697 createDeviceFunctionCall(rewriter, mangle(fnName, argTypes),
698 LLVM::LLVMVoidType::get(rewriter.getContext()),
699 argTypes, args, {}, noUnwindAttrs,
700 op.getOperation());
701 rewriter.eraseOp(op);
702 return success();
703 }
704};
705template <typename OpType>
706class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
707 using OpConversionPattern<OpType>::OpConversionPattern;
708 LogicalResult
709 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
710 ConversionPatternRewriter &rewriter) const override {
711 constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp>;
712 constexpr bool isPrefetch = std::is_same_v<OpType, BlockPrefetch2dOp>;
713
714 auto loc = op.getLoc();
715 auto *moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
716 VectorType vecType;
717 bool packReg = false;
718 bool transpose = false;
719 if constexpr (isLoad) {
720 vecType = op.getRes().getType();
721 packReg = op.getPackRegister();
722 transpose = op.getTranspose();
723 } else if constexpr (!isPrefetch) {
724 vecType = op.getStoredVal().getType();
725 }
726
727 auto i32Type = rewriter.getI32Type();
728 Value byteCoord =
729 LLVM::UndefOp::create(rewriter, loc, VectorType::get(2, i32Type));
730 Value zero = LLVM::ConstantOp::create(rewriter, loc, i32Type, 0);
731 Value one = LLVM::ConstantOp::create(rewriter, loc, i32Type, 1);
732 byteCoord = LLVM::InsertElementOp::create(
733 rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
734 byteCoord = LLVM::InsertElementOp::create(
735 rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
736 SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
737 op.getBasePitch(), byteCoord};
738
739 // Annotate pointer (args[0]) with cache control before the call.
740 applyCacheControlAnnotation(rewriter, loc, op, args, moduleOp,
741 /*ptrIdx=*/0);
742
743 SmallVector<Type> retTypes;
744 Value spvLoadDstPtr;
745 std::string funcName{"intel_sub_group_2d_block_"};
746 std::string bitWidthId;
747 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
748 SmallVector<std::pair<unsigned, StringRef>, 4> paramAttrs;
749 if constexpr (isPrefetch) { // Prefetch
750 funcName += "prefetch";
751 paramAttrs = {std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName())};
752 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
753 /*other=*/LLVM::ModRefInfo::NoModRef,
754 /*argMem=*/LLVM::ModRefInfo::Ref,
755 /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
756 /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
757 /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
758 /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
759 funcAttr = noUnwindAttrs;
760 funcAttr.memEffectsAttr = memAttr;
761 } else {
762 auto vecElemType = vecType.getElementType();
763 auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth();
764 Value numElems = LLVM::ConstantOp::create(rewriter, loc, i32Type,
765 vecType.getNumElements());
766 auto dstOrSrcPtr = LLVM::AllocaOp::create(
767 rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()),
768 vecElemType, numElems);
769 args.push_back(dstOrSrcPtr);
770 if constexpr (isLoad) { // Load
771 funcName += "read";
772 bitWidthId = getTypeMangling(vecElemType, /*isUnsigned=*/true);
773 if (packReg)
774 funcName += "_transform";
775 else if (transpose)
776 funcName += "_transpose";
777 spvLoadDstPtr = dstOrSrcPtr;
778 retTypes.push_back(vecType);
779 paramAttrs = {
780 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
781 std::make_pair(0, LLVM::LLVMDialect::getReadonlyAttrName()),
782 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
783 std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()),
784 };
785 } else { // Store
786 funcName += "write";
787 bitWidthId = (vecElemBitWidth == 32)
788 ? "j"
789 : ((vecElemBitWidth == 16) ? "t" : "h");
790 LLVM::StoreOp::create(rewriter, loc, op.getStoredVal(), dstOrSrcPtr);
791 paramAttrs = {
792 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
793 std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()),
794 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
795 std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()),
796 };
797 }
798 }
799
800 funcName =
801 llvm::formatv("{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(),
802 op.getTileHeight(), op.getTileWidth(), op.getVBlocks())
803 .str();
804 std::string prefetchCode("");
805 if (!isPrefetch)
806 prefetchCode += "P";
807 funcName = llvm::formatv("_Z{0}{1}PU3AS1viiiDv2_i{2}{3}", funcName.size(),
808 funcName, prefetchCode, bitWidthId)
809 .str();
810 SmallVector<Type> argTypes;
811 for (auto arg : args) {
812 argTypes.push_back(arg.getType());
813 }
814 createDeviceFunctionCall(
815 rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
816 argTypes, args, paramAttrs, funcAttr, op.getOperation());
817
818 if constexpr (isLoad)
819 rewriter.replaceOp(
820 op, LLVM::LoadOp::create(rewriter, loc, vecType, spvLoadDstPtr));
821 else
822 rewriter.eraseOp(op);
823 return success();
824 }
825};
826
827template <typename OpType>
828class BlockLoadStore1DToOCLPattern : public OpConversionPattern<OpType> {
829 using OpConversionPattern<OpType>::OpConversionPattern;
830 LogicalResult
831 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
832 ConversionPatternRewriter &rewriter) const override {
833 constexpr bool isStore = std::is_same_v<OpType, xevm::BlockStoreOp>;
834 auto loc = op.getLoc();
835 auto *moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
836
837 // Get OpenCL function name
838 // https://registry.khronos.org/OpenCL/extensions/
839 // intel/cl_intel_subgroup_local_block_io.html
840 std::string funcName{"intel_sub_group_block_"};
841 // Value or Result type can be vector or scalar
842 Type valOrResTy;
843 if constexpr (isStore) {
844 funcName += "write_u";
845 valOrResTy = op.getVal().getType();
846 } else {
847 funcName += "read_u";
848 valOrResTy = op.getType();
849 }
850 // Get element type of the vector/scalar
851 VectorType vecTy = dyn_cast<VectorType>(valOrResTy);
852 Type elemType = vecTy ? vecTy.getElementType() : valOrResTy;
853 funcName += getTypeMangling(elemType);
854 if (vecTy)
855 funcName += std::to_string(vecTy.getNumElements());
856 SmallVector<Type, 2> argTypes{};
857 // XeVM BlockLoad/StoreOp always use signless integer types
858 // but OpenCL builtins expect unsigned types
859 // use unsigned types for mangling
860 SmallVector<bool, 2> isUnsigned{};
861 // arg0: pointer to the src/dst address
862 // arg1 - only if store : vector to store
863 // Prepare arguments
864 SmallVector<Value, 2> args{};
865 args.push_back(op.getPtr());
866 argTypes.push_back(op.getPtr().getType());
867 isUnsigned.push_back(true);
868
869 // Annotate pointer (args[0]) with cache control.
870 applyCacheControlAnnotation(rewriter, loc, op, args, moduleOp,
871 /*ptrIdx=*/0);
872 // Update argTypes[0] in case the pointer type changed (it shouldn't
873 // change type, but the value is now the annotated pointer).
874 argTypes[0] = args[0].getType();
875
876 Type retType;
877 if constexpr (isStore) {
878 args.push_back(op.getVal());
879 argTypes.push_back(op.getVal().getType());
880 isUnsigned.push_back(true);
881 retType = LLVM::LLVMVoidType::get(rewriter.getContext());
882 } else {
883 retType = valOrResTy;
884 }
885 funcName = std::string("_Z") + std::to_string(funcName.size()) + funcName +
886 "PU3AS" +
887 std::to_string(op.getPtr().getType().getAddressSpace());
888 funcName += getTypeMangling(elemType, /*isUnsigned=*/true);
889 if constexpr (isStore)
890 funcName += getTypeMangling(valOrResTy, /*isUnsigned=*/true);
891 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
892
893 LLVM::CallOp call =
894 createDeviceFunctionCall(rewriter, funcName, retType, argTypes, args,
895 {}, funcAttr, op.getOperation());
896
897 if constexpr (isStore)
898 rewriter.eraseOp(op);
899 else
900 rewriter.replaceOp(op, call->getResult(0));
901 return success();
902 }
903};
904
905template <typename OpType>
906class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> {
907 using OpConversionPattern<OpType>::OpConversionPattern;
908 LogicalResult
909 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
910 ConversionPatternRewriter &rewriter) const override {
911 if (!op->hasAttr("cache_control"))
912 return failure();
913
914 auto *moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
915 std::optional<ArrayAttr> optCacheControls =
916 getCacheControlMetadata(rewriter, op);
917 if (!optCacheControls) {
918 rewriter.modifyOpInPlace(op, [&]() { op->removeAttr("cache_control"); });
919 return success();
920 }
921
922 // Determine which operand is the pointer.
923 constexpr bool isStore = std::is_same_v<OpType, LLVM::StoreOp>;
924 unsigned ptrIdx = isStore ? 1 : 0;
925 Value ptr = op->getOperand(ptrIdx);
926
927 // Emit annotation intrinsic calls on the pointer.
928 Value annotatedPtr = annotatePtrWithCacheControl(
929 rewriter, op->getLoc(), ptr, *optCacheControls, moduleOp);
930
931 // Replace the pointer operand with the annotated one.
932 rewriter.modifyOpInPlace(op, [&]() {
933 op->setOperand(ptrIdx, annotatedPtr);
934 op->removeAttr("cache_control");
935 });
936 return success();
937 }
938};
939
940//===----------------------------------------------------------------------===//
941// GPU index id operations
942//===----------------------------------------------------------------------===//
943/*
944// Launch Config ops
945// dimidx - x, y, z - is fixed to i32
946// return type is set by XeVM type converter
947// get_local_id
948xevm::WorkitemIdXOp;
949xevm::WorkitemIdYOp;
950xevm::WorkitemIdZOp;
951// get_local_size
952xevm::WorkgroupDimXOp;
953xevm::WorkgroupDimYOp;
954xevm::WorkgroupDimZOp;
955// get_group_id
956xevm::WorkgroupIdXOp;
957xevm::WorkgroupIdYOp;
958xevm::WorkgroupIdZOp;
959// get_num_groups
960xevm::GridDimXOp;
961xevm::GridDimYOp;
962xevm::GridDimZOp;
963// get_global_id : to be added if needed
964*/
965
966// Helpers to get the OpenCL function name and dimension argument for each op.
967static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdXOp) {
968 return {"get_local_id", 0};
969}
970static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdYOp) {
971 return {"get_local_id", 1};
972}
973static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdZOp) {
974 return {"get_local_id", 2};
975}
976static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimXOp) {
977 return {"get_local_size", 0};
978}
979static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimYOp) {
980 return {"get_local_size", 1};
981}
982static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimZOp) {
983 return {"get_local_size", 2};
984}
985static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdXOp) {
986 return {"get_group_id", 0};
987}
988static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdYOp) {
989 return {"get_group_id", 1};
990}
991static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdZOp) {
992 return {"get_group_id", 2};
993}
994static std::pair<StringRef, int64_t> getConfig(xevm::GridDimXOp) {
995 return {"get_num_groups", 0};
996}
997static std::pair<StringRef, int64_t> getConfig(xevm::GridDimYOp) {
998 return {"get_num_groups", 1};
999}
1000static std::pair<StringRef, int64_t> getConfig(xevm::GridDimZOp) {
1001 return {"get_num_groups", 2};
1002}
1003/// Replace `xevm.*` with an `llvm.call` to the corresponding OpenCL func with
1004/// a constant argument for the dimension - x, y or z.
1005template <typename OpType>
1006class LaunchConfigOpToOCLPattern : public OpConversionPattern<OpType> {
1007 using OpConversionPattern<OpType>::OpConversionPattern;
1008 LogicalResult
1009 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
1010 ConversionPatternRewriter &rewriter) const override {
1011 Location loc = op->getLoc();
1012 auto [baseName, dim] = getConfig(op);
1013 Type dimTy = rewriter.getI32Type();
1014 Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
1015 static_cast<int64_t>(dim));
1016 std::string func = mangle(baseName, {dimTy}, {true});
1017 Type resTy = op.getType();
1018 auto call =
1019 createDeviceFunctionCall(rewriter, func, resTy, {dimTy}, {dimVal}, {},
1020 noUnwindWillReturnAttrs, op.getOperation());
1021 constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
1022 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
1023 /*other=*/noModRef,
1024 /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef,
1025 /*errnoMem=*/noModRef,
1026 /*targetMem0=*/noModRef,
1027 /*targetMem1=*/noModRef);
1028 call.setMemoryEffectsAttr(memAttr);
1029 rewriter.replaceOp(op, call);
1030 return success();
1031 }
1032};
1033
1034/*
1035// Subgroup ops
1036// get_sub_group_local_id
1037xevm::LaneIdOp;
1038// get_sub_group_id
1039xevm::SubgroupIdOp;
1040// get_sub_group_size
1041xevm::SubgroupSizeOp;
1042// get_num_sub_groups : to be added if needed
1043*/
1044
1045// Helpers to get the OpenCL function name for each op.
1046static StringRef getConfig(xevm::LaneIdOp) { return "get_sub_group_local_id"; }
1047static StringRef getConfig(xevm::SubgroupIdOp) { return "get_sub_group_id"; }
1048static StringRef getConfig(xevm::SubgroupSizeOp) {
1049 return "get_sub_group_size";
1050}
1051template <typename OpType>
1052class SubgroupOpWorkitemOpToOCLPattern : public OpConversionPattern<OpType> {
1053 using OpConversionPattern<OpType>::OpConversionPattern;
1054 LogicalResult
1055 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
1056 ConversionPatternRewriter &rewriter) const override {
1057 std::string func = mangle(getConfig(op).str(), {});
1058 Type resTy = op.getType();
1059 auto call =
1060 createDeviceFunctionCall(rewriter, func, resTy, {}, {}, {},
1061 noUnwindWillReturnAttrs, op.getOperation());
1062 constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
1063 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
1064 /*other=*/noModRef,
1065 /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef,
1066 /*errnoMem=*/noModRef,
1067 /*targetMem0=*/noModRef,
1068 /*targetMem1=*/noModRef);
1069 call.setMemoryEffectsAttr(memAttr);
1070 rewriter.replaceOp(op, call);
1071 return success();
1072 }
1073};
1074
1075class AllocaToGlobalPattern : public OpConversionPattern<LLVM::AllocaOp> {
1076 using OpConversionPattern::OpConversionPattern;
1077 LogicalResult
1078 matchAndRewrite(LLVM::AllocaOp op, LLVM::AllocaOp::Adaptor adaptor,
1079 ConversionPatternRewriter &rewriter) const override {
1080 auto ptrType = cast<LLVM::LLVMPointerType>(op.getType());
1081 auto addrSpace = ptrType.getAddressSpace();
1082 if (addrSpace != 3)
1083 return failure();
1084 auto symTable = op->getParentWithTrait<OpTrait::SymbolTable>();
1085 if (!symTable)
1086 return failure();
1087 Block *moduleBody;
1088 if (ModuleOp mod = dyn_cast<ModuleOp>(*symTable)) {
1089 moduleBody = mod.getBody();
1090 } else if (gpu::GPUModuleOp gpuMod =
1091 dyn_cast<gpu::GPUModuleOp>(*symTable)) {
1092 moduleBody = gpuMod.getBody();
1093 } else {
1094 return failure();
1095 }
1096 auto val = op.getArraySize();
1097 APInt cst;
1098 if (!matchPattern(val, m_ConstantInt(&cst)))
1099 return failure();
1100 auto loc = op.getLoc();
1101 auto globalType = LLVM::LLVMArrayType::get(
1102 rewriter.getContext(), op.getElemType(), cst.getZExtValue());
1103 LLVM::GlobalOp globalVar;
1104 {
1105 OpBuilder::InsertionGuard guard(rewriter);
1106 rewriter.setInsertionPointToStart(moduleBody);
1107 auto alignment = op.getAlignment();
1108 globalVar = LLVM::GlobalOp::create(
1109 rewriter, loc, globalType, /*isConstant=*/false,
1110 /*linkage=*/LLVM::Linkage::Internal,
1111 /*name=*/std::string("__global_alloca_") +
1112 std::to_string(getNextGlobalIdx()),
1113 /*value=*/Attribute(),
1114 /*alignment=*/alignment ? *alignment : 0, /*addrSpace=*/addrSpace);
1115 }
1116 rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, globalVar);
1117 return success();
1118 }
1119
1120private:
1121 static unsigned getNextGlobalIdx() {
1122 static unsigned globalIdx = 0;
1123 return globalIdx++;
1124 }
1125};
1126
1127// Checks if shufflevector is used as a way to extract a contiguous slice
1128// from a vector.
1129// - source vector V1 and V2 are the same vector.
1130// - mask size is not greater than the source vector size
1131// - mask values represent a sequence of consecutive increasing numbers
1132// that stay in bounds of the source vector when used for indexing.
1133static bool isExtractingContiguousSlice(LLVM::ShuffleVectorOp op) {
1134 if (op.getV1() != op.getV2())
1135 return false;
1136 auto maskAttr = op.getMask();
1137 int64_t maskSize = static_cast<int64_t>(maskAttr.size());
1138 int64_t sourceSize = op.getV1().getType().getNumElements();
1139 if (maskSize > sourceSize)
1140 return false;
1141 int64_t firstIndex = maskAttr[0];
1142 for (int64_t i = 1; i < maskSize; ++i) {
1143 int64_t index = maskAttr[i];
1144 if (index != firstIndex + i)
1145 return false;
1146 if (index >= sourceSize)
1147 return false;
1148 }
1149 return true;
1150}
1151
1152// Input vector of a shuffle vector op extracting a contiguous slice is an
1153// illegal vector in SPIRV kernel if the vector size is > 16 elements.
1154// To legalize this case, keep applying the following transformations until no
1155// more match:
1156// 1. keep hoisting the shuffle vector op past unary element-wise operations
1157// start with fpext, fptrunc and bitcast for now.
1158// 2. merge with another shuffle vector op
1159// 3. merge with load as a smaller load
1160class HandleVectorExtractPattern
1161 : public OpRewritePattern<LLVM::ShuffleVectorOp> {
1162 using OpRewritePattern<LLVM::ShuffleVectorOp>::OpRewritePattern;
1163
1164 void initialize() { setHasBoundedRewriteRecursion(); }
1165
1166 LogicalResult matchAndRewrite(LLVM::ShuffleVectorOp op,
1167 PatternRewriter &rewriter) const override {
1168
1169 if (!isExtractingContiguousSlice(op))
1170 return failure();
1171
1172 auto mask = op.getMask();
1173 auto loc = op.getLoc();
1174 auto ty = op.getType();
1175 // Check source operand to determine rewrite pattern.
1176 auto src = op.getV1();
1177 // 1. Hoist past unary element-wise operations
1178 if (auto srcOp = src.getDefiningOp()) {
1179 if (isa<LLVM::FPExtOp>(srcOp) || isa<LLVM::FPTruncOp>(srcOp)) {
1180 Value srcInput = srcOp->getOperand(0);
1181 // Create new shuffle vector op with unary input as source.
1182 auto srcVecTy = dyn_cast<VectorType>(srcInput.getType());
1183 auto newShuffleVecTy =
1184 VectorType::get(mask.size(), srcVecTy.getElementType());
1185 auto newShuffle = LLVM::ShuffleVectorOp::create(
1186 rewriter, loc, newShuffleVecTy, srcInput, srcInput, mask);
1187 // Create new unary op with new shuffle as input.
1188 Value newUnaryOp;
1189 if (isa<LLVM::FPExtOp>(srcOp)) {
1190 newUnaryOp = LLVM::FPExtOp::create(rewriter, loc, ty, newShuffle);
1191 } else {
1192 newUnaryOp = LLVM::FPTruncOp::create(rewriter, loc, ty, newShuffle);
1193 }
1194 rewriter.replaceOp(op, newUnaryOp);
1195 } else if (isa<LLVM::BitcastOp>(srcOp)) {
1196 Value srcInput = srcOp->getOperand(0);
1197 // Create new shuffle vector op with unary input as source.
1198 auto srcInputVecTy = dyn_cast<VectorType>(srcInput.getType());
1199 auto srcInputSize = srcInputVecTy.getNumElements();
1200 auto srcResVecTy = dyn_cast<VectorType>(srcOp->getResult(0).getType());
1201 auto srcResSize = srcResVecTy.getNumElements();
1202 auto maskSize = static_cast<int32_t>(mask.size());
1203 if (srcInputSize > srcResSize) {
1204 return failure();
1205 }
1206 if (srcResSize % srcInputSize != 0) {
1207 return failure();
1208 }
1209 auto maskScale = srcResSize / srcInputSize;
1210 if (maskScale != 1) {
1211 if (mask[0] % maskScale != 0) {
1212 return failure();
1213 }
1214 // Create a new mask that maps to the source vector
1215 SmallVector<int32_t> newMask;
1216 int32_t newMaskSize = maskSize / maskScale;
1217 int32_t maskStart = mask[0] / maskScale;
1218 for (int32_t i = 0; i < newMaskSize; ++i) {
1219 newMask.push_back(maskStart + i);
1220 }
1221 mask = newMask;
1222 }
1223 auto newShuffleVecTy =
1224 VectorType::get(srcInputSize, srcInputVecTy.getElementType());
1225 auto newShuffle = LLVM::ShuffleVectorOp::create(
1226 rewriter, loc, newShuffleVecTy, srcInput, srcInput, mask);
1227 // Create new unary op with new shuffle as input.
1228 auto newBitcast =
1229 LLVM::BitcastOp::create(rewriter, loc, ty, newShuffle);
1230 rewriter.replaceOp(op, newBitcast);
1231 } else if (isa<LLVM::ShuffleVectorOp>(srcOp)) {
1232 // 2. Merge with source shuffle vector op if, the source op is
1233 // also extracting a contigous slice and create a new
1234 // shuffle vector op directly from the source of
1235 // the first shuffle.
1236 auto srcShuffle = cast<LLVM::ShuffleVectorOp>(srcOp);
1237 if (!isExtractingContiguousSlice(srcShuffle))
1238 return failure();
1239 auto srcMask = srcShuffle.getMask();
1240 SmallVector<int32_t> combinedMask;
1241 for (auto index : mask) {
1242 combinedMask.push_back(srcMask[index]);
1243 }
1244 auto newShuffle = LLVM::ShuffleVectorOp::create(
1245 rewriter, loc, ty, srcShuffle.getV1(), srcShuffle.getV1(),
1246 DenseI32ArrayAttr::get(rewriter.getContext(), combinedMask));
1247 rewriter.replaceOp(op, newShuffle);
1248 } else if (isa<LLVM::LoadOp>(srcOp)) {
1249 // 3. Merge with load as a smaller load
1250 auto loadOp = cast<LLVM::LoadOp>(srcOp);
1251 auto loadPtr = loadOp.getAddr();
1252 auto loadTy = dyn_cast<VectorType>(loadOp.getType());
1253 auto elemTy = loadTy.getElementType();
1254 auto firstIndex = mask[0];
1255 auto newVecTy = VectorType::get(mask.size(), elemTy);
1256 // GEPOp is needed if first index is not zero
1257 if (firstIndex) {
1258 auto newPtr = LLVM::GEPOp::create(
1259 rewriter, loc,
1260 LLVM::LLVMPointerType::get(rewriter.getContext(),
1261 loadPtr.getType().getAddressSpace()),
1262 elemTy, loadPtr, ArrayRef<LLVM::GEPArg>{firstIndex});
1263 auto newLoad = LLVM::LoadOp::create(rewriter, loc, newVecTy, newPtr);
1264 rewriter.replaceOp(op, newLoad);
1265 } else {
1266 auto newLoad = LLVM::LoadOp::create(rewriter, loc, newVecTy, loadPtr);
1267 rewriter.replaceOp(op, newLoad);
1268 }
1269 } else {
1270 return failure();
1271 }
1272 }
1273 return success();
1274 }
1275};
1276
1277//===----------------------------------------------------------------------===//
1278// Pass Definition
1279//===----------------------------------------------------------------------===//
1280
1281struct ConvertXeVMToLLVMPass
1282 : public impl::ConvertXeVMToLLVMPassBase<ConvertXeVMToLLVMPass> {
1283 using Base::Base;
1284
1285 void getDependentDialects(DialectRegistry &registry) const override {
1286 registry.insert<LLVM::LLVMDialect, XeVMDialect>();
1287 }
1288
1289 void runOnOperation() override {
1290 ConversionTarget target(getContext());
1291 RewritePatternSet patterns(&getContext());
1293 if (failed(applyPartialConversion(getOperation(), target,
1294 std::move(patterns))))
1295 signalPassFailure();
1296
1297 // Apply in-dialect lowerings to handle illegal vectors
1298 {
1299 RewritePatternSet vectorPatterns(&getContext());
1300 vectorPatterns.add<HandleVectorExtractPattern>(&getContext());
1301 GreedyRewriteConfig config{};
1302 // folding can remove ops with temporary attributes used to
1303 // represent LLVM metadata, so disable it here.
1304 // Effectively just this single pattern is applied without any
1305 // op folding patterns from dialects.
1306 config.enableFolding(false);
1307 // config.setMaxIterations(GreedyRewriteConfig::kNoLimit);
1308 // config.setMaxNumRewrites(GreedyRewriteConfig::kNoLimit);
1309 (void)applyPatternsGreedily(getOperation(), std::move(vectorPatterns),
1310 config);
1311 }
1312 }
1313};
1314} // namespace
1315
1316//===----------------------------------------------------------------------===//
1317// Pattern Population
1318//===----------------------------------------------------------------------===//
1319
1320void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target,
1321 RewritePatternSet &patterns) {
1322 // some LLVM operations need to be converted.
1323 target.addDynamicallyLegalDialect<LLVM::LLVMDialect>([](Operation *op) {
1324 // llvm alloca op with addrspace 3 for OpenCL (Workgroup) is not handled
1325 // properly by SPIRV backend. It needs to be rewritten as a sequence with
1326 // llvm global.
1327 if (isa<LLVM::AllocaOp>(op)) {
1328 LLVM::AllocaOp aOp = cast<LLVM::AllocaOp>(op);
1329 LLVM::LLVMPointerType pTy = cast<LLVM::LLVMPointerType>(aOp.getType());
1330 auto addrSpace = pTy.getAddressSpace();
1331 return addrSpace != 3;
1332 }
1333 // cache_control attribute should be converted.
1334 return !op->hasAttr("cache_control");
1335 });
1336 target.addIllegalDialect<XeVMDialect>();
1337 patterns.add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
1338 LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
1339 LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
1340 MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern,
1341 LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
1342 LLVMLoadStoreToOCLPattern<LLVM::StoreOp>,
1343 BlockLoadStore1DToOCLPattern<BlockLoadOp>,
1344 BlockLoadStore1DToOCLPattern<BlockStoreOp>,
1345 LaunchConfigOpToOCLPattern<WorkitemIdXOp>,
1346 LaunchConfigOpToOCLPattern<WorkitemIdYOp>,
1347 LaunchConfigOpToOCLPattern<WorkitemIdZOp>,
1348 LaunchConfigOpToOCLPattern<WorkgroupDimXOp>,
1349 LaunchConfigOpToOCLPattern<WorkgroupDimYOp>,
1350 LaunchConfigOpToOCLPattern<WorkgroupDimZOp>,
1351 LaunchConfigOpToOCLPattern<WorkgroupIdXOp>,
1352 LaunchConfigOpToOCLPattern<WorkgroupIdYOp>,
1353 LaunchConfigOpToOCLPattern<WorkgroupIdZOp>,
1354 LaunchConfigOpToOCLPattern<GridDimXOp>,
1355 LaunchConfigOpToOCLPattern<GridDimYOp>,
1356 LaunchConfigOpToOCLPattern<GridDimZOp>,
1357 SubgroupOpWorkitemOpToOCLPattern<LaneIdOp>,
1358 SubgroupOpWorkitemOpToOCLPattern<SubgroupIdOp>,
1359 SubgroupOpWorkitemOpToOCLPattern<SubgroupSizeOp>,
1360 AllocaToGlobalPattern>(patterns.getContext());
1361}
return success()
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
Attributes are known-constant values of operations.
Definition Attributes.h:25
MLIRContext * getContext() const
Definition Builders.h:56
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:712
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:274
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
Block & front()
Definition Region.h:65
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes={}, Type resultType={}, bool isVarArg=false, bool isReserved=false, SymbolTableCollection *symbolTables=nullptr)
Create a FuncOp with signature resultType(paramTypes) and name name`.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateXeVMToLLVMConversionPatterns(ConversionTarget &target, RewritePatternSet &patterns)
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...