MLIR 23.0.0git
XeVMToLLVM.cpp
Go to the documentation of this file.
1//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
16#include "mlir/Pass/Pass.h"
17#include "mlir/Support/LLVM.h"
18#include "llvm/ADT/ArrayRef.h"
19#include "llvm/Support/FormatVariadic.h"
20
22#include "mlir/IR/Matchers.h"
23#include "mlir/IR/Types.h"
25
26#include "llvm/ADT/TypeSwitch.h"
27
28namespace mlir {
29#define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
30#include "mlir/Conversion/Passes.h.inc"
31} // namespace mlir
32
33using namespace mlir;
34using namespace xevm;
35
36namespace {
37
38struct LLVMFuncAttributeOptions {
39 bool isConvergent = false;
40 bool isNoUnwind = false;
41 bool isWillReturn = false;
42 LLVM::MemoryEffectsAttr memEffectsAttr{};
43};
44static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
45 false, true, false, {}};
46static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
47 false, true, true, {}};
48static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
49 true, true, true, {}};
50
51std::string getTypeMangling(Type ty, bool isUnsigned = false) {
53 .Case([isUnsigned](VectorType ty) -> std::string {
54 return "Dv" + std::to_string(ty.getNumElements()) + "_" +
55 getTypeMangling(ty.getElementType(), isUnsigned);
56 })
57 .Case([](Float16Type) -> std::string { return "Dh"; })
58 .Case([](Float32Type) -> std::string { return "f"; })
59 .Case([](Float64Type) -> std::string { return "d"; })
60 .Case([isUnsigned](IntegerType ty) -> std::string {
61 switch (ty.getWidth()) {
62 case 8:
63 return isUnsigned ? "h" : "c";
64 case 16:
65 return isUnsigned ? "t" : "s";
66 case 32:
67 return isUnsigned ? "j" : "i";
68 case 64:
69 return isUnsigned ? "m" : "l";
70 default:
71 llvm_unreachable("unhandled integer type");
72 }
73 })
74 .DefaultUnreachable("unhandled type for mangling");
75}
76
77std::string mangle(StringRef baseName, ArrayRef<Type> types,
78 ArrayRef<bool> isUnsigned = {}) {
79 assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
80 "Signedness info doesn't match");
81 std::string s;
82 llvm::raw_string_ostream os(s);
83 llvm::SmallDenseMap<Type, unsigned> substitutions;
84 os << "_Z" << baseName.size() << baseName;
85 for (auto [idx, type] : llvm::enumerate(types)) {
86 auto it = substitutions.find(type);
87 if (it != substitutions.end()) {
88 os << "S";
89 // First substitution is `S_`, second is `S0_`, and so on.
90 if (unsigned firstIdx = it->getSecond(); firstIdx > 0)
91 os << firstIdx - 1;
92 os << "_";
93 } else {
94 if (!type.isIntOrFloat())
95 substitutions[type] = substitutions.size();
96 os << getTypeMangling(type, isUnsigned.empty() ? false : isUnsigned[idx]);
97 }
98 }
99 return os.str();
100}
101
102static int32_t getL1CacheControl(LoadCacheControl cc) {
103 int32_t control = 0;
104 switch (cc) {
105 case LoadCacheControl::USE_DEFAULT:
106 control = -1;
107 break;
108 case LoadCacheControl::L1C_L2UC_L3UC:
109 case LoadCacheControl::L1C_L2UC_L3C:
110 case LoadCacheControl::L1C_L2C_L3UC:
111 case LoadCacheControl::L1C_L2C_L3C:
112 control = 1;
113 break;
114 case LoadCacheControl::L1S_L2UC_L3UC:
115 case LoadCacheControl::L1S_L2UC_L3C:
116 case LoadCacheControl::L1S_L2C_L3UC:
117 case LoadCacheControl::L1S_L2C_L3C:
118 control = 2;
119 break;
120 case LoadCacheControl::INVALIDATE_READ:
121 control = 3;
122 break;
123 default:
124 break;
125 }
126 return control;
127}
128
129static int32_t getL1CacheControl(StoreCacheControl cc) {
130 int32_t control = 0;
131 switch (cc) {
132 case StoreCacheControl::USE_DEFAULT:
133 control = -1;
134 break;
135 case StoreCacheControl::L1WT_L2UC_L3UC:
136 case StoreCacheControl::L1WT_L2UC_L3WB:
137 case StoreCacheControl::L1WT_L2WB_L3UC:
138 case StoreCacheControl::L1WT_L2WB_L3WB:
139 control = 1;
140 break;
141 case StoreCacheControl::L1WB_L2UC_L3UC:
142 case StoreCacheControl::L1WB_L2WB_L3UC:
143 case StoreCacheControl::L1WB_L2UC_L3WB:
144 control = 2;
145 break;
146 case StoreCacheControl::L1S_L2UC_L3UC:
147 case StoreCacheControl::L1S_L2UC_L3WB:
148 case StoreCacheControl::L1S_L2WB_L3UC:
149 case StoreCacheControl::L1S_L2WB_L3WB:
150 control = 3;
151 break;
152 default:
153 break;
154 }
155 return control;
156}
157
158static int32_t getL3CacheControl(LoadCacheControl cc) {
159 int32_t control = 0;
160 switch (cc) {
161 case LoadCacheControl::USE_DEFAULT:
162 control = -1;
163 break;
164 case LoadCacheControl::L1UC_L2UC_L3C:
165 case LoadCacheControl::L1UC_L2C_L3C:
166 case LoadCacheControl::L1C_L2UC_L3C:
167 case LoadCacheControl::L1C_L2C_L3C:
168 case LoadCacheControl::L1S_L2UC_L3C:
169 case LoadCacheControl::L1S_L2C_L3C:
170 control = 1;
171 break;
172 case LoadCacheControl::INVALIDATE_READ:
173 control = 3;
174 break;
175 default:
176 break;
177 }
178 return control;
179}
180
181static int32_t getL3CacheControl(StoreCacheControl cc) {
182 int32_t control = 0;
183 switch (cc) {
184 case StoreCacheControl::USE_DEFAULT:
185 control = -1;
186 break;
187 case StoreCacheControl::L1UC_L2UC_L3WB:
188 case StoreCacheControl::L1UC_L2WB_L3WB:
189 case StoreCacheControl::L1WT_L2UC_L3WB:
190 case StoreCacheControl::L1WT_L2WB_L3WB:
191 case StoreCacheControl::L1S_L2UC_L3WB:
192 case StoreCacheControl::L1S_L2WB_L3WB:
193 case StoreCacheControl::L1WB_L2UC_L3WB:
194 control = 2;
195 break;
196 default:
197 break;
198 }
199 return control;
200}
201
202static std::optional<LoadCacheControl> getCacheControl(PrefetchOp op) {
203 return op.getCacheControl();
204}
205
206static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) {
207 return op.getCacheControl();
208}
209
210static std::optional<LoadCacheControl> getCacheControl(BlockLoadOp op) {
211 return op.getCacheControl();
212}
213
214static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) {
215 return op.getCacheControl();
216}
217
218static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) {
219 return op.getCacheControl();
220}
221
222static std::optional<StoreCacheControl> getCacheControl(BlockStoreOp op) {
223 return op.getCacheControl();
224}
225
226static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) {
227 if (op->hasAttr("cache_control")) {
228 auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>("cache_control");
229 if (!attr)
230 return std::nullopt;
231 return std::optional<LoadCacheControl>(attr.getValue());
232 }
233 return std::nullopt;
234}
235
236static std::optional<StoreCacheControl> getCacheControl(LLVM::StoreOp op) {
237 if (op->hasAttr("cache_control")) {
238 auto attr = op->getAttrOfType<xevm::StoreCacheControlAttr>("cache_control");
239 if (!attr)
240 return std::nullopt;
241 return std::optional<StoreCacheControl>(attr.getValue());
242 }
243 return std::nullopt;
244}
245
246template <typename OpType>
247int32_t getL1CacheControl(OpType op) {
248 return getL1CacheControl(*getCacheControl(op));
249}
250
251template <typename OpType>
252int32_t getL3CacheControl(OpType op) {
253 return getL3CacheControl(*getCacheControl(op));
254}
255
256template <typename OpType>
257static std::optional<ArrayAttr>
258getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
259 if (!getCacheControl(op))
260 return {};
261
262 constexpr int32_t decorationCacheControlArity{3};
263 constexpr int32_t loadCacheControlKey{6442};
264 constexpr int32_t storeCacheControlKey{6443};
265 constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
266 std::is_same_v<OpType, BlockPrefetch2dOp> ||
267 std::is_same_v<OpType, LLVM::LoadOp> ||
268 std::is_same_v<OpType, BlockLoadOp> ||
269 std::is_same_v<OpType, PrefetchOp>;
270
271 // If the cache control is USE_DEFAULT, then we don’t emit any metadata.
272 // Assert that if one of the L1 or L3 cache control values is USE_DEFAULT
273 // (represented as -1), then both must be USE_DEFAULT; otherwise there is a
274 // bug.
275 assert(((getL1CacheControl<OpType>(op) == -1) ==
276 (getL3CacheControl<OpType>(op) == -1)) &&
277 "If one of L1 or L3 cache control is USE_DEFAULT, both must be "
278 "USE_DEFAULT");
279
280 if (getL1CacheControl<OpType>(op) == -1 &&
281 getL3CacheControl<OpType>(op) == -1)
282 return {};
283 const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
285 controlKey, 0, getL1CacheControl<OpType>(op)};
287 controlKey, 1, getL3CacheControl<OpType>(op)};
288 auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
289 auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
290
291 SmallVector<Attribute, 2> combinedAttrs = {arrayAttrL1, arrayAttrL3};
292 return rewriter.getArrayAttr(combinedAttrs);
293}
294
295//===----------------------------------------------------------------------===//
296// Cache control annotation utilities
297//
298// Instead of attaching cache control as MLIR attributes and handling them
299// during LLVM translation, we directly emit llvm.intr.ptr.annotation op in
300// MLIR.
301//===----------------------------------------------------------------------===//
302
303/// Build one cache-control payload string per attribute.
304///
305/// Each Attribute is expected to be an ArrayAttr of 3 IntegerAttr values:
306/// [SPIR-V decoration token, cache level, cache control value]
307///
308/// A single entry produces a string like: {6442:"0,1"}
309/// where the quote characters (0x22) will appear as \22 in LLVM IR textual
310/// form.
312buildCacheControlPayloads(ArrayRef<Attribute> attrs) {
314 llvm::StringMap<bool> seen;
315
316 for (Attribute a : attrs) {
317 auto arr = dyn_cast<ArrayAttr>(a);
318 if (!arr)
319 continue;
320
321 auto vals = arr.getValue();
322 assert(vals.size() == 3 &&
323 "Expected exactly 3 integer values (Token, CacheLevel, "
324 "ControlValue) in cache control attribute.");
325
326 auto tokenAttr = dyn_cast<IntegerAttr>(vals[0]);
327 auto secondAttr = dyn_cast<IntegerAttr>(vals[1]);
328 auto thirdAttr = dyn_cast<IntegerAttr>(vals[2]);
329
330 if (!tokenAttr || !secondAttr || !thirdAttr)
331 continue;
332
333 // Produce: {SPIR-V decoration token:"L1 cache control,L3 cache control"}
334 // The quote char (0x22) is embedded literally; LLVM IR prints it as \22.
335 std::string entry =
336 llvm::formatv("{{{0}:\"{1},{2}\"}", tokenAttr.getValue().getZExtValue(),
337 secondAttr.getValue().getZExtValue(),
338 thirdAttr.getValue().getZExtValue());
339
340 // Deduplicate identical annotations.
341 if (!seen.insert({entry, true}).second)
342 continue;
343
344 payloads.push_back(std::move(entry));
345 }
346 return payloads;
347}
348/// Counter for generating unique global variable names.
349static std::atomic<uint64_t> globalNameCounter{0};
350
351/// Get or create a global metadata string and return a !llvm.ptr<1> value
352/// pointing to it. The AddressOfOp is created at the current rewriter
353/// insertion point; the GlobalOp is created at the module start.
354static Value createMetadataStringPtr(ConversionPatternRewriter &rewriter,
355 Operation *moduleOp, Location loc,
356 StringRef value, StringRef nameHint) {
357 // Build null-terminated string.
358 std::string strWithNull = value.str();
359 strWithNull.push_back('\0');
360 StringRef strRef(strWithNull.data(), strWithNull.size());
361
362 auto as1PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
363
364 // Search for an existing global with the same content.
365 for (auto &op : moduleOp->getRegion(0).front()) {
366 if (auto existingGlobal = dyn_cast<LLVM::GlobalOp>(&op)) {
367 if (!existingGlobal.getSection() ||
368 *existingGlobal.getSection() != "llvm.metadata")
369 continue;
370 if (auto strAttr =
371 dyn_cast_or_null<StringAttr>(existingGlobal.getValueOrNull())) {
372 if (strAttr.getValue() == strRef) {
373 return LLVM::AddressOfOp::create(rewriter, loc, as1PtrTy,
374 existingGlobal.getSymName());
375 }
376 }
377 }
378 }
379
380 // Create new global at module start.
381 auto i8Type = rewriter.getI8Type();
382 auto arrayType = LLVM::LLVMArrayType::get(i8Type, strWithNull.size());
383 std::string globalName =
384 llvm::formatv("{0}.{1}", nameHint,
385 globalNameCounter.fetch_add(1, std::memory_order_relaxed))
386 .str();
387
388 {
389 OpBuilder::InsertionGuard guard(rewriter);
390 rewriter.setInsertionPointToStart(&moduleOp->getRegion(0).front());
391
392 auto globalOp =
393 LLVM::GlobalOp::create(rewriter, loc, arrayType,
394 /*isConstant=*/true, LLVM::Linkage::Private,
395 globalName, rewriter.getStringAttr(strRef));
396 globalOp.setSection(StringRef("llvm.metadata"));
397 globalOp.setUnnamedAddr(LLVM::UnnamedAddr::Global);
398 globalOp.setAlignment(1);
399 globalOp.setAddrSpace(1);
400 }
401 // InsertionGuard restores the original insertion point here.
402
403 return LLVM::AddressOfOp::create(rewriter, loc, as1PtrTy, globalName);
404}
405
406/// Annotate a pointer value with cache control metadata by emitting chained
407/// `llvm.intr.ptr.annotation` ops (LLVM::PtrAnnotation).
408///
409/// This is the MLIR-level equivalent of handleDecorationCacheControl() from
410/// the LLVM translation layer. For each cache control attribute, it emits:
411///
412/// %ann = llvm.intr.ptr.annotation %ptr, @".str.cachecontrol.N",
413/// @".str.file.N", 0, null : !llvm.ptr<AS>
414///
415/// Multiple annotations are chained: the result of each annotation op is
416/// fed as the pointer input to the next one.
417///
418/// \param rewriter The pattern rewriter.
419/// \param loc Source location for created ops.
420/// \param ptr The pointer value to annotate.
421/// \param cacheControls The cache control ArrayAttr (from
422/// getCacheControlMetadata).
423/// \param moduleOp The enclosing module (for creating globals).
424/// \returns The annotated pointer value (or the original ptr if no
425/// annotations).
426static Value annotatePtrWithCacheControl(ConversionPatternRewriter &rewriter,
427 Location loc, Value ptr,
428 ArrayAttr cacheControls,
429 Operation *moduleOp) {
430 SmallVector<std::string> payloads =
431 buildCacheControlPayloads(cacheControls.getValue());
432 if (payloads.empty())
433 return ptr;
434
435 auto ptrType = cast<LLVM::LLVMPointerType>(ptr.getType());
436 auto as1PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
437 auto i32Ty = rewriter.getI32Type();
438
439 // Create shared constants for all annotations on this pointer.
440 Value fileStr =
441 createMetadataStringPtr(rewriter, moduleOp, loc, "", ".str.file");
442 Value lineVal = LLVM::ConstantOp::create(rewriter, loc, i32Ty, 0);
443 Value nullAS1 = LLVM::ZeroOp::create(rewriter, loc, as1PtrTy);
444
445 // Chain: each annotation takes the result of the previous one as its
446 // pointer operand.
447 Value curPtr = ptr;
448 for (const std::string &payload : payloads) {
449 Value annStr = createMetadataStringPtr(rewriter, moduleOp, loc, payload,
450 ".str.cachecontrol");
451 auto annOp = LLVM::PtrAnnotation::create(rewriter, loc, ptrType, curPtr,
452 annStr, fileStr, lineVal, nullAS1);
453 curPtr = annOp.getResult();
454 }
455
456 return curPtr;
457}
458
459/// Helper to apply cache control annotation on a pointer operand of a call.
460/// Replaces the pointer argument of the call with an annotated version.
461///
462/// For operations that produce a call (like block load/store/prefetch), the
463/// pointer is typically the first argument. This function:
464/// 1. Builds the annotation chain on the pointer.
465/// 2. Replaces the pointer operand in the provided args list.
466///
467/// \param rewriter The pattern rewriter.
468/// \param loc Source location.
469/// \param ptr The original pointer value (first arg to the call).
470/// \param cacheControls The cache control metadata.
471/// \param moduleOp The enclosing module.
472/// \param args The argument list (modified in place: args[ptrIdx] is
473/// replaced).
474/// \param ptrIdx Index of the pointer in the args list (default 0).
475template <typename OpType>
476static void
477applyCacheControlAnnotation(ConversionPatternRewriter &rewriter, Location loc,
478 OpType op, SmallVectorImpl<Value> &args,
479 Operation *moduleOp, unsigned ptrIdx = 0) {
480 std::optional<ArrayAttr> optCacheControls =
481 getCacheControlMetadata(rewriter, op);
482 if (!optCacheControls)
483 return;
484
485 Value annotatedPtr = annotatePtrWithCacheControl(rewriter, loc, args[ptrIdx],
486 *optCacheControls, moduleOp);
487 args[ptrIdx] = annotatedPtr;
488}
489
490//===----------------------------------------------------------------------===//
491// End cache control annotation utilities
492//===----------------------------------------------------------------------===//
493
494static LLVM::CallOp createDeviceFunctionCall(
495 ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
496 ArrayRef<Type> argTypes, ArrayRef<Value> args,
497 mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs,
498 LLVMFuncAttributeOptions funcAttributeOptions, Operation *op) {
499 auto *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
500 assert(moduleOp && "Expecting module");
501 Location loc = op->getLoc();
502
503 auto funcOpRes =
504 LLVM::lookupOrCreateFn(rewriter, moduleOp, funcName, argTypes, retType);
505 assert(!failed(funcOpRes));
506 LLVM::LLVMFuncOp funcOp = funcOpRes.value();
507 funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
508 funcOp.setConvergent(funcAttributeOptions.isConvergent);
509 funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
510 funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
511
512 if (funcAttributeOptions.memEffectsAttr)
513 funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
514
515 for (auto [idx, attrName] : paramAttrs)
516 funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
517
518 auto callOp = LLVM::CallOp::create(rewriter, loc, funcOp, args);
519 callOp->setAttrs(funcOp->getAttrs());
520
521 return callOp;
522}
523
524static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
525 switch (pTy) {
526 case xevm::ElemType::F32:
527 case xevm::ElemType::TF32:
528 return 1;
529 case xevm::ElemType::BF16:
530 case xevm::ElemType::F16:
531 return 2;
532 case xevm::ElemType::U8:
533 case xevm::ElemType::S8:
534 case xevm::ElemType::BF8:
535 case xevm::ElemType::F8:
536 return 4;
537 case xevm::ElemType::E2M1:
538 case xevm::ElemType::U4:
539 case xevm::ElemType::S4:
540 return 8;
541 default:
542 llvm_unreachable("unsupported xevm::ElemType");
543 }
544}
545
546class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
547 using OpConversionPattern::OpConversionPattern;
548 LogicalResult
549 matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor,
550 ConversionPatternRewriter &rewriter) const override {
551 if (!op.getC()) {
552 return rewriter.notifyMatchFailure(op, "OCL requires C operand");
553 }
554 auto precisionA = op.getTypes().getA();
555 auto precisionB = op.getTypes().getB();
556 auto precisionC = op.getTypes().getC();
557 auto precisionD = op.getTypes().getD();
558 if (precisionC != precisionD) {
559 return rewriter.notifyMatchFailure(op, "type of C and D need to match");
560 }
561 if (precisionC != xevm::ElemType::S32 &&
562 precisionC != xevm::ElemType::F32 &&
563 precisionC != xevm::ElemType::F16 &&
564 precisionC != xevm::ElemType::BF16) {
565 return rewriter.notifyMatchFailure(
566 op, "type of C and D must be S32, F32, F16 or BF16");
567 }
568 if (precisionA == xevm::ElemType::S32 ||
569 precisionA == xevm::ElemType::F32) {
570 return rewriter.notifyMatchFailure(op, "type of A cannot be S32 or F32");
571 }
572 if (precisionB == xevm::ElemType::S32 ||
573 precisionB == xevm::ElemType::F32) {
574 return rewriter.notifyMatchFailure(op, "type of B cannot be S32 or F32");
575 }
576 constexpr uint32_t bitWidthPackedA{16};
577 constexpr uint32_t bitWidthPackedB{32};
578 auto loc = op.getLoc();
579
580 auto castIfNeeded = [&](Value val, Type packedType) -> Value {
581 VectorType origTy = cast<VectorType>(val.getType());
582 const uint32_t vecBitSize =
583 origTy.getNumElements() *
584 origTy.getElementType().getIntOrFloatBitWidth();
585 VectorType newTy = VectorType::get(
586 vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
587 if (origTy != newTy)
588 val = LLVM::BitcastOp::create(rewriter, loc, newTy, val);
589 return val;
590 };
591
592 Value a = op.getA();
593 Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
594 ? cast<Type>(rewriter.getF32Type())
595 : rewriter.getIntegerType(bitWidthPackedA);
596 a = castIfNeeded(a, packedAType);
597
598 Value b = op.getB();
599 Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
600 ? cast<Type>(rewriter.getF32Type())
601 : rewriter.getIntegerType(bitWidthPackedB);
602 b = castIfNeeded(b, packedBType);
603
604 Value c = op.getC();
605 VectorType cOrigTy = cast<VectorType>(c.getType());
606 VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
607 assert(cOrigTy == resOrigTy && "Accumulator and result type mismatch");
608 // OCL builtins encode bfloat16 as int16
609 VectorType cTy =
610 cOrigTy.getElementType().isBF16()
611 ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
612 : cOrigTy;
613 VectorType resTy = cTy;
614 if (cOrigTy != cTy)
615 c = LLVM::BitcastOp::create(rewriter, loc, cTy, c);
616
617 constexpr int32_t systolicDepth{8};
618 std::string fnName =
619 llvm::formatv("intel_sub_group_{0}_{1}_matrix_mad_k{2}",
620 stringifyElemType(op.getTypes().getA()).str(),
621 stringifyElemType(op.getTypes().getB()).str(),
622 systolicDepth *
623 getNumOperandsPerDword(op.getTypes().getA()))
624 .str();
625 SmallVector<Type> argTypes{a.getType(), b.getType(), cTy};
626 fnName = mangle(fnName, argTypes);
627 SmallVector<Value> args{a, b, c};
628
629 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
630 /*other=*/LLVM::ModRefInfo::NoModRef,
631 /*argMem=*/LLVM::ModRefInfo::NoModRef,
632 /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
633 /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
634 /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
635 /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
636 auto funcAttrs = convergentNoUnwindWillReturnAttrs;
637 funcAttrs.memEffectsAttr = memAttr;
638 Value result =
639 createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
640 funcAttrs, op.getOperation())
641 ->getResult(0);
642
643 if (resOrigTy != resTy)
644 result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy, result);
645
646 rewriter.replaceOp(op, result);
647 return success();
648 }
649};
650
651class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
652 using OpConversionPattern::OpConversionPattern;
653 LogicalResult
654 matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor,
655 ConversionPatternRewriter &rewriter) const override {
656 auto loc = op.getLoc();
657 auto *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
658
659 const std::string fnName{"_Z8prefetchPU3AS1Kcm"};
660 Value one =
661 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), 1);
662 SmallVector<Value> args{op.getPtr(), one};
663
664 // Annotate pointer with cache control before passing to the call.
665 applyCacheControlAnnotation(rewriter, loc, op, args, moduleOp,
666 /*ptrIdx=*/0);
667
668 SmallVector<Type> argTypes;
669 for (auto arg : args)
670 argTypes.push_back(arg.getType());
671 auto funcAttr = noUnwindAttrs;
672 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
673 /*other=*/LLVM::ModRefInfo::NoModRef,
674 /*argMem=*/LLVM::ModRefInfo::Ref,
675 /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
676 /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
677 /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
678 /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
679 funcAttr.memEffectsAttr = memAttr;
680
681 createDeviceFunctionCall(rewriter, fnName,
682 LLVM::LLVMVoidType::get(rewriter.getContext()),
683 argTypes, args, {}, funcAttr, op.getOperation());
684 rewriter.eraseOp(op);
685 return success();
686 }
687};
688
689class MemfenceToOCLPattern : public OpConversionPattern<MemfenceOp> {
690 using OpConversionPattern::OpConversionPattern;
691 LogicalResult
692 matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor,
693 ConversionPatternRewriter &rewriter) const override {
694 auto loc = op.getLoc();
695 const std::string fnName{"atomic_work_item_fence"};
696 int memScope, addrSpace;
697 switch (op.getAddrspace()) {
698 case xevm::AddrSpace::SHARED:
699 addrSpace = 1; // CLK_LOCAL_MEM_FENCE
700 break;
701 case xevm::AddrSpace::GLOBAL:
702 addrSpace = 2; // CLK_GLOBAL_MEM_FENCE
703 break;
704 default:
705 // GENERIC is not supported in OpenCL
706 return rewriter.notifyMatchFailure(
707 op, "Fence only supports global and shared address spaces.");
708 }
709 switch (op.getScope()) {
710 case xevm::MemScope::WORKGROUP:
711 memScope = 1;
712 break;
713 case xevm::MemScope::DEVICE:
714 memScope = 2;
715 break;
716 default:
717 // CLUSTER and SYSTEM are not supported in OpenCL
718 return rewriter.notifyMatchFailure(
719 op, "Fence only supports workgroup and device memory scopes.");
720 }
721 Type i32Type = rewriter.getI32Type();
722 Value acqRel = LLVM::ConstantOp::create(rewriter, loc, i32Type, 4);
723 Value memScopeConst =
724 LLVM::ConstantOp::create(rewriter, loc, i32Type, memScope);
725 Value addrSpaceConst =
726 LLVM::ConstantOp::create(rewriter, loc, i32Type, addrSpace);
727 SmallVector<Value> args{addrSpaceConst, acqRel, memScopeConst};
728 SmallVector<Type> argTypes{3, i32Type};
729 createDeviceFunctionCall(rewriter, mangle(fnName, argTypes),
730 LLVM::LLVMVoidType::get(rewriter.getContext()),
731 argTypes, args, {}, noUnwindAttrs,
732 op.getOperation());
733 rewriter.eraseOp(op);
734 return success();
735 }
736};
737template <typename OpType>
738class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
739 using OpConversionPattern<OpType>::OpConversionPattern;
740 LogicalResult
741 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
742 ConversionPatternRewriter &rewriter) const override {
743 constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp>;
744 constexpr bool isPrefetch = std::is_same_v<OpType, BlockPrefetch2dOp>;
745
746 auto loc = op.getLoc();
747 auto *moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
748 VectorType vecType;
749 bool packReg = false;
750 bool transpose = false;
751 if constexpr (isLoad) {
752 vecType = op.getRes().getType();
753 packReg = op.getPackRegister();
754 transpose = op.getTranspose();
755 } else if constexpr (!isPrefetch) {
756 vecType = op.getStoredVal().getType();
757 }
758
759 auto i32Type = rewriter.getI32Type();
760 Value byteCoord =
761 LLVM::UndefOp::create(rewriter, loc, VectorType::get(2, i32Type));
762 Value zero = LLVM::ConstantOp::create(rewriter, loc, i32Type, 0);
763 Value one = LLVM::ConstantOp::create(rewriter, loc, i32Type, 1);
764 byteCoord = LLVM::InsertElementOp::create(
765 rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
766 byteCoord = LLVM::InsertElementOp::create(
767 rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
768 SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
769 op.getBasePitch(), byteCoord};
770
771 // Annotate pointer (args[0]) with cache control before the call.
772 applyCacheControlAnnotation(rewriter, loc, op, args, moduleOp,
773 /*ptrIdx=*/0);
774
775 SmallVector<Type> retTypes;
776 Value spvLoadDstPtr;
777 std::string funcName{"intel_sub_group_2d_block_"};
778 std::string bitWidthId;
779 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
780 SmallVector<std::pair<unsigned, StringRef>, 4> paramAttrs;
781 if constexpr (isPrefetch) { // Prefetch
782 funcName += "prefetch";
783 paramAttrs = {std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName())};
784 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
785 /*other=*/LLVM::ModRefInfo::NoModRef,
786 /*argMem=*/LLVM::ModRefInfo::Ref,
787 /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
788 /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
789 /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
790 /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
791 funcAttr = noUnwindAttrs;
792 funcAttr.memEffectsAttr = memAttr;
793 } else {
794 auto vecElemType = vecType.getElementType();
795 auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth();
796 Value numElems = LLVM::ConstantOp::create(rewriter, loc, i32Type,
797 vecType.getNumElements());
798 auto dstOrSrcPtr = LLVM::AllocaOp::create(
799 rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()),
800 vecElemType, numElems);
801 args.push_back(dstOrSrcPtr);
802 if constexpr (isLoad) { // Load
803 funcName += "read";
804 bitWidthId = getTypeMangling(vecElemType, /*isUnsigned=*/true);
805 if (packReg)
806 funcName += "_transform";
807 else if (transpose)
808 funcName += "_transpose";
809 spvLoadDstPtr = dstOrSrcPtr;
810 retTypes.push_back(vecType);
811 paramAttrs = {
812 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
813 std::make_pair(0, LLVM::LLVMDialect::getReadonlyAttrName()),
814 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
815 std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()),
816 };
817 } else { // Store
818 funcName += "write";
819 bitWidthId = (vecElemBitWidth == 32)
820 ? "j"
821 : ((vecElemBitWidth == 16) ? "t" : "h");
822 LLVM::StoreOp::create(rewriter, loc, op.getStoredVal(), dstOrSrcPtr);
823 paramAttrs = {
824 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
825 std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()),
826 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
827 std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()),
828 };
829 }
830 }
831
832 funcName =
833 llvm::formatv("{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(),
834 op.getTileHeight(), op.getTileWidth(), op.getVBlocks())
835 .str();
836 std::string prefetchCode("");
837 if (!isPrefetch)
838 prefetchCode += "P";
839 funcName = llvm::formatv("_Z{0}{1}PU3AS1viiiDv2_i{2}{3}", funcName.size(),
840 funcName, prefetchCode, bitWidthId)
841 .str();
842 SmallVector<Type> argTypes;
843 for (auto arg : args) {
844 argTypes.push_back(arg.getType());
845 }
846 createDeviceFunctionCall(
847 rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
848 argTypes, args, paramAttrs, funcAttr, op.getOperation());
849
850 if constexpr (isLoad)
851 rewriter.replaceOp(
852 op, LLVM::LoadOp::create(rewriter, loc, vecType, spvLoadDstPtr));
853 else
854 rewriter.eraseOp(op);
855 return success();
856 }
857};
858
859template <typename OpType>
860class BlockLoadStore1DToOCLPattern : public OpConversionPattern<OpType> {
861 using OpConversionPattern<OpType>::OpConversionPattern;
862 LogicalResult
863 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
864 ConversionPatternRewriter &rewriter) const override {
865 constexpr bool isStore = std::is_same_v<OpType, xevm::BlockStoreOp>;
866 auto loc = op.getLoc();
867 auto *moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
868
869 // Get OpenCL function name
870 // https://registry.khronos.org/OpenCL/extensions/
871 // intel/cl_intel_subgroup_local_block_io.html
872 std::string funcName{"intel_sub_group_block_"};
873 // Value or Result type can be vector or scalar
874 Type valOrResTy;
875 if constexpr (isStore) {
876 funcName += "write_u";
877 valOrResTy = op.getVal().getType();
878 } else {
879 funcName += "read_u";
880 valOrResTy = op.getType();
881 }
882 // Get element type of the vector/scalar
883 VectorType vecTy = dyn_cast<VectorType>(valOrResTy);
884 Type elemType = vecTy ? vecTy.getElementType() : valOrResTy;
885 funcName += getTypeMangling(elemType);
886 if (vecTy)
887 funcName += std::to_string(vecTy.getNumElements());
888 SmallVector<Type, 2> argTypes{};
889 // XeVM BlockLoad/StoreOp always use signless integer types
890 // but OpenCL builtins expect unsigned types
891 // use unsigned types for mangling
892 SmallVector<bool, 2> isUnsigned{};
893 // arg0: pointer to the src/dst address
894 // arg1 - only if store : vector to store
895 // Prepare arguments
896 SmallVector<Value, 2> args{};
897 args.push_back(op.getPtr());
898 argTypes.push_back(op.getPtr().getType());
899 isUnsigned.push_back(true);
900
901 // Annotate pointer (args[0]) with cache control.
902 applyCacheControlAnnotation(rewriter, loc, op, args, moduleOp,
903 /*ptrIdx=*/0);
904 // Update argTypes[0] in case the pointer type changed (it shouldn't
905 // change type, but the value is now the annotated pointer).
906 argTypes[0] = args[0].getType();
907
908 Type retType;
909 if constexpr (isStore) {
910 args.push_back(op.getVal());
911 argTypes.push_back(op.getVal().getType());
912 isUnsigned.push_back(true);
913 retType = LLVM::LLVMVoidType::get(rewriter.getContext());
914 } else {
915 retType = valOrResTy;
916 }
917 funcName = std::string("_Z") + std::to_string(funcName.size()) + funcName +
918 "PU3AS" +
919 std::to_string(op.getPtr().getType().getAddressSpace());
920 funcName += getTypeMangling(elemType, /*isUnsigned=*/true);
921 if constexpr (isStore)
922 funcName += getTypeMangling(valOrResTy, /*isUnsigned=*/true);
923 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
924
925 LLVM::CallOp call =
926 createDeviceFunctionCall(rewriter, funcName, retType, argTypes, args,
927 {}, funcAttr, op.getOperation());
928
929 if constexpr (isStore)
930 rewriter.eraseOp(op);
931 else
932 rewriter.replaceOp(op, call->getResult(0));
933 return success();
934 }
935};
936
937template <typename OpType>
938class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> {
939 using OpConversionPattern<OpType>::OpConversionPattern;
940 LogicalResult
941 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
942 ConversionPatternRewriter &rewriter) const override {
943 if (!op->hasAttr("cache_control"))
944 return failure();
945
946 auto *moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
947 std::optional<ArrayAttr> optCacheControls =
948 getCacheControlMetadata(rewriter, op);
949 if (!optCacheControls) {
950 rewriter.modifyOpInPlace(op, [&]() { op->removeAttr("cache_control"); });
951 return success();
952 }
953
954 // Determine which operand is the pointer.
955 constexpr bool isStore = std::is_same_v<OpType, LLVM::StoreOp>;
956 unsigned ptrIdx = isStore ? 1 : 0;
957 Value ptr = op->getOperand(ptrIdx);
958
959 // Emit annotation intrinsic calls on the pointer.
960 Value annotatedPtr = annotatePtrWithCacheControl(
961 rewriter, op->getLoc(), ptr, *optCacheControls, moduleOp);
962
963 // Replace the pointer operand with the annotated one.
964 rewriter.modifyOpInPlace(op, [&]() {
965 op->setOperand(ptrIdx, annotatedPtr);
966 op->removeAttr("cache_control");
967 });
968 return success();
969 }
970};
971
972//===----------------------------------------------------------------------===//
973// GPU index id operations
974//===----------------------------------------------------------------------===//
975/*
976// Launch Config ops
977// dimidx - x, y, z - is fixed to i32
978// return type is set by XeVM type converter
979// get_local_id
980xevm::WorkitemIdXOp;
981xevm::WorkitemIdYOp;
982xevm::WorkitemIdZOp;
983// get_local_size
984xevm::WorkgroupDimXOp;
985xevm::WorkgroupDimYOp;
986xevm::WorkgroupDimZOp;
987// get_group_id
988xevm::WorkgroupIdXOp;
989xevm::WorkgroupIdYOp;
990xevm::WorkgroupIdZOp;
991// get_num_groups
992xevm::GridDimXOp;
993xevm::GridDimYOp;
994xevm::GridDimZOp;
995// get_global_id : to be added if needed
996*/
997
998// Helpers to get the OpenCL function name and dimension argument for each op.
999static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdXOp) {
1000 return {"get_local_id", 0};
1001}
1002static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdYOp) {
1003 return {"get_local_id", 1};
1004}
1005static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdZOp) {
1006 return {"get_local_id", 2};
1007}
1008static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimXOp) {
1009 return {"get_local_size", 0};
1010}
1011static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimYOp) {
1012 return {"get_local_size", 1};
1013}
1014static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimZOp) {
1015 return {"get_local_size", 2};
1016}
1017static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdXOp) {
1018 return {"get_group_id", 0};
1019}
1020static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdYOp) {
1021 return {"get_group_id", 1};
1022}
1023static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdZOp) {
1024 return {"get_group_id", 2};
1025}
1026static std::pair<StringRef, int64_t> getConfig(xevm::GridDimXOp) {
1027 return {"get_num_groups", 0};
1028}
1029static std::pair<StringRef, int64_t> getConfig(xevm::GridDimYOp) {
1030 return {"get_num_groups", 1};
1031}
1032static std::pair<StringRef, int64_t> getConfig(xevm::GridDimZOp) {
1033 return {"get_num_groups", 2};
1034}
1035/// Replace `xevm.*` with an `llvm.call` to the corresponding OpenCL func with
1036/// a constant argument for the dimension - x, y or z.
1037template <typename OpType>
1038class LaunchConfigOpToOCLPattern : public OpConversionPattern<OpType> {
1039 using OpConversionPattern<OpType>::OpConversionPattern;
1040 LogicalResult
1041 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
1042 ConversionPatternRewriter &rewriter) const override {
1043 Location loc = op->getLoc();
1044 auto [baseName, dim] = getConfig(op);
1045 Type dimTy = rewriter.getI32Type();
1046 Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
1047 static_cast<int64_t>(dim));
1048 std::string func = mangle(baseName, {dimTy}, {true});
1049 Type resTy = op.getType();
1050 auto call =
1051 createDeviceFunctionCall(rewriter, func, resTy, {dimTy}, {dimVal}, {},
1052 noUnwindWillReturnAttrs, op.getOperation());
1053 constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
1054 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
1055 /*other=*/noModRef,
1056 /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef,
1057 /*errnoMem=*/noModRef,
1058 /*targetMem0=*/noModRef,
1059 /*targetMem1=*/noModRef);
1060 call.setMemoryEffectsAttr(memAttr);
1061 rewriter.replaceOp(op, call);
1062 return success();
1063 }
1064};
1065
1066/*
1067// Subgroup ops
1068// get_sub_group_local_id
1069xevm::LaneIdOp;
1070// get_sub_group_id
1071xevm::SubgroupIdOp;
1072// get_sub_group_size
1073xevm::SubgroupSizeOp;
1074// get_num_sub_groups : to be added if needed
1075*/
1076
1077// Helpers to get the OpenCL function name for each op.
1078static StringRef getConfig(xevm::LaneIdOp) { return "get_sub_group_local_id"; }
1079static StringRef getConfig(xevm::SubgroupIdOp) { return "get_sub_group_id"; }
1080static StringRef getConfig(xevm::SubgroupSizeOp) {
1081 return "get_sub_group_size";
1082}
1083template <typename OpType>
1084class SubgroupOpWorkitemOpToOCLPattern : public OpConversionPattern<OpType> {
1085 using OpConversionPattern<OpType>::OpConversionPattern;
1086 LogicalResult
1087 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
1088 ConversionPatternRewriter &rewriter) const override {
1089 std::string func = mangle(getConfig(op).str(), {});
1090 Type resTy = op.getType();
1091 auto call =
1092 createDeviceFunctionCall(rewriter, func, resTy, {}, {}, {},
1093 noUnwindWillReturnAttrs, op.getOperation());
1094 constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
1095 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
1096 /*other=*/noModRef,
1097 /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef,
1098 /*errnoMem=*/noModRef,
1099 /*targetMem0=*/noModRef,
1100 /*targetMem1=*/noModRef);
1101 call.setMemoryEffectsAttr(memAttr);
1102 rewriter.replaceOp(op, call);
1103 return success();
1104 }
1105};
1106
1107class TruncfToOCLPattern : public OpConversionPattern<TruncfOp> {
1108 using OpConversionPattern::OpConversionPattern;
1109 LogicalResult
1110 matchAndRewrite(TruncfOp op, TruncfOp::Adaptor adaptor,
1111 ConversionPatternRewriter &rewriter) const override {
1112 // Supported source and result types are resticted for now.
1113 auto srcEtype = op.getSrcEtype().getEtype();
1114 auto dstEtype = op.getDstEtype().getEtype();
1115 if (auto vecSrcTy = dyn_cast<VectorType>(op.getSrc().getType())) {
1116 if (vecSrcTy.getNumElements() != 16)
1117 return rewriter.notifyMatchFailure(
1118 op, "Only vector src of 16 elements is supported");
1119 } else {
1120 return rewriter.notifyMatchFailure(op, "Scalar src is not supported.");
1121 }
1122 if (auto vecDstTy = dyn_cast<VectorType>(op.getDst().getType())) {
1123 if (vecDstTy.getNumElements() != 16)
1124 return rewriter.notifyMatchFailure(
1125 op, "Only vector dst of 16 elements is supported");
1126 } else {
1127 return rewriter.notifyMatchFailure(op, "Scalar dst is not supported.");
1128 }
1129 if (srcEtype == TruncfSrcElemTypes::F16 &&
1130 dstEtype == TruncfDstElemTypes::BF8) {
1131 // BF8 is just F16 with lower 8 bits of mantessa discard.
1132 // Signbit Exponent Mantessa
1133 // BF8 1 5 2
1134 // F16 1 5 10
1135 // Xe arch is Little Endian so BF8 is just the second byte of the two
1136 // byte representation used for F16
1137 auto firstHalf =
1138 LLVM::ShuffleVectorOp::create(rewriter, op.getLoc(), op.getSrc(),
1139 op.getSrc(), {0, 1, 2, 3, 4, 5, 6, 7});
1140 auto secondHalf = LLVM::ShuffleVectorOp::create(
1141 rewriter, op.getLoc(), op.getSrc(), op.getSrc(),
1142 {8, 9, 10, 11, 12, 13, 14, 15});
1143 auto firstHalfCasted = LLVM::BitcastOp::create(
1144 rewriter, op.getLoc(), VectorType::get(16, rewriter.getI8Type()),
1145 firstHalf);
1146 auto secondHalfCasted = LLVM::BitcastOp::create(
1147 rewriter, op.getLoc(), VectorType::get(16, rewriter.getI8Type()),
1148 secondHalf);
1149 // Gather just the second bytes from every two byte F16 values
1150 auto resFirstHalf = LLVM::ShuffleVectorOp::create(
1151 rewriter, op.getLoc(), firstHalfCasted, firstHalfCasted,
1152 {1, 3, 5, 7, 9, 11, 13, 15});
1153 auto resSecondHalf = LLVM::ShuffleVectorOp::create(
1154 rewriter, op.getLoc(), secondHalfCasted, secondHalfCasted,
1155 {1, 3, 5, 7, 9, 11, 13, 15});
1156 auto res = LLVM::ShuffleVectorOp::create(
1157 rewriter, op.getLoc(), resFirstHalf, resSecondHalf,
1158 {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15});
1159 rewriter.replaceOp(op, res);
1160 } else {
1161 return rewriter.notifyMatchFailure(
1162 op, "Unsupported src, dst element type pair.");
1163 }
1164 return success();
1165 }
1166};
1167
1168class MMAMxToOCLPattern : public OpConversionPattern<MMAMxOp> {
1169 using OpConversionPattern::OpConversionPattern;
1170 LogicalResult
1171 matchAndRewrite(MMAMxOp op, MMAMxOp::Adaptor adaptor,
1172 ConversionPatternRewriter &rewriter) const override {
1173 if (!op.getC()) {
1174 return rewriter.notifyMatchFailure(op, "OCL requires C operand");
1175 }
1176 auto precisionC = op.getTypes().getC();
1177 auto precisionD = op.getTypes().getD();
1178 if (precisionC != precisionD) {
1179 return rewriter.notifyMatchFailure(op, "type of C and D need to match");
1180 }
1181
1182 constexpr uint32_t bitWidthPackedA{16};
1183 constexpr uint32_t bitWidthPackedB{32};
1184 auto loc = op.getLoc();
1185
1186 auto castIfNeeded = [&](Value val, Type packedType) -> Value {
1187 VectorType origTy = cast<VectorType>(val.getType());
1188 const uint32_t vecBitSize =
1189 origTy.getNumElements() *
1190 origTy.getElementType().getIntOrFloatBitWidth();
1191 VectorType newTy = VectorType::get(
1192 vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
1193 if (origTy != newTy)
1194 val = LLVM::BitcastOp::create(rewriter, loc, newTy, val);
1195 return val;
1196 };
1197
1198 Value a = op.getA();
1199 Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
1200 ? cast<Type>(rewriter.getF32Type())
1201 : rewriter.getIntegerType(bitWidthPackedA);
1202 a = castIfNeeded(a, packedAType);
1203
1204 Value b = op.getB();
1205 Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
1206 ? cast<Type>(rewriter.getF32Type())
1207 : rewriter.getIntegerType(bitWidthPackedB);
1208 b = castIfNeeded(b, packedBType);
1209
1210 Value c = op.getC();
1211 VectorType cOrigTy = cast<VectorType>(c.getType());
1212 VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
1213 assert(cOrigTy == resOrigTy && "Accumulator and result type mismatch");
1214 // OCL builtins encode bfloat16 as int16
1215 VectorType cTy =
1216 cOrigTy.getElementType().isBF16()
1217 ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
1218 : cOrigTy;
1219 VectorType resTy = cTy;
1220 if (cOrigTy != cTy)
1221 c = LLVM::BitcastOp::create(rewriter, loc, cTy, c);
1222
1223 constexpr int32_t systolicDepth{8};
1224 std::string fnName =
1225 llvm::formatv("intel_sub_group_{0}_{1}_scaled_matrix_mad_k{2}_{3}",
1226 stringifyElemType(op.getTypes().getA()).str(),
1227 stringifyElemType(op.getTypes().getB()).str(),
1228 systolicDepth *
1229 getNumOperandsPerDword(op.getTypes().getA()),
1230 stringifyElemType(op.getTypes().getC()).str())
1231 .str();
1232 auto scaleA = op.getScaleA();
1233 auto scaleB = op.getScaleB();
1234 SmallVector<Type> argTypes{a.getType(), b.getType(), cTy, scaleA.getType(),
1235 scaleB.getType()};
1236 fnName = mangle(fnName, argTypes);
1237 SmallVector<Value> args{a, b, c, scaleA, scaleB};
1238
1239 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
1240 /*other=*/LLVM::ModRefInfo::NoModRef,
1241 /*argMem=*/LLVM::ModRefInfo::NoModRef,
1242 /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
1243 /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
1244 /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
1245 /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
1246 auto funcAttrs = convergentNoUnwindWillReturnAttrs;
1247 funcAttrs.memEffectsAttr = memAttr;
1248 Value result =
1249 createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
1250 funcAttrs, op.getOperation())
1251 ->getResult(0);
1252
1253 if (resOrigTy != resTy)
1254 result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy, result);
1255
1256 rewriter.replaceOp(op, result);
1257 return success();
1258 }
1259};
1260
1261class AllocaToGlobalPattern : public OpConversionPattern<LLVM::AllocaOp> {
1262 using OpConversionPattern::OpConversionPattern;
1263 LogicalResult
1264 matchAndRewrite(LLVM::AllocaOp op, LLVM::AllocaOp::Adaptor adaptor,
1265 ConversionPatternRewriter &rewriter) const override {
1266 auto ptrType = cast<LLVM::LLVMPointerType>(op.getType());
1267 auto addrSpace = ptrType.getAddressSpace();
1268 if (addrSpace != 3)
1269 return failure();
1270 auto symTable = op->getParentWithTrait<OpTrait::SymbolTable>();
1271 if (!symTable)
1272 return failure();
1273 Block *moduleBody;
1274 if (ModuleOp mod = dyn_cast<ModuleOp>(*symTable)) {
1275 moduleBody = mod.getBody();
1276 } else if (gpu::GPUModuleOp gpuMod =
1277 dyn_cast<gpu::GPUModuleOp>(*symTable)) {
1278 moduleBody = gpuMod.getBody();
1279 } else {
1280 return failure();
1281 }
1282 auto val = op.getArraySize();
1283 APInt cst;
1284 if (!matchPattern(val, m_ConstantInt(&cst)))
1285 return failure();
1286 auto loc = op.getLoc();
1287 auto globalType = LLVM::LLVMArrayType::get(
1288 rewriter.getContext(), op.getElemType(), cst.getZExtValue());
1289 LLVM::GlobalOp globalVar;
1290 {
1291 OpBuilder::InsertionGuard guard(rewriter);
1292 rewriter.setInsertionPointToStart(moduleBody);
1293 auto alignment = op.getAlignment();
1294 globalVar = LLVM::GlobalOp::create(
1295 rewriter, loc, globalType, /*isConstant=*/false,
1296 /*linkage=*/LLVM::Linkage::Internal,
1297 /*name=*/std::string("__global_alloca_") +
1298 std::to_string(getNextGlobalIdx()),
1299 /*value=*/Attribute(),
1300 /*alignment=*/alignment ? *alignment : 0, /*addrSpace=*/addrSpace);
1301 }
1302 rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, globalVar);
1303 return success();
1304 }
1305
1306private:
1307 static unsigned getNextGlobalIdx() {
1308 static unsigned globalIdx = 0;
1309 return globalIdx++;
1310 }
1311};
1312
1313// Checks if shufflevector is used as a way to extract a contiguous slice
1314// from a vector.
1315// - source vector V1 and V2 are the same vector.
1316// - mask size is not greater than the source vector size
1317// - mask values represent a sequence of consecutive increasing numbers
1318// that stay in bounds of the source vector when used for indexing.
1319static bool isExtractingContiguousSlice(LLVM::ShuffleVectorOp op) {
1320 if (op.getV1() != op.getV2())
1321 return false;
1322 auto maskAttr = op.getMask();
1323 int64_t maskSize = static_cast<int64_t>(maskAttr.size());
1324 int64_t sourceSize = op.getV1().getType().getNumElements();
1325 if (maskSize > sourceSize)
1326 return false;
1327 int64_t firstIndex = maskAttr[0];
1328 for (int64_t i = 1; i < maskSize; ++i) {
1329 int64_t index = maskAttr[i];
1330 if (index != firstIndex + i)
1331 return false;
1332 if (index >= sourceSize)
1333 return false;
1334 }
1335 return true;
1336}
1337
1338// Input vector of a shuffle vector op extracting a contiguous slice is an
1339// illegal vector in SPIRV kernel if the vector size is > 16 elements.
1340// To legalize this case, keep applying the following transformations until no
1341// more match:
1342// 1. keep hoisting the shuffle vector op past unary element-wise operations
1343// start with fpext, fptrunc and bitcast for now.
1344// 2. merge with another shuffle vector op
1345// 3. merge with load as a smaller load
1346class HandleVectorExtractPattern
1347 : public OpRewritePattern<LLVM::ShuffleVectorOp> {
1348 using OpRewritePattern<LLVM::ShuffleVectorOp>::OpRewritePattern;
1349
1350 void initialize() { setHasBoundedRewriteRecursion(); }
1351
1352 LogicalResult matchAndRewrite(LLVM::ShuffleVectorOp op,
1353 PatternRewriter &rewriter) const override {
1354
1355 if (!isExtractingContiguousSlice(op))
1356 return failure();
1357
1358 auto mask = op.getMask();
1359 auto loc = op.getLoc();
1360 auto ty = op.getType();
1361 // Check source operand to determine rewrite pattern.
1362 auto src = op.getV1();
1363 // 1. Hoist past unary element-wise operations
1364 if (auto srcOp = src.getDefiningOp()) {
1365 if (isa<LLVM::FPExtOp>(srcOp) || isa<LLVM::FPTruncOp>(srcOp)) {
1366 Value srcInput = srcOp->getOperand(0);
1367 // Create new shuffle vector op with unary input as source.
1368 auto srcVecTy = dyn_cast<VectorType>(srcInput.getType());
1369 auto newShuffleVecTy =
1370 VectorType::get(mask.size(), srcVecTy.getElementType());
1371 auto newShuffle = LLVM::ShuffleVectorOp::create(
1372 rewriter, loc, newShuffleVecTy, srcInput, srcInput, mask);
1373 // Create new unary op with new shuffle as input.
1374 Value newUnaryOp;
1375 if (isa<LLVM::FPExtOp>(srcOp)) {
1376 newUnaryOp = LLVM::FPExtOp::create(rewriter, loc, ty, newShuffle);
1377 } else {
1378 newUnaryOp = LLVM::FPTruncOp::create(rewriter, loc, ty, newShuffle);
1379 }
1380 rewriter.replaceOp(op, newUnaryOp);
1381 } else if (isa<LLVM::BitcastOp>(srcOp)) {
1382 Value srcInput = srcOp->getOperand(0);
1383 // Create new shuffle vector op with unary input as source.
1384 auto srcInputVecTy = dyn_cast<VectorType>(srcInput.getType());
1385 auto srcInputSize = srcInputVecTy.getNumElements();
1386 auto srcResVecTy = dyn_cast<VectorType>(srcOp->getResult(0).getType());
1387 auto srcResSize = srcResVecTy.getNumElements();
1388 auto maskSize = static_cast<int32_t>(mask.size());
1389 if (srcInputSize > srcResSize) {
1390 return failure();
1391 }
1392 if (srcResSize % srcInputSize != 0) {
1393 return failure();
1394 }
1395 auto maskScale = srcResSize / srcInputSize;
1396 if (maskScale != 1) {
1397 if (mask[0] % maskScale != 0) {
1398 return failure();
1399 }
1400 // Create a new mask that maps to the source vector
1401 SmallVector<int32_t> newMask;
1402 int32_t newMaskSize = maskSize / maskScale;
1403 int32_t maskStart = mask[0] / maskScale;
1404 for (int32_t i = 0; i < newMaskSize; ++i) {
1405 newMask.push_back(maskStart + i);
1406 }
1407 mask = newMask;
1408 }
1409 auto newShuffleVecTy =
1410 VectorType::get(srcInputSize, srcInputVecTy.getElementType());
1411 auto newShuffle = LLVM::ShuffleVectorOp::create(
1412 rewriter, loc, newShuffleVecTy, srcInput, srcInput, mask);
1413 // Create new unary op with new shuffle as input.
1414 auto newBitcast =
1415 LLVM::BitcastOp::create(rewriter, loc, ty, newShuffle);
1416 rewriter.replaceOp(op, newBitcast);
1417 } else if (isa<LLVM::ShuffleVectorOp>(srcOp)) {
1418 // 2. Merge with source shuffle vector op if, the source op is
1419 // also extracting a contigous slice and create a new
1420 // shuffle vector op directly from the source of
1421 // the first shuffle.
1422 auto srcShuffle = cast<LLVM::ShuffleVectorOp>(srcOp);
1423 if (!isExtractingContiguousSlice(srcShuffle))
1424 return failure();
1425 auto srcMask = srcShuffle.getMask();
1426 SmallVector<int32_t> combinedMask;
1427 for (auto index : mask) {
1428 combinedMask.push_back(srcMask[index]);
1429 }
1430 auto newShuffle = LLVM::ShuffleVectorOp::create(
1431 rewriter, loc, ty, srcShuffle.getV1(), srcShuffle.getV1(),
1432 DenseI32ArrayAttr::get(rewriter.getContext(), combinedMask));
1433 rewriter.replaceOp(op, newShuffle);
1434 } else if (isa<LLVM::LoadOp>(srcOp)) {
1435 // 3. Merge with load as a smaller load
1436 auto loadOp = cast<LLVM::LoadOp>(srcOp);
1437 auto loadPtr = loadOp.getAddr();
1438 auto loadAddrSpace = loadPtr.getType().getAddressSpace();
1439 if (loadAddrSpace != 0)
1440 return failure();
1441 auto loadTy = dyn_cast<VectorType>(loadOp.getType());
1442 auto elemTy = loadTy.getElementType();
1443 auto firstIndex = mask[0];
1444 auto newVecTy = VectorType::get(mask.size(), elemTy);
1445 // GEPOp is needed if first index is not zero
1446 if (firstIndex) {
1447 auto newPtr = LLVM::GEPOp::create(
1448 rewriter, loc,
1449 LLVM::LLVMPointerType::get(rewriter.getContext(), loadAddrSpace),
1450 elemTy, loadPtr, ArrayRef<LLVM::GEPArg>{firstIndex});
1451 auto newLoad = LLVM::LoadOp::create(rewriter, loc, newVecTy, newPtr);
1452 rewriter.replaceOp(op, newLoad);
1453 } else {
1454 auto newLoad = LLVM::LoadOp::create(rewriter, loc, newVecTy, loadPtr);
1455 rewriter.replaceOp(op, newLoad);
1456 }
1457 } else {
1458 return failure();
1459 }
1460 }
1461 return success();
1462 }
1463};
1464
1465//===----------------------------------------------------------------------===//
1466// Pass Definition
1467//===----------------------------------------------------------------------===//
1468
1469struct ConvertXeVMToLLVMPass
1470 : public impl::ConvertXeVMToLLVMPassBase<ConvertXeVMToLLVMPass> {
1471 using Base::Base;
1472
1473 void getDependentDialects(DialectRegistry &registry) const override {
1474 registry.insert<LLVM::LLVMDialect, XeVMDialect>();
1475 }
1476
1477 void runOnOperation() override {
1478 ConversionTarget target(getContext());
1479 RewritePatternSet patterns(&getContext());
1481 if (failed(applyPartialConversion(getOperation(), target,
1482 std::move(patterns))))
1483 signalPassFailure();
1484
1485 // Apply in-dialect lowerings to handle illegal vectors
1486 {
1487 RewritePatternSet vectorPatterns(&getContext());
1488 vectorPatterns.add<HandleVectorExtractPattern>(&getContext());
1489 GreedyRewriteConfig config{};
1490 // folding can remove ops with temporary attributes used to
1491 // represent LLVM metadata, so disable it here.
1492 // Effectively just this single pattern is applied without any
1493 // op folding patterns from dialects.
1494 config.enableFolding(false);
1495 // config.setMaxIterations(GreedyRewriteConfig::kNoLimit);
1496 // config.setMaxNumRewrites(GreedyRewriteConfig::kNoLimit);
1497 (void)applyPatternsGreedily(getOperation(), std::move(vectorPatterns),
1498 config);
1499 }
1500 }
1501};
1502} // namespace
1503
1504//===----------------------------------------------------------------------===//
1505// Pattern Population
1506//===----------------------------------------------------------------------===//
1507
1508void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target,
1509 RewritePatternSet &patterns) {
1510 // some LLVM operations need to be converted.
1511 target.addDynamicallyLegalDialect<LLVM::LLVMDialect>([](Operation *op) {
1512 // llvm alloca op with addrspace 3 for OpenCL (Workgroup) is not handled
1513 // properly by SPIRV backend. It needs to be rewritten as a sequence with
1514 // llvm global.
1515 if (isa<LLVM::AllocaOp>(op)) {
1516 LLVM::AllocaOp aOp = cast<LLVM::AllocaOp>(op);
1517 LLVM::LLVMPointerType pTy = cast<LLVM::LLVMPointerType>(aOp.getType());
1518 auto addrSpace = pTy.getAddressSpace();
1519 return addrSpace != 3;
1520 }
1521 // cache_control attribute should be converted.
1522 return !op->hasAttr("cache_control");
1523 });
1524 target.addIllegalDialect<XeVMDialect>();
1525 patterns.add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
1526 LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
1527 LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
1528 MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern,
1529 LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
1530 LLVMLoadStoreToOCLPattern<LLVM::StoreOp>,
1531 BlockLoadStore1DToOCLPattern<BlockLoadOp>,
1532 BlockLoadStore1DToOCLPattern<BlockStoreOp>,
1533 LaunchConfigOpToOCLPattern<WorkitemIdXOp>,
1534 LaunchConfigOpToOCLPattern<WorkitemIdYOp>,
1535 LaunchConfigOpToOCLPattern<WorkitemIdZOp>,
1536 LaunchConfigOpToOCLPattern<WorkgroupDimXOp>,
1537 LaunchConfigOpToOCLPattern<WorkgroupDimYOp>,
1538 LaunchConfigOpToOCLPattern<WorkgroupDimZOp>,
1539 LaunchConfigOpToOCLPattern<WorkgroupIdXOp>,
1540 LaunchConfigOpToOCLPattern<WorkgroupIdYOp>,
1541 LaunchConfigOpToOCLPattern<WorkgroupIdZOp>,
1542 LaunchConfigOpToOCLPattern<GridDimXOp>,
1543 LaunchConfigOpToOCLPattern<GridDimYOp>,
1544 LaunchConfigOpToOCLPattern<GridDimZOp>,
1545 SubgroupOpWorkitemOpToOCLPattern<LaneIdOp>,
1546 SubgroupOpWorkitemOpToOCLPattern<SubgroupIdOp>,
1547 SubgroupOpWorkitemOpToOCLPattern<SubgroupSizeOp>,
1548 TruncfToOCLPattern, MMAMxToOCLPattern, AllocaToGlobalPattern>(
1549 patterns.getContext());
1550}
return success()
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
Attributes are known-constant values of operations.
Definition Attributes.h:25
MLIRContext * getContext() const
Definition Builders.h:56
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:712
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:274
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
Block & front()
Definition Region.h:65
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes={}, Type resultType={}, bool isVarArg=false, bool isReserved=false, SymbolTableCollection *symbolTables=nullptr)
Create a FuncOp with signature resultType(paramTypes) and name name`.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateXeVMToLLVMConversionPatterns(ConversionTarget &target, RewritePatternSet &patterns)
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...