MLIR  22.0.0git
XeVMToLLVM.cpp
Go to the documentation of this file.
1 //===-- XeVMToLLVM.cpp - XeVM to LLVM dialect conversion --------*- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
16 #include "mlir/Pass/Pass.h"
17 #include "mlir/Support/LLVM.h"
18 #include "llvm/Support/FormatVariadic.h"
19 
20 #include "mlir/IR/BuiltinTypes.h"
21 #include "mlir/IR/Types.h"
22 
23 #include "llvm/ADT/TypeSwitch.h"
24 
25 namespace mlir {
26 #define GEN_PASS_DEF_CONVERTXEVMTOLLVMPASS
27 #include "mlir/Conversion/Passes.h.inc"
28 } // namespace mlir
29 
30 using namespace mlir;
31 using namespace xevm;
32 
33 namespace {
34 
35 struct LLVMFuncAttributeOptions {
36  bool isConvergent = false;
37  bool isNoUnwind = false;
38  bool isWillReturn = false;
39  LLVM::MemoryEffectsAttr memEffectsAttr{};
40 };
41 static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
42  false, true, false, {}};
43 static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
44  false, true, true, {}};
45 static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
46  true, true, true, {}};
47 
48 std::string getTypeMangling(Type ty, bool isUnsigned = false) {
50  .Case([isUnsigned](VectorType ty) -> std::string {
51  return "Dv" + std::to_string(ty.getNumElements()) + "_" +
52  getTypeMangling(ty.getElementType(), isUnsigned);
53  })
54  .Case([](Float16Type) -> std::string { return "Dh"; })
55  .Case([](Float32Type) -> std::string { return "f"; })
56  .Case([](Float64Type) -> std::string { return "d"; })
57  .Case([isUnsigned](IntegerType ty) -> std::string {
58  switch (ty.getWidth()) {
59  case 8:
60  return isUnsigned ? "h" : "c";
61  case 16:
62  return isUnsigned ? "t" : "s";
63  case 32:
64  return isUnsigned ? "j" : "i";
65  case 64:
66  return isUnsigned ? "m" : "l";
67  default:
68  llvm_unreachable("unhandled integer type");
69  }
70  })
71  .Default([](Type) -> std::string {
72  llvm_unreachable("unhandled type for mangling");
73  });
74 }
75 
76 std::string mangle(StringRef baseName, ArrayRef<Type> types,
77  ArrayRef<bool> isUnsigned = {}) {
78  assert((isUnsigned.empty() || isUnsigned.size() == types.size()) &&
79  "Signedness info doesn't match");
80  std::string s;
81  llvm::raw_string_ostream os(s);
82  llvm::SmallDenseMap<Type, unsigned> substitutions;
83  os << "_Z" << baseName.size() << baseName;
84  for (auto [idx, type] : llvm::enumerate(types)) {
85  auto it = substitutions.find(type);
86  if (it != substitutions.end()) {
87  os << "S";
88  // First substitution is `S_`, second is `S0_`, and so on.
89  if (unsigned firstIdx = it->getSecond(); firstIdx > 0)
90  os << firstIdx - 1;
91  os << "_";
92  } else {
93  if (!type.isIntOrFloat())
94  substitutions[type] = substitutions.size();
95  os << getTypeMangling(type, isUnsigned.empty() ? false : isUnsigned[idx]);
96  }
97  }
98  return os.str();
99 }
100 
101 template <bool isLoad, typename OpType>
102 int32_t getL1CacheControl(OpType op) {
103  int32_t control = 0;
104  if constexpr (isLoad) {
105  switch (*op.getCacheControl()) {
106  case LoadCacheControl::L1UC_L2UC_L3UC:
107  case LoadCacheControl::L1UC_L2UC_L3C:
108  case LoadCacheControl::L1UC_L2C_L3UC:
109  case LoadCacheControl::L1UC_L2C_L3C:
110  control = 1;
111  break;
112  case LoadCacheControl::L1C_L2UC_L3UC:
113  case LoadCacheControl::L1C_L2UC_L3C:
114  case LoadCacheControl::L1C_L2C_L3UC:
115  case LoadCacheControl::L1C_L2C_L3C:
116  control = 2;
117  break;
118  case LoadCacheControl::L1S_L2UC_L3UC:
119  case LoadCacheControl::L1S_L2UC_L3C:
120  case LoadCacheControl::L1S_L2C_L3UC:
121  case LoadCacheControl::L1S_L2C_L3C:
122  control = 3;
123  break;
124  case LoadCacheControl::INVALIDATE_READ:
125  control = 4;
126  break;
127  }
128  } else {
129  switch (*op.getCacheControl()) {
130  case StoreCacheControl::L1UC_L2UC_L3UC:
131  case StoreCacheControl::L1UC_L2UC_L3WB:
132  case StoreCacheControl::L1UC_L2WB_L3UC:
133  case StoreCacheControl::L1UC_L2WB_L3WB:
134  control = 1;
135  break;
136  case StoreCacheControl::L1WT_L2UC_L3UC:
137  case StoreCacheControl::L1WT_L2UC_L3WB:
138  case StoreCacheControl::L1WT_L2WB_L3UC:
139  case StoreCacheControl::L1WT_L2WB_L3WB:
140  control = 2;
141  break;
142  case StoreCacheControl::L1S_L2UC_L3UC:
143  case StoreCacheControl::L1S_L2UC_L3WB:
144  case StoreCacheControl::L1S_L2WB_L3UC:
145  case StoreCacheControl::L1S_L2WB_L3WB:
146  control = 3;
147  break;
148  case StoreCacheControl::L1WB_L2UC_L3UC:
149  case StoreCacheControl::L1WB_L2WB_L3UC:
150  case StoreCacheControl::L1WB_L2UC_L3WB:
151  control = 4;
152  break;
153  }
154  }
155  return control;
156 }
157 
158 template <bool isLoad, typename OpType>
159 int32_t getL3CacheControl(OpType op) {
160  int32_t control = 0;
161  if constexpr (isLoad) {
162  switch (*op.getCacheControl()) {
163  case LoadCacheControl::L1UC_L2UC_L3UC:
164  case LoadCacheControl::L1UC_L2C_L3UC:
165  case LoadCacheControl::L1C_L2UC_L3UC:
166  case LoadCacheControl::L1C_L2C_L3UC:
167  case LoadCacheControl::L1S_L2UC_L3UC:
168  case LoadCacheControl::L1S_L2C_L3UC:
169  control = 1;
170  break;
171  case LoadCacheControl::L1UC_L2UC_L3C:
172  case LoadCacheControl::L1UC_L2C_L3C:
173  case LoadCacheControl::L1C_L2UC_L3C:
174  case LoadCacheControl::L1C_L2C_L3C:
175  case LoadCacheControl::L1S_L2UC_L3C:
176  case LoadCacheControl::L1S_L2C_L3C:
177  control = 2;
178  break;
179  case LoadCacheControl::INVALIDATE_READ:
180  control = 4;
181  break;
182  }
183  } else {
184  switch (*op.getCacheControl()) {
185  case StoreCacheControl::L1UC_L2UC_L3UC:
186  case StoreCacheControl::L1UC_L2WB_L3UC:
187  case StoreCacheControl::L1WT_L2UC_L3UC:
188  case StoreCacheControl::L1WT_L2WB_L3UC:
189  case StoreCacheControl::L1S_L2UC_L3UC:
190  case StoreCacheControl::L1S_L2WB_L3UC:
191  case StoreCacheControl::L1WB_L2UC_L3UC:
192  case StoreCacheControl::L1WB_L2WB_L3UC:
193  control = 1;
194  break;
195  case StoreCacheControl::L1UC_L2UC_L3WB:
196  case StoreCacheControl::L1UC_L2WB_L3WB:
197  case StoreCacheControl::L1WT_L2UC_L3WB:
198  case StoreCacheControl::L1WT_L2WB_L3WB:
199  case StoreCacheControl::L1S_L2UC_L3WB:
200  case StoreCacheControl::L1S_L2WB_L3WB:
201  case StoreCacheControl::L1WB_L2UC_L3WB:
202  control = 2;
203  break;
204  }
205  }
206  return control;
207 }
208 
209 template <bool isLoad, typename OpType>
210 static std::optional<ArrayAttr>
211 getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op) {
212  if (!op.getCacheControl())
213  return {};
214  constexpr int32_t decorationCacheControlArity{4};
215  constexpr int32_t loadCacheControlKey{6442};
216  constexpr int32_t storeCacheControlKey{6443};
217  const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
219  controlKey, 0, getL1CacheControl<isLoad, OpType>(op), 0};
221  controlKey, 1, getL3CacheControl<isLoad, OpType>(op), 0};
222  auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
223  auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
224 
225  SmallVector<Attribute, 2> combinedAttrs = {arrayAttrL1, arrayAttrL3};
226  return rewriter.getArrayAttr(combinedAttrs);
227 }
228 
229 static LLVM::CallOp createDeviceFunctionCall(
230  ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
231  ArrayRef<Type> argTypes, ArrayRef<Value> args,
232  mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs,
233  LLVMFuncAttributeOptions funcAttributeOptions, Operation *op) {
234  auto moduleOp = op->getParentWithTrait<OpTrait::SymbolTable>();
235  assert(moduleOp && "Expecting module");
236  Location loc = op->getLoc();
237 
238  auto funcOpRes =
239  LLVM::lookupOrCreateFn(rewriter, moduleOp, funcName, argTypes, retType);
240  assert(!failed(funcOpRes));
241  LLVM::LLVMFuncOp funcOp = funcOpRes.value();
242  funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
243  funcOp.setConvergent(funcAttributeOptions.isConvergent);
244  funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
245  funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
246 
247  if (funcAttributeOptions.memEffectsAttr)
248  funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
249 
250  for (auto [idx, attrName] : paramAttrs)
251  funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
252 
253  auto callOp = LLVM::CallOp::create(rewriter, loc, funcOp, args);
254  callOp->setAttrs(funcOp->getAttrs());
255 
256  return callOp;
257 }
258 
259 class MMAToOCLPattern : public OpConversionPattern<xevm::MMAOp> {
261  LogicalResult
262  matchAndRewrite(xevm::MMAOp op, xevm::MMAOp::Adaptor adaptor,
263  ConversionPatternRewriter &rewriter) const override {
264  if (!op.getC()) {
265  return rewriter.notifyMatchFailure(op, "OCL requires C operand");
266  }
267  auto precisionA = op.getTypes().getA();
268  auto precisionB = op.getTypes().getB();
269  auto precisionC = op.getTypes().getC();
270  auto precisionD = op.getTypes().getD();
271  if (precisionC != precisionD) {
272  return rewriter.notifyMatchFailure(op, "type of C and D need to match");
273  }
274  if (precisionC != xevm::ElemType::S32 &&
275  precisionC != xevm::ElemType::F32 &&
276  precisionC != xevm::ElemType::F16 &&
277  precisionC != xevm::ElemType::BF16) {
278  return rewriter.notifyMatchFailure(
279  op, "type of C and D must be S32, F32, F16 or BF16");
280  }
281  if (precisionA == xevm::ElemType::S32 ||
282  precisionA == xevm::ElemType::F32) {
283  return rewriter.notifyMatchFailure(op, "type of A cannot be S32 or F32");
284  }
285  if (precisionB == xevm::ElemType::S32 ||
286  precisionB == xevm::ElemType::F32) {
287  return rewriter.notifyMatchFailure(op, "type of B cannot be S32 or F32");
288  }
289  constexpr uint32_t bitWidthPackedA{16};
290  constexpr uint32_t bitWidthPackedB{32};
291  auto loc = op.getLoc();
292 
293  auto castIfNeeded = [&](Value val, Type packedType) -> Value {
294  VectorType origTy = cast<VectorType>(val.getType());
295  const uint32_t vecBitSize =
296  origTy.getNumElements() *
297  origTy.getElementType().getIntOrFloatBitWidth();
298  VectorType newTy = VectorType::get(
299  vecBitSize / packedType.getIntOrFloatBitWidth(), packedType);
300  if (origTy != newTy)
301  val = LLVM::BitcastOp::create(rewriter, loc, newTy, val);
302  return val;
303  };
304 
305  Value a = op.getA();
306  Type packedAType = (op.getTypes().getA() == xevm::ElemType::TF32)
307  ? cast<Type>(rewriter.getF32Type())
308  : rewriter.getIntegerType(bitWidthPackedA);
309  a = castIfNeeded(a, packedAType);
310 
311  Value b = op.getB();
312  Type packedBType = (op.getTypes().getB() == xevm::ElemType::TF32)
313  ? cast<Type>(rewriter.getF32Type())
314  : rewriter.getIntegerType(bitWidthPackedB);
315  b = castIfNeeded(b, packedBType);
316 
317  Value c = op.getC();
318  VectorType cOrigTy = cast<VectorType>(c.getType());
319  VectorType resOrigTy = cast<VectorType>(op->getResultTypes()[0]);
320  assert(cOrigTy == resOrigTy && "Accumulator and result type mismatch");
321  // OCL builtins encode bfloat16 as int16
322  VectorType cTy =
323  cOrigTy.getElementType().isBF16()
324  ? VectorType::get(cOrigTy.getShape(), rewriter.getIntegerType(16))
325  : cOrigTy;
326  VectorType resTy = cTy;
327  if (cOrigTy != cTy)
328  c = LLVM::BitcastOp::create(rewriter, loc, cTy, c);
329 
330  constexpr int32_t systolicDepth{8};
331  std::string fnName =
332  llvm::formatv("intel_sub_group_{0}_{1}_matrix_mad_k{2}",
333  stringifyElemType(op.getTypes().getA()).str(),
334  stringifyElemType(op.getTypes().getB()).str(),
335  systolicDepth *
336  getNumOperandsPerDword(op.getTypes().getA()))
337  .str();
338  SmallVector<Type> argTypes{a.getType(), b.getType(), cTy};
339  fnName = mangle(fnName, argTypes);
340  SmallVector<Value> args{a, b, c};
341 
342  auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
343  /*other=*/LLVM::ModRefInfo::NoModRef,
344  /*argMem=*/LLVM::ModRefInfo::NoModRef,
345  /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
346  auto funcAttrs = convergentNoUnwindWillReturnAttrs;
347  funcAttrs.memEffectsAttr = memAttr;
348  Value result =
349  createDeviceFunctionCall(rewriter, fnName, resTy, argTypes, args, {},
350  funcAttrs, op.getOperation())
351  ->getResult(0);
352 
353  if (resOrigTy != resTy)
354  result = LLVM::BitcastOp::create(rewriter, loc, resOrigTy, result);
355 
356  rewriter.replaceOp(op, result);
357  return success();
358  }
359 
360 private:
361  static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
362  switch (pTy) {
363  case xevm::ElemType::TF32:
364  return 1;
365  case xevm::ElemType::BF16:
366  case xevm::ElemType::F16:
367  return 2;
368  case xevm::ElemType::U8:
369  case xevm::ElemType::S8:
370  return 4;
371  default:
372  llvm_unreachable("unsupported xevm::ElemType");
373  }
374  }
375 };
376 
377 class PrefetchToOCLPattern : public OpConversionPattern<PrefetchOp> {
379  LogicalResult
380  matchAndRewrite(PrefetchOp op, PrefetchOp::Adaptor adaptor,
381  ConversionPatternRewriter &rewriter) const override {
382  auto loc = op.getLoc();
383  const std::string fnName{"_Z8prefetchPU3AS1Kcm"};
384  Value one =
385  LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(), 1);
386  SmallVector<Value> args{op.getPtr(), one};
387  SmallVector<Type> argTypes;
388  for (auto arg : args)
389  argTypes.push_back(arg.getType());
390  auto funcAttr = noUnwindAttrs;
391  auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
392  /*other=*/LLVM::ModRefInfo::NoModRef,
393  /*argMem=*/LLVM::ModRefInfo::Ref,
394  /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
395  funcAttr.memEffectsAttr = memAttr;
396 
397  LLVM::CallOp call = createDeviceFunctionCall(
398  rewriter, fnName, LLVM::LLVMVoidType::get(rewriter.getContext()),
399  argTypes, args, {}, funcAttr, op.getOperation());
400  if (std::optional<ArrayAttr> optCacheControls =
401  getCacheControlMetadata<true>(rewriter, op))
402  call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
403  rewriter.eraseOp(op);
404  return success();
405  }
406 };
407 
408 class MemfenceToOCLPattern : public OpConversionPattern<MemfenceOp> {
410  LogicalResult
411  matchAndRewrite(MemfenceOp op, MemfenceOp::Adaptor adaptor,
412  ConversionPatternRewriter &rewriter) const override {
413  auto loc = op.getLoc();
414  const std::string fnName{"atomic_work_item_fence"};
415  int memScope, addrSpace;
416  switch (op.getAddrspace()) {
417  case xevm::AddrSpace::SHARED:
418  addrSpace = 1; // CLK_LOCAL_MEM_FENCE
419  break;
420  case xevm::AddrSpace::GLOBAL:
421  addrSpace = 2; // CLK_GLOBAL_MEM_FENCE
422  break;
423  default:
424  // GENERIC is not supported in OpenCL
425  return rewriter.notifyMatchFailure(
426  op, "Fence only supports global and shared address spaces.");
427  }
428  switch (op.getScope()) {
429  case xevm::MemScope::WORKGROUP:
430  memScope = 1;
431  break;
432  case xevm::MemScope::DEVICE:
433  memScope = 2;
434  break;
435  default:
436  // CLUSTER and SYSTEM are not supported in OpenCL
437  return rewriter.notifyMatchFailure(
438  op, "Fence only supports workgroup and device memory scopes.");
439  }
440  Type i32Type = rewriter.getI32Type();
441  Value acqRel = LLVM::ConstantOp::create(rewriter, loc, i32Type, 4);
442  Value memScopeConst =
443  LLVM::ConstantOp::create(rewriter, loc, i32Type, memScope);
444  Value addrSpaceConst =
445  LLVM::ConstantOp::create(rewriter, loc, i32Type, addrSpace);
446  SmallVector<Value> args{addrSpaceConst, acqRel, memScopeConst};
447  SmallVector<Type> argTypes{3, i32Type};
448  createDeviceFunctionCall(rewriter, mangle(fnName, argTypes),
450  argTypes, args, {}, noUnwindAttrs,
451  op.getOperation());
452  rewriter.eraseOp(op);
453  return success();
454  }
455 };
456 template <typename OpType>
457 class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
459  LogicalResult
460  matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
461  ConversionPatternRewriter &rewriter) const override {
462  constexpr bool isLoad = std::is_same_v<OpType, BlockLoad2dOp>;
463  constexpr bool isPrefetch = std::is_same_v<OpType, BlockPrefetch2dOp>;
464 
465  auto loc = op.getLoc();
466  VectorType vecType;
467  bool packReg = false;
468  bool transpose = false;
469  if constexpr (isLoad) {
470  vecType = op.getRes().getType();
471  packReg = op.getPackRegister();
472  transpose = op.getTranspose();
473  } else if constexpr (!isPrefetch) {
474  vecType = op.getStoredVal().getType();
475  }
476 
477  auto i32Type = rewriter.getI32Type();
478  Value byteCoord =
479  LLVM::UndefOp::create(rewriter, loc, VectorType::get(2, i32Type));
480  Value zero = LLVM::ConstantOp::create(rewriter, loc, i32Type, 0);
481  Value one = LLVM::ConstantOp::create(rewriter, loc, i32Type, 1);
482  byteCoord = LLVM::InsertElementOp::create(
483  rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
484  byteCoord = LLVM::InsertElementOp::create(
485  rewriter, loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
486  SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
487  op.getBasePitch(), byteCoord};
488  SmallVector<Type> retTypes;
489  Value spvLoadDstPtr;
490  std::string funcName{"intel_sub_group_2d_block_"};
491  std::string bitWidthId;
492  LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
494  if constexpr (isPrefetch) { // Prefetch
495  funcName += "prefetch";
496  paramAttrs = {std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName())};
497  auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
498  /*other=*/LLVM::ModRefInfo::NoModRef,
499  /*argMem=*/LLVM::ModRefInfo::Ref,
500  /*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
501  funcAttr = noUnwindAttrs;
502  funcAttr.memEffectsAttr = memAttr;
503  } else {
504  auto vecElemType = vecType.getElementType();
505  auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth();
506  Value numElems = LLVM::ConstantOp::create(rewriter, loc, i32Type,
507  vecType.getNumElements());
508  auto dstOrSrcPtr = LLVM::AllocaOp::create(
509  rewriter, loc, LLVM::LLVMPointerType::get(rewriter.getContext()),
510  vecElemType, numElems);
511  args.push_back(dstOrSrcPtr);
512  if constexpr (isLoad) { // Load
513  funcName += "read";
514  bitWidthId = getTypeMangling(vecElemType, /*isUnsigned=*/true);
515  if (packReg)
516  funcName += "_transform";
517  else if (transpose)
518  funcName += "_transpose";
519  spvLoadDstPtr = dstOrSrcPtr;
520  retTypes.push_back(vecType);
521  paramAttrs = {
522  std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
523  std::make_pair(0, LLVM::LLVMDialect::getReadonlyAttrName()),
524  std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
525  std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()),
526  };
527  } else { // Store
528  funcName += "write";
529  bitWidthId = (vecElemBitWidth == 32)
530  ? "j"
531  : ((vecElemBitWidth == 16) ? "t" : "h");
532  LLVM::StoreOp::create(rewriter, loc, op.getStoredVal(), dstOrSrcPtr);
533  paramAttrs = {
534  std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
535  std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()),
536  std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
537  std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()),
538  };
539  }
540  }
541 
542  funcName =
543  llvm::formatv("{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(),
544  op.getTileHeight(), op.getTileWidth(), op.getVBlocks())
545  .str();
546  std::string prefetchCode("");
547  if (!isPrefetch)
548  prefetchCode += "P";
549  funcName = llvm::formatv("_Z{0}{1}PU3AS1viiiDv2_i{2}{3}", funcName.size(),
550  funcName, prefetchCode, bitWidthId)
551  .str();
552  SmallVector<Type> argTypes;
553  for (auto arg : args) {
554  argTypes.push_back(arg.getType());
555  }
556  LLVM::CallOp call = createDeviceFunctionCall(
557  rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
558  argTypes, args, paramAttrs, funcAttr, op.getOperation());
559  if (std::optional<ArrayAttr> optCacheControls =
560  getCacheControlMetadata < isLoad || isPrefetch > (rewriter, op)) {
561  call->setAttr(XeVMDialect::getCacheControlsAttrName(), *optCacheControls);
562  }
563  if constexpr (isLoad)
564  rewriter.replaceOp(
565  op, LLVM::LoadOp::create(rewriter, loc, vecType, spvLoadDstPtr));
566  else
567  rewriter.eraseOp(op);
568  return success();
569  }
570 };
571 
572 //===----------------------------------------------------------------------===//
573 // Pass Definition
574 //===----------------------------------------------------------------------===//
575 
576 struct ConvertXeVMToLLVMPass
577  : public impl::ConvertXeVMToLLVMPassBase<ConvertXeVMToLLVMPass> {
578  using Base::Base;
579 
580  void getDependentDialects(DialectRegistry &registry) const override {
581  registry.insert<LLVM::LLVMDialect, XeVMDialect>();
582  }
583 
584  void runOnOperation() override {
585  ConversionTarget target(getContext());
586  target.addLegalDialect<LLVM::LLVMDialect>();
587  target.addIllegalDialect<XeVMDialect>();
590  if (failed(applyPartialConversion(getOperation(), target,
591  std::move(patterns))))
592  signalPassFailure();
593  }
594 };
595 } // namespace
596 
597 //===----------------------------------------------------------------------===//
598 // ConvertToLLVMPatternInterface implementation
599 //===----------------------------------------------------------------------===//
600 
601 namespace {
602 /// Implement the interface to convert XeVM to LLVM.
603 struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
605  void loadDependentDialects(MLIRContext *context) const final {
606  context->loadDialect<LLVM::LLVMDialect>();
607  }
608 
609  /// Hook for derived dialect interface to provide conversion patterns
610  /// and mark dialect legal for the conversion target.
611  void populateConvertToLLVMConversionPatterns(
612  ConversionTarget &target, LLVMTypeConverter &typeConverter,
613  RewritePatternSet &patterns) const final {
615  }
616 };
617 } // namespace
618 
619 //===----------------------------------------------------------------------===//
620 // Pattern Population
621 //===----------------------------------------------------------------------===//
622 
624  patterns.add<LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
625  LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
626  LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>,
627  MMAToOCLPattern, MemfenceToOCLPattern, PrefetchToOCLPattern>(
628  patterns.getContext());
629 }
630 
632  registry.addExtension(+[](MLIRContext *ctx, XeVMDialect *dialect) {
633  dialect->addInterfaces<XeVMToLLVMDialectInterface>();
634  });
635 }
static MLIRContext * getContext(OpFoldResult val)
UnitAttr getUnitAttr()
Definition: Builders.cpp:97
FloatType getF32Type()
Definition: Builders.cpp:42
ArrayAttr getI32ArrayAttr(ArrayRef< int32_t > values)
Definition: Builders.cpp:275
IntegerType getI64Type()
Definition: Builders.cpp:64
IntegerType getI32Type()
Definition: Builders.cpp:62
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getArrayAttr(ArrayRef< Attribute > value)
Definition: Builders.cpp:265
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition: Builders.h:98
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
Base class for dialect interfaces providing translation to LLVM IR.
ConvertToLLVMPatternInterface(Dialect *dialect)
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
bool addExtension(TypeID extensionID, std::unique_ptr< DialectExtensionBase > extension)
Add the given extension to the registry.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
A trait used to provide symbol table functionalities to a region operation.
Definition: SymbolTable.h:452
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentWithTrait()
Returns the closest surrounding parent operation with trait Trait.
Definition: Operation.h:248
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
FailureOr< LLVM::LLVMFuncOp > lookupOrCreateFn(OpBuilder &b, Operation *moduleOp, StringRef name, ArrayRef< Type > paramTypes={}, Type resultType={}, bool isVarArg=false, bool isReserved=false, SymbolTableCollection *symbolTables=nullptr)
Create a FuncOp with signature resultType(paramTypes)and namename`.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
Include the generated interface declarations.
void populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns)
Definition: XeVMToLLVM.cpp:623
const FrozenRewritePatternSet & patterns
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.
void registerConvertXeVMToLLVMInterface(DialectRegistry &registry)
Definition: XeVMToLLVM.cpp:631