MLIR  22.0.0git
XeGPUToXeVM.cpp
Go to the documentation of this file.
1 //===-- XeGPUToXeVM.cpp - XeGPU to XeVM dialect conversion ------*- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Support/LLVM.h"
26 #include "llvm/Support/FormatVariadic.h"
27 
28 #include "mlir/IR/BuiltinTypes.h"
29 #include "mlir/IR/Types.h"
30 
31 #include "llvm/ADT/TypeSwitch.h"
32 
33 #include <numeric>
34 
35 namespace mlir {
36 #define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
37 #include "mlir/Conversion/Passes.h.inc"
38 } // namespace mlir
39 
40 using namespace mlir;
41 
42 namespace {
43 
44 // TODO: Below are uArch dependent values, should move away from hardcoding
45 static constexpr int32_t systolicDepth{8};
46 static constexpr int32_t executionSize{16};
47 
48 // Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
49 enum class NdTdescOffset : uint32_t {
50  BasePtr = 0, // Base pointer (i64)
51  BaseShapeW = 2, // Base shape width (i32)
52  BaseShapeH = 3, // Base shape height (i32)
53  TensorOffsetW = 4, // Tensor offset W (i32)
54  TensorOffsetH = 5 // Tensor offset H (i32)
55 };
56 
57 static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
58  switch (xeGpuMemspace) {
59  case xegpu::MemorySpace::Global:
60  return static_cast<int>(xevm::AddrSpace::GLOBAL);
61  case xegpu::MemorySpace::SLM:
62  return static_cast<int>(xevm::AddrSpace::SHARED);
63  }
64 }
65 
66 // Get same bitwidth flat vector type of new element type.
67 static VectorType encodeVectorTypeTo(VectorType currentVecType,
68  Type toElemType) {
69  auto elemType = currentVecType.getElementType();
70  auto currentBitWidth = elemType.getIntOrFloatBitWidth();
71  auto newBitWidth = toElemType.getIntOrFloatBitWidth();
72  const int size =
73  currentVecType.getNumElements() * currentBitWidth / newBitWidth;
74  return VectorType::get(size, toElemType);
75 }
76 
77 static xevm::LoadCacheControl
78 translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
79  std::optional<xegpu::CachePolicy> L3hint) {
80  auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
81  auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
82  switch (L1hintVal) {
83  case xegpu::CachePolicy::CACHED:
84  if (L3hintVal == xegpu::CachePolicy::CACHED)
85  return xevm::LoadCacheControl::L1C_L2UC_L3C;
86  else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
87  return xevm::LoadCacheControl::L1C_L2UC_L3UC;
88  else
89  llvm_unreachable("Unsupported cache control.");
90  case xegpu::CachePolicy::UNCACHED:
91  if (L3hintVal == xegpu::CachePolicy::CACHED)
92  return xevm::LoadCacheControl::L1UC_L2UC_L3C;
93  else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
94  return xevm::LoadCacheControl::L1UC_L2UC_L3UC;
95  else
96  llvm_unreachable("Unsupported cache control.");
97  case xegpu::CachePolicy::STREAMING:
98  if (L3hintVal == xegpu::CachePolicy::CACHED)
99  return xevm::LoadCacheControl::L1S_L2UC_L3C;
100  else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
101  return xevm::LoadCacheControl::L1S_L2UC_L3UC;
102  else
103  llvm_unreachable("Unsupported cache control.");
104  case xegpu::CachePolicy::READ_INVALIDATE:
105  return xevm::LoadCacheControl::INVALIDATE_READ;
106  default:
107  llvm_unreachable("Unsupported cache control.");
108  }
109 }
110 
111 static xevm::StoreCacheControl
112 translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
113  std::optional<xegpu::CachePolicy> L3hint) {
114  auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
115  auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
116  switch (L1hintVal) {
117  case xegpu::CachePolicy::UNCACHED:
118  if (L3hintVal == xegpu::CachePolicy::UNCACHED)
119  return xevm::StoreCacheControl::L1UC_L2UC_L3UC;
120  else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
121  return xevm::StoreCacheControl::L1UC_L2UC_L3WB;
122  else
123  llvm_unreachable("Unsupported cache control.");
124  case xegpu::CachePolicy::STREAMING:
125  if (L3hintVal == xegpu::CachePolicy::UNCACHED)
126  return xevm::StoreCacheControl::L1S_L2UC_L3UC;
127  else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
128  return xevm::StoreCacheControl::L1S_L2UC_L3WB;
129  else
130  llvm_unreachable("Unsupported cache control.");
131  case xegpu::CachePolicy::WRITE_BACK:
132  if (L3hintVal == xegpu::CachePolicy::UNCACHED)
133  return xevm::StoreCacheControl::L1WB_L2UC_L3UC;
134  else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
135  return xevm::StoreCacheControl::L1WB_L2UC_L3WB;
136  else
137  llvm_unreachable("Unsupported cache control.");
138  case xegpu::CachePolicy::WRITE_THROUGH:
139  if (L3hintVal == xegpu::CachePolicy::UNCACHED)
140  return xevm::StoreCacheControl::L1WT_L2UC_L3UC;
141  else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
142  return xevm::StoreCacheControl::L1WT_L2UC_L3WB;
143  else
144  llvm_unreachable("Unsupported cache control.");
145  default:
146  llvm_unreachable("Unsupported cache control.");
147  }
148 }
149 
150 class CreateNdDescToXeVMPattern
151  : public OpConversionPattern<xegpu::CreateNdDescOp> {
153  LogicalResult
154  matchAndRewrite(xegpu::CreateNdDescOp op,
155  xegpu::CreateNdDescOp::Adaptor adaptor,
156  ConversionPatternRewriter &rewriter) const override {
157  auto loc = op.getLoc();
158  auto source = op.getSource();
159  // Op is lowered to a code sequence that populates payload.
160  // Payload is a 8xi32 vector. Offset to individual fields are defined in
161  // NdTdescOffset enum.
162  Type payloadElemTy = rewriter.getI32Type();
163  VectorType payloadTy = VectorType::get(8, payloadElemTy);
164  Type i64Ty = rewriter.getI64Type();
165  // 4xi64 view is used for inserting the base pointer.
166  VectorType payloadI64Ty = VectorType::get(4, i64Ty);
167  // Initialize payload to zero.
168  Value payload = arith::ConstantOp::create(
169  rewriter, loc,
170  DenseElementsAttr::get(payloadTy, IntegerAttr::get(payloadElemTy, 0)));
171 
172  Value baseAddr;
173  Value baseShapeW;
174  Value baseShapeH;
175  Value offsetW;
176  Value offsetH;
177 
178  // Source can be a memref or a pointer (ui64, ui32, i64 or i32).
179  SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
180  SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
181  // Descriptor shape is expected to be 2D.
182  int64_t rank = mixedSizes.size();
183  if (rank != 2)
184  return rewriter.notifyMatchFailure(op, "Expected 2D shape.");
185  auto sourceTy = source.getType();
186  auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
187  // If source is a memref, we need to extract the aligned pointer as index.
188  // Pointer type is passed as i32 or i64 by type converter.
189  if (sourceMemrefTy) {
190  if (!sourceMemrefTy.hasStaticShape()) {
191  return rewriter.notifyMatchFailure(op, "Expected static memref shape.");
192  }
193  baseAddr =
194  memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
195  } else {
196  baseAddr = adaptor.getSource();
197  }
198  // Utility for creating offset values from op fold result.
199  auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
200  unsigned idx) -> Value {
201  Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofrVec[idx]);
202  val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
203  return val;
204  };
205  // Offsets can be either 2D or not provided (0 is used).
206  if (mixedOffsets.size() == 2) {
207  offsetW = createOffset(mixedOffsets, 1);
208  offsetH = createOffset(mixedOffsets, 0);
209  } else if (mixedOffsets.size() == 0) {
210  offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
211  offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
212  } else {
213  return rewriter.notifyMatchFailure(op,
214  "Expected 2D offsets or no offsets.");
215  }
216  // Get shape values from op fold results.
217  baseShapeW = createOffset(mixedSizes, 1);
218  baseShapeH = createOffset(mixedSizes, 0);
219  if (sourceMemrefTy) {
220  // Cast index to i64.
221  baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
222  } else if (baseAddr.getType() != i64Ty) {
223  // Pointer type may be i32. Cast to i64 if needed.
224  baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
225  }
226  // Populate payload.
227  Value payLoadAsI64 =
228  vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
229  payLoadAsI64 =
230  vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
231  static_cast<int>(NdTdescOffset::BasePtr));
232  payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
233  payload =
234  vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
235  static_cast<int>(NdTdescOffset::BaseShapeW));
236  payload =
237  vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
238  static_cast<int>(NdTdescOffset::BaseShapeH));
239  payload = vector::InsertOp::create(
240  rewriter, loc, offsetW, payload,
241  static_cast<int>(NdTdescOffset::TensorOffsetW));
242  payload = vector::InsertOp::create(
243  rewriter, loc, offsetH, payload,
244  static_cast<int>(NdTdescOffset::TensorOffsetH));
245  rewriter.replaceOp(op, payload);
246  return success();
247  }
248 };
249 
250 class UpdateNdOffsetToXeVMPattern
251  : public OpConversionPattern<xegpu::UpdateNdOffsetOp> {
253  LogicalResult
254  matchAndRewrite(xegpu::UpdateNdOffsetOp op,
255  xegpu::UpdateNdOffsetOp::Adaptor adaptor,
256  ConversionPatternRewriter &rewriter) const override {
257  auto loc = op.getLoc();
258  auto mixedOffsets = op.getMixedOffsets();
259  // Only 2D offsets are supported for now.
260  if (mixedOffsets.size() != 2)
261  return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
262  auto payload = adaptor.getTensorDesc();
263  // Utility for updating payload offset values from op fold result.
264  auto updateOffset = [&](unsigned idx, int payloadPos) -> Value {
265  Value offset =
266  getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[idx]);
267  offset = getValueOrCreateCastToIndexLike(rewriter, loc,
268  rewriter.getI32Type(), offset);
269  Value oldOffset =
270  vector::ExtractOp::create(rewriter, loc, payload, payloadPos);
271  Value newOffset = arith::AddIOp::create(rewriter, loc, oldOffset, offset);
272  return vector::InsertOp::create(rewriter, loc, newOffset, payload,
273  payloadPos);
274  };
275  // Update offsets in the payload.
276  payload = updateOffset(0, static_cast<int>(NdTdescOffset::TensorOffsetH));
277  payload = updateOffset(1, static_cast<int>(NdTdescOffset::TensorOffsetW));
278  rewriter.replaceOp(op, payload);
279  return success();
280  }
281 };
282 
283 template <
284  typename OpType,
285  typename = std::enable_if_t<llvm::is_one_of<
286  OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp>::value>>
287 class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
289  LogicalResult
290  matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
291  ConversionPatternRewriter &rewriter) const override {
292  auto loc = op.getLoc();
293  auto ctxt = rewriter.getContext();
294 
295  auto tdesc = adaptor.getTensorDesc();
296  auto tdescTy = op.getTensorDescType();
297  if (tdescTy.getRank() != 2)
298  return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor.");
299  auto elemType = tdescTy.getElementType();
300  auto elemBitSize = elemType.getIntOrFloatBitWidth();
301  if (elemBitSize % 8 != 0)
302  return rewriter.notifyMatchFailure(
303  op, "Expected element type bit width to be multiple of 8.");
304 
305  VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
306  Value payLoadAsI64 =
307  vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
308  Value basePtr = vector::ExtractOp::create(
309  rewriter, loc, payLoadAsI64, static_cast<int>(NdTdescOffset::BasePtr));
310  Value baseShapeW = vector::ExtractOp::create(
311  rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
312  Value baseShapeH = vector::ExtractOp::create(
313  rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
314  // Offsets provided in two ways:
315  // 1. Offsets are extracted from the tensor descriptor.
316  // 2. (Mixed) offsets which are provided by the op.
317  Value offsetW;
318  Value offsetH;
319  auto mixedOffsets = op.getMixedOffsets();
320  int64_t opOffsetsSize = mixedOffsets.size();
321  if (opOffsetsSize != 0 && opOffsetsSize != 2)
322  return rewriter.notifyMatchFailure(op,
323  "Expected 2D offsets or no offsets.");
324  if (opOffsetsSize) {
325  // If mixed offsets are provided by the op convert them to i32.
326  offsetW = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
327  offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
328  rewriter.getI32Type(), offsetW);
329  offsetH = getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
330  offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
331  rewriter.getI32Type(), offsetH);
332  } else {
333  // If offsets are not available, we need to extract them from the tensor
334  // descriptor.
335  offsetW = vector::ExtractOp::create(
336  rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetW));
337  offsetH = vector::ExtractOp::create(
338  rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::TensorOffsetH));
339  }
340  // Get address space from tensor descriptor memory space.
341  auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
342  ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
343  // Convert base pointer (i64) to LLVM pointer type.
344  Value basePtrLLVM =
345  LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
346  // Compute element byte size and surface width in bytes.
347  Value elemByteSize = arith::ConstantIntOp::create(
348  rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
349  Value surfaceW =
350  arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
351 
352  // Get tile sizes and vblocks from the tensor descriptor type.
353  auto tileW = tdescTy.getDimSize(1);
354  auto tileH = tdescTy.getDimSize(0);
355  int32_t vblocks = tdescTy.getArrayLength();
356  if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
357  Value src = adaptor.getValue();
358  // If store value is a scalar, get value from op instead of adaptor.
359  // Adaptor might have optimized away single element vector
360  if (src.getType().isIntOrFloat()) {
361  src = op.getValue();
362  }
363  VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
364  if (!srcVecTy)
365  return rewriter.notifyMatchFailure(
366  op, "Expected store value to be a vector type.");
367  // Get flat vector type of integer type with matching element bit size.
368  VectorType newSrcVecTy =
369  encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
370  if (srcVecTy != newSrcVecTy)
371  src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
372  auto storeCacheControl =
373  translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
374  xevm::BlockStore2dOp::create(
375  rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
376  offsetH, elemBitSize, tileW, tileH, src,
377  xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
378  rewriter.eraseOp(op);
379  } else {
380  auto loadCacheControl =
381  translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
382  if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
383  xevm::BlockPrefetch2dOp::create(
384  rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
385  offsetH, elemBitSize, tileW, tileH, vblocks,
386  xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
387  rewriter.eraseOp(op);
388  } else {
389  VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
390  const bool vnni = op.getPacked().value_or(false);
391  auto transposeValue = op.getTranspose();
392  bool transpose =
393  transposeValue.has_value() && transposeValue.value()[0] == 1;
394  VectorType loadedTy = encodeVectorTypeTo(
395  dstVecTy, vnni ? rewriter.getI32Type()
396  : rewriter.getIntegerType(elemBitSize));
397 
398  Value resultFlatVec = xevm::BlockLoad2dOp::create(
399  rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
400  surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
401  transpose, vnni,
402  xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
403  resultFlatVec = vector::BitCastOp::create(
404  rewriter, loc,
405  encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
406  resultFlatVec);
407  rewriter.replaceOp(op, resultFlatVec);
408  }
409  }
410  return success();
411  }
412 };
413 
414 // Add a builder that creates
415 // offset * elemByteSize + baseAddr
416 static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
417  Value baseAddr, Value offset, int64_t elemByteSize) {
419  rewriter, loc, rewriter.getI64Type(), elemByteSize);
420  Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
421  Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
422  return newAddr;
423 }
424 
425 class CreateDescToXeVMPattern
426  : public OpConversionPattern<xegpu::CreateDescOp> {
428  LogicalResult
429  matchAndRewrite(xegpu::CreateDescOp op, xegpu::CreateDescOp::Adaptor adaptor,
430  ConversionPatternRewriter &rewriter) const override {
431  auto eTy = op.getTensorDescType().getElementType();
432  auto eBw = eTy.getIntOrFloatBitWidth();
433  if (eBw % 8 != 0)
434  return rewriter.notifyMatchFailure(
435  op, "Expected element type bit width to be multiple of 8.");
436  auto loc = op.getLoc();
437  // Offsets are provided as scalar i64 by type converter.
438  auto offsets = adaptor.getOffsets();
439  // Source type can be a 1D memref or pointer type (ui64, ui32, i64 or i32).
440  // But type converter will convert them to integer types.
441  Value addr = adaptor.getSource();
442  // ui32 or i32 are passed as i32 so they need to be casted to i64.
443  if (addr.getType() != rewriter.getI64Type())
444  addr = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(), addr);
445  auto laneAddr = addOffset(rewriter, loc, addr, offsets, eBw / 8);
446  rewriter.replaceOp(op, laneAddr);
447  return success();
448  }
449 };
450 
451 class UpdateOffsetToXeVMPattern
452  : public OpConversionPattern<xegpu::UpdateOffsetOp> {
454  LogicalResult
455  matchAndRewrite(xegpu::UpdateOffsetOp op,
456  xegpu::UpdateOffsetOp::Adaptor adaptor,
457  ConversionPatternRewriter &rewriter) const override {
458  auto eTy = op.getTensorDescType().getElementType();
459  auto eBw = eTy.getIntOrFloatBitWidth();
460  if (eBw % 8 != 0)
461  return rewriter.notifyMatchFailure(
462  op, "Expected element type bit width to be multiple of 8.");
463  auto loc = op.getLoc();
464  // Scatter descriptor is provided as scalar i64 by type converter.
465  // Offsets are provided as scalar i64 by type converter.
466  Value newOffset = addOffset(rewriter, loc, adaptor.getTensorDesc(),
467  adaptor.getOffsets(), eBw / 8);
468  rewriter.replaceOp(op, newOffset);
469  return success();
470  }
471 };
472 
473 template <typename OpType,
474  typename = std::enable_if_t<llvm::is_one_of<
475  OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
476 class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
478  LogicalResult
479  matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
480  ConversionPatternRewriter &rewriter) const override {
481  auto loc = op.getLoc();
482  auto ctxt = rewriter.getContext();
483  auto tdescTy = op.getTensorDescType();
484  Value basePtrI64;
485  // Load result or Store valye Type can be vector or scalar.
486  Type valOrResTy;
487  if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
488  valOrResTy = op.getResult().getType();
489  else
490  valOrResTy = adaptor.getValue().getType();
491  VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
492  bool hasScalarVal = !valOrResVecTy;
493  int64_t elemBitWidth =
494  hasScalarVal ? valOrResTy.getIntOrFloatBitWidth()
495  : valOrResVecTy.getElementType().getIntOrFloatBitWidth();
496  // Element type must be multiple of 8 bits.
497  if (elemBitWidth % 8 != 0)
498  return rewriter.notifyMatchFailure(
499  op, "Expected element type bit width to be multiple of 8.");
500  int64_t elemByteSize = elemBitWidth / 8;
501  // Default memory space is global.
502  LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
503  ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
504  // If tensor descriptor is available, we use its memory space.
505  if (tdescTy)
506  ptrTypeLLVM = LLVM::LLVMPointerType::get(
507  ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
508  // Base pointer can come from source (load) or dest (store).
509  // If they are memrefs, we use their memory space.
510  if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
511  basePtrI64 = adaptor.getSource();
512  if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
513  auto addrSpace = memRefTy.getMemorySpaceAsInt();
514  if (addrSpace != 0)
515  ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
516  }
517  } else {
518  basePtrI64 = adaptor.getDest();
519  if (auto memRefTy = dyn_cast<MemRefType>(op.getDest().getType())) {
520  auto addrSpace = memRefTy.getMemorySpaceAsInt();
521  if (addrSpace != 0)
522  ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
523  }
524  }
525  // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
526  if (basePtrI64.getType() != rewriter.getI64Type()) {
527  basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
528  basePtrI64);
529  }
530  Value offsets = adaptor.getOffsets();
531  Value mask = adaptor.getMask();
532  if (offsets) {
533  if (dyn_cast<VectorType>(offsets.getType())) {
534  // Offset needs be scalar. Single element vector is converted to scalar
535  // by type converter.
536  return rewriter.notifyMatchFailure(op,
537  "Expected offsets to be a scalar.");
538  } else {
539  // If offsets are provided, we add them to the base pointer.
540  // Offsets are in number of elements, we need to multiply by
541  // element byte size.
542  basePtrI64 =
543  addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
544  }
545  }
546  // Convert base pointer (i64) to LLVM pointer type.
547  Value basePtrLLVM =
548  LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
549 
550  Value maskForLane;
551  VectorType maskVecTy = dyn_cast<VectorType>(mask.getType());
552  if (maskVecTy) {
553  // Mask needs be scalar. Single element vector is converted to scalar by
554  // type converter.
555  return rewriter.notifyMatchFailure(op, "Expected mask to be a scalar.");
556  } else
557  maskForLane = mask;
558  if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
559  scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy},
560  maskForLane, true, true);
561  // If mask is true,- then clause - load from memory and yield.
562  rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
563  if (!hasScalarVal)
564  valOrResTy = VectorType::get({valOrResVecTy.getNumElements()},
565  valOrResVecTy.getElementType());
566  Value loaded =
567  LLVM::LoadOp::create(rewriter, loc, valOrResTy, basePtrLLVM);
568  // Set cache control attribute on the load operation.
569  loaded.getDefiningOp()->setAttr(
570  "cache_control", xevm::LoadCacheControlAttr::get(
571  ctxt, translateLoadXeGPUCacheHint(
572  op.getL1Hint(), op.getL3Hint())));
573  scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
574  rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
575  // If mask is false - else clause -yield a vector of zeros.
576  auto eTy = hasScalarVal ? valOrResTy : valOrResVecTy.getElementType();
577  TypedAttr eVal;
578  if (eTy.isFloat())
579  eVal = FloatAttr::get(eTy, 0.0);
580  else
581  eVal = IntegerAttr::get(eTy, 0);
582  if (hasScalarVal)
583  loaded = arith::ConstantOp::create(rewriter, loc, eVal);
584  else
585  loaded = arith::ConstantOp::create(
586  rewriter, loc, DenseElementsAttr::get(valOrResVecTy, eVal));
587  scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
588  rewriter.replaceOp(op, ifOp.getResult(0));
589  } else {
590  // If mask is true, perform the store.
591  scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane, false);
592  auto body = ifOp.getBody();
593  rewriter.setInsertionPointToStart(body);
594  auto storeOp =
595  LLVM::StoreOp::create(rewriter, loc, adaptor.getValue(), basePtrLLVM);
596  // Set cache control attribute on the store operation.
597  storeOp.getOperation()->setAttr(
598  "cache_control", xevm::StoreCacheControlAttr::get(
599  ctxt, translateStoreXeGPUCacheHint(
600  op.getL1Hint(), op.getL3Hint())));
601  rewriter.eraseOp(op);
602  }
603  return success();
604  }
605 };
606 
607 class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
609  LogicalResult
610  matchAndRewrite(xegpu::PrefetchOp op, xegpu::PrefetchOp::Adaptor adaptor,
611  ConversionPatternRewriter &rewriter) const override {
612  auto loc = op.getLoc();
613  auto ctxt = rewriter.getContext();
614  auto tdescTy = op.getTensorDescType();
615  Value basePtrI64 = adaptor.getSource();
616  // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
617  if (basePtrI64.getType() != rewriter.getI64Type())
618  basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
619  basePtrI64);
620  Value offsets = adaptor.getOffsets();
621  if (offsets) {
622  VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType());
623  if (offsetsVecTy) {
624  // Offset needs be scalar.
625  return rewriter.notifyMatchFailure(op,
626  "Expected offsets to be a scalar.");
627  } else {
628  int64_t elemBitWidth{0};
629  int64_t elemByteSize;
630  // Element byte size can come from three sources:
631  if (tdescTy) {
632  // If tensor descriptor is available, we use its element type to
633  // determine element byte size.
634  elemBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth();
635  } else if (auto memRefTy = dyn_cast<MemRefType>(op.getSourceType())) {
636  // If memref is available, we use its element type to
637  // determine element byte size.
638  elemBitWidth = memRefTy.getElementType().getIntOrFloatBitWidth();
639  } else {
640  // Otherwise, we use the provided offset byte alignment.
641  elemByteSize = *op.getOffsetAlignByte();
642  }
643  if (elemBitWidth != 0) {
644  if (elemBitWidth % 8 != 0)
645  return rewriter.notifyMatchFailure(
646  op, "Expected element type bit width to be multiple of 8.");
647  elemByteSize = elemBitWidth / 8;
648  }
649  basePtrI64 =
650  addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
651  }
652  }
653  // Default memory space is global.
654  LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
655  ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
656  // If tensor descriptor is available, we use its memory space.
657  if (tdescTy)
658  ptrTypeLLVM = LLVM::LLVMPointerType::get(
659  ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
660  // If source is a memref, we use its memory space.
661  if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
662  auto addrSpace = memRefTy.getMemorySpaceAsInt();
663  if (addrSpace != 0)
664  ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
665  }
666  // Convert base pointer (i64) to LLVM pointer type.
667  Value ptrLLVM =
668  LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
669  // Create the prefetch op with cache control attribute.
670  xevm::PrefetchOp::create(
671  rewriter, loc, ptrLLVM,
673  ctxt, translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint())));
674  rewriter.eraseOp(op);
675  return success();
676  }
677 };
678 
679 class FenceToXeVMPattern : public OpConversionPattern<xegpu::FenceOp> {
681  LogicalResult
682  matchAndRewrite(xegpu::FenceOp op, xegpu::FenceOp::Adaptor adaptor,
683  ConversionPatternRewriter &rewriter) const override {
684  auto loc = op.getLoc();
685  xevm::MemScope memScope{xevm::MemScope::WORKGROUP};
686  switch (op.getFenceScope()) {
687  case xegpu::FenceScope::Workgroup:
688  memScope = xevm::MemScope::WORKGROUP;
689  break;
690  case xegpu::FenceScope::GPU:
691  memScope = xevm::MemScope::DEVICE;
692  break;
693  }
694  xevm::AddrSpace addrSpace{xevm::AddrSpace::GLOBAL};
695  switch (op.getMemoryKind()) {
696  case xegpu::MemorySpace::Global:
697  addrSpace = xevm::AddrSpace::GLOBAL;
698  break;
699  case xegpu::MemorySpace::SLM:
700  addrSpace = xevm::AddrSpace::SHARED;
701  break;
702  }
703  xevm::MemfenceOp::create(rewriter, loc, memScope, addrSpace);
704  rewriter.eraseOp(op);
705  return success();
706  }
707 };
708 
709 class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
711  LogicalResult
712  matchAndRewrite(xegpu::DpasOp op, xegpu::DpasOp::Adaptor adaptor,
713  ConversionPatternRewriter &rewriter) const override {
714  auto loc = op.getLoc();
715  auto ctxt = rewriter.getContext();
716  auto aTy = cast<VectorType>(op.getLhs().getType());
717  auto bTy = cast<VectorType>(op.getRhs().getType());
718  auto resultType = cast<VectorType>(op.getResultType());
719 
720  auto encodePrecision = [&](Type type) -> xevm::ElemType {
721  if (type == rewriter.getBF16Type())
722  return xevm::ElemType::BF16;
723  else if (type == rewriter.getF16Type())
724  return xevm::ElemType::F16;
725  else if (type == rewriter.getTF32Type())
726  return xevm::ElemType::TF32;
727  else if (type.isInteger(8)) {
728  if (type.isUnsignedInteger())
729  return xevm::ElemType::U8;
730  return xevm::ElemType::S8;
731  } else if (type == rewriter.getF32Type())
732  return xevm::ElemType::F32;
733  else if (type.isInteger(32))
734  return xevm::ElemType::S32;
735  llvm_unreachable("add more support for ElemType");
736  };
737  xevm::ElemType precATy = encodePrecision(aTy.getElementType());
738  xevm::ElemType precBTy = encodePrecision(bTy.getElementType());
739  Value c = op.getAcc();
740  if (!c) {
741  auto elementTy = resultType.getElementType();
742  Attribute initValueAttr;
743  if (isa<FloatType>(elementTy))
744  initValueAttr = FloatAttr::get(elementTy, 0.0);
745  else
746  initValueAttr = IntegerAttr::get(elementTy, 0);
747  c = arith::ConstantOp::create(
748  rewriter, loc, DenseElementsAttr::get(resultType, initValueAttr));
749  }
750 
751  Value aVec = op.getLhs();
752  Value bVec = op.getRhs();
753  auto cvecty = cast<VectorType>(c.getType());
754  xevm::ElemType precCTy = encodePrecision(cvecty.getElementType());
755  xevm::ElemType precDTy = encodePrecision(resultType.getElementType());
756  VectorType cNty =
757  VectorType::get(cvecty.getNumElements(), cvecty.getElementType());
758  if (cvecty != cNty)
759  c = vector::ShapeCastOp::create(rewriter, loc, cNty, c);
760  Value dpasRes = xevm::MMAOp::create(
761  rewriter, loc, cNty, aVec, bVec, c,
762  xevm::MMAShapeAttr::get(ctxt, cvecty.getNumElements(), executionSize,
763  systolicDepth *
764  getNumOperandsPerDword(precATy)),
765  xevm::MMATypesAttr::get(ctxt, precDTy, precATy, precBTy, precCTy));
766  if (cvecty != cNty)
767  dpasRes = vector::ShapeCastOp::create(rewriter, loc, resultType, dpasRes);
768  rewriter.replaceOp(op, dpasRes);
769  return success();
770  }
771 
772 private:
773  static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
774  switch (pTy) {
775  case xevm::ElemType::TF32:
776  return 1;
777  case xevm::ElemType::BF16:
778  case xevm::ElemType::F16:
779  return 2;
780  case xevm::ElemType::U8:
781  case xevm::ElemType::S8:
782  return 4;
783  default:
784  llvm_unreachable("unsupported xevm::ElemType");
785  }
786  }
787 };
788 
789 static std::optional<LLVM::AtomicBinOp>
790 matchSimpleAtomicOp(arith::AtomicRMWKind arithKind) {
791  switch (arithKind) {
792  case arith::AtomicRMWKind::addf:
793  return LLVM::AtomicBinOp::fadd;
794  case arith::AtomicRMWKind::addi:
795  return LLVM::AtomicBinOp::add;
796  case arith::AtomicRMWKind::assign:
797  return LLVM::AtomicBinOp::xchg;
798  case arith::AtomicRMWKind::maximumf:
799  return LLVM::AtomicBinOp::fmax;
800  case arith::AtomicRMWKind::maxs:
801  return LLVM::AtomicBinOp::max;
802  case arith::AtomicRMWKind::maxu:
803  return LLVM::AtomicBinOp::umax;
804  case arith::AtomicRMWKind::minimumf:
805  return LLVM::AtomicBinOp::fmin;
806  case arith::AtomicRMWKind::mins:
807  return LLVM::AtomicBinOp::min;
808  case arith::AtomicRMWKind::minu:
809  return LLVM::AtomicBinOp::umin;
810  case arith::AtomicRMWKind::ori:
811  return LLVM::AtomicBinOp::_or;
812  case arith::AtomicRMWKind::andi:
813  return LLVM::AtomicBinOp::_and;
814  default:
815  return std::nullopt;
816  }
817 }
818 
819 class AtomicRMWToXeVMPattern : public OpConversionPattern<xegpu::AtomicRMWOp> {
821  LogicalResult
822  matchAndRewrite(xegpu::AtomicRMWOp op, xegpu::AtomicRMWOp::Adaptor adaptor,
823  ConversionPatternRewriter &rewriter) const override {
824  auto loc = op.getLoc();
825  auto ctxt = rewriter.getContext();
826  auto tdesc = op.getTensorDesc().getType();
827  auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
828  ctxt, getNumericXeVMAddrSpace(tdesc.getMemorySpace()));
829  Value basePtrI64 = arith::IndexCastOp::create(
830  rewriter, loc, rewriter.getI64Type(), adaptor.getTensorDesc());
831  Value basePtrLLVM =
832  LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
833  VectorType srcOrDstVecTy = cast<VectorType>(op.getValue().getType());
834  VectorType srcOrDstFlatVecTy = VectorType::get(
835  srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
836  Value srcFlatVec = vector::ShapeCastOp::create(
837  rewriter, loc, srcOrDstFlatVecTy, op.getValue());
838  auto atomicKind = matchSimpleAtomicOp(op.getKind());
839  assert(atomicKind.has_value());
840  Value resVec = srcFlatVec;
841  for (int i = 0; i < srcOrDstVecTy.getNumElements(); i++) {
842  auto val = vector::ExtractOp::create(rewriter, loc, resVec, i);
843  Value idx = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
844  rewriter.getIndexAttr(i));
845  Value currPtr =
846  LLVM::GEPOp::create(rewriter, loc, ptrTypeLLVM,
847  srcOrDstVecTy.getElementType(), basePtrLLVM, idx);
848  Value newVal =
849  LLVM::AtomicRMWOp::create(rewriter, loc, atomicKind.value(), currPtr,
850  val, LLVM::AtomicOrdering::seq_cst);
851  resVec = vector::InsertOp::create(rewriter, loc, newVal, resVec, i);
852  }
853  rewriter.replaceOp(op, resVec);
854  return success();
855  }
856 };
857 
858 //===----------------------------------------------------------------------===//
859 // Pass Definition
860 //===----------------------------------------------------------------------===//
861 
862 struct ConvertXeGPUToXeVMPass
863  : public impl::ConvertXeGPUToXeVMPassBase<ConvertXeGPUToXeVMPass> {
864  using Base::Base;
865 
866  void runOnOperation() override {
867  LLVMTypeConverter typeConverter(&getContext());
868  typeConverter.addConversion([&](VectorType type) -> Type {
869  unsigned rank = type.getRank();
870  auto elemType = type.getElementType();
871  // If the element type is index, convert it to i64.
872  if (llvm::isa<IndexType>(elemType))
873  elemType = IntegerType::get(&getContext(), 64);
874  // If the vector is a scalar or has a single element, return the element
875  if (rank < 1 || type.getNumElements() == 1)
876  return elemType;
877  // Otherwise, convert the vector to a flat vector type.
878  int64_t sum =
879  std::accumulate(type.getShape().begin(), type.getShape().end(),
880  int64_t{1}, std::multiplies<int64_t>());
881  return VectorType::get(sum, elemType);
882  });
883  typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
884  if (type.isScattered())
885  return IntegerType::get(&getContext(), 64);
886  auto i32Type = IntegerType::get(&getContext(), 32);
887  return VectorType::get(8, i32Type);
888  });
889  typeConverter.addConversion([&](MemRefType type) -> Type {
890  // Convert MemRefType to i64 type.
891  return IntegerType::get(&getContext(), 64);
892  });
893 
894  // LLVM type converter puts unrealized casts for the following cases:
895  // add materialization casts to handle them.
896 
897  // Materialization to convert memref to i64
898  auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
899  ValueRange inputs,
900  Location loc) -> Value {
901  if (inputs.size() != 1)
902  return {};
903  auto input = inputs.front();
904  if (auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
905 
906  Value addr =
907  memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, input);
908  return arith::IndexCastUIOp::create(builder, loc, type, addr)
909  .getResult();
910  }
911  return {};
912  };
913 
914  // Materialization to convert ui64 to i64
915  auto ui64MaterializationCast = [](OpBuilder &builder, Type type,
916  ValueRange inputs,
917  Location loc) -> Value {
918  if (inputs.size() != 1)
919  return {};
920  auto input = inputs.front();
921  if (input.getType() == builder.getIntegerType(64, false)) {
922  Value cast =
923  index::CastUOp::create(builder, loc, builder.getIndexType(), input)
924  .getResult();
925  return arith::IndexCastUIOp::create(builder, loc, type, cast)
926  .getResult();
927  }
928  return {};
929  };
930 
931  // Materialization to convert ui32 to i32
932  auto ui32MaterializationCast = [](OpBuilder &builder, Type type,
933  ValueRange inputs,
934  Location loc) -> Value {
935  if (inputs.size() != 1)
936  return {};
937  auto input = inputs.front();
938  if (input.getType() == builder.getIntegerType(32, false)) {
939  Value cast =
940  index::CastUOp::create(builder, loc, builder.getIndexType(), input)
941  .getResult();
942  return arith::IndexCastUIOp::create(builder, loc, type, cast)
943  .getResult();
944  }
945  return {};
946  };
947 
948  // Materialization to convert
949  // - single element 1D vector to scalar
950  // - bitcast vector of same rank
951  // - shape vector of different rank but same element type
952  auto vectorMaterializationCast = [](OpBuilder &builder, Type type,
953  ValueRange inputs,
954  Location loc) -> Value {
955  if (inputs.size() != 1)
956  return {};
957  auto input = inputs.front();
958  if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
959  if (vecTy.getNumElements() == 1) {
960  // If the vector has a single element, return the element type.
961  Value cast =
962  vector::ExtractOp::create(builder, loc, input, 0).getResult();
963  if (vecTy.getElementType() == builder.getIndexType())
964  cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
965  .getResult();
966  return cast;
967  } else if (auto targetVecTy = dyn_cast<VectorType>(type)) {
968  // If the target type is a vector of same rank,
969  // bitcast to the target type.
970  if (targetVecTy.getRank() == vecTy.getRank())
971  return vector::BitCastOp::create(builder, loc, targetVecTy, input)
972  .getResult();
973  else if (targetVecTy.getElementType() == vecTy.getElementType()) {
974  // If the target type is a vector of different rank but same element
975  // type, reshape to the target type.
976  return vector::ShapeCastOp::create(builder, loc, targetVecTy, input)
977  .getResult();
978  }
979  }
980  }
981  return {};
982  };
983  typeConverter.addSourceMaterialization(memrefMaterializationCast);
984  typeConverter.addSourceMaterialization(ui64MaterializationCast);
985  typeConverter.addSourceMaterialization(ui32MaterializationCast);
986  typeConverter.addSourceMaterialization(vectorMaterializationCast);
987  typeConverter.addTargetMaterialization(memrefMaterializationCast);
988  typeConverter.addTargetMaterialization(ui32MaterializationCast);
989  typeConverter.addTargetMaterialization(ui64MaterializationCast);
990  typeConverter.addTargetMaterialization(vectorMaterializationCast);
991  ConversionTarget target(getContext());
992  target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
993  vector::VectorDialect, arith::ArithDialect,
994  memref::MemRefDialect, gpu::GPUDialect,
995  index::IndexDialect>();
996  target.addIllegalDialect<xegpu::XeGPUDialect>();
997 
1001  patterns, target);
1002  if (failed(applyPartialConversion(getOperation(), target,
1003  std::move(patterns))))
1004  signalPassFailure();
1005  }
1006 };
1007 } // namespace
1008 
1009 //===----------------------------------------------------------------------===//
1010 // Pattern Population
1011 //===----------------------------------------------------------------------===//
1013  const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
1014  patterns.add<CreateNdDescToXeVMPattern, UpdateNdOffsetToXeVMPattern,
1015  LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
1016  LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
1017  LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>(
1018  typeConverter, patterns.getContext());
1019  patterns.add<CreateDescToXeVMPattern, UpdateOffsetToXeVMPattern,
1020  AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
1021  LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
1022  LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
1023  typeConverter, patterns.getContext());
1024  patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
1025  patterns.getContext());
1026 }
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:107
FloatType getF32Type()
Definition: Builders.cpp:42
FloatType getTF32Type()
Definition: Builders.cpp:40
IntegerType getI64Type()
Definition: Builders.cpp:64
IntegerType getI32Type()
Definition: Builders.cpp:62
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:66
FloatType getF16Type()
Definition: Builders.cpp:38
FloatType getBF16Type()
Definition: Builders.cpp:36
MLIRContext * getContext() const
Definition: Builders.h:56
IndexType getIndexType()
Definition: Builders.cpp:50
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:207
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition: ArithOps.cpp:258
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
detail::LazyTextBuild add(const char *fmt, Ts &&...ts)
Create a Remark with llvm::formatv formatting.
Definition: Remarks.h:463
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Include the generated interface declarations.
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:102
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition: Utils.cpp:119
const FrozenRewritePatternSet & patterns
void populateXeGPUToXeVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.