MLIR  22.0.0git
XeVMToLLVM.cpp
Go to the documentation of this file.
1 //===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Support/LLVM.h"
18 #include "llvm/Support/FormatVariadic.h"
19 
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Types.h"
22 
23 #include "llvm/ADT/TypeSwitch.h"
24 
25 namespace mlir {
26 #define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
27 #include "mlir/Conversion/Passes.h.inc"
28 } // namespace mlir
29 
30 using namespace mlir;
31 using namespace xevm;
32 
33 namespace {
34 
35 struct LLVMFuncAttributeOptions {
36  bool isConvergent = false;
37  bool isNoUnwind = false;
38  bool isWillReturn = false;
39  LLVM::MemoryEffectsAttr memEffectsAttr{};
40 };
41 static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
42  false, true, false, {}};
43 static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
44  false, true, true, {}};
45 static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
46  true, true, true, {}};
47 
48 std::string getTypeMangling(Type ty, bool isUnsigned = false) {
50  .Case([isUnsigned](VectorType ty) -> std::string {
51  return "Dv" + std::to_string(ty.getNumElements()) + "_" +
52  getTypeMangling(ty.getElementType(), isUnsigned);
53  })
54  .Case([](Float16Type) -> std::string { return "Dh"; })
55  .Case([](Float32Type) -> std::string { return "f"; })
56  .Case([](Float64Type) -> std::string { return "d"; })
57  .Case([isUnsigned](IntegerType ty) -> std::string {
58  switch (ty.getWidth()) {
59  case 8:
60  return isUnsigned ? "h" : "c";
61  case 16:
62  return isUnsigned ? "t" : "s";
63  case 32:
64  return isUnsigned ? "j" : "i";
65  case 64:
66  return isUnsigned ? "m" : "l";
67  default:
68  llvm_unreachable("unhandled integer type");
69  }
70  })
71  .DefaultUnreachable("unhandled type for mangling");
72 }
73 
74 std::string mangle(StringRef baseName, ArrayRef<Type> types,
75  ArrayRef<bool> isUnsigned = {}) {
76  assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
77  "Signedness info doesn't match");
78  std::string s;
79  llvm::raw_string_ostream os(s);
80  llvm::SmallDenseMap<Type, unsigned> substitutions;
81  os << "_Z" << baseName.size() << baseName;
82  for (auto [idx, type] : llvm::enumerate(types)) {
83  auto it = substitutions.find(type);
84  if (it != substitutions.end()) {
85  os << "S";
86  // First substitution is `S_`, second is `S0_`, and so on.
87  if (unsigned firstIdx = it->getSecond(); firstIdx > 0)
88  os << firstIdx - 1;
89  os << "_";
90  } else {
91  if (!type.isIntOrFloat())
92  substitutions[type] = substitutions.size();
93  os << getTypeMangling(type, isUnsigned.empty() ? false : isUnsigned[idx]);
94  }
95  }
96  return os.str();
97 }
98 
99 static int32_t getL1CacheControl(LoadCacheControl cc) {
100  int32_t control = 0;
101  switch (cc) {
102  case LoadCacheControl::L1UC_L2UC_L3UC:
103  case LoadCacheControl::L1UC_L2UC_L3C:
104  case LoadCacheControl::L1UC_L2C_L3UC:
105  case LoadCacheControl::L1UC_L2C_L3C:
106  control = 1;
107  break;
108  case LoadCacheControl::L1C_L2UC_L3UC:
109  case LoadCacheControl::L1C_L2UC_L3C:
110  case LoadCacheControl::L1C_L2C_L3UC:
111  case LoadCacheControl::L1C_L2C_L3C:
112  control = 2;
113  break;
114  case LoadCacheControl::L1S_L2UC_L3UC:
115  case LoadCacheControl::L1S_L2UC_L3C:
116  case LoadCacheControl::L1S_L2C_L3UC:
117  case LoadCacheControl::L1S_L2C_L3C:
118  control = 3;
119  break;
120  case LoadCacheControl::INVALIDATE_READ:
121  control = 4;
122  break;
123  }
124  return control;
125 }
126 
127 static int32_t getL1CacheControl(StoreCacheControl cc) {
128  int32_t control = 0;
129  switch (cc) {
130  case StoreCacheControl::L1UC_L2UC_L3UC:
131  case StoreCacheControl::L1UC_L2UC_L3WB:
132  case StoreCacheControl::L1UC_L2WB_L3UC:
133  case StoreCacheControl::L1UC_L2WB_L3WB:
134  control = 1;
135  break;
136  case StoreCacheControl::L1WT_L2UC_L3UC:
137  case StoreCacheControl::L1WT_L2UC_L3WB:
138  case StoreCacheControl::L1WT_L2WB_L3UC:
139  case StoreCacheControl::L1WT_L2WB_L3WB:
140  control = 2;
141  break;
142  case StoreCacheControl::L1S_L2UC_L3UC:
143  case StoreCacheControl::L1S_L2UC_L3WB:
144  case StoreCacheControl::L1S_L2WB_L3UC:
145  case StoreCacheControl::L1S_L2WB_L3WB:
146  control = 3;
147  break;
148  case StoreCacheControl::L1WB_L2UC_L3UC:
149  case StoreCacheControl::L1WB_L2WB_L3UC:
150  case StoreCacheControl::L1WB_L2UC_L3WB:
151  control = 4;
152  break;
153  }
154  return control;
155 }
156 
157 static int32_t getL3CacheControl(LoadCacheControl cc) {
158  int32_t control = 0;
159  switch (cc) {
160  case LoadCacheControl::L1UC_L2UC_L3UC:
161  case LoadCacheControl::L1UC_L2C_L3UC:
162  case LoadCacheControl::L1C_L2UC_L3UC:
163  case LoadCacheControl::L1C_L2C_L3UC:
164  case LoadCacheControl::L1S_L2UC_L3UC:
165  case LoadCacheControl::L1S_L2C_L3UC:
166  control = 1;
167  break;
168  case LoadCacheControl::L1UC_L2UC_L3C:
169  case LoadCacheControl::L1UC_L2C_L3C:
170  case LoadCacheControl::L1C_L2UC_L3C:
171  case LoadCacheControl::L1C_L2C_L3C:
172  case LoadCacheControl::L1S_L2UC_L3C:
173  case LoadCacheControl::L1S_L2C_L3C:
174  control = 2;
175  break;
176  case LoadCacheControl::INVALIDATE_READ:
177  control = 4;
178  break;
179  }
180  return control;
181 }
182 
183 static int32_t getL3CacheControl(StoreCacheControl cc) {
184  int32_t control = 0;
185  switch (cc) {
186  case StoreCacheControl::L1UC_L2UC_L3UC:
187  case StoreCacheControl::L1UC_L2WB_L3UC:
188  case StoreCacheControl::L1WT_L2UC_L3UC:
189  case StoreCacheControl::L1WT_L2WB_L3UC:
190  case StoreCacheControl::L1S_L2UC_L3UC:
191  case StoreCacheControl::L1S_L2WB_L3UC:
192  case StoreCacheControl::L1WB_L2UC_L3UC:
193  case StoreCacheControl::L1WB_L2WB_L3UC:
194  control = 1;
195  break;
196  case StoreCacheControl::L1UC_L2UC_L3WB:
197  case StoreCacheControl::L1UC_L2WB_L3WB:
198  case StoreCacheControl::L1WT_L2UC_L3WB:
199  case StoreCacheControl::L1WT_L2WB_L3WB:
200  case StoreCacheControl::L1S_L2UC_L3WB:
201  case StoreCacheControl::L1S_L2WB_L3WB:
202  case StoreCacheControl::L1WB_L2UC_L3WB:
203  control = 2;
204  break;
205  }
206  return control;
207 }
208 
209 static std::optional<LoadCacheControl> getCacheControl(PrefetchOp op) {
210  return op.getCacheControl();
211 }
212 
213 static std::optional<LoadCacheControl> getCacheControl(BlockLoad2dOp op) {
214  return op.getCacheControl();
215 }
216 
217 static std::optional<LoadCacheControl> getCacheControl(BlockLoadOp op) {
218  return op.getCacheControl();
219 }
220 
221 static std::optional<LoadCacheControl> getCacheControl(BlockPrefetch2dOp op) {
222  return op.getCacheControl();
223 }
224 
225 static std::optional<StoreCacheControl> getCacheControl(BlockStore2dOp op) {
226  return op.getCacheControl();
227 }
228 
229 static std::optional<StoreCacheControl> getCacheControl(BlockStoreOp op) {
230  return op.getCacheControl();
231 }
232 
233 static std::optional<LoadCacheControl> getCacheControl(LLVM::LoadOp op) {
234  if (op->hasAttr("cache_control")) {
235  auto attr = op->getAttrOfType<xevm::LoadCacheControlAttr>("cache_control");
236  if (!attr)
237  return std::nullopt;
238  return std::optional<LoadCacheControl>(attr.getValue());
239  }
240  return std::nullopt;
241 }
242 
243 static std::optional<StoreCacheControl> getCacheControl(LLVM::StoreOp op) {
244  if (op->hasAttr("cache_control")) {
245  auto attr = op->getAttrOfType<xevm::StoreCacheControlAttr>("cache_control");
246  if (!attr)
247  return std::nullopt;
248  return std::optional<StoreCacheControl>(attr.getValue());
249  }
250  return std::nullopt;
251 }
252 
253 template <typename OpType>
254 int32_t getL1CacheControl(OpType op) {
255  return getL1CacheControl(*getCacheControl(op));
256 }
257 
258 template <typename OpType>
259 int32_t getL3CacheControl(OpType op) {
260  return getL3CacheControl(*getCacheControl(op));
261 }
262 
263 template <typename OpType>
264 static std::optional<ArrayAttr>
265 getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
266  if (!getCacheControl(op))
267  return {};
268  constexpr int32_t decorationCacheControlArity{4};
269  constexpr int32_t loadCacheControlKey{6442};
270  constexpr int32_t storeCacheControlKey{6443};
271  constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp> ||
272  std::is_same_v<OpType, BlockPrefetch2dOp> ||
273  std::is_same_v<OpType, LLVM::LoadOp> ||
274  std::is_same_v<OpType, BlockLoadOp> ||
275  std::is_same_v<OpType, PrefetchOp>;
276  const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
278  controlKey, 0, getL1CacheControl<OpType>(op), 0};
280  controlKey, 1, getL3CacheControl<OpType>(op), 0};
281  auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
282  auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
283 
284  SmallVector<Attribute, 2> combinedAttrs = {arrayAttrL1, arrayAttrL3};
285  return rewriter.getArrayAttr(combinedAttrs);
286 }
287 
288 static LLVM::CallOp createDeviceFunctionCall(
289  ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
290  ArrayRef<Type> argTypes, ArrayRef<Value> args,
291  mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs,
292  LLVMFuncAttributeOptions funcAttributeOptions, Operation *op) {
293  auto moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
294  assert(moduleOp && "Expecting module");
295  Location loc = op->getLoc();
296 
297  auto funcOpRes =
298  LLVM::lookupOrCreateFn(rewriter, moduleOp, funcName, argTypes, retType);
299  assert(!failed(funcOpRes));
300  LLVM::LLVMFuncOp funcOp = funcOpRes.value();
301  funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
302  funcOp.setConvergent(funcAttributeOptions.isConvergent);
303  funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
304  funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
305 
306  if (funcAttributeOptions.memEffectsAttr)
307  funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
308 
309  for (auto [idx, attrName] : paramAttrs)
310  funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
311 
312  auto callOp = LLVM::CallOp::create(rewriter, loc, funcOp, args);
313  callOp->setAttrs(funcOp->getAttrs());
314 
315  return callOp;
316 }
317 
318 class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
320  LogicalResult
321  matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor,
322  ConversionPatternRewriter &rewriter) const override {
323  if (!op.getC()) {
324  return rewriter.notifyMatchFailure(op, "OCL requires C operand");
325  }
326  auto precisionA = op.getTypes().getA();
327  auto precisionB = op.getTypes().getB();
328  auto precisionC = op.getTypes().getC();
329  auto precisionD = op.getTypes().getD();
330  if (precisionC != precisionD) {
331  return rewriter.notifyMatchFailure(op, "type of C and D need to match");
332  }
333  if (precisionC != xevm::ElemType::S32 &&
334  precisionC != xevm::ElemType::F32 &&
335  precisionC != xevm::ElemType::F16 &&
336  precisionC != xevm::ElemType::BF16) {
337  return rewriter.notifyMatchFailure(
338  op, "type of C and D must be S32, F32, F16 or BF16");
339  }
340  if (precisionA == xevm::ElemType::S32 ||
341  precisionA == xevm::ElemType::F32) {
342  return rewriter.notifyMatchFailure(op, "type of A cannot be S32 or F32");
343  }
344  if (precisionB == xevm::ElemType::S32 ||
345  precisionB == xevm::ElemType::F32) {
346  return rewriter.notifyMatchFailure(op, "type of B cannot be S32 or F32");
347  }
348  constexpr uint32_t bitWidthPackedA{16};
349  constexpr uint32_t bitWidthPackedB{32};
350  auto loc = op.getLoc();
351 
352  auto castIfNeeded = [&](Value val, Type packedType) -> Value {
353  VectorType origTy = cast<VectorType>(val.getType());
354  const uint32_t vecBitSize =
355  origTy.getNumElements() *
356  origTy.getElementType().getIntOrFloatBitWidth();
357  VectorType newTy = VectorType::get(
358  vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
359  if (origTy != newTy)
360  val = LLVM::BitcastOp::create(rewriter, loc, newTy, val);
361  return val;
362  };
363 
364  Value a = op.getA();
365  Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
366  ? cast<Type>(rewriter.getF32Type())
367  : rewriter.getIntegerType(bitWidthPackedA);
368  a = castIfNeeded(a, packedAType);
369 
370  Value b = op.getB();
371  Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
372  ? cast<Type>(rewriter.getF32Type())
373  : rewriter.getIntegerType(bitWidthPackedB);
374  b = castIfNeeded(b, packedBType);
375 
376  Value c = op.getC();
377  VectorType cOrigTy = cast<VectorType>(c.getType());
378  VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
379  assert(cOrigTy == resOrigTy && "Accumulator and result type mismatch");
380  // OCL builtins encode bfloat16 as int16
381  VectorType cTy =
382  cOrigTy.getElementType().isBF16()
383  ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
384  : cOrigTy;
385  VectorType resTy = cTy;
386  if (cOrigTy != cTy)
387  c = LLVM::BitcastOp::create(rewriter, loc, cTy, c);
388 
389  constexpr int32_t systolicDepth{8};
390  std::string fnName =
391  llvm::formatv("intel_sub_group_{0}_{1}_matrix_mad_k{2}",
392  stringifyElemType(op.getTypes().getA()).str(),
393  stringifyElemType(op.getTypes().getB()).str(),
394  systolicDepth *
395  getNumOperandsPerDword(op.getTypes().getA()))
396  .str();
397  SmallVector<Type> argTypes{a.getType(), b.getType(), cTy};
398  fnName = mangle(fnName, argTypes);
399  SmallVector<Value> args{a, b, c};
400 
401  auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
402  /*other=*/LLVM::ModRefInfo::NoModRef,
403  /*argMem=*/LLVM::ModRefInfo::NoModRef,
404  /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
405  auto funcAttrs = convergentNoUnwindWillReturnAttrs;
406  funcAttrs.memEffectsAttr = memAttr;
407  Value result =
408  createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
409  funcAttrs, op.getOperation())
410  ->getResult(0);
411 
412  if (resOrigTy != resTy)
413  result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy, result);
414 
415  rewriter.replaceOp(op, result);
416  return success();
417  }
418 
419 private:
420  static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
421  switch (pTy) {
422  case xevm::ElemType::TF32:
423  return 1;
424  case xevm::ElemType::BF16:
425  case xevm::ElemType::F16:
426  return 2;
427  case xevm::ElemType::U8:
428  case xevm::ElemType::S8:
429  return 4;
430  default:
431  llvm_unreachable("unsupported xevm::ElemType");
432  }
433  }
434 };
435 
436 class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
438  LogicalResult
439  matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor,
440  ConversionPatternRewriter &rewriter) const override {
441  auto loc = op.getLoc();
442  const std::string fnName{"_Z8prefetchPU3AS1Kcm"};
443  Value one =
444  LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), 1);
445  SmallVector<Value> args{op.getPtr(), one};
446  SmallVector<Type> argTypes;
447  for (auto arg : args)
448  argTypes.push_back(arg.getType());
449  auto funcAttr = noUnwindAttrs;
450  auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
451  /*other=*/LLVM::ModRefInfo::NoModRef,
452  /*argMem=*/LLVM::ModRefInfo::Ref,
453  /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
454  funcAttr.memEffectsAttr = memAttr;
455 
456  LLVM::CallOp call = createDeviceFunctionCall(
457  rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()),
458  argTypes, args, {}, funcAttr, op.getOperation());
459  if (std::optional<ArrayAttr> optCacheControls =
460  getCacheControlMetadata(rewriter, op))
461  call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
462  rewriter.eraseOp(op);
463  return success();
464  }
465 };
466 
467 class MemfenceToOCLPattern : public OpConversionPattern<MemfenceOp> {
469  LogicalResult
470  matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor,
471  ConversionPatternRewriter &rewriter) const override {
472  auto loc = op.getLoc();
473  const std::string fnName{"atomic_work_item_fence"};
474  int memScope, addrSpace;
475  switch (op.getAddrspace()) {
476  case xevm::AddrSpace::SHARED:
477  addrSpace = 1; // CLK_LOCAL_MEM_FENCE
478  break;
479  case xevm::AddrSpace::GLOBAL:
480  addrSpace = 2; // CLK_GLOBAL_MEM_FENCE
481  break;
482  default:
483  // GENERIC is not supported in OpenCL
484  return rewriter.notifyMatchFailure(
485  op, "Fence only supports global and shared address spaces.");
486  }
487  switch (op.getScope()) {
488  case xevm::MemScope::WORKGROUP:
489  memScope = 1;
490  break;
491  case xevm::MemScope::DEVICE:
492  memScope = 2;
493  break;
494  default:
495  // CLUSTER and SYSTEM are not supported in OpenCL
496  return rewriter.notifyMatchFailure(
497  op, "Fence only supports workgroup and device memory scopes.");
498  }
499  Type i32Type = rewriter.getI32Type();
500  Value acqRel = LLVM::ConstantOp::create(rewriter, loc, i32Type, 4);
501  Value memScopeConst =
502  LLVM::ConstantOp::create(rewriter, loc, i32Type, memScope);
503  Value addrSpaceConst =
504  LLVM::ConstantOp::create(rewriter, loc, i32Type, addrSpace);
505  SmallVector<Value> args{addrSpaceConst, acqRel, memScopeConst};
506  SmallVector<Type> argTypes{3, i32Type};
507  createDeviceFunctionCall(rewriter, mangle(fnName, argTypes),
509  argTypes, args, {}, noUnwindAttrs,
510  op.getOperation());
511  rewriter.eraseOp(op);
512  return success();
513  }
514 };
515 template <typename OpType>
516 class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
518  LogicalResult
519  matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
520  ConversionPatternRewriter &rewriter) const override {
521  constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp>;
522  constexpr bool isPrefetch = std::is_same_v<OpType, BlockPrefetch2dOp>;
523 
524  auto loc = op.getLoc();
525  VectorType vecType;
526  bool packReg = false;
527  bool transpose = false;
528  if constexpr (isLoad) {
529  vecType = op.getRes().getType();
530  packReg = op.getPackRegister();
531  transpose = op.getTranspose();
532  } else if constexpr (!isPrefetch) {
533  vecType = op.getStoredVal().getType();
534  }
535 
536  auto i32Type = rewriter.getI32Type();
537  Value byteCoord =
538  LLVM::UndefOp::create(rewriter, loc, VectorType::get(2, i32Type));
539  Value zero = LLVM::ConstantOp::create(rewriter, loc, i32Type, 0);
540  Value one = LLVM::ConstantOp::create(rewriter, loc, i32Type, 1);
541  byteCoord = LLVM::InsertElementOp::create(
542  rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
543  byteCoord = LLVM::InsertElementOp::create(
544  rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
545  SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
546  op.getBasePitch(), byteCoord};
547  SmallVector<Type> retTypes;
548  Value spvLoadDstPtr;
549  std::string funcName{"intel_sub_group_2d_block_"};
550  std::string bitWidthId;
551  LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
553  if constexpr (isPrefetch) { // Prefetch
554  funcName += "prefetch";
555  paramAttrs = {std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName())};
556  auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
557  /*other=*/LLVM::ModRefInfo::NoModRef,
558  /*argMem=*/LLVM::ModRefInfo::Ref,
559  /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
560  funcAttr = noUnwindAttrs;
561  funcAttr.memEffectsAttr = memAttr;
562  } else {
563  auto vecElemType = vecType.getElementType();
564  auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth();
565  Value numElems = LLVM::ConstantOp::create(rewriter, loc, i32Type,
566  vecType.getNumElements());
567  auto dstOrSrcPtr = LLVM::AllocaOp::create(
568  rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()),
569  vecElemType, numElems);
570  args.push_back(dstOrSrcPtr);
571  if constexpr (isLoad) { // Load
572  funcName += "read";
573  bitWidthId = getTypeMangling(vecElemType, /*isUnsigned=*/true);
574  if (packReg)
575  funcName += "_transform";
576  else if (transpose)
577  funcName += "_transpose";
578  spvLoadDstPtr = dstOrSrcPtr;
579  retTypes.push_back(vecType);
580  paramAttrs = {
581  std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
582  std::make_pair(0, LLVM::LLVMDialect::getReadonlyAttrName()),
583  std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
584  std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()),
585  };
586  } else { // Store
587  funcName += "write";
588  bitWidthId = (vecElemBitWidth == 32)
589  ? "j"
590  : ((vecElemBitWidth == 16) ? "t" : "h");
591  LLVM::StoreOp::create(rewriter, loc, op.getStoredVal(), dstOrSrcPtr);
592  paramAttrs = {
593  std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
594  std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()),
595  std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
596  std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()),
597  };
598  }
599  }
600 
601  funcName =
602  llvm::formatv("{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(),
603  op.getTileHeight(), op.getTileWidth(), op.getVBlocks())
604  .str();
605  std::string prefetchCode("");
606  if (!isPrefetch)
607  prefetchCode += "P";
608  funcName = llvm::formatv("_Z{0}{1}PU3AS1viiiDv2_i{2}{3}", funcName.size(),
609  funcName, prefetchCode, bitWidthId)
610  .str();
611  SmallVector<Type> argTypes;
612  for (auto arg : args) {
613  argTypes.push_back(arg.getType());
614  }
615  LLVM::CallOp call = createDeviceFunctionCall(
616  rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
617  argTypes, args, paramAttrs, funcAttr, op.getOperation());
618  if (std::optional<ArrayAttr> optCacheControls =
619  getCacheControlMetadata(rewriter, op)) {
620  call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
621  }
622  if constexpr (isLoad)
623  rewriter.replaceOp(
624  op, LLVM::LoadOp::create(rewriter, loc, vecType, spvLoadDstPtr));
625  else
626  rewriter.eraseOp(op);
627  return success();
628  }
629 };
630 
631 template <typename OpType>
632 class BlockLoadStore1DToOCLPattern : public OpConversionPattern<OpType> {
634  LogicalResult
635  matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
636  ConversionPatternRewriter &rewriter) const override {
637  constexpr bool isStore = std::is_same_v<OpType, xevm::BlockStoreOp>;
638  // Get OpenCL function name
639  // https://registry.khronos.org/OpenCL/extensions/
640  // intel/cl_intel_subgroup_local_block_io.html
641  std::string funcName{"intel_sub_group_block_"};
642  // Value or Result type can be vector or scalar
643  Type valOrResTy;
644  if constexpr (isStore) {
645  funcName += "write_u";
646  valOrResTy = op.getVal().getType();
647  } else {
648  funcName += "read_u";
649  valOrResTy = op.getType();
650  }
651  // Get element type of the vector/scalar
652  VectorType vecTy = dyn_cast<VectorType>(valOrResTy);
653  Type elemType = vecTy ? vecTy.getElementType() : valOrResTy;
654  funcName += getTypeMangling(elemType);
655  if (vecTy)
656  funcName += std::to_string(vecTy.getNumElements());
657  SmallVector<Type, 2> argTypes{};
658  // XeVM BlockLoad/StoreOp always use signless integer types
659  // but OpenCL builtins expect unsigned types
660  // use unsigned types for mangling
661  SmallVector<bool, 2> isUnsigned{};
662  // arg0: pointer to the src/dst address
663  // arg1 - only if store : vector to store
664  // Prepare arguments
665  SmallVector<Value, 2> args{};
666  args.push_back(op.getPtr());
667  argTypes.push_back(op.getPtr().getType());
668  isUnsigned.push_back(true);
669  Type retType;
670  if constexpr (isStore) {
671  args.push_back(op.getVal());
672  argTypes.push_back(op.getVal().getType());
673  isUnsigned.push_back(true);
674  retType = LLVM::LLVMVoidType::get(rewriter.getContext());
675  } else {
676  retType = valOrResTy;
677  }
678  funcName = std::string("_Z") + std::to_string(funcName.size()) + funcName +
679  "PU3AS" +
680  std::to_string(op.getPtr().getType().getAddressSpace());
681  funcName += getTypeMangling(elemType, /*isUnsigned=*/true);
682  if constexpr (isStore)
683  funcName += getTypeMangling(valOrResTy, /*isUnsigned=*/true);
684  LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
685 
686  LLVM::CallOp call =
687  createDeviceFunctionCall(rewriter, funcName, retType, argTypes, args,
688  {}, funcAttr, op.getOperation());
689  if (std::optional<ArrayAttr> optCacheControls =
690  getCacheControlMetadata(rewriter, op)) {
691  call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
692  }
693  if constexpr (isStore)
694  rewriter.eraseOp(op);
695  else
696  rewriter.replaceOp(op, call->getResult(0));
697  return success();
698  }
699 };
700 
701 template <typename OpType>
702 class LLVMLoadStoreToOCLPattern : public OpConversionPattern<OpType> {
704  LogicalResult
705  matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
706  ConversionPatternRewriter &rewriter) const override {
707  if (!op->hasAttr("cache_control"))
708  return failure();
709  std::optional<ArrayAttr> optCacheControls =
710  getCacheControlMetadata(rewriter, op);
711  op->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
712  op->removeAttr("cache_control");
713  return success();
714  }
715 };
716 
717 //===----------------------------------------------------------------------===//
718 // GPU index id operations
719 //===----------------------------------------------------------------------===//
720 /*
721 // Launch Config ops
722 // dimidx - x, y, z - is fixed to i32
723 // return type is set by XeVM type converter
724 // get_local_id
725 xevm::WorkitemIdXOp;
726 xevm::WorkitemIdYOp;
727 xevm::WorkitemIdZOp;
728 // get_local_size
729 xevm::WorkgroupDimXOp;
730 xevm::WorkgroupDimYOp;
731 xevm::WorkgroupDimZOp;
732 // get_group_id
733 xevm::WorkgroupIdXOp;
734 xevm::WorkgroupIdYOp;
735 xevm::WorkgroupIdZOp;
736 // get_num_groups
737 xevm::GridDimXOp;
738 xevm::GridDimYOp;
739 xevm::GridDimZOp;
740 // get_global_id : to be added if needed
741 */
742 
743 // Helpers to get the OpenCL function name and dimension argument for each op.
744 static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdXOp) {
745  return {"get_local_id", 0};
746 }
747 static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdYOp) {
748  return {"get_local_id", 1};
749 }
750 static std::pair<StringRef, int64_t> getConfig(xevm::WorkitemIdZOp) {
751  return {"get_local_id", 2};
752 }
753 static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimXOp) {
754  return {"get_local_size", 0};
755 }
756 static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimYOp) {
757  return {"get_local_size", 1};
758 }
759 static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupDimZOp) {
760  return {"get_local_size", 2};
761 }
762 static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdXOp) {
763  return {"get_group_id", 0};
764 }
765 static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdYOp) {
766  return {"get_group_id", 1};
767 }
768 static std::pair<StringRef, int64_t> getConfig(xevm::WorkgroupIdZOp) {
769  return {"get_group_id", 2};
770 }
771 static std::pair<StringRef, int64_t> getConfig(xevm::GridDimXOp) {
772  return {"get_num_groups", 0};
773 }
774 static std::pair<StringRef, int64_t> getConfig(xevm::GridDimYOp) {
775  return {"get_num_groups", 1};
776 }
777 static std::pair<StringRef, int64_t> getConfig(xevm::GridDimZOp) {
778  return {"get_num_groups", 2};
779 }
780 /// Replace `xevm.*` with an `llvm.call` to the corresponding OpenCL func with
781 /// a constant argument for the dimension - x, y or z.
782 template <typename OpType>
783 class LaunchConfigOpToOCLPattern : public OpConversionPattern<OpType> {
785  LogicalResult
786  matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
787  ConversionPatternRewriter &rewriter) const override {
788  Location loc = op->getLoc();
789  auto [baseName, dim] = getConfig(op);
790  Type dimTy = rewriter.getI32Type();
791  Value dimVal = LLVM::ConstantOp::create(rewriter, loc, dimTy,
792  static_cast<int64_t>(dim));
793  std::string func = mangle(baseName, {dimTy}, {true});
794  Type resTy = op.getType();
795  auto call =
796  createDeviceFunctionCall(rewriter, func, resTy, {dimTy}, {dimVal}, {},
797  noUnwindWillReturnAttrs, op.getOperation());
798  constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
799  auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
800  /*other=*/noModRef,
801  /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef);
802  call.setMemoryEffectsAttr(memAttr);
803  rewriter.replaceOp(op, call);
804  return success();
805  }
806 };
807 
808 /*
809 // Subgroup ops
810 // get_sub_group_local_id
811 xevm::LaneIdOp;
812 // get_sub_group_id
813 xevm::SubgroupIdOp;
814 // get_sub_group_size
815 xevm::SubgroupSizeOp;
816 // get_num_sub_groups : to be added if needed
817 */
818 
819 // Helpers to get the OpenCL function name for each op.
820 static StringRef getConfig(xevm::LaneIdOp) { return "get_sub_group_local_id"; }
821 static StringRef getConfig(xevm::SubgroupIdOp) { return "get_sub_group_id"; }
822 static StringRef getConfig(xevm::SubgroupSizeOp) {
823  return "get_sub_group_size";
824 }
825 template <typename OpType>
826 class SubgroupOpWorkitemOpToOCLPattern : public OpConversionPattern<OpType> {
828  LogicalResult
829  matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
830  ConversionPatternRewriter &rewriter) const override {
831  std::string func = mangle(getConfig(op).str(), {});
832  Type resTy = op.getType();
833  auto call =
834  createDeviceFunctionCall(rewriter, func, resTy, {}, {}, {},
835  noUnwindWillReturnAttrs, op.getOperation());
836  constexpr auto noModRef = LLVM::ModRefInfo::NoModRef;
837  auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
838  /*other=*/noModRef,
839  /*argMem=*/noModRef, /*inaccessibleMem=*/noModRef);
840  call.setMemoryEffectsAttr(memAttr);
841  rewriter.replaceOp(op, call);
842  return success();
843  }
844 };
845 
846 //===----------------------------------------------------------------------===//
847 // Pass Definition
848 //===----------------------------------------------------------------------===//
849 
850 struct ConvertXeVMToLLVMPass
851  : public impl::ConvertXeVMToLLVMPassBase<ConvertXeVMToLLVMPass> {
852  using Base::Base;
853 
854  void getDependentDialects(DialectRegistry &registry) const override {
855  registry.insert<LLVM::LLVMDialect, XeVMDialect>();
856  }
857 
858  void runOnOperation() override {
859  ConversionTarget target(getContext());
862  if (failed(applyPartialConversion(getOperation(), target,
863  std::move(patterns))))
864  signalPassFailure();
865  }
866 };
867 } // namespace
868 
869 //===----------------------------------------------------------------------===//
870 // ConvertToLLVMPatternInterface implementation
871 //===----------------------------------------------------------------------===//
872 
873 namespace {
874 /// Implement the interface to convert XeVM to LLVM.
875 struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
877  void loadDependentDialects(MLIRContext *context) const final {
878  context->loadDialect<LLVM::LLVMDialect>();
879  }
880 
881  /// Hook for derived dialect interface to provide conversion patterns
882  /// and mark dialect legal for the conversion target.
883  void populateConvertToLLVMConversionPatterns(
884  ConversionTarget &target, LLVMTypeConverter &typeConverter,
885  RewritePatternSet &patterns) const final {
887  }
888 };
889 } // namespace
890 
891 //===----------------------------------------------------------------------===//
892 // Pattern Population
893 //===----------------------------------------------------------------------===//
894 
897  target.addDynamicallyLegalDialect<LLVM::LLVMDialect>(
898  [](Operation *op) { return !op->hasAttr("cache_control"); });
899  target.addIllegalDialect<XeVMDialect>();
900  patterns.add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
901  LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
902  LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
903  MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern,
904  LLVMLoadStoreToOCLPattern<LLVM::LoadOp>,
905  LLVMLoadStoreToOCLPattern<LLVM::StoreOp>,
906  BlockLoadStore1DToOCLPattern<BlockLoadOp>,
907  BlockLoadStore1DToOCLPattern<BlockStoreOp>,
908  LaunchConfigOpToOCLPattern<WorkitemIdXOp>,
909  LaunchConfigOpToOCLPattern<WorkitemIdYOp>,
910  LaunchConfigOpToOCLPattern<WorkitemIdZOp>,
911  LaunchConfigOpToOCLPattern<WorkgroupDimXOp>,
912  LaunchConfigOpToOCLPattern<WorkgroupDimYOp>,
913  LaunchConfigOpToOCLPattern<WorkgroupDimZOp>,
914  LaunchConfigOpToOCLPattern<WorkgroupIdXOp>,
915  LaunchConfigOpToOCLPattern<WorkgroupIdYOp>,
916  LaunchConfigOpToOCLPattern<WorkgroupIdZOp>,
917  LaunchConfigOpToOCLPattern<GridDimXOp>,
918  LaunchConfigOpToOCLPattern<GridDimYOp>,
919  LaunchConfigOpToOCLPattern<GridDimZOp>,
920  SubgroupOpWorkitemOpToOCLPattern<LaneIdOp>,
921  SubgroupOpWorkitemOpToOCLPattern<SubgroupIdOp>,
922  SubgroupOpWorkitemOpToOCLPattern<SubgroupSizeOp>>(
923  patterns.getContext());
924 }
925 
927  registry.addExtension(+[](MLIRContext *ctx, XeVMDialect *dialect) {
928  dialect->addInterfaces<XeVMToLLVMDialectInterface>();
929  });
930 }
static MLIRContext * getContext(OpFoldResult val)
UnitAttr getUnitAttr()
Definition: Builders.cpp:98
FloatType getF32Type()
Definition: Builders.cpp:43
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:276
IntegerType getI64Type()
Definition: Builders.cpp:65
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:266
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:98
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
void addDynamicallyLegalDialect(const DynamicLegalityCallbackFn &callback, StringRef name, Names... names)
Register the operations of the given dialects as dynamically legal, i.e.
void addIllegalDialect(StringRef name, Names... names)
Register the operations of the given dialects as illegal, i.e.
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:452
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes={}, Type resultType={}, bool isVarArg=false, bool isReserved=false, SymbolTableCollection *symbolTables=nullptr)
Create a FuncOp with signature resultType(paramTypes)and namename`.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:561
Include the generated interface declarations.
void populateXeVMToLLVMConversionPatterns(ConversionTarget &target, RewritePatternSet &patterns)
Definition: XeVMToLLVM.cpp:895
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void registerConvertXeVMToLLVMInterface(DialectRegistry &registry)
Definition: XeVMToLLVM.cpp:926