MLIR  22.0.0git
XeGPUToXeVM.cpp
Go to the documentation of this file.
1 //===-- XeGPUToXeVM.cpp - XeGPU to XeVM dialect conversion ------*- C++ -*-===//
2 //
3 // This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
12 
24 #include "mlir/Pass/Pass.h"
25 #include "mlir/Support/LLVM.h"
26 #include "llvm/Support/FormatVariadic.h"
27 
28 #include "mlir/IR/BuiltinTypes.h"
29 #include "mlir/IR/Types.h"
30 
31 #include "llvm/ADT/TypeSwitch.h"
32 
33 #include <numeric>
34 
35 namespace mlir {
36 #define GEN_PASS_DEF_CONVERTXEGPUTOXEVMPASS
37 #include "mlir/Conversion/Passes.h.inc"
38 } // namespace mlir
39 
40 using namespace mlir;
41 
42 namespace {
43 
44 // TODO: Below are uArch dependent values, should move away from hardcoding
45 static constexpr int32_t systolicDepth{8};
46 static constexpr int32_t executionSize{16};
47 
48 // Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
49 enum class NdTdescOffset : uint32_t {
50  BasePtr = 0, // Base pointer (i64)
51  BaseShapeW = 2, // Base shape width (i32)
52  BaseShapeH = 3, // Base shape height (i32)
53  TensorOffsetW = 4, // Tensor offset W (i32)
54  TensorOffsetH = 5 // Tensor offset H (i32)
55 };
56 
57 static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
58  switch (xeGpuMemspace) {
59  case xegpu::MemorySpace::Global:
60  return static_cast<int>(xevm::AddrSpace::GLOBAL);
61  case xegpu::MemorySpace::SLM:
62  return static_cast<int>(xevm::AddrSpace::SHARED);
63  }
64 }
65 
66 // Get same bitwidth flat vector type of new element type.
67 static VectorType encodeVectorTypeTo(VectorType currentVecType,
68  Type toElemType) {
69  auto elemType = currentVecType.getElementType();
70  auto currentBitWidth = elemType.getIntOrFloatBitWidth();
71  auto newBitWidth = toElemType.getIntOrFloatBitWidth();
72  const int size =
73  currentVecType.getNumElements() * currentBitWidth / newBitWidth;
74  return VectorType::get(size, toElemType);
75 }
76 
77 static xevm::LoadCacheControl
78 translateLoadXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
79  std::optional<xegpu::CachePolicy> L3hint) {
80  auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
81  auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
82  switch (L1hintVal) {
83  case xegpu::CachePolicy::CACHED:
84  if (L3hintVal == xegpu::CachePolicy::CACHED)
85  return xevm::LoadCacheControl::L1C_L2UC_L3C;
86  else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
87  return xevm::LoadCacheControl::L1C_L2UC_L3UC;
88  else
89  llvm_unreachable("Unsupported cache control.");
90  case xegpu::CachePolicy::UNCACHED:
91  if (L3hintVal == xegpu::CachePolicy::CACHED)
92  return xevm::LoadCacheControl::L1UC_L2UC_L3C;
93  else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
94  return xevm::LoadCacheControl::L1UC_L2UC_L3UC;
95  else
96  llvm_unreachable("Unsupported cache control.");
97  case xegpu::CachePolicy::STREAMING:
98  if (L3hintVal == xegpu::CachePolicy::CACHED)
99  return xevm::LoadCacheControl::L1S_L2UC_L3C;
100  else if (L3hintVal == xegpu::CachePolicy::UNCACHED)
101  return xevm::LoadCacheControl::L1S_L2UC_L3UC;
102  else
103  llvm_unreachable("Unsupported cache control.");
104  case xegpu::CachePolicy::READ_INVALIDATE:
105  return xevm::LoadCacheControl::INVALIDATE_READ;
106  default:
107  llvm_unreachable("Unsupported cache control.");
108  }
109 }
110 
111 static xevm::StoreCacheControl
112 translateStoreXeGPUCacheHint(std::optional<xegpu::CachePolicy> L1hint,
113  std::optional<xegpu::CachePolicy> L3hint) {
114  auto L1hintVal = L1hint.value_or(xegpu::CachePolicy::UNCACHED);
115  auto L3hintVal = L3hint.value_or(xegpu::CachePolicy::UNCACHED);
116  switch (L1hintVal) {
117  case xegpu::CachePolicy::UNCACHED:
118  if (L3hintVal == xegpu::CachePolicy::UNCACHED)
119  return xevm::StoreCacheControl::L1UC_L2UC_L3UC;
120  else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
121  return xevm::StoreCacheControl::L1UC_L2UC_L3WB;
122  else
123  llvm_unreachable("Unsupported cache control.");
124  case xegpu::CachePolicy::STREAMING:
125  if (L3hintVal == xegpu::CachePolicy::UNCACHED)
126  return xevm::StoreCacheControl::L1S_L2UC_L3UC;
127  else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
128  return xevm::StoreCacheControl::L1S_L2UC_L3WB;
129  else
130  llvm_unreachable("Unsupported cache control.");
131  case xegpu::CachePolicy::WRITE_BACK:
132  if (L3hintVal == xegpu::CachePolicy::UNCACHED)
133  return xevm::StoreCacheControl::L1WB_L2UC_L3UC;
134  else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
135  return xevm::StoreCacheControl::L1WB_L2UC_L3WB;
136  else
137  llvm_unreachable("Unsupported cache control.");
138  case xegpu::CachePolicy::WRITE_THROUGH:
139  if (L3hintVal == xegpu::CachePolicy::UNCACHED)
140  return xevm::StoreCacheControl::L1WT_L2UC_L3UC;
141  else if (L3hintVal == xegpu::CachePolicy::WRITE_BACK)
142  return xevm::StoreCacheControl::L1WT_L2UC_L3WB;
143  else
144  llvm_unreachable("Unsupported cache control.");
145  default:
146  llvm_unreachable("Unsupported cache control.");
147  }
148 }
149 
150 class CreateNdDescToXeVMPattern
151  : public OpConversionPattern<xegpu::CreateNdDescOp> {
153  LogicalResult
154  matchAndRewrite(xegpu::CreateNdDescOp op,
155  xegpu::CreateNdDescOp::Adaptor adaptor,
156  ConversionPatternRewriter &rewriter) const override {
157  SmallVector<OpFoldResult> mixedOffsets = op.getMixedOffsets();
158  if (mixedOffsets.size() != 0)
159  return rewriter.notifyMatchFailure(op, "Offsets not supported.");
160  auto loc = op.getLoc();
161  auto source = op.getSource();
162  // Op is lowered to a code sequence that populates payload.
163  // Payload is a 8xi32 vector. Offset to individual fields are defined in
164  // NdTdescOffset enum.
165  Type payloadElemTy = rewriter.getI32Type();
166  VectorType payloadTy = VectorType::get(8, payloadElemTy);
167  Type i64Ty = rewriter.getI64Type();
168  // 4xi64 view is used for inserting the base pointer.
169  VectorType payloadI64Ty = VectorType::get(4, i64Ty);
170  // Initialize payload to zero.
171  Value payload = arith::ConstantOp::create(
172  rewriter, loc,
173  DenseElementsAttr::get(payloadTy, IntegerAttr::get(payloadElemTy, 0)));
174 
175  Value baseAddr;
176  Value baseShapeW;
177  Value baseShapeH;
178  Value offsetW;
179  Value offsetH;
180 
181  // Source can be a memref or a pointer (ui64, ui32, i64 or i32).
182  SmallVector<OpFoldResult> mixedSizes = op.getMixedSizes();
183  // Descriptor shape is expected to be 2D.
184  int64_t rank = mixedSizes.size();
185  if (rank != 2)
186  return rewriter.notifyMatchFailure(op, "Expected 2D shape.");
187  auto sourceTy = source.getType();
188  auto sourceMemrefTy = dyn_cast<MemRefType>(sourceTy);
189  // If source is a memref, we need to extract the aligned pointer as index.
190  // Pointer type is passed as i32 or i64 by type converter.
191  if (sourceMemrefTy) {
192  if (!sourceMemrefTy.hasStaticShape()) {
193  return rewriter.notifyMatchFailure(op, "Expected static memref shape.");
194  }
195  baseAddr =
196  memref::ExtractAlignedPointerAsIndexOp::create(rewriter, loc, source);
197  } else {
198  baseAddr = adaptor.getSource();
199  }
200  // Utility for creating offset values from op fold result.
201  auto createOffset = [&](SmallVector<OpFoldResult> &ofrVec,
202  unsigned idx) -> Value {
203  Value val = getValueOrCreateConstantIntOp(rewriter, loc, ofrVec[idx]);
204  val = getValueOrCreateCastToIndexLike(rewriter, loc, payloadElemTy, val);
205  return val;
206  };
207  // Offsets are not supported (0 is used).
208  offsetW = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
209  offsetH = arith::ConstantIntOp::create(rewriter, loc, payloadElemTy, 0);
210  // Get shape values from op fold results.
211  baseShapeW = createOffset(mixedSizes, 1);
212  baseShapeH = createOffset(mixedSizes, 0);
213  if (sourceMemrefTy) {
214  // Cast index to i64.
215  baseAddr = arith::IndexCastUIOp::create(rewriter, loc, i64Ty, baseAddr);
216  } else if (baseAddr.getType() != i64Ty) {
217  // Pointer type may be i32. Cast to i64 if needed.
218  baseAddr = arith::ExtUIOp::create(rewriter, loc, i64Ty, baseAddr);
219  }
220  // Populate payload.
221  Value payLoadAsI64 =
222  vector::BitCastOp::create(rewriter, loc, payloadI64Ty, payload);
223  payLoadAsI64 =
224  vector::InsertOp::create(rewriter, loc, baseAddr, payLoadAsI64,
225  static_cast<int>(NdTdescOffset::BasePtr));
226  payload = vector::BitCastOp::create(rewriter, loc, payloadTy, payLoadAsI64);
227  payload =
228  vector::InsertOp::create(rewriter, loc, baseShapeW, payload,
229  static_cast<int>(NdTdescOffset::BaseShapeW));
230  payload =
231  vector::InsertOp::create(rewriter, loc, baseShapeH, payload,
232  static_cast<int>(NdTdescOffset::BaseShapeH));
233  payload = vector::InsertOp::create(
234  rewriter, loc, offsetW, payload,
235  static_cast<int>(NdTdescOffset::TensorOffsetW));
236  payload = vector::InsertOp::create(
237  rewriter, loc, offsetH, payload,
238  static_cast<int>(NdTdescOffset::TensorOffsetH));
239  rewriter.replaceOp(op, payload);
240  return success();
241  }
242 };
243 
244 template <
245  typename OpType,
246  typename = std::enable_if_t<llvm::is_one_of<
247  OpType, xegpu::LoadNdOp, xegpu::StoreNdOp, xegpu::PrefetchNdOp>::value>>
248 class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
250  LogicalResult
251  matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
252  ConversionPatternRewriter &rewriter) const override {
253  auto mixedOffsets = op.getMixedOffsets();
254  int64_t opOffsetsSize = mixedOffsets.size();
255  if (opOffsetsSize != 2)
256  return rewriter.notifyMatchFailure(op, "Expected 2D offsets.");
257  auto loc = op.getLoc();
258  auto ctxt = rewriter.getContext();
259 
260  auto tdesc = adaptor.getTensorDesc();
261  auto tdescTy = op.getTensorDescType();
262  if (tdescTy.getRank() != 2)
263  return rewriter.notifyMatchFailure(op, "Expected 2D tensor descriptor.");
264  auto elemType = tdescTy.getElementType();
265  auto elemBitSize = elemType.getIntOrFloatBitWidth();
266  if (elemBitSize % 8 != 0)
267  return rewriter.notifyMatchFailure(
268  op, "Expected element type bit width to be multiple of 8.");
269 
270  VectorType payloadI64Ty = VectorType::get(4, rewriter.getI64Type());
271  Value payLoadAsI64 =
272  vector::BitCastOp::create(rewriter, loc, payloadI64Ty, tdesc);
273  Value basePtr = vector::ExtractOp::create(
274  rewriter, loc, payLoadAsI64, static_cast<int>(NdTdescOffset::BasePtr));
275  Value baseShapeW = vector::ExtractOp::create(
276  rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeW));
277  Value baseShapeH = vector::ExtractOp::create(
278  rewriter, loc, tdesc, static_cast<int>(NdTdescOffset::BaseShapeH));
279  // Offsets are provided by the op.
280  // convert them to i32.
281  Value offsetW =
282  getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[1]);
283  offsetW = getValueOrCreateCastToIndexLike(rewriter, loc,
284  rewriter.getI32Type(), offsetW);
285  Value offsetH =
286  getValueOrCreateConstantIntOp(rewriter, loc, mixedOffsets[0]);
287  offsetH = getValueOrCreateCastToIndexLike(rewriter, loc,
288  rewriter.getI32Type(), offsetH);
289  // Get address space from tensor descriptor memory space.
290  auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
291  ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
292  // Convert base pointer (i64) to LLVM pointer type.
293  Value basePtrLLVM =
294  LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtr);
295  // Compute element byte size and surface width in bytes.
296  Value elemByteSize = arith::ConstantIntOp::create(
297  rewriter, loc, rewriter.getI32Type(), elemBitSize / 8);
298  Value surfaceW =
299  arith::MulIOp::create(rewriter, loc, baseShapeW, elemByteSize);
300 
301  // Get tile sizes and vblocks from the tensor descriptor type.
302  auto tileW = tdescTy.getDimSize(1);
303  auto tileH = tdescTy.getDimSize(0);
304  int32_t vblocks = tdescTy.getArrayLength();
305  if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
306  Value src = adaptor.getValue();
307  // If store value is a scalar, get value from op instead of adaptor.
308  // Adaptor might have optimized away single element vector
309  if (src.getType().isIntOrFloat()) {
310  src = op.getValue();
311  }
312  VectorType srcVecTy = dyn_cast<VectorType>(src.getType());
313  if (!srcVecTy)
314  return rewriter.notifyMatchFailure(
315  op, "Expected store value to be a vector type.");
316  // Get flat vector type of integer type with matching element bit size.
317  VectorType newSrcVecTy =
318  encodeVectorTypeTo(srcVecTy, rewriter.getIntegerType(elemBitSize));
319  if (srcVecTy != newSrcVecTy)
320  src = vector::BitCastOp::create(rewriter, loc, newSrcVecTy, src);
321  auto storeCacheControl =
322  translateStoreXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
323  xevm::BlockStore2dOp::create(
324  rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
325  offsetH, elemBitSize, tileW, tileH, src,
326  xevm::StoreCacheControlAttr::get(ctxt, storeCacheControl));
327  rewriter.eraseOp(op);
328  } else {
329  auto loadCacheControl =
330  translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint());
331  if constexpr (std::is_same_v<OpType, xegpu::PrefetchNdOp>) {
332  xevm::BlockPrefetch2dOp::create(
333  rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
334  offsetH, elemBitSize, tileW, tileH, vblocks,
335  xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
336  rewriter.eraseOp(op);
337  } else {
338  VectorType dstVecTy = cast<VectorType>(op.getValue().getType());
339  const bool vnni = op.getPacked().value_or(false);
340  auto transposeValue = op.getTranspose();
341  bool transpose =
342  transposeValue.has_value() && transposeValue.value()[0] == 1;
343  VectorType loadedTy = encodeVectorTypeTo(
344  dstVecTy, vnni ? rewriter.getI32Type()
345  : rewriter.getIntegerType(elemBitSize));
346 
347  Value resultFlatVec = xevm::BlockLoad2dOp::create(
348  rewriter, loc, loadedTy, basePtrLLVM, surfaceW, baseShapeH,
349  surfaceW, offsetW, offsetH, elemBitSize, tileW, tileH, vblocks,
350  transpose, vnni,
351  xevm::LoadCacheControlAttr::get(ctxt, loadCacheControl));
352  resultFlatVec = vector::BitCastOp::create(
353  rewriter, loc,
354  encodeVectorTypeTo(loadedTy, dstVecTy.getElementType()),
355  resultFlatVec);
356  rewriter.replaceOp(op, resultFlatVec);
357  }
358  }
359  return success();
360  }
361 };
362 
363 // Add a builder that creates
364 // offset * elemByteSize + baseAddr
365 static Value addOffset(ConversionPatternRewriter &rewriter, Location loc,
366  Value baseAddr, Value offset, int64_t elemByteSize) {
368  rewriter, loc, rewriter.getI64Type(), elemByteSize);
369  Value byteOffset = arith::MulIOp::create(rewriter, loc, offset, byteSize);
370  Value newAddr = arith::AddIOp::create(rewriter, loc, baseAddr, byteOffset);
371  return newAddr;
372 }
373 
374 template <typename OpType,
375  typename = std::enable_if_t<llvm::is_one_of<
376  OpType, xegpu::LoadGatherOp, xegpu::StoreScatterOp>::value>>
377 class LoadStoreToXeVMPattern : public OpConversionPattern<OpType> {
379  LogicalResult
380  matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
381  ConversionPatternRewriter &rewriter) const override {
382  Value offset = adaptor.getOffsets();
383  if (!offset)
384  return rewriter.notifyMatchFailure(op, "Expected offset to be provided.");
385  auto loc = op.getLoc();
386  auto ctxt = rewriter.getContext();
387  auto tdescTy = op.getTensorDescType();
388  Value basePtrI64;
389  // Load result or Store valye Type can be vector or scalar.
390  Type valOrResTy;
391  if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>)
392  valOrResTy = op.getResult().getType();
393  else
394  valOrResTy = adaptor.getValue().getType();
395  VectorType valOrResVecTy = dyn_cast<VectorType>(valOrResTy);
396  bool hasScalarVal = !valOrResVecTy;
397  int64_t elemBitWidth =
398  hasScalarVal ? valOrResTy.getIntOrFloatBitWidth()
399  : valOrResVecTy.getElementType().getIntOrFloatBitWidth();
400  // Element type must be multiple of 8 bits.
401  if (elemBitWidth % 8 != 0)
402  return rewriter.notifyMatchFailure(
403  op, "Expected element type bit width to be multiple of 8.");
404  int64_t elemByteSize = elemBitWidth / 8;
405  // Default memory space is global.
406  LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
407  ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
408  // If tensor descriptor is available, we use its memory space.
409  if (tdescTy)
410  ptrTypeLLVM = LLVM::LLVMPointerType::get(
411  ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
412  // Base pointer can come from source (load) or dest (store).
413  // If they are memrefs, we use their memory space.
414  if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
415  basePtrI64 = adaptor.getSource();
416  if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
417  auto addrSpace = memRefTy.getMemorySpaceAsInt();
418  if (addrSpace != 0)
419  ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
420  }
421  } else {
422  basePtrI64 = adaptor.getDest();
423  if (auto memRefTy = dyn_cast<MemRefType>(op.getDest().getType())) {
424  auto addrSpace = memRefTy.getMemorySpaceAsInt();
425  if (addrSpace != 0)
426  ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
427  }
428  }
429  // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
430  if (basePtrI64.getType() != rewriter.getI64Type()) {
431  basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
432  basePtrI64);
433  }
434  Value mask = adaptor.getMask();
435  if (dyn_cast<VectorType>(offset.getType())) {
436  // Offset needs be scalar. Single element vector is converted to scalar
437  // by type converter.
438  return rewriter.notifyMatchFailure(op, "Expected offset to be a scalar.");
439  } else {
440  // If offset is provided, we add them to the base pointer.
441  // Offset is in number of elements, we need to multiply by
442  // element byte size.
443  basePtrI64 = addOffset(rewriter, loc, basePtrI64, offset, elemByteSize);
444  }
445  // Convert base pointer (i64) to LLVM pointer type.
446  Value basePtrLLVM =
447  LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
448 
449  Value maskForLane;
450  VectorType maskVecTy = dyn_cast<VectorType>(mask.getType());
451  if (maskVecTy) {
452  // Mask needs be scalar. Single element vector is converted to scalar by
453  // type converter.
454  return rewriter.notifyMatchFailure(op, "Expected mask to be a scalar.");
455  } else
456  maskForLane = mask;
457  if constexpr (std::is_same_v<OpType, xegpu::LoadGatherOp>) {
458  scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, {valOrResTy},
459  maskForLane, true, true);
460  // If mask is true,- then clause - load from memory and yield.
461  rewriter.setInsertionPointToStart(&ifOp.getThenRegion().front());
462  if (!hasScalarVal)
463  valOrResTy = VectorType::get({valOrResVecTy.getNumElements()},
464  valOrResVecTy.getElementType());
465  Value loaded =
466  LLVM::LoadOp::create(rewriter, loc, valOrResTy, basePtrLLVM);
467  // Set cache control attribute on the load operation.
468  loaded.getDefiningOp()->setAttr(
469  "cache_control", xevm::LoadCacheControlAttr::get(
470  ctxt, translateLoadXeGPUCacheHint(
471  op.getL1Hint(), op.getL3Hint())));
472  scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
473  rewriter.setInsertionPointToStart(&ifOp.getElseRegion().front());
474  // If mask is false - else clause -yield a vector of zeros.
475  auto eTy = hasScalarVal ? valOrResTy : valOrResVecTy.getElementType();
476  TypedAttr eVal;
477  if (eTy.isFloat())
478  eVal = FloatAttr::get(eTy, 0.0);
479  else
480  eVal = IntegerAttr::get(eTy, 0);
481  if (hasScalarVal)
482  loaded = arith::ConstantOp::create(rewriter, loc, eVal);
483  else
484  loaded = arith::ConstantOp::create(
485  rewriter, loc, DenseElementsAttr::get(valOrResVecTy, eVal));
486  scf::YieldOp::create(rewriter, loc, ValueRange{loaded});
487  rewriter.replaceOp(op, ifOp.getResult(0));
488  } else {
489  // If mask is true, perform the store.
490  scf::IfOp ifOp = scf::IfOp::create(rewriter, loc, maskForLane, false);
491  auto body = ifOp.getBody();
492  rewriter.setInsertionPointToStart(body);
493  auto storeOp =
494  LLVM::StoreOp::create(rewriter, loc, adaptor.getValue(), basePtrLLVM);
495  // Set cache control attribute on the store operation.
496  storeOp.getOperation()->setAttr(
497  "cache_control", xevm::StoreCacheControlAttr::get(
498  ctxt, translateStoreXeGPUCacheHint(
499  op.getL1Hint(), op.getL3Hint())));
500  rewriter.eraseOp(op);
501  }
502  return success();
503  }
504 };
505 
506 class PrefetchToXeVMPattern : public OpConversionPattern<xegpu::PrefetchOp> {
508  LogicalResult
509  matchAndRewrite(xegpu::PrefetchOp op, xegpu::PrefetchOp::Adaptor adaptor,
510  ConversionPatternRewriter &rewriter) const override {
511  auto loc = op.getLoc();
512  auto ctxt = rewriter.getContext();
513  auto tdescTy = op.getTensorDescType();
514  Value basePtrI64 = adaptor.getSource();
515  // Base pointer is passed as i32 or i64 by adaptor, cast to i64 if needed.
516  if (basePtrI64.getType() != rewriter.getI64Type())
517  basePtrI64 = arith::ExtUIOp::create(rewriter, loc, rewriter.getI64Type(),
518  basePtrI64);
519  Value offsets = adaptor.getOffsets();
520  if (offsets) {
521  VectorType offsetsVecTy = dyn_cast<VectorType>(offsets.getType());
522  if (offsetsVecTy) {
523  // Offset needs be scalar.
524  return rewriter.notifyMatchFailure(op,
525  "Expected offsets to be a scalar.");
526  } else {
527  int64_t elemBitWidth{0};
528  int64_t elemByteSize;
529  // Element byte size can come from three sources:
530  if (tdescTy) {
531  // If tensor descriptor is available, we use its element type to
532  // determine element byte size.
533  elemBitWidth = tdescTy.getElementType().getIntOrFloatBitWidth();
534  } else if (auto memRefTy = dyn_cast<MemRefType>(op.getSourceType())) {
535  // If memref is available, we use its element type to
536  // determine element byte size.
537  elemBitWidth = memRefTy.getElementType().getIntOrFloatBitWidth();
538  } else {
539  // Otherwise, we use the provided offset byte alignment.
540  elemByteSize = *op.getOffsetAlignByte();
541  }
542  if (elemBitWidth != 0) {
543  if (elemBitWidth % 8 != 0)
544  return rewriter.notifyMatchFailure(
545  op, "Expected element type bit width to be multiple of 8.");
546  elemByteSize = elemBitWidth / 8;
547  }
548  basePtrI64 =
549  addOffset(rewriter, loc, basePtrI64, offsets, elemByteSize);
550  }
551  }
552  // Default memory space is global.
553  LLVM::LLVMPointerType ptrTypeLLVM = LLVM::LLVMPointerType::get(
554  ctxt, getNumericXeVMAddrSpace(xegpu::MemorySpace::Global));
555  // If tensor descriptor is available, we use its memory space.
556  if (tdescTy)
557  ptrTypeLLVM = LLVM::LLVMPointerType::get(
558  ctxt, getNumericXeVMAddrSpace(tdescTy.getMemorySpace()));
559  // If source is a memref, we use its memory space.
560  if (auto memRefTy = dyn_cast<MemRefType>(op.getSource().getType())) {
561  auto addrSpace = memRefTy.getMemorySpaceAsInt();
562  if (addrSpace != 0)
563  ptrTypeLLVM = LLVM::LLVMPointerType::get(ctxt, addrSpace);
564  }
565  // Convert base pointer (i64) to LLVM pointer type.
566  Value ptrLLVM =
567  LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
568  // Create the prefetch op with cache control attribute.
569  xevm::PrefetchOp::create(
570  rewriter, loc, ptrLLVM,
572  ctxt, translateLoadXeGPUCacheHint(op.getL1Hint(), op.getL3Hint())));
573  rewriter.eraseOp(op);
574  return success();
575  }
576 };
577 
578 class FenceToXeVMPattern : public OpConversionPattern<xegpu::FenceOp> {
580  LogicalResult
581  matchAndRewrite(xegpu::FenceOp op, xegpu::FenceOp::Adaptor adaptor,
582  ConversionPatternRewriter &rewriter) const override {
583  auto loc = op.getLoc();
584  xevm::MemScope memScope{xevm::MemScope::WORKGROUP};
585  switch (op.getFenceScope()) {
586  case xegpu::FenceScope::Workgroup:
587  memScope = xevm::MemScope::WORKGROUP;
588  break;
589  case xegpu::FenceScope::GPU:
590  memScope = xevm::MemScope::DEVICE;
591  break;
592  }
593  xevm::AddrSpace addrSpace{xevm::AddrSpace::GLOBAL};
594  switch (op.getMemoryKind()) {
595  case xegpu::MemorySpace::Global:
596  addrSpace = xevm::AddrSpace::GLOBAL;
597  break;
598  case xegpu::MemorySpace::SLM:
599  addrSpace = xevm::AddrSpace::SHARED;
600  break;
601  }
602  xevm::MemfenceOp::create(rewriter, loc, memScope, addrSpace);
603  rewriter.eraseOp(op);
604  return success();
605  }
606 };
607 
608 class DpasToXeVMPattern : public OpConversionPattern<xegpu::DpasOp> {
610  LogicalResult
611  matchAndRewrite(xegpu::DpasOp op, xegpu::DpasOp::Adaptor adaptor,
612  ConversionPatternRewriter &rewriter) const override {
613  auto loc = op.getLoc();
614  auto ctxt = rewriter.getContext();
615  auto aTy = cast<VectorType>(op.getLhs().getType());
616  auto bTy = cast<VectorType>(op.getRhs().getType());
617  auto resultType = cast<VectorType>(op.getResultType());
618 
619  auto encodePrecision = [&](Type type) -> xevm::ElemType {
620  if (type == rewriter.getBF16Type())
621  return xevm::ElemType::BF16;
622  else if (type == rewriter.getF16Type())
623  return xevm::ElemType::F16;
624  else if (type == rewriter.getTF32Type())
625  return xevm::ElemType::TF32;
626  else if (type.isInteger(8)) {
627  if (type.isUnsignedInteger())
628  return xevm::ElemType::U8;
629  return xevm::ElemType::S8;
630  } else if (type == rewriter.getF32Type())
631  return xevm::ElemType::F32;
632  else if (type.isInteger(32))
633  return xevm::ElemType::S32;
634  llvm_unreachable("add more support for ElemType");
635  };
636  xevm::ElemType precATy = encodePrecision(aTy.getElementType());
637  xevm::ElemType precBTy = encodePrecision(bTy.getElementType());
638  Value c = op.getAcc();
639  if (!c) {
640  auto elementTy = resultType.getElementType();
641  Attribute initValueAttr;
642  if (isa<FloatType>(elementTy))
643  initValueAttr = FloatAttr::get(elementTy, 0.0);
644  else
645  initValueAttr = IntegerAttr::get(elementTy, 0);
646  c = arith::ConstantOp::create(
647  rewriter, loc, DenseElementsAttr::get(resultType, initValueAttr));
648  }
649 
650  Value aVec = op.getLhs();
651  Value bVec = op.getRhs();
652  auto cvecty = cast<VectorType>(c.getType());
653  xevm::ElemType precCTy = encodePrecision(cvecty.getElementType());
654  xevm::ElemType precDTy = encodePrecision(resultType.getElementType());
655  VectorType cNty =
656  VectorType::get(cvecty.getNumElements(), cvecty.getElementType());
657  if (cvecty != cNty)
658  c = vector::ShapeCastOp::create(rewriter, loc, cNty, c);
659  Value dpasRes = xevm::MMAOp::create(
660  rewriter, loc, cNty, aVec, bVec, c,
661  xevm::MMAShapeAttr::get(ctxt, cvecty.getNumElements(), executionSize,
662  systolicDepth *
663  getNumOperandsPerDword(precATy)),
664  xevm::MMATypesAttr::get(ctxt, precDTy, precATy, precBTy, precCTy));
665  if (cvecty != cNty)
666  dpasRes = vector::ShapeCastOp::create(rewriter, loc, resultType, dpasRes);
667  rewriter.replaceOp(op, dpasRes);
668  return success();
669  }
670 
671 private:
672  static unsigned getNumOperandsPerDword(xevm::ElemType pTy) {
673  switch (pTy) {
674  case xevm::ElemType::TF32:
675  return 1;
676  case xevm::ElemType::BF16:
677  case xevm::ElemType::F16:
678  return 2;
679  case xevm::ElemType::U8:
680  case xevm::ElemType::S8:
681  return 4;
682  default:
683  llvm_unreachable("unsupported xevm::ElemType");
684  }
685  }
686 };
687 
688 static std::optional<LLVM::AtomicBinOp>
689 matchSimpleAtomicOp(arith::AtomicRMWKind arithKind) {
690  switch (arithKind) {
691  case arith::AtomicRMWKind::addf:
692  return LLVM::AtomicBinOp::fadd;
693  case arith::AtomicRMWKind::addi:
694  return LLVM::AtomicBinOp::add;
695  case arith::AtomicRMWKind::assign:
696  return LLVM::AtomicBinOp::xchg;
697  case arith::AtomicRMWKind::maximumf:
698  return LLVM::AtomicBinOp::fmax;
699  case arith::AtomicRMWKind::maxs:
700  return LLVM::AtomicBinOp::max;
701  case arith::AtomicRMWKind::maxu:
702  return LLVM::AtomicBinOp::umax;
703  case arith::AtomicRMWKind::minimumf:
704  return LLVM::AtomicBinOp::fmin;
705  case arith::AtomicRMWKind::mins:
706  return LLVM::AtomicBinOp::min;
707  case arith::AtomicRMWKind::minu:
708  return LLVM::AtomicBinOp::umin;
709  case arith::AtomicRMWKind::ori:
710  return LLVM::AtomicBinOp::_or;
711  case arith::AtomicRMWKind::andi:
712  return LLVM::AtomicBinOp::_and;
713  default:
714  return std::nullopt;
715  }
716 }
717 
718 class AtomicRMWToXeVMPattern : public OpConversionPattern<xegpu::AtomicRMWOp> {
720  LogicalResult
721  matchAndRewrite(xegpu::AtomicRMWOp op, xegpu::AtomicRMWOp::Adaptor adaptor,
722  ConversionPatternRewriter &rewriter) const override {
723  auto loc = op.getLoc();
724  auto ctxt = rewriter.getContext();
725  auto tdesc = op.getTensorDesc().getType();
726  auto ptrTypeLLVM = LLVM::LLVMPointerType::get(
727  ctxt, getNumericXeVMAddrSpace(tdesc.getMemorySpace()));
728  Value basePtrI64 = arith::IndexCastOp::create(
729  rewriter, loc, rewriter.getI64Type(), adaptor.getTensorDesc());
730  Value basePtrLLVM =
731  LLVM::IntToPtrOp::create(rewriter, loc, ptrTypeLLVM, basePtrI64);
732  VectorType srcOrDstVecTy = cast<VectorType>(op.getValue().getType());
733  VectorType srcOrDstFlatVecTy = VectorType::get(
734  srcOrDstVecTy.getNumElements(), srcOrDstVecTy.getElementType());
735  Value srcFlatVec = vector::ShapeCastOp::create(
736  rewriter, loc, srcOrDstFlatVecTy, op.getValue());
737  auto atomicKind = matchSimpleAtomicOp(op.getKind());
738  assert(atomicKind.has_value());
739  Value resVec = srcFlatVec;
740  for (int i = 0; i < srcOrDstVecTy.getNumElements(); i++) {
741  auto val = vector::ExtractOp::create(rewriter, loc, resVec, i);
742  Value idx = LLVM::ConstantOp::create(rewriter, loc, rewriter.getI64Type(),
743  rewriter.getIndexAttr(i));
744  Value currPtr =
745  LLVM::GEPOp::create(rewriter, loc, ptrTypeLLVM,
746  srcOrDstVecTy.getElementType(), basePtrLLVM, idx);
747  Value newVal =
748  LLVM::AtomicRMWOp::create(rewriter, loc, atomicKind.value(), currPtr,
749  val, LLVM::AtomicOrdering::seq_cst);
750  resVec = vector::InsertOp::create(rewriter, loc, newVal, resVec, i);
751  }
752  rewriter.replaceOp(op, resVec);
753  return success();
754  }
755 };
756 
757 //===----------------------------------------------------------------------===//
758 // Pass Definition
759 //===----------------------------------------------------------------------===//
760 
761 struct ConvertXeGPUToXeVMPass
762  : public impl::ConvertXeGPUToXeVMPassBase<ConvertXeGPUToXeVMPass> {
763  using Base::Base;
764 
765  void runOnOperation() override {
766  LLVMTypeConverter typeConverter(&getContext());
767  typeConverter.addConversion([&](VectorType type) -> Type {
768  unsigned rank = type.getRank();
769  auto elemType = type.getElementType();
770  // If the element type is index, convert it to i64.
771  if (llvm::isa<IndexType>(elemType))
772  elemType = IntegerType::get(&getContext(), 64);
773  // If the vector is a scalar or has a single element, return the element
774  if (rank < 1 || type.getNumElements() == 1)
775  return elemType;
776  // Otherwise, convert the vector to a flat vector type.
777  int64_t sum =
778  std::accumulate(type.getShape().begin(), type.getShape().end(),
779  int64_t{1}, std::multiplies<int64_t>());
780  return VectorType::get(sum, elemType);
781  });
782  typeConverter.addConversion([&](xegpu::TensorDescType type) -> Type {
783  if (type.isScattered())
784  return IntegerType::get(&getContext(), 64);
785  auto i32Type = IntegerType::get(&getContext(), 32);
786  return VectorType::get(8, i32Type);
787  });
788  typeConverter.addConversion([&](MemRefType type) -> Type {
789  // Convert MemRefType to i64 type.
790  return IntegerType::get(&getContext(), 64);
791  });
792 
793  // LLVM type converter puts unrealized casts for the following cases:
794  // add materialization casts to handle them.
795 
796  // Materialization to convert memref to i64
797  auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
798  ValueRange inputs,
799  Location loc) -> Value {
800  if (inputs.size() != 1)
801  return {};
802  auto input = inputs.front();
803  if (auto memrefTy = dyn_cast<MemRefType>(input.getType())) {
804 
805  Value addr =
806  memref::ExtractAlignedPointerAsIndexOp::create(builder, loc, input);
807  return arith::IndexCastUIOp::create(builder, loc, type, addr)
808  .getResult();
809  }
810  return {};
811  };
812 
813  // Materialization to convert ui64 to i64
814  auto ui64MaterializationCast = [](OpBuilder &builder, Type type,
815  ValueRange inputs,
816  Location loc) -> Value {
817  if (inputs.size() != 1)
818  return {};
819  auto input = inputs.front();
820  if (input.getType() == builder.getIntegerType(64, false)) {
821  Value cast =
822  index::CastUOp::create(builder, loc, builder.getIndexType(), input)
823  .getResult();
824  return arith::IndexCastUIOp::create(builder, loc, type, cast)
825  .getResult();
826  }
827  return {};
828  };
829 
830  // Materialization to convert ui32 to i32
831  auto ui32MaterializationCast = [](OpBuilder &builder, Type type,
832  ValueRange inputs,
833  Location loc) -> Value {
834  if (inputs.size() != 1)
835  return {};
836  auto input = inputs.front();
837  if (input.getType() == builder.getIntegerType(32, false)) {
838  Value cast =
839  index::CastUOp::create(builder, loc, builder.getIndexType(), input)
840  .getResult();
841  return arith::IndexCastUIOp::create(builder, loc, type, cast)
842  .getResult();
843  }
844  return {};
845  };
846 
847  // Materialization to convert
848  // - single element 1D vector to scalar
849  // - bitcast vector of same rank
850  // - shape vector of different rank but same element type
851  auto vectorMaterializationCast = [](OpBuilder &builder, Type type,
852  ValueRange inputs,
853  Location loc) -> Value {
854  if (inputs.size() != 1)
855  return {};
856  auto input = inputs.front();
857  if (auto vecTy = dyn_cast<VectorType>(input.getType())) {
858  if (vecTy.getNumElements() == 1) {
859  // If the vector has a single element, return the element type.
860  Value cast =
861  vector::ExtractOp::create(builder, loc, input, 0).getResult();
862  if (vecTy.getElementType() == builder.getIndexType())
863  cast = arith::IndexCastUIOp::create(builder, loc, type, cast)
864  .getResult();
865  return cast;
866  } else if (auto targetVecTy = dyn_cast<VectorType>(type)) {
867  // If the target type is a vector of same rank,
868  // bitcast to the target type.
869  if (targetVecTy.getRank() == vecTy.getRank())
870  return vector::BitCastOp::create(builder, loc, targetVecTy, input)
871  .getResult();
872  else if (targetVecTy.getElementType() == vecTy.getElementType()) {
873  // If the target type is a vector of different rank but same element
874  // type, reshape to the target type.
875  return vector::ShapeCastOp::create(builder, loc, targetVecTy, input)
876  .getResult();
877  }
878  }
879  }
880  return {};
881  };
882  typeConverter.addSourceMaterialization(memrefMaterializationCast);
883  typeConverter.addSourceMaterialization(ui64MaterializationCast);
884  typeConverter.addSourceMaterialization(ui32MaterializationCast);
885  typeConverter.addSourceMaterialization(vectorMaterializationCast);
886  typeConverter.addTargetMaterialization(memrefMaterializationCast);
887  typeConverter.addTargetMaterialization(ui32MaterializationCast);
888  typeConverter.addTargetMaterialization(ui64MaterializationCast);
889  typeConverter.addTargetMaterialization(vectorMaterializationCast);
890  ConversionTarget target(getContext());
891  target.addLegalDialect<xevm::XeVMDialect, LLVM::LLVMDialect,
892  vector::VectorDialect, arith::ArithDialect,
893  memref::MemRefDialect, gpu::GPUDialect,
894  index::IndexDialect>();
895  target.addIllegalDialect<xegpu::XeGPUDialect>();
896 
900  patterns, target);
901  if (failed(applyPartialConversion(getOperation(), target,
902  std::move(patterns))))
903  signalPassFailure();
904  }
905 };
906 } // namespace
907 
908 //===----------------------------------------------------------------------===//
909 // Pattern Population
910 //===----------------------------------------------------------------------===//
912  const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
913  patterns.add<CreateNdDescToXeVMPattern,
914  LoadStorePrefetchNdToXeVMPattern<xegpu::LoadNdOp>,
915  LoadStorePrefetchNdToXeVMPattern<xegpu::StoreNdOp>,
916  LoadStorePrefetchNdToXeVMPattern<xegpu::PrefetchNdOp>>(
917  typeConverter, patterns.getContext());
918  patterns.add<AtomicRMWToXeVMPattern, PrefetchToXeVMPattern,
919  LoadStoreToXeVMPattern<xegpu::LoadGatherOp>,
920  LoadStoreToXeVMPattern<xegpu::StoreScatterOp>>(
921  typeConverter, patterns.getContext());
922  patterns.add<FenceToXeVMPattern, DpasToXeVMPattern>(typeConverter,
923  patterns.getContext());
924 }
static MLIRContext * getContext(OpFoldResult val)
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Value min(ImplicitLocOpBuilder &builder, Value value, Value bound)
Attributes are known-constant values of operations.
Definition: Attributes.h:25
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:108
FloatType getF32Type()
Definition: Builders.cpp:43
FloatType getTF32Type()
Definition: Builders.cpp:41
IntegerType getI64Type()
Definition: Builders.cpp:65
IntegerType getI32Type()
Definition: Builders.cpp:63
IntegerType getIntegerType(unsigned width)
Definition: Builders.cpp:67
FloatType getF16Type()
Definition: Builders.cpp:39
FloatType getBF16Type()
Definition: Builders.cpp:37
MLIRContext * getContext() const
Definition: Builders.h:56
IndexType getIndexType()
Definition: Builders.cpp:51
This class implements a pattern rewriter for use with ConversionPatterns.
void replaceOp(Operation *op, ValueRange newValues) override
Replace the given operation with the new values.
void eraseOp(Operation *op) override
PatternRewriter hook for erasing a dead operation.
This class describes a specific conversion target.
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class helps build Operations.
Definition: Builders.h:207
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition: Builders.h:431
OpConversionPattern is a wrapper around ConversionPattern that allows for matching and rewriting agai...
OpConversionPattern(MLIRContext *context, PatternBenefit benefit=1)
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition: Operation.h:582
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isIntOrFloat() const
Return true if this is an integer (of any signedness) or a float type.
Definition: Types.cpp:116
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static ConstantIntOp create(OpBuilder &builder, Location location, int64_t value, unsigned width)
Definition: ArithOps.cpp:258
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
detail::LazyTextBuild add(const char *fmt, Ts &&...ts)
Create a Remark with llvm::formatv formatting.
Definition: Remarks.h:463
void populateSCFStructuralTypeConversionsAndLegality(const TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target, PatternBenefit benefit=1)
Populates patterns for SCF structural type conversions and sets up the provided ConversionTarget with...
Include the generated interface declarations.
Value getValueOrCreateConstantIntOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:102
Value getValueOrCreateCastToIndexLike(OpBuilder &b, Location loc, Type targetType, Value value)
Create a cast from an index-like value (index or integer) to another index-like value.
Definition: Utils.cpp:119
const FrozenRewritePatternSet & patterns
void populateXeGPUToXeVMConversionPatterns(const LLVMTypeConverter &typeConverter, RewritePatternSet &patterns)
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
LogicalResult applyPartialConversion(ArrayRef< Operation * > ops, const ConversionTarget &target, const FrozenRewritePatternSet &patterns, ConversionConfig config=ConversionConfig())
Below we define several entry points for operation conversion.