MLIR 23.0.0git
XeVMToLLVM.cpp
Go to the documentation of this file.
1//===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===//
2//
3// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
10
16#include "mlir/Pass/Pass.h"
17#include "mlir/Support/LLVM.h"
18#include "llvm/ADT/ArrayRef.h"
19#include "llvm/Support/FormatVariadic.h"
20
22#include "mlir/IR/Matchers.h"
23#include "mlir/IR/Types.h"
25
26#include "llvm/ADT/TypeSwitch.h"
27
28namespace mlir {
29#define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
30#include "mlir/Conversion/Passes.h.inc"
31} // namespace mlir
32
33using namespace mlir;
34using namespace xevm;
35
36namespace {
37
38struct LLVMFuncAttributeOptions {
39 bool isConvergent = false;
40 bool isNoUnwind = false;
41 bool isWillReturn = false;
42 LLVM::MemoryEffectsAttr memEffectsAttr{};
43};
44static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
45 false, true, false, {}};
46static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
47 false, true, true, {}};
48static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
49 true, true, true, {}};
50
51std::string getTypeMangling(Type ty, bool isUnsigned = false) {
53 .Case([isUnsigned](VectorType ty) -> std::string {
54 return "Dv" + std::to_string(ty.getNumElements()) + "_" +
55 getTypeMangling(ty.getElementType(), isUnsigned);
56 })
57 .Case([](Float16Type) -> std::string { return "Dh"; })
58 .Case([](Float32Type) -> std::string { return "f"; })
59 .Case([](Float64Type) -> std::string { return "d"; })
60 .Case([isUnsigned](IntegerType ty) -> std::string {
61 switch (ty.getWidth()) {
62 case 8:
63 return isUnsigned ? "h" : "c";
64 case 16:
65 return isUnsigned ? "t" : "s";
66 case 32:
67 return isUnsigned ? "j" : "i";
68 case 64:
69 return isUnsigned ? "m" : "l";
70 default:
71 llvm_unreachable("unhandled integer type");
72 }
73 })
74 .DefaultUnreachable("unhandled type for mangling");
75}
76
77std::string mangle(StringRef baseName, ArrayRef<Type> types,
78 ArrayRef<bool> isUnsigned = {}) {
79 assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
80 "Signedness info doesn't match");
81 std::string s;
82 llvm::raw_string_ostream os(s);
83 llvm::SmallDenseMap<Type, unsigned> substitutions;
84 os << "_Z" << baseName.size() << baseName;
85 for (auto [idx, type] : llvm::enumerate(types)) {
86 auto it = substitutions.find(type);
87 if (it != substitutions.end()) {
88 os << "S";
89 // First substitution is `S_`, second is `S0_`, and so on.
90 if (unsigned firstIdx = it->getSecond(); firstIdx > 0)
91 os << firstIdx - 1;
92 os << "_";
93 } else {
94 if (!type.isIntOrFloat())
95 substitutions[type] = substitutions.size();
96 os << getTypeMangling(type, isUnsigned.empty() ? false : isUnsigned[idx]);
97 }
98 }
99 return os.str();
100}
101
102std::string builtinElemType(ElemType elemType) {
103 switch (elemType) {
104 case ElemType::BF8:
105 return "bf8";
106 case ElemType::F8:
107 return "hf8";
108 case ElemType::BF16:
109 return "bf";
110 case ElemType::F16:
111 return "hf";
112 case ElemType::F32:
113 return "f";
114 default:
115 return stringifyElemType(elemType).str();
116 }
117}
118
119static int32_t getL1CacheControl(LoadCacheControl cc) {
120 int32_t control = 0;
121 switch (cc) {
122 case LoadCacheControl::USE_DEFAULT:
123 control = -1;
124 break;
125 case LoadCacheControl::L1C_L2UC_L3UC:
126 case LoadCacheControl::L1C_L2UC_L3C:
127 case LoadCacheControl::L1C_L2C_L3UC:
128 case LoadCacheControl::L1C_L2C_L3C:
129 control = 1;
130 break;
131 case LoadCacheControl::L1S_L2UC_L3UC:
132 case LoadCacheControl::L1S_L2UC_L3C:
133 case LoadCacheControl::L1S_L2C_L3UC:
134 case LoadCacheControl::L1S_L2C_L3C:
135 control = 2;
136 break;
137 case LoadCacheControl::INVALIDATE_READ:
138 control = 3;
139 break;
140 default:
141 break;
142 }
143 return control;
144}
145
146static int32_t getL1CacheControl(StoreCacheControl cc) {
147 int32_t control = 0;
148 switch (cc) {
149 case StoreCacheControl::USE_DEFAULT:
150 control = -1;
151 break;
152 case StoreCacheControl::L1WT_L2UC_L3UC:
153 case StoreCacheControl::L1WT_L2UC_L3WB:
154 case StoreCacheControl::L1WT_L2WB_L3UC:
155 case StoreCacheControl::L1WT_L2WB_L3WB:
156 control = 1;
157 break;
158 case StoreCacheControl::L1WB_L2UC_L3UC:
159 case StoreCacheControl::L1WB_L2WB_L3UC:
160 case StoreCacheControl::L1WB_L2UC_L3WB:
161 control = 2;
162 break;
163 case StoreCacheControl::L1S_L2UC_L3UC:
164 case StoreCacheControl::L1S_L2UC_L3WB:
165 case StoreCacheControl::L1S_L2WB_L3UC:
166 case StoreCacheControl::L1S_L2WB_L3WB:
167 control = 3;
168 break;
169 default:
170 break;
171 }
172 return control;
173}
174
175static int32_t getL3CacheControl(LoadCacheControl cc) {
176 int32_t control = 0;
177 switch (cc) {
178 case LoadCacheControl::USE_DEFAULT:
179 control = -1;
180 break;
181 case LoadCacheControl::L1UC_L2UC_L3C:
182 case LoadCacheControl::L1UC_L2C_L3C:
183 case LoadCacheControl::L1C_L2UC_L3C:
184 case LoadCacheControl::L1C_L2C_L3C:
185 case LoadCacheControl::L1S_L2UC_L3C:
186 case LoadCacheControl::L1S_L2C_L3C:
187 control = 1;
188 break;
189 case LoadCacheControl::INVALIDATE_READ:
190 control = 3;
191 break;
192 default:
193 break;
194 }
195 return control;
196}
197
198static int32_t getL3CacheControl(StoreCacheControl cc) {
199 int32_t control = 0;
200 switch (cc) {
201 case StoreCacheControl::USE_DEFAULT:
202 control = -1;
203 break;
204 case StoreCacheControl::L1UC_L2UC_L3WB:
205 case StoreCacheControl::L1UC_L2WB_L3WB:
206 case StoreCacheControl::L1WT_L2UC_L3WB:
207 case StoreCacheControl::L1WT_L2WB_L3WB:
208 case StoreCacheControl::L1S_L2UC_L3WB:
209 case StoreCacheControl::L1S_L2WB_L3WB:
210 case StoreCacheControl::L1WB_L2UC_L3WB:
211 control = 2;
212 break;
213 default:
214 break;
215 }
216 return control;
217}
218
219static std::optional<LoadCacheControl> getCacheControl(PrefetchOp op) {
220 return op.getCacheControl();
221}
222
223static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) {
224 return op.getCacheControl();
225}
226
227static std::optional<LoadCacheControl> getCacheControl(BlockLoadOp op) {
228 return op.getCacheControl();
229}
230
231static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) {
232 return op.getCacheControl();
233}
234
235static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) {
236 return op.getCacheControl();
237}
238
239static std::optional<StoreCacheControl> getCacheControl(BlockStoreOp op) {
240 return op.getCacheControl();
241}
242
243static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) {
244 if (op->hasAttr("cache_control")) {
245 auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>("cache_control");
246 if (!attr)
247 return std::nullopt;
248 return std::optional<LoadCacheControl>(attr.getValue());
249 }
250 return std::nullopt;
251}
252
253static std::optional<StoreCacheControl> getCacheControl(LLVM::StoreOp op) {
254 if (op->hasAttr("cache_control")) {
255 auto attr = op->getAttrOfType<xevm::StoreCacheControlAttr>("cache_control");
256 if (!attr)
257 return std::nullopt;
258 return std::optional<StoreCacheControl>(attr.getValue());
259 }
260 return std::nullopt;
261}
262
263template <typename OpType>
264int32_t getL1CacheControl(OpType op) {
265 return getL1CacheControl(*getCacheControl(op));
266}
267
268template <typename OpType>
269int32_t getL3CacheControl(OpType op) {
270 return getL3CacheControl(*getCacheControl(op));
271}
272
273template <typename OpType>
274static std::optional<ArrayAttr>
275getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
276 if (!getCacheControl(op))
277 return {};
278
279 constexpr int32_t decorationCacheControlArity{3};
280 constexpr int32_t loadCacheControlKey{6442};
281 constexpr int32_t storeCacheControlKey{6443};
282 constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
283 std::is_same_v<OpType, BlockPrefetch2dOp> ||
284 std::is_same_v<OpType, LLVM::LoadOp> ||
285 std::is_same_v<OpType, BlockLoadOp> ||
286 std::is_same_v<OpType, PrefetchOp>;
287
288 // If the cache control is USE_DEFAULT, then we don’t emit any metadata.
289 // Assert that if one of the L1 or L3 cache control values is USE_DEFAULT
290 // (represented as -1), then both must be USE_DEFAULT; otherwise there is a
291 // bug.
292 assert(((getL1CacheControl<OpType>(op) == -1) ==
293 (getL3CacheControl<OpType>(op) == -1)) &&
294 "If one of L1 or L3 cache control is USE_DEFAULT, both must be "
295 "USE_DEFAULT");
296
297 if (getL1CacheControl<OpType>(op) == -1 &&
298 getL3CacheControl<OpType>(op) == -1)
299 return {};
300 const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
302 controlKey, 0, getL1CacheControl<OpType>(op)};
304 controlKey, 1, getL3CacheControl<OpType>(op)};
305 auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
306 auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
307
308 SmallVector<Attribute, 2> combinedAttrs = {arrayAttrL1, arrayAttrL3};
309 return rewriter.getArrayAttr(combinedAttrs);
310}
311
312//===----------------------------------------------------------------------===//
313// Cache control annotation utilities
314//
315// Instead of attaching cache control as MLIR attributes and handling them
316// during LLVM translation, we directly emit llvm.intr.ptr.annotation op in
317// MLIR.
318//===----------------------------------------------------------------------===//
319
320/// Build one cache-control payload string per attribute.
321///
322/// Each Attribute is expected to be an ArrayAttr of 3 IntegerAttr values:
323/// [SPIR-V decoration token, cache level, cache control value]
324///
325/// A single entry produces a string like: {6442:"0,1"}
326/// where the quote characters (0x22) will appear as \22 in LLVM IR textual
327/// form.
329buildCacheControlPayloads(ArrayRef<Attribute> attrs) {
331 llvm::StringMap<bool> seen;
332
333 for (Attribute a : attrs) {
334 auto arr = dyn_cast<ArrayAttr>(a);
335 if (!arr)
336 continue;
337
338 auto vals = arr.getValue();
339 assert(vals.size() == 3 &&
340 "Expected exactly 3 integer values (Token, CacheLevel, "
341 "ControlValue) in cache control attribute.");
342
343 auto tokenAttr = dyn_cast<IntegerAttr>(vals[0]);
344 auto secondAttr = dyn_cast<IntegerAttr>(vals[1]);
345 auto thirdAttr = dyn_cast<IntegerAttr>(vals[2]);
346
347 if (!tokenAttr || !secondAttr || !thirdAttr)
348 continue;
349
350 // Produce: {SPIR-V decoration token:"L1 cache control,L3 cache control"}
351 // The quote char (0x22) is embedded literally; LLVM IR prints it as \22.
352 std::string entry =
353 llvm::formatv("{{{0}:\"{1},{2}\"}", tokenAttr.getValue().getZExtValue(),
354 secondAttr.getValue().getZExtValue(),
355 thirdAttr.getValue().getZExtValue());
356
357 // Deduplicate identical annotations.
358 if (!seen.insert({entry, true}).second)
359 continue;
360
361 payloads.push_back(std::move(entry));
362 }
363 return payloads;
364}
365/// Counter for generating unique global variable names.
366static std::atomic<uint64_t> globalNameCounter{0};
367
368/// Get or create a global metadata string and return a !llvm.ptr<1> value
369/// pointing to it. The AddressOfOp is created at the current rewriter
370/// insertion point; the GlobalOp is created at the module start.
371static Value createMetadataStringPtr(ConversionPatternRewriter &rewriter,
372 Operation *moduleOp, Location loc,
373 StringRef value, StringRef nameHint) {
374 // Build null-terminated string.
375 std::string strWithNull = value.str();
376 strWithNull.push_back('\0');
377 StringRef strRef(strWithNull.data(), strWithNull.size());
378
379 auto as1PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
380
381 // Search for an existing global with the same content.
382 for (auto &op : moduleOp->getRegion(0).front()) {
383 if (auto existingGlobal = dyn_cast<LLVM::GlobalOp>(&op)) {
384 if (!existingGlobal.getSection() ||
385 *existingGlobal.getSection() != "llvm.metadata")
386 continue;
387 if (auto strAttr =
388 dyn_cast_or_null<StringAttr>(existingGlobal.getValueOrNull())) {
389 if (strAttr.getValue() == strRef) {
390 return LLVM::AddressOfOp::create(rewriter, loc, as1PtrTy,
391 existingGlobal.getSymName());
392 }
393 }
394 }
395 }
396
397 // Create new global at module start.
398 auto i8Type = rewriter.getI8Type();
399 auto arrayType = LLVM::LLVMArrayType::get(i8Type, strWithNull.size());
400 std::string globalName =
401 llvm::formatv("{0}.{1}", nameHint,
402 globalNameCounter.fetch_add(1, std::memory_order_relaxed))
403 .str();
404
405 {
406 OpBuilder::InsertionGuard guard(rewriter);
407 rewriter.setInsertionPointToStart(&moduleOp->getRegion(0).front());
408
409 auto globalOp =
410 LLVM::GlobalOp::create(rewriter, loc, arrayType,
411 /*isConstant=*/true, LLVM::Linkage::Private,
412 globalName, rewriter.getStringAttr(strRef));
413 globalOp.setSection(StringRef("llvm.metadata"));
414 globalOp.setUnnamedAddr(LLVM::UnnamedAddr::Global);
415 globalOp.setAlignment(1);
416 globalOp.setAddrSpace(1);
417 }
418 // InsertionGuard restores the original insertion point here.
419
420 return LLVM::AddressOfOp::create(rewriter, loc, as1PtrTy, globalName);
421}
422
423/// Annotate a pointer value with cache control metadata by emitting chained
424/// `llvm.intr.ptr.annotation` ops (LLVM::PtrAnnotation).
425///
426/// This is the MLIR-level equivalent of handleDecorationCacheControl() from
427/// the LLVM translation layer. For each cache control attribute, it emits:
428///
429/// %ann = llvm.intr.ptr.annotation %ptr, @".str.cachecontrol.N",
430/// @".str.file.N", 0, null : !llvm.ptr<AS>
431///
432/// Multiple annotations are chained: the result of each annotation op is
433/// fed as the pointer input to the next one.
434///
435/// \param rewriter The pattern rewriter.
436/// \param loc Source location for created ops.
437/// \param ptr The pointer value to annotate.
438/// \param cacheControls The cache control ArrayAttr (from
439/// getCacheControlMetadata).
440/// \param moduleOp The enclosing module (for creating globals).
441/// \returns The annotated pointer value (or the original ptr if no
442/// annotations).
443static Value annotatePtrWithCacheControl(ConversionPatternRewriter &rewriter,
444 Location loc, Value ptr,
445 ArrayAttr cacheControls,
446 Operation *moduleOp) {
447 SmallVector<std::string> payloads =
448 buildCacheControlPayloads(cacheControls.getValue());
449 if (payloads.empty())
450 return ptr;
451
452 auto ptrType = cast<LLVM::LLVMPointerType>(ptr.getType());
453 auto as1PtrTy = LLVM::LLVMPointerType::get(rewriter.getContext(), 1);
454 auto i32Ty = rewriter.getI32Type();
455
456 // Create shared constants for all annotations on this pointer.
457 Value fileStr =
458 createMetadataStringPtr(rewriter, moduleOp, loc, "", ".str.file");
459 Value lineVal = LLVM::ConstantOp::create(rewriter, loc, i32Ty, 0);
460 Value nullAS1 = LLVM::ZeroOp::create(rewriter, loc, as1PtrTy);
461
462 // Chain: each annotation takes the result of the previous one as its
463 // pointer operand.
464 Value curPtr = ptr;
465 for (const std::string &payload : payloads) {
466 Value annStr = createMetadataStringPtr(rewriter, moduleOp, loc, payload,
467 ".str.cachecontrol");
468 auto annOp = LLVM::PtrAnnotation::create(rewriter, loc, ptrType, curPtr,
469 annStr, fileStr, lineVal, nullAS1);
470 curPtr = annOp.getResult();
471 }
472
473 return curPtr;
474}
475
476/// Helper to apply cache control annotation on a pointer operand of a call.
477/// Replaces the pointer argument of the call with an annotated version.
478///
479/// For operations that produce a call (like block load/store/prefetch), the
480/// pointer is typically the first argument. This function:
481/// 1. Builds the annotation chain on the pointer.
482/// 2. Replaces the pointer operand in the provided args list.
483///
484/// \param rewriter The pattern rewriter.
485/// \param loc Source location.
486/// \param ptr The original pointer value (first arg to the call).
487/// \param cacheControls The cache control metadata.
488/// \param moduleOp The enclosing module.
489/// \param args The argument list (modified in place: args[ptrIdx] is
490/// replaced).
491/// \param ptrIdx Index of the pointer in the args list (default 0).
492template <typename OpType>
493static void
494applyCacheControlAnnotation(ConversionPatternRewriter &rewriter, Location loc,
495 OpType op, SmallVectorImpl<Value> &args,
496 Operation *moduleOp, unsigned ptrIdx = 0) {
497 std::optional<ArrayAttr> optCacheControls =
498 getCacheControlMetadata(rewriter, op);
499 if (!optCacheControls)
500 return;
501
502 Value annotatedPtr = annotatePtrWithCacheControl(rewriter, loc, args[ptrIdx],
503 *optCacheControls, moduleOp);
504 args[ptrIdx] = annotatedPtr;
505}
506
507//===----------------------------------------------------------------------===//
508// End cache control annotation utilities
509//===----------------------------------------------------------------------===//
510
511static LLVM::CallOp createDeviceFunctionCall(
512 ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
513 ArrayRef<Type> argTypes, ArrayRef<Value> args,
514 mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs,
515 LLVMFuncAttributeOptions funcAttributeOptions, Operation *op) {
516 auto *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
517 assert(moduleOp && "Expecting module");
518 Location loc = op->getLoc();
519
520 auto funcOpRes =
521 LLVM::lookupOrCreateFn(rewriter, moduleOp, funcName, argTypes, retType);
522 assert(!failed(funcOpRes));
523 LLVM::LLVMFuncOp funcOp = funcOpRes.value();
524 funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
525 funcOp.setConvergent(funcAttributeOptions.isConvergent);
526 funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
527 funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
528
529 if (funcAttributeOptions.memEffectsAttr)
530 funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
531
532 for (auto [idx, attrName] : paramAttrs)
533 funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
534
535 auto callOp = LLVM::CallOp::create(rewriter, loc, funcOp, args);
536 callOp->setAttrs(funcOp->getAttrs());
537
538 return callOp;
539}
540
541static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
542 switch (pTy) {
543 case xevm::ElemType::F32:
544 case xevm::ElemType::TF32:
545 return 1;
546 case xevm::ElemType::BF16:
547 case xevm::ElemType::F16:
548 return 2;
549 case xevm::ElemType::U8:
550 case xevm::ElemType::S8:
551 case xevm::ElemType::BF8:
552 case xevm::ElemType::F8:
553 return 4;
554 case xevm::ElemType::E2M1:
555 case xevm::ElemType::U4:
556 case xevm::ElemType::S4:
557 return 8;
558 default:
559 llvm_unreachable("unsupported xevm::ElemType");
560 }
561}
562
563class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
564 using OpConversionPattern::OpConversionPattern;
565 LogicalResult
566 matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor,
567 ConversionPatternRewriter &rewriter) const override {
568 if (!op.getC()) {
569 return rewriter.notifyMatchFailure(op, "OCL requires C operand");
570 }
571 auto precisionA = op.getTypes().getA();
572 auto precisionB = op.getTypes().getB();
573 auto precisionC = op.getTypes().getC();
574 auto precisionD = op.getTypes().getD();
575 if (precisionC != precisionD) {
576 return rewriter.notifyMatchFailure(op, "type of C and D need to match");
577 }
578 if (precisionC != xevm::ElemType::S32 &&
579 precisionC != xevm::ElemType::F32 &&
580 precisionC != xevm::ElemType::F16 &&
581 precisionC != xevm::ElemType::BF16) {
582 return rewriter.notifyMatchFailure(
583 op, "type of C and D must be S32, F32, F16 or BF16");
584 }
585 if (precisionA == xevm::ElemType::S32 ||
586 precisionA == xevm::ElemType::F32) {
587 return rewriter.notifyMatchFailure(op, "type of A cannot be S32 or F32");
588 }
589 if (precisionB == xevm::ElemType::S32 ||
590 precisionB == xevm::ElemType::F32) {
591 return rewriter.notifyMatchFailure(op, "type of B cannot be S32 or F32");
592 }
593 constexpr uint32_t bitWidthPackedA{16};
594 constexpr uint32_t bitWidthPackedB{32};
595 auto loc = op.getLoc();
596
597 auto castIfNeeded = [&](Value val, Type packedType) -> Value {
598 VectorType origTy = cast<VectorType>(val.getType());
599 const uint32_t vecBitSize =
600 origTy.getNumElements() *
601 origTy.getElementType().getIntOrFloatBitWidth();
602 VectorType newTy = VectorType::get(
603 vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
604 if (origTy != newTy)
605 val = LLVM::BitcastOp::create(rewriter, loc, newTy, val);
606 return val;
607 };
608
609 Value a = op.getA();
610 Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
611 ? cast<Type>(rewriter.getF32Type())
612 : rewriter.getIntegerType(bitWidthPackedA);
613 a = castIfNeeded(a, packedAType);
614
615 Value b = op.getB();
616 Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
617 ? cast<Type>(rewriter.getF32Type())
618 : rewriter.getIntegerType(bitWidthPackedB);
619 b = castIfNeeded(b, packedBType);
620
621 Value c = op.getC();
622 VectorType cOrigTy = cast<VectorType>(c.getType());
623 VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
624 assert(cOrigTy == resOrigTy && "Accumulator and result type mismatch");
625 // OCL builtins encode bfloat16 as int16
626 VectorType cTy =
627 cOrigTy.getElementType().isBF16()
628 ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
629 : cOrigTy;
630 VectorType resTy = cTy;
631 if (cOrigTy != cTy)
632 c = LLVM::BitcastOp::create(rewriter, loc, cTy, c);
633
634 constexpr int32_t systolicDepth{8};
635 std::string fnName =
636 llvm::formatv("intel_sub_group_{0}_{1}_matrix_mad_k{2}",
637 stringifyElemType(op.getTypes().getA()).str(),
638 stringifyElemType(op.getTypes().getB()).str(),
639 systolicDepth *
640 getNumOperandsPerDword(op.getTypes().getA()))
641 .str();
642 SmallVector<Type> argTypes{a.getType(), b.getType(), cTy};
643 fnName = mangle(fnName, argTypes);
644 SmallVector<Value> args{a, b, c};
645
646 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
647 /*other=*/LLVM::ModRefInfo::NoModRef,
648 /*argMem=*/LLVM::ModRefInfo::NoModRef,
649 /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
650 /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
651 /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
652 /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
653 auto funcAttrs = convergentNoUnwindWillReturnAttrs;
654 funcAttrs.memEffectsAttr = memAttr;
655 Value result =
656 createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
657 funcAttrs, op.getOperation())
658 ->getResult(0);
659
660 if (resOrigTy != resTy)
661 result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy, result);
662
663 rewriter.replaceOp(op, result);
664 return success();
665 }
666};
667
668class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
669 using OpConversionPattern::OpConversionPattern;
670 LogicalResult
671 matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor,
672 ConversionPatternRewriter &rewriter) const override {
673 auto loc = op.getLoc();
674 auto *moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
675
676 const std::string fnName{"_Z8prefetchPU3AS1Kcm"};
677 Value one =
678 LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), 1);
679 SmallVector<Value> args{op.getPtr(), one};
680
681 // Annotate pointer with cache control before passing to the call.
682 applyCacheControlAnnotation(rewriter, loc, op, args, moduleOp,
683 /*ptrIdx=*/0);
684
685 SmallVector<Type> argTypes;
686 for (auto arg : args)
687 argTypes.push_back(arg.getType());
688 auto funcAttr = noUnwindAttrs;
689 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
690 /*other=*/LLVM::ModRefInfo::NoModRef,
691 /*argMem=*/LLVM::ModRefInfo::Ref,
692 /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
693 /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
694 /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
695 /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
696 funcAttr.memEffectsAttr = memAttr;
697
698 createDeviceFunctionCall(rewriter, fnName,
699 LLVM::LLVMVoidType::get(rewriter.getContext()),
700 argTypes, args, {}, funcAttr, op.getOperation());
701 rewriter.eraseOp(op);
702 return success();
703 }
704};
705
706class MemfenceToOCLPattern : public OpConversionPattern<MemfenceOp> {
707 using OpConversionPattern::OpConversionPattern;
708 LogicalResult
709 matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor,
710 ConversionPatternRewriter &rewriter) const override {
711 auto loc = op.getLoc();
712 const std::string fnName{"atomic_work_item_fence"};
713 int memScope, addrSpace;
714 switch (op.getAddrspace()) {
715 case xevm::AddrSpace::SHARED:
716 addrSpace = 1; // CLK_LOCAL_MEM_FENCE
717 break;
718 case xevm::AddrSpace::GLOBAL:
719 addrSpace = 2; // CLK_GLOBAL_MEM_FENCE
720 break;
721 default:
722 // GENERIC is not supported in OpenCL
723 return rewriter.notifyMatchFailure(
724 op, "Fence only supports global and shared address spaces.");
725 }
726 switch (op.getScope()) {
727 case xevm::MemScope::WORKGROUP:
728 memScope = 1;
729 break;
730 case xevm::MemScope::DEVICE:
731 memScope = 2;
732 break;
733 default:
734 // CLUSTER and SYSTEM are not supported in OpenCL
735 return rewriter.notifyMatchFailure(
736 op, "Fence only supports workgroup and device memory scopes.");
737 }
738 Type i32Type = rewriter.getI32Type();
739 Value acqRel = LLVM::ConstantOp::create(rewriter, loc, i32Type, 4);
740 Value memScopeConst =
741 LLVM::ConstantOp::create(rewriter, loc, i32Type, memScope);
742 Value addrSpaceConst =
743 LLVM::ConstantOp::create(rewriter, loc, i32Type, addrSpace);
744 SmallVector<Value> args{addrSpaceConst, acqRel, memScopeConst};
745 SmallVector<Type> argTypes{3, i32Type};
746 createDeviceFunctionCall(rewriter, mangle(fnName, argTypes),
747 LLVM::LLVMVoidType::get(rewriter.getContext()),
748 argTypes, args, {}, noUnwindAttrs,
749 op.getOperation());
750 rewriter.eraseOp(op);
751 return success();
752 }
753};
754template <typename OpType>
755class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
756 using OpConversionPattern<OpType>::OpConversionPattern;
757 LogicalResult
758 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
759 ConversionPatternRewriter &rewriter) const override {
760 constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp>;
761 constexpr bool isPrefetch = std::is_same_v<OpType, BlockPrefetch2dOp>;
762
763 auto loc = op.getLoc();
764 auto *moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
765 VectorType vecType;
766 bool packReg = false;
767 bool transpose = false;
768 if constexpr (isLoad) {
769 vecType = op.getRes().getType();
770 packReg = op.getPackRegister();
771 transpose = op.getTranspose();
772 } else if constexpr (!isPrefetch) {
773 vecType = op.getStoredVal().getType();
774 }
775
776 auto i32Type = rewriter.getI32Type();
777 Value byteCoord =
778 LLVM::UndefOp::create(rewriter, loc, VectorType::get(2, i32Type));
779 Value zero = LLVM::ConstantOp::create(rewriter, loc, i32Type, 0);
780 Value one = LLVM::ConstantOp::create(rewriter, loc, i32Type, 1);
781 byteCoord = LLVM::InsertElementOp::create(
782 rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
783 byteCoord = LLVM::InsertElementOp::create(
784 rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
785 SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
786 op.getBasePitch(), byteCoord};
787
788 // Annotate pointer (args[0]) with cache control before the call.
789 applyCacheControlAnnotation(rewriter, loc, op, args, moduleOp,
790 /*ptrIdx=*/0);
791
792 SmallVector<Type> retTypes;
793 Value spvLoadDstPtr;
794 std::string funcName{"intel_sub_group_2d_block_"};
795 std::string bitWidthId;
796 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
797 SmallVector<std::pair<unsigned, StringRef>, 4> paramAttrs;
798 if constexpr (isPrefetch) { // Prefetch
799 funcName += "prefetch";
800 paramAttrs = {std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName())};
801 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
802 /*other=*/LLVM::ModRefInfo::NoModRef,
803 /*argMem=*/LLVM::ModRefInfo::Ref,
804 /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
805 /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
806 /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
807 /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
808 funcAttr = noUnwindAttrs;
809 funcAttr.memEffectsAttr = memAttr;
810 } else {
811 auto vecElemType = vecType.getElementType();
812 auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth();
813 auto vecNumElems = vecType.getNumElements();
814 // OpenCL Intel 2D block load has a special case
815 // when element bit size is 8 and tile width is 32, which is twice
816 // the subgroup size, loaded element is packed as i16.
817 // To reflect this, element bit size is updated to 16 and
818 // vector length is reduced by half.
819 if (op.getElemSizeInBits() == 8 && op.getTileWidth() == 32) {
820 vecElemBitWidth = 16;
821 vecElemType = rewriter.getI16Type();
822 vecNumElems = vecNumElems / 2;
823 }
824 Value numElems =
825 LLVM::ConstantOp::create(rewriter, loc, i32Type, vecNumElems);
826 auto dstOrSrcPtr = LLVM::AllocaOp::create(
827 rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()),
828 vecElemType, numElems);
829 args.push_back(dstOrSrcPtr);
830 if constexpr (isLoad) { // Load
831 funcName += "read";
832 bitWidthId = getTypeMangling(vecElemType, /*isUnsigned=*/true);
833 if (packReg)
834 funcName += "_transform";
835 else if (transpose)
836 funcName += "_transpose";
837 spvLoadDstPtr = dstOrSrcPtr;
838 retTypes.push_back(vecType);
839 paramAttrs = {
840 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
841 std::make_pair(0, LLVM::LLVMDialect::getReadonlyAttrName()),
842 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
843 std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()),
844 };
845 } else { // Store
846 funcName += "write";
847 bitWidthId = (vecElemBitWidth == 32)
848 ? "j"
849 : ((vecElemBitWidth == 16) ? "t" : "h");
850 LLVM::StoreOp::create(rewriter, loc, op.getStoredVal(), dstOrSrcPtr);
851 paramAttrs = {
852 std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
853 std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()),
854 std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
855 std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()),
856 };
857 }
858 }
859
860 funcName =
861 llvm::formatv("{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(),
862 op.getTileHeight(), op.getTileWidth(), op.getVBlocks())
863 .str();
864 std::string prefetchCode("");
865 if (!isPrefetch)
866 prefetchCode += "P";
867 funcName = llvm::formatv("_Z{0}{1}PU3AS1viiiDv2_i{2}{3}", funcName.size(),
868 funcName, prefetchCode, bitWidthId)
869 .str();
870 SmallVector<Type> argTypes;
871 for (auto arg : args) {
872 argTypes.push_back(arg.getType());
873 }
874 createDeviceFunctionCall(
875 rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
876 argTypes, args, paramAttrs, funcAttr, op.getOperation());
877
878 if constexpr (isLoad)
879 rewriter.replaceOp(
880 op, LLVM::LoadOp::create(rewriter, loc, vecType, spvLoadDstPtr));
881 else
882 rewriter.eraseOp(op);
883 return success();
884 }
885};
886
887template <typename OpType>
888class BlockLoadStore1DToOCLPattern : public OpConversionPattern<OpType> {
889 using OpConversionPattern<OpType>::OpConversionPattern;
890 LogicalResult
891 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
892 ConversionPatternRewriter &rewriter) const override {
893 constexpr bool isStore = std::is_same_v<OpType, xevm::BlockStoreOp>;
894 auto loc = op.getLoc();
895 auto *moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
896
897 // Get OpenCL function name
898 // https://registry.khronos.org/OpenCL/extensions/
899 // intel/cl_intel_subgroup_local_block_io.html
900 std::string funcName{"intel_sub_group_block_"};
901 // Value or Result type can be vector or scalar
902 Type valOrResTy;
903 if constexpr (isStore) {
904 funcName += "write_u";
905 valOrResTy = op.getVal().getType();
906 } else {
907 funcName += "read_u";
908 valOrResTy = op.getType();
909 }
910 // Get element type of the vector/scalar
911 VectorType vecTy = dyn_cast<VectorType>(valOrResTy);
912 Type elemType = vecTy ? vecTy.getElementType() : valOrResTy;
913 funcName += getTypeMangling(elemType);
914 if (vecTy)
915 funcName += std::to_string(vecTy.getNumElements());
916 SmallVector<Type, 2> argTypes{};
917 // XeVM BlockLoad/StoreOp always use signless integer types
918 // but OpenCL builtins expect unsigned types
919 // use unsigned types for mangling
920 SmallVector<bool, 2> isUnsigned{};
921 // arg0: pointer to the src/dst address
922 // arg1 - only if store : vector to store
923 // Prepare arguments
924 SmallVector<Value, 2> args{};
925 args.push_back(op.getPtr());
926 argTypes.push_back(op.getPtr().getType());
927 isUnsigned.push_back(true);
928
929 // Annotate pointer (args[0]) with cache control.
930 applyCacheControlAnnotation(rewriter, loc, op, args, moduleOp,
931 /*ptrIdx=*/0);
932 // Update argTypes[0] in case the pointer type changed (it shouldn't
933 // change type, but the value is now the annotated pointer).
934 argTypes[0] = args[0].getType();
935
936 Type retType;
937 if constexpr (isStore) {
938 args.push_back(op.getVal());
939 argTypes.push_back(op.getVal().getType());
940 isUnsigned.push_back(true);
941 retType = LLVM::LLVMVoidType::get(rewriter.getContext());
942 } else {
943 retType = valOrResTy;
944 }
945 funcName = std::string("_Z") + std::to_string(funcName.size()) + funcName +
946 "PU3AS" +
947 std::to_string(op.getPtr().getType().getAddressSpace());
948 funcName += getTypeMangling(elemType, /*isUnsigned=*/true);
949 if constexpr (isStore)
950 funcName += getTypeMangling(valOrResTy, /*isUnsigned=*/true);
951 LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
952
953 LLVM::CallOp call =
954 createDeviceFunctionCall(rewriter, funcName, retType, argTypes, args,
955 {}, funcAttr, op.getOperation());
956
957 if constexpr (isStore)
958 rewriter.eraseOp(op);
959 else
960 rewriter.replaceOp(op, call->getResult(0));
961 return success();
962 }
963};
964
965template <typename OpType>
966class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> {
967 using OpConversionPattern<OpType>::OpConversionPattern;
968 LogicalResult
969 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
970 ConversionPatternRewriter &rewriter) const override {
971 if (!op->hasAttr("cache_control"))
972 return failure();
973
974 auto *moduleOp = op->template getParentWithTrait<OpTrait::SymbolTable>();
975 std::optional<ArrayAttr> optCacheControls =
976 getCacheControlMetadata(rewriter, op);
977 if (!optCacheControls) {
978 rewriter.modifyOpInPlace(op, [&]() { op->removeAttr("cache_control"); });
979 return success();
980 }
981
982 // Determine which operand is the pointer.
983 constexpr bool isStore = std::is_same_v<OpType, LLVM::StoreOp>;
984 unsigned ptrIdx = isStore ? 1 : 0;
985 Value ptr = op->getOperand(ptrIdx);
986
987 // Emit annotation intrinsic calls on the pointer.
988 Value annotatedPtr = annotatePtrWithCacheControl(
989 rewriter, op->getLoc(), ptr, *optCacheControls, moduleOp);
990
991 // Replace the pointer operand with the annotated one.
992 rewriter.modifyOpInPlace(op, [&]() {
993 op->setOperand(ptrIdx, annotatedPtr);
994 op->removeAttr("cache_control");
995 });
996 return success();
997 }
998};
999
1000//===----------------------------------------------------------------------===//
1001// GPU index id operations
1002//===----------------------------------------------------------------------===//
1003/*
1004// Launch Config ops
1005// dimidx - x, y, z - is fixed to i32
1006// return type is set by XeVM type converter
1007// get_local_id
1008xevm::WorkitemIdXOp;
1009xevm::WorkitemIdYOp;
1010xevm::WorkitemIdZOp;
1011// get_local_size
1012xevm::WorkgroupDimXOp;
1013xevm::WorkgroupDimYOp;
1014xevm::WorkgroupDimZOp;
1015// get_group_id
1016xevm::WorkgroupIdXOp;
1017xevm::WorkgroupIdYOp;
1018xevm::WorkgroupIdZOp;
1019// get_num_groups
1020xevm::GridDimXOp;
1021xevm::GridDimYOp;
1022xevm::GridDimZOp;
1023// get_global_id : to be added if needed
1024*/
1025
1026// Helpers to get the OpenCL function name and dimension argument for each op.
1027static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdXOp) {
1028 return {"get_local_id", 0};
1029}
1030static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdYOp) {
1031 return {"get_local_id", 1};
1032}
1033static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdZOp) {
1034 return {"get_local_id", 2};
1035}
1036static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimXOp) {
1037 return {"get_local_size", 0};
1038}
1039static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimYOp) {
1040 return {"get_local_size", 1};
1041}
1042static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimZOp) {
1043 return {"get_local_size", 2};
1044}
1045static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdXOp) {
1046 return {"get_group_id", 0};
1047}
1048static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdYOp) {
1049 return {"get_group_id", 1};
1050}
1051static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdZOp) {
1052 return {"get_group_id", 2};
1053}
1054static std::pair<StringRef, int64_t> getConfig(xevm::GridDimXOp) {
1055 return {"get_num_groups", 0};
1056}
1057static std::pair<StringRef, int64_t> getConfig(xevm::GridDimYOp) {
1058 return {"get_num_groups", 1};
1059}
1060static std::pair<StringRef, int64_t> getConfig(xevm::GridDimZOp) {
1061 return {"get_num_groups", 2};
1062}
1063/// Replace `xevm.*` with an `llvm.call` to the corresponding OpenCL func with
1064/// a constant argument for the dimension - x, y or z.
1065template <typename OpType>
1066class LaunchConfigOpToOCLPattern : public OpConversionPattern<OpType> {
1067 using OpConversionPattern<OpType>::OpConversionPattern;
1068 LogicalResult
1069 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
1070 ConversionPatternRewriter &rewriter) const override {
1071 Location loc = op->getLoc();
1072 auto [baseName, dim] = getConfig(op);
1073 Type dimTy = rewriter.getI32Type();
1074 Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
1075 static_cast<int64_t>(dim));
1076 std::string func = mangle(baseName, {dimTy}, {true});
1077 Type resTy = op.getType();
1078 auto call =
1079 createDeviceFunctionCall(rewriter, func, resTy, {dimTy}, {dimVal}, {},
1080 noUnwindWillReturnAttrs, op.getOperation());
1081 constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
1082 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
1083 /*other=*/noModRef,
1084 /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef,
1085 /*errnoMem=*/noModRef,
1086 /*targetMem0=*/noModRef,
1087 /*targetMem1=*/noModRef);
1088 call.setMemoryEffectsAttr(memAttr);
1089 rewriter.replaceOp(op, call);
1090 return success();
1091 }
1092};
1093
1094/*
1095// Subgroup ops
1096// get_sub_group_local_id
1097xevm::LaneIdOp;
1098// get_sub_group_id
1099xevm::SubgroupIdOp;
1100// get_sub_group_size
1101xevm::SubgroupSizeOp;
1102// get_num_sub_groups : to be added if needed
1103*/
1104
1105// Helpers to get the OpenCL function name for each op.
1106static StringRef getConfig(xevm::LaneIdOp) { return "get_sub_group_local_id"; }
1107static StringRef getConfig(xevm::SubgroupIdOp) { return "get_sub_group_id"; }
1108static StringRef getConfig(xevm::SubgroupSizeOp) {
1109 return "get_sub_group_size";
1110}
1111template <typename OpType>
1112class SubgroupOpWorkitemOpToOCLPattern : public OpConversionPattern<OpType> {
1113 using OpConversionPattern<OpType>::OpConversionPattern;
1114 LogicalResult
1115 matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
1116 ConversionPatternRewriter &rewriter) const override {
1117 std::string func = mangle(getConfig(op).str(), {});
1118 Type resTy = op.getType();
1119 auto call =
1120 createDeviceFunctionCall(rewriter, func, resTy, {}, {}, {},
1121 noUnwindWillReturnAttrs, op.getOperation());
1122 constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
1123 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
1124 /*other=*/noModRef,
1125 /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef,
1126 /*errnoMem=*/noModRef,
1127 /*targetMem0=*/noModRef,
1128 /*targetMem1=*/noModRef);
1129 call.setMemoryEffectsAttr(memAttr);
1130 rewriter.replaceOp(op, call);
1131 return success();
1132 }
1133};
1134
1135class TruncfToOCLPattern : public OpConversionPattern<TruncfOp> {
1136 using OpConversionPattern::OpConversionPattern;
1137 LogicalResult
1138 matchAndRewrite(TruncfOp op, TruncfOp::Adaptor adaptor,
1139 ConversionPatternRewriter &rewriter) const override {
1140 // Supported source and result types are resticted for now.
1141 auto srcEtype = op.getSrcEtype().getEtype();
1142 auto dstEtype = op.getDstEtype().getEtype();
1143 // Currently only 16 input elements are supported as
1144 // - Any vector beyond 16 elements not a valid OpenCL vector.
1145 // - 2D block load can only load up to 16 16bit elements per lane.
1146 // Widest load is 8x16xi32 with 16 lanes, which is 16 16bit
1147 // elements per lane.
1148 // - mma_mx A and B operands need more than 16 elements per lane
1149 //
1150 // Conversion is done in batches depending on the dst type.
1151 // batch_size =
1152 // 16 if dst type == fp8
1153 // 8 if dst type == fp4
1154 // For num_elem > batch_size
1155 // convert batch of batch_size
1156 // cast batch to i32 elem type vector
1157 // concat batches by shufflevector
1158 // For num_elem = batch_size
1159 // use API for conversion
1160 // Scalar case is not supported until usage case become clear.
1161 auto vecSrcTy = dyn_cast<VectorType>(op.getSrc().getType());
1162 if (!vecSrcTy) {
1163 return rewriter.notifyMatchFailure(op, "Scalar src is not supported.");
1164 }
1165 if (vecSrcTy.getNumElements() != 16)
1166 return rewriter.notifyMatchFailure(
1167 op, "Only vector src of 16 elements is supported");
1168 auto vecDstTy = dyn_cast<VectorType>(op.getDst().getType());
1169 if (!vecDstTy)
1170 return rewriter.notifyMatchFailure(op, "Scalar dst is not supported.");
1171 Value src = op.getSrc();
1172 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
1173 /*other=*/LLVM::ModRefInfo::NoModRef,
1174 /*argMem=*/LLVM::ModRefInfo::NoModRef,
1175 /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
1176 /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
1177 /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
1178 /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
1179 auto funcAttrs = convergentNoUnwindWillReturnAttrs;
1180 funcAttrs.memEffectsAttr = memAttr;
1181
1182 // Handle the case where dst type is fp4 first.
1183 if (dstEtype == TruncfDstElemTypes::E2M1) {
1184 // Convert 8 elements at a time.
1185 // To convert 8 elements, vector<8xf16>:
1186 // Use:
1187 // uint __builtin_IB_dnscl_hf16(uint, uint, 1, 0)
1188 // uint __builtin_IB_dnscl_hf16(uint, uint, 1, 3)
1189 // llvm.or
1190 Value cast = LLVM::BitcastOp::create(
1191 rewriter, op.getLoc(), VectorType::get(8, rewriter.getI32Type()),
1192 src);
1193
1194 std::string fnName = "__builtin_IB_dnscl_";
1195 fnName += (srcEtype == TruncfSrcElemTypes::F16) ? "hf16" : "bf16";
1196 auto genDnscl = [&](Value input, Value idx0, Value idx1, Value dstTy,
1197 Value mode) -> Value {
1198 Value arg1 =
1199 LLVM::ExtractElementOp::create(rewriter, op.getLoc(), input, idx0)
1200 ->getResult(0);
1201 Value arg2 =
1202 LLVM::ExtractElementOp::create(rewriter, op.getLoc(), input, idx1)
1203 ->getResult(0);
1204 SmallVector<Type> argTypes{arg1.getType(), arg2.getType(),
1205 dstTy.getType(), mode.getType()};
1206 SmallVector<Value> args{arg1, arg2, dstTy, mode};
1207 Value dnscl = createDeviceFunctionCall(
1208 rewriter, fnName, rewriter.getI32Type(), argTypes,
1209 args, {}, funcAttrs, op.getOperation())
1210 ->getResult(0);
1211 return dnscl;
1212 };
1213
1214 Value zero = LLVM::ConstantOp::create(rewriter, op.getLoc(),
1215 rewriter.getI32Type(), 0);
1216 Value one = LLVM::ConstantOp::create(rewriter, op.getLoc(),
1217 rewriter.getI32Type(), 1);
1218 Value two = LLVM::ConstantOp::create(rewriter, op.getLoc(),
1219 rewriter.getI32Type(), 2);
1220 Value three = LLVM::ConstantOp::create(rewriter, op.getLoc(),
1221 rewriter.getI32Type(), 3);
1222 Value even = genDnscl(cast, zero, two, one, zero);
1223 Value odd = genDnscl(cast, one, three, one, two);
1224 Value firstHalf = LLVM::OrOp::create(rewriter, op.getLoc(), even, odd);
1225 Value four = LLVM::ConstantOp::create(rewriter, op.getLoc(),
1226 rewriter.getI32Type(), 4);
1227 Value five = LLVM::ConstantOp::create(rewriter, op.getLoc(),
1228 rewriter.getI32Type(), 5);
1229 Value six = LLVM::ConstantOp::create(rewriter, op.getLoc(),
1230 rewriter.getI32Type(), 6);
1231 Value seven = LLVM::ConstantOp::create(rewriter, op.getLoc(),
1232 rewriter.getI32Type(), 7);
1233 even = genDnscl(cast, four, six, one, zero);
1234 odd = genDnscl(cast, five, seven, one, two);
1235 Value secondHalf = LLVM::OrOp::create(rewriter, op.getLoc(), even, odd);
1236 // Create vector<2xi32> from two i32 values and then bitcast to
1237 // vector<8xi8> to match the dst type.
1238 Value combined = LLVM::UndefOp::create(
1239 rewriter, op.getLoc(), VectorType::get(2, rewriter.getI32Type()));
1240 combined = LLVM::InsertElementOp::create(rewriter, op.getLoc(), combined,
1241 firstHalf, zero)
1242 ->getResult(0);
1243 combined = LLVM::InsertElementOp::create(rewriter, op.getLoc(), combined,
1244 secondHalf, one)
1245 ->getResult(0);
1246 Value result =
1247 LLVM::BitcastOp::create(rewriter, op.getLoc(), vecDstTy, combined);
1248 rewriter.replaceOp(op, result);
1249 return success();
1250 }
1251
1252 // Handle the case where dst type is fp8.
1253 // BF16 type needs some preprocessing before conversion,
1254 // First extended to F32 and then truncated to F16.
1255 if (srcEtype == TruncfSrcElemTypes::BF16) {
1256 // Step 1: Extend to F32
1257 // Use float16 __builtin_IB_bftof_16(short16)
1258 src = LLVM::BitcastOp::create(
1259 rewriter, op.getLoc(),
1260 VectorType::get(vecSrcTy.getShape(), rewriter.getI16Type()), src);
1261 std::string fnName = "__builtin_IB_bftof_16";
1262 SmallVector<Type> argTypes{src.getType()};
1263 SmallVector<Value> args{src};
1264 Type resTy = VectorType::get(vecSrcTy.getShape(), rewriter.getF32Type());
1265 src = createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args,
1266 {}, funcAttrs, op.getOperation())
1267 ->getResult(0);
1268 // Step 2: Truncf to F16
1269 // Use half16 convert_half16(float16)
1270 std::string truncFnName = "convert_half16";
1271 SmallVector<Type> truncArgTypes{src.getType()};
1272 SmallVector<Value> truncArgs{src};
1273 truncFnName = mangle(truncFnName, truncArgTypes);
1274 resTy = VectorType::get(vecSrcTy.getShape(), rewriter.getF16Type());
1275 src =
1276 createDeviceFunctionCall(rewriter, truncFnName, resTy, truncArgTypes,
1277 truncArgs, {}, funcAttrs, op.getOperation())
1278 ->getResult(0);
1279 }
1280 if (dstEtype == TruncfDstElemTypes::BF8) { // Float8E5M2Type
1281 // Use char16 __builtin_IB_hftobf8_16(half16)
1282 std::string fnName = "__builtin_IB_hftobf8_16";
1283 SmallVector<Type> argTypes{src.getType()};
1284 SmallVector<Value> args{src};
1285 Value result =
1286 createDeviceFunctionCall(rewriter, fnName, vecDstTy, argTypes, args,
1287 {}, funcAttrs, op.getOperation())
1288 ->getResult(0);
1289
1290 rewriter.replaceOp(op, result);
1291 } else if (dstEtype == TruncfDstElemTypes::F8) { // Float8E4M3FNType
1292 // Use char16 __builtin_IB_hftohf8_16(half16)
1293 std::string fnName = "__builtin_IB_hftohf8_16";
1294 SmallVector<Type> argTypes{src.getType()};
1295 SmallVector<Value> args{src};
1296 Value result =
1297 createDeviceFunctionCall(rewriter, fnName, vecDstTy, argTypes, args,
1298 {}, funcAttrs, op.getOperation())
1299 ->getResult(0);
1300
1301 rewriter.replaceOp(op, result);
1302 } else {
1303 return rewriter.notifyMatchFailure(
1304 op, "Unsupported src, dst element type pair.");
1305 }
1306 return success();
1307 }
1308};
1309
1310class ExtfToOCLPattern : public OpConversionPattern<ExtfOp> {
1311 using OpConversionPattern::OpConversionPattern;
1312 LogicalResult
1313 matchAndRewrite(ExtfOp op, ExtfOp::Adaptor adaptor,
1314 ConversionPatternRewriter &rewriter) const override {
1315 // `xevm.extf` is the inverse of `xevm.truncf`. Supported source and result
1316 // types are restricted for now, mirroring the truncf lowering.
1317 auto srcEtype = op.getSrcEtype().getEtype();
1318 auto dstEtype = op.getDstEtype().getEtype();
1319 // Scalar case is not supported until usage case become clear.
1320 auto vecSrcTy = dyn_cast<VectorType>(op.getSrc().getType());
1321 if (!vecSrcTy)
1322 return rewriter.notifyMatchFailure(op, "Scalar src is not supported.");
1323 auto vecDstTy = dyn_cast<VectorType>(op.getDst().getType());
1324 if (!vecDstTy)
1325 return rewriter.notifyMatchFailure(op, "Scalar dst is not supported.");
1326 Value src = op.getSrc();
1327 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
1328 /*other=*/LLVM::ModRefInfo::NoModRef,
1329 /*argMem=*/LLVM::ModRefInfo::NoModRef,
1330 /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
1331 /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
1332 /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
1333 /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
1334 auto funcAttrs = convergentNoUnwindWillReturnAttrs;
1335 funcAttrs.memEffectsAttr = memAttr;
1336
1337 // Handle the case where src type is fp4 (e2m1) first.
1338 if (srcEtype == ExtfSrcElemTypes::E2M1) {
1339 // 16 fp4 values are packed into vector<8xi8>, the result is a
1340 // vector<16xf16> or vector<16xbf16>.
1341 // Use:
1342 // uint16 __builtin_IB_shfl_idx4_lut(int lut_index)
1343 // uint8 __builtin_IB_shfl_idx4_to_fp16_8_packed(uint16 lut,
1344 // char8 source)
1345 // The lookup table selects the target format:
1346 // 7 = e2m1 -> f16, 5 = e2m1 -> bf16.
1347 if (vecSrcTy.getNumElements() != 8 || vecDstTy.getNumElements() != 16)
1348 return rewriter.notifyMatchFailure(
1349 op, "fp4 src expects a vector<8xi8> src and a 16 element dst");
1350 constexpr int kLutE2M1ToF16 = 7;
1351 constexpr int kLutE2M1ToBF16 = 5;
1352 int lutIndex =
1353 (dstEtype == ExtfDstElemTypes::F16) ? kLutE2M1ToF16 : kLutE2M1ToBF16;
1354 Value lutIdx = LLVM::ConstantOp::create(rewriter, op.getLoc(),
1355 rewriter.getI32Type(), lutIndex);
1356 Type lutTy = VectorType::get(16, rewriter.getI32Type());
1357 Value lut =
1358 createDeviceFunctionCall(rewriter, "__builtin_IB_shfl_idx4_lut",
1359 lutTy, {lutIdx.getType()}, {lutIdx}, {},
1360 funcAttrs, op.getOperation())
1361 ->getResult(0);
1362 Type packedResTy = VectorType::get(8, rewriter.getI32Type());
1363 SmallVector<Type> convArgTypes{lut.getType(), src.getType()};
1364 SmallVector<Value> convArgs{lut, src};
1365 Value result =
1366 createDeviceFunctionCall(
1367 rewriter, "__builtin_IB_shfl_idx4_to_fp16_8_packed", packedResTy,
1368 convArgTypes, convArgs, {}, funcAttrs, op.getOperation())
1369 ->getResult(0);
1370 // The builtin returns the f16/bf16 bits packed as i32, bitcast to the
1371 // f16/bf16 dst type.
1372 result = LLVM::BitcastOp::create(rewriter, op.getLoc(), vecDstTy, result);
1373 rewriter.replaceOp(op, result);
1374 return success();
1375 }
1376
1377 // Handle the case where src type is fp8 (bf8/hf8).
1378 // Only 16 input elements are supported, see TruncfToOCLPattern for details.
1379 if (vecSrcTy.getNumElements() != 16)
1380 return rewriter.notifyMatchFailure(
1381 op, "Only vector src of 16 elements is supported");
1382
1383 // Step 1: Extend fp8 (bf8/hf8) to F16.
1384 // bf8 -> half: half16 __builtin_IB_bf8tohf_16(char16)
1385 // hf8 -> half: half16 __builtin_IB_hf8tohf_16(char16)
1386 std::string fnName = (srcEtype == ExtfSrcElemTypes::BF8)
1387 ? "__builtin_IB_bf8tohf_16"
1388 : "__builtin_IB_hf8tohf_16";
1389 Type f16Ty = VectorType::get(vecSrcTy.getShape(), rewriter.getF16Type());
1390 SmallVector<Type> argTypes{src.getType()};
1391 SmallVector<Value> args{src};
1392 Value result =
1393 createDeviceFunctionCall(rewriter, fnName, f16Ty, argTypes, args, {},
1394 funcAttrs, op.getOperation())
1395 ->getResult(0);
1396
1397 // When the destination is F16, we are done.
1398 if (dstEtype == ExtfDstElemTypes::F16) {
1399 rewriter.replaceOp(op, result);
1400 return success();
1401 }
1402
1403 // BF16 destination needs some postprocessing.
1404 // First extend F16 to F32 and then truncate to BF16.
1405 // Step 2: Extend to F32.
1406 // Use float16 convert_float16(half16)
1407 std::string convFnName = "convert_float16";
1408 SmallVector<Type> convArgTypes{result.getType()};
1409 SmallVector<Value> convArgs{result};
1410 convFnName = mangle(convFnName, convArgTypes);
1411 Type f32Ty = VectorType::get(vecSrcTy.getShape(), rewriter.getF32Type());
1412 result =
1413 createDeviceFunctionCall(rewriter, convFnName, f32Ty, convArgTypes,
1414 convArgs, {}, funcAttrs, op.getOperation())
1415 ->getResult(0);
1416 // Step 3: Truncate F32 to BF16.
1417 // Use short16 __builtin_IB_ftobf_16(float16)
1418 constexpr StringRef ftobfFnName = "__builtin_IB_ftobf_16";
1419 SmallVector<Type> ftobfArgTypes{result.getType()};
1420 SmallVector<Value> ftobfArgs{result};
1421 Type i16Ty = VectorType::get(vecSrcTy.getShape(), rewriter.getI16Type());
1422 result =
1423 createDeviceFunctionCall(rewriter, ftobfFnName, i16Ty, ftobfArgTypes,
1424 ftobfArgs, {}, funcAttrs, op.getOperation())
1425 ->getResult(0);
1426 // The builtin returns the bf16 bits as i16, bitcast to the bf16 dst type.
1427 result = LLVM::BitcastOp::create(rewriter, op.getLoc(), vecDstTy, result);
1428 rewriter.replaceOp(op, result);
1429 return success();
1430 }
1431};
1432
1433class MMAMxToOCLPattern : public OpConversionPattern<MMAMxOp> {
1434 using OpConversionPattern::OpConversionPattern;
1435 LogicalResult
1436 matchAndRewrite(MMAMxOp op, MMAMxOp::Adaptor adaptor,
1437 ConversionPatternRewriter &rewriter) const override {
1438 if (!op.getC()) {
1439 return rewriter.notifyMatchFailure(op, "OCL requires C operand");
1440 }
1441 auto precisionC = op.getTypes().getC();
1442 auto precisionD = op.getTypes().getD();
1443 if (precisionC != precisionD) {
1444 return rewriter.notifyMatchFailure(op, "type of C and D need to match");
1445 }
1446
1447 constexpr uint32_t bitWidthPackedA{16};
1448 constexpr uint32_t bitWidthPackedB{32};
1449 auto loc = op.getLoc();
1450
1451 auto castIfNeeded = [&](Value val, Type packedType) -> Value {
1452 VectorType origTy = cast<VectorType>(val.getType());
1453 const uint32_t vecBitSize =
1454 origTy.getNumElements() *
1455 origTy.getElementType().getIntOrFloatBitWidth();
1456 VectorType newTy = VectorType::get(
1457 vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
1458 if (origTy != newTy)
1459 val = LLVM::BitcastOp::create(rewriter, loc, newTy, val);
1460 return val;
1461 };
1462
1463 Value a = op.getA();
1464 Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
1465 ? cast<Type>(rewriter.getF32Type())
1466 : rewriter.getIntegerType(bitWidthPackedA);
1467 a = castIfNeeded(a, packedAType);
1468
1469 Value b = op.getB();
1470 Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
1471 ? cast<Type>(rewriter.getF32Type())
1472 : rewriter.getIntegerType(bitWidthPackedB);
1473 b = castIfNeeded(b, packedBType);
1474
1475 Value c = op.getC();
1476 VectorType cOrigTy = cast<VectorType>(c.getType());
1477 VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
1478 assert(cOrigTy == resOrigTy && "Accumulator and result type mismatch");
1479 // OCL builtins encode bfloat16 as int16
1480 VectorType cTy =
1481 cOrigTy.getElementType().isBF16()
1482 ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
1483 : cOrigTy;
1484 VectorType resTy = cTy;
1485 if (cOrigTy != cTy)
1486 c = LLVM::BitcastOp::create(rewriter, loc, cTy, c);
1487
1488 std::string fnName =
1489 llvm::formatv("__builtin_IB_sub_group16_bdpas_{0}_{1}_{2}_{3}_8_8",
1490 builtinElemType(op.getTypes().getD()),
1491 builtinElemType(op.getTypes().getC()),
1492 builtinElemType(op.getTypes().getA()),
1493 builtinElemType(op.getTypes().getB()))
1494 .str();
1495 auto scaleA = op.getScaleA();
1496 auto scaleB = op.getScaleB();
1497 SmallVector<Type> argTypes{cTy, a.getType(), b.getType(), scaleA.getType(),
1498 scaleB.getType()};
1499 SmallVector<Value> args{c, a, b, scaleA, scaleB};
1500
1501 auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
1502 /*other=*/LLVM::ModRefInfo::NoModRef,
1503 /*argMem=*/LLVM::ModRefInfo::NoModRef,
1504 /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef,
1505 /*errnoMem=*/LLVM::ModRefInfo::NoModRef,
1506 /*targetMem0=*/LLVM::ModRefInfo::NoModRef,
1507 /*targetMem1=*/LLVM::ModRefInfo::NoModRef);
1508 auto funcAttrs = convergentNoUnwindWillReturnAttrs;
1509 funcAttrs.memEffectsAttr = memAttr;
1510 Value result =
1511 createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
1512 funcAttrs, op.getOperation())
1513 ->getResult(0);
1514
1515 if (resOrigTy != resTy)
1516 result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy, result);
1517
1518 rewriter.replaceOp(op, result);
1519 return success();
1520 }
1521};
1522
1523class AllocaToGlobalPattern : public OpConversionPattern<LLVM::AllocaOp> {
1524 using OpConversionPattern::OpConversionPattern;
1525 LogicalResult
1526 matchAndRewrite(LLVM::AllocaOp op, LLVM::AllocaOp::Adaptor adaptor,
1527 ConversionPatternRewriter &rewriter) const override {
1528 auto ptrType = cast<LLVM::LLVMPointerType>(op.getType());
1529 auto addrSpace = ptrType.getAddressSpace();
1530 if (addrSpace != 3)
1531 return failure();
1532 auto symTable = op->getParentWithTrait<OpTrait::SymbolTable>();
1533 if (!symTable)
1534 return failure();
1535 Block *moduleBody;
1536 if (ModuleOp mod = dyn_cast<ModuleOp>(*symTable)) {
1537 moduleBody = mod.getBody();
1538 } else if (gpu::GPUModuleOp gpuMod =
1539 dyn_cast<gpu::GPUModuleOp>(*symTable)) {
1540 moduleBody = gpuMod.getBody();
1541 } else {
1542 return failure();
1543 }
1544 auto val = op.getArraySize();
1545 APInt cst;
1546 if (!matchPattern(val, m_ConstantInt(&cst)))
1547 return failure();
1548 auto loc = op.getLoc();
1549 auto globalType = LLVM::LLVMArrayType::get(
1550 rewriter.getContext(), op.getElemType(), cst.getZExtValue());
1551 LLVM::GlobalOp globalVar;
1552 {
1553 OpBuilder::InsertionGuard guard(rewriter);
1554 rewriter.setInsertionPointToStart(moduleBody);
1555 auto alignment = op.getAlignment();
1556 globalVar = LLVM::GlobalOp::create(
1557 rewriter, loc, globalType, /*isConstant=*/false,
1558 /*linkage=*/LLVM::Linkage::Internal,
1559 /*name=*/std::string("__global_alloca_") +
1560 std::to_string(getNextGlobalIdx()),
1561 /*value=*/Attribute(),
1562 /*alignment=*/alignment ? *alignment : 0, /*addrSpace=*/addrSpace);
1563 }
1564 rewriter.replaceOpWithNewOp<LLVM::AddressOfOp>(op, globalVar);
1565 return success();
1566 }
1567
1568private:
1569 static unsigned getNextGlobalIdx() {
1570 static unsigned globalIdx = 0;
1571 return globalIdx++;
1572 }
1573};
1574
1575// Checks if shufflevector is used as a way to extract a contiguous slice
1576// from a vector.
1577// - source vector V1 and V2 are the same vector.
1578// - mask size is not greater than the source vector size
1579// - mask values represent a sequence of consecutive increasing numbers
1580// that stay in bounds of the source vector when used for indexing.
1581static bool isExtractingContiguousSlice(LLVM::ShuffleVectorOp op) {
1582 if (op.getV1() != op.getV2())
1583 return false;
1584 auto maskAttr = op.getMask();
1585 int64_t maskSize = static_cast<int64_t>(maskAttr.size());
1586 int64_t sourceSize = op.getV1().getType().getNumElements();
1587 if (maskSize > sourceSize)
1588 return false;
1589 int64_t firstIndex = maskAttr[0];
1590 for (int64_t i = 1; i < maskSize; ++i) {
1591 int64_t index = maskAttr[i];
1592 if (index != firstIndex + i)
1593 return false;
1594 if (index >= sourceSize)
1595 return false;
1596 }
1597 return true;
1598}
1599
1600// Input vector of a shuffle vector op extracting a contiguous slice is an
1601// illegal vector in SPIRV kernel if the vector size is > 16 elements.
1602// To legalize this case, keep applying the following transformations until no
1603// more match:
1604// 1. keep hoisting the shuffle vector op past unary element-wise operations
1605// start with fpext, fptrunc and bitcast for now.
1606// 2. merge with another shuffle vector op
1607// 3. merge with load as a smaller load
1608class HandleVectorExtractPattern
1609 : public OpRewritePattern<LLVM::ShuffleVectorOp> {
1610 using OpRewritePattern<LLVM::ShuffleVectorOp>::OpRewritePattern;
1611
1612 void initialize() { setHasBoundedRewriteRecursion(); }
1613
1614 LogicalResult matchAndRewrite(LLVM::ShuffleVectorOp op,
1615 PatternRewriter &rewriter) const override {
1616
1617 if (!isExtractingContiguousSlice(op))
1618 return failure();
1619
1620 auto mask = op.getMask();
1621 auto loc = op.getLoc();
1622 auto ty = op.getType();
1623 // Check source operand to determine rewrite pattern.
1624 auto src = op.getV1();
1625 // 1. Hoist past unary element-wise operations
1626 if (auto srcOp = src.getDefiningOp()) {
1627 if (isa<LLVM::FPExtOp>(srcOp) || isa<LLVM::FPTruncOp>(srcOp)) {
1628 Value srcInput = srcOp->getOperand(0);
1629 // Create new shuffle vector op with unary input as source.
1630 auto srcVecTy = dyn_cast<VectorType>(srcInput.getType());
1631 auto newShuffleVecTy =
1632 VectorType::get(mask.size(), srcVecTy.getElementType());
1633 auto newShuffle = LLVM::ShuffleVectorOp::create(
1634 rewriter, loc, newShuffleVecTy, srcInput, srcInput, mask);
1635 // Create new unary op with new shuffle as input.
1636 Value newUnaryOp;
1637 if (isa<LLVM::FPExtOp>(srcOp)) {
1638 newUnaryOp = LLVM::FPExtOp::create(rewriter, loc, ty, newShuffle);
1639 } else {
1640 newUnaryOp = LLVM::FPTruncOp::create(rewriter, loc, ty, newShuffle);
1641 }
1642 rewriter.replaceOp(op, newUnaryOp);
1643 } else if (isa<LLVM::BitcastOp>(srcOp)) {
1644 Value srcInput = srcOp->getOperand(0);
1645 // Create new shuffle vector op with unary input as source.
1646 auto srcInputVecTy = dyn_cast<VectorType>(srcInput.getType());
1647 auto srcInputSize = srcInputVecTy.getNumElements();
1648 auto srcResVecTy = dyn_cast<VectorType>(srcOp->getResult(0).getType());
1649 auto srcResSize = srcResVecTy.getNumElements();
1650 auto maskSize = static_cast<int32_t>(mask.size());
1651 if (srcInputSize > srcResSize) {
1652 return failure();
1653 }
1654 if (srcResSize % srcInputSize != 0) {
1655 return failure();
1656 }
1657 auto maskScale = srcResSize / srcInputSize;
1658 if (maskScale != 1) {
1659 if (mask[0] % maskScale != 0) {
1660 return failure();
1661 }
1662 // Create a new mask that maps to the source vector
1663 SmallVector<int32_t> newMask;
1664 int32_t newMaskSize = maskSize / maskScale;
1665 int32_t maskStart = mask[0] / maskScale;
1666 for (int32_t i = 0; i < newMaskSize; ++i) {
1667 newMask.push_back(maskStart + i);
1668 }
1669 mask = newMask;
1670 }
1671 auto newShuffleVecTy =
1672 VectorType::get(srcInputSize, srcInputVecTy.getElementType());
1673 auto newShuffle = LLVM::ShuffleVectorOp::create(
1674 rewriter, loc, newShuffleVecTy, srcInput, srcInput, mask);
1675 // Create new unary op with new shuffle as input.
1676 auto newBitcast =
1677 LLVM::BitcastOp::create(rewriter, loc, ty, newShuffle);
1678 rewriter.replaceOp(op, newBitcast);
1679 } else if (isa<LLVM::ShuffleVectorOp>(srcOp)) {
1680 // 2. Merge with source shuffle vector op if, the source op is
1681 // also extracting a contigous slice and create a new
1682 // shuffle vector op directly from the source of
1683 // the first shuffle.
1684 auto srcShuffle = cast<LLVM::ShuffleVectorOp>(srcOp);
1685 if (!isExtractingContiguousSlice(srcShuffle))
1686 return failure();
1687 auto srcMask = srcShuffle.getMask();
1688 SmallVector<int32_t> combinedMask;
1689 for (auto index : mask) {
1690 combinedMask.push_back(srcMask[index]);
1691 }
1692 auto newShuffle = LLVM::ShuffleVectorOp::create(
1693 rewriter, loc, ty, srcShuffle.getV1(), srcShuffle.getV1(),
1694 DenseI32ArrayAttr::get(rewriter.getContext(), combinedMask));
1695 rewriter.replaceOp(op, newShuffle);
1696 } else if (isa<LLVM::LoadOp>(srcOp)) {
1697 // 3. Merge with load as a smaller load
1698 auto loadOp = cast<LLVM::LoadOp>(srcOp);
1699 auto loadPtr = loadOp.getAddr();
1700 auto loadAddrSpace = loadPtr.getType().getAddressSpace();
1701 if (loadAddrSpace != 0)
1702 return failure();
1703 auto loadTy = dyn_cast<VectorType>(loadOp.getType());
1704 auto elemTy = loadTy.getElementType();
1705 auto firstIndex = mask[0];
1706 auto newVecTy = VectorType::get(mask.size(), elemTy);
1707 // GEPOp is needed if first index is not zero
1708 if (firstIndex) {
1709 auto newPtr = LLVM::GEPOp::create(
1710 rewriter, loc,
1711 LLVM::LLVMPointerType::get(rewriter.getContext(), loadAddrSpace),
1712 elemTy, loadPtr, ArrayRef<LLVM::GEPArg>{firstIndex});
1713 auto newLoad = LLVM::LoadOp::create(rewriter, loc, newVecTy, newPtr);
1714 rewriter.replaceOp(op, newLoad);
1715 } else {
1716 auto newLoad = LLVM::LoadOp::create(rewriter, loc, newVecTy, loadPtr);
1717 rewriter.replaceOp(op, newLoad);
1718 }
1719 } else {
1720 return failure();
1721 }
1722 } else {
1723 // No defining op (e.g. function argument): nothing to hoist/merge.
1724 return failure();
1725 }
1726 return success();
1727 }
1728};
1729
1730//===----------------------------------------------------------------------===//
1731// Pass Definition
1732//===----------------------------------------------------------------------===//
1733
1734struct ConvertXeVMToLLVMPass
1735 : public impl::ConvertXeVMToLLVMPassBase<ConvertXeVMToLLVMPass> {
1736 using Base::Base;
1737
1738 void getDependentDialects(DialectRegistry &registry) const override {
1739 registry.insert<LLVM::LLVMDialect, XeVMDialect>();
1740 }
1741
1742 void runOnOperation() override {
1743 ConversionTarget target(getContext());
1744 RewritePatternSet patterns(&getContext());
1746 if (failed(applyPartialConversion(getOperation(), target,
1747 std::move(patterns))))
1748 signalPassFailure();
1749
1750 // Apply in-dialect lowerings to handle illegal vectors
1751 {
1752 RewritePatternSet vectorPatterns(&getContext());
1753 vectorPatterns.add<HandleVectorExtractPattern>(&getContext());
1754 GreedyRewriteConfig config{};
1755 // folding can remove ops with temporary attributes used to
1756 // represent LLVM metadata, so disable it here.
1757 // Effectively just this single pattern is applied without any
1758 // op folding patterns from dialects.
1759 config.enableFolding(false);
1760 // config.setMaxIterations(GreedyRewriteConfig::kNoLimit);
1761 // config.setMaxNumRewrites(GreedyRewriteConfig::kNoLimit);
1762 (void)applyPatternsGreedily(getOperation(), std::move(vectorPatterns),
1763 config);
1764 }
1765 }
1766};
1767} // namespace
1768
1769//===----------------------------------------------------------------------===//
1770// Pattern Population
1771//===----------------------------------------------------------------------===//
1772
1773void ::mlir::populateXeVMToLLVMConversionPatterns(ConversionTarget &target,
1774 RewritePatternSet &patterns) {
1775 // some LLVM operations need to be converted.
1776 target.addDynamicallyLegalDialect<LLVM::LLVMDialect>([](Operation *op) {
1777 // llvm alloca op with addrspace 3 for OpenCL (Workgroup) is not handled
1778 // properly by SPIRV backend. It needs to be rewritten as a sequence with
1779 // llvm global.
1780 if (isa<LLVM::AllocaOp>(op)) {
1781 LLVM::AllocaOp aOp = cast<LLVM::AllocaOp>(op);
1782 LLVM::LLVMPointerType pTy = cast<LLVM::LLVMPointerType>(aOp.getType());
1783 auto addrSpace = pTy.getAddressSpace();
1784 return addrSpace != 3;
1785 }
1786 // cache_control attribute should be converted.
1787 return !op->hasAttr("cache_control");
1788 });
1789 target.addIllegalDialect<XeVMDialect>();
1790 patterns
1791 .add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
1792 LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
1793 LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>, MMAToOCLPattern,
1794 MemfenceToOCLPattern, PrefetchToOCLPattern,
1795 LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
1796 LLVMLoadStoreToOCLPattern<LLVM::StoreOp>,
1797 BlockLoadStore1DToOCLPattern<BlockLoadOp>,
1798 BlockLoadStore1DToOCLPattern<BlockStoreOp>,
1799 LaunchConfigOpToOCLPattern<WorkitemIdXOp>,
1800 LaunchConfigOpToOCLPattern<WorkitemIdYOp>,
1801 LaunchConfigOpToOCLPattern<WorkitemIdZOp>,
1802 LaunchConfigOpToOCLPattern<WorkgroupDimXOp>,
1803 LaunchConfigOpToOCLPattern<WorkgroupDimYOp>,
1804 LaunchConfigOpToOCLPattern<WorkgroupDimZOp>,
1805 LaunchConfigOpToOCLPattern<WorkgroupIdXOp>,
1806 LaunchConfigOpToOCLPattern<WorkgroupIdYOp>,
1807 LaunchConfigOpToOCLPattern<WorkgroupIdZOp>,
1808 LaunchConfigOpToOCLPattern<GridDimXOp>,
1809 LaunchConfigOpToOCLPattern<GridDimYOp>,
1810 LaunchConfigOpToOCLPattern<GridDimZOp>,
1811 SubgroupOpWorkitemOpToOCLPattern<LaneIdOp>,
1812 SubgroupOpWorkitemOpToOCLPattern<SubgroupIdOp>,
1813 SubgroupOpWorkitemOpToOCLPattern<SubgroupSizeOp>, TruncfToOCLPattern,
1814 ExtfToOCLPattern, MMAMxToOCLPattern, AllocaToGlobalPattern>(
1815 patterns.getContext());
1816}
return success()
LogicalResult initialize(unsigned origNumLoops, ArrayRef< ReassociationIndices > foldedIterationDims)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
Attributes are known-constant values of operations.
Definition Attributes.h:25
MLIRContext * getContext() const
Definition Builders.h:56
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
A trait used to provide symbol table functionalities to a region operation.
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:711
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition Operation.h:273
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
Block & front()
Definition Region.h:65
MLIRContext * getContext() const
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes={}, Type resultType={}, bool isVarArg=false, bool isReserved=false, SymbolTableCollection *symbolTables=nullptr)
Create a FuncOp with signature resultType(paramTypes) and name name`.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::constant_int_value_binder m_ConstantInt(IntegerAttr::ValueType *bind_value)
Matches a constant holding a scalar/vector/tensor integer (splat) and writes the integer value to bin...
Definition Matchers.h:527
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
void populateXeVMToLLVMConversionPatterns(ConversionTarget &target, RewritePatternSet &patterns)
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...