MLIR  22.0.0git
VectorToXeGPU.cpp
Go to the documentation of this file.
1 //===- VectorToXeGPU.cpp - Convert vector to XeGPU dialect ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector operations to XeGPU dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
22 #include "mlir/Pass/Pass.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 
26 #include <algorithm>
27 #include <optional>
28 
29 namespace mlir {
30 #define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
31 #include "mlir/Conversion/Passes.h.inc"
32 } // namespace mlir
33 
34 using namespace mlir;
35 
36 namespace {
37 
38 // Return true if value represents a zero constant.
39 static bool isZeroConstant(Value val) {
40  auto constant = val.getDefiningOp<arith::ConstantOp>();
41  if (!constant)
42  return false;
43 
44  return TypeSwitch<Attribute, bool>(constant.getValue())
45  .Case<FloatAttr>(
46  [](auto floatAttr) { return floatAttr.getValue().isZero(); })
47  .Case<IntegerAttr>(
48  [](auto intAttr) { return intAttr.getValue().isZero(); })
49  .Default([](auto) { return false; });
50 }
51 
52 static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter,
53  Operation *op, VectorType vecTy) {
54  // Validate only vector as the basic vector store and load ops guarantee
55  // XeGPU-compatible memref source.
56  unsigned vecRank = vecTy.getRank();
57  if (!(vecRank == 1 || vecRank == 2))
58  return rewriter.notifyMatchFailure(op, "Expects 1D or 2D vector");
59 
60  return success();
61 }
62 
63 static LogicalResult transferPreconditions(PatternRewriter &rewriter,
64  VectorTransferOpInterface xferOp) {
65  if (xferOp.getMask())
66  return rewriter.notifyMatchFailure(xferOp,
67  "Masked transfer is not supported");
68 
69  auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
70  if (!srcTy)
71  return rewriter.notifyMatchFailure(xferOp, "Expects memref source");
72 
73  // Validate further transfer op semantics.
74  SmallVector<int64_t> strides;
75  int64_t offset;
76  if (failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
77  return rewriter.notifyMatchFailure(
78  xferOp, "Buffer must be contiguous in the innermost dimension");
79 
80  VectorType vecTy = xferOp.getVectorType();
81  unsigned vecRank = vecTy.getRank();
82  if (xferOp.hasOutOfBoundsDim() && vecRank < 2)
83  return rewriter.notifyMatchFailure(
84  xferOp, "Boundary check is available only for block instructions.");
85 
86  AffineMap map = xferOp.getPermutationMap();
87  if (!map.isProjectedPermutation(/*allowZeroInResults=*/false))
88  return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map");
89  unsigned numInputDims = map.getNumInputs();
90  for (AffineExpr expr : map.getResults().take_back(vecRank)) {
91  auto dim = dyn_cast<AffineDimExpr>(expr);
92  if (dim.getPosition() < (numInputDims - vecRank))
93  return rewriter.notifyMatchFailure(
94  xferOp, "Only the innermost dimensions can be accessed");
95  }
96 
97  return success();
98 }
99 
100 static xegpu::CreateNdDescOp
101 createNdDescriptor(PatternRewriter &rewriter, Location loc,
102  xegpu::TensorDescType descType, TypedValue<MemRefType> src,
103  Operation::operand_range offsets) {
104  MemRefType srcTy = src.getType();
105  auto [strides, offset] = srcTy.getStridesAndOffset();
106 
107  xegpu::CreateNdDescOp ndDesc;
108  if (srcTy.hasStaticShape()) {
109  ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
110  getAsOpFoldResult(offsets));
111  } else {
112  // In case of any dynamic shapes, source's shape and strides have to be
113  // explicitly provided.
114  SmallVector<Value> sourceDims;
115  unsigned srcRank = srcTy.getRank();
116  for (unsigned i = 0; i < srcRank; ++i)
117  sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i));
118 
119  SmallVector<int64_t> constOffsets;
120  SmallVector<Value> dynOffsets;
121  for (Value offset : offsets) {
122  std::optional<int64_t> staticVal = getConstantIntValue(offset);
123  if (!staticVal)
124  dynOffsets.push_back(offset);
125  constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic));
126  }
127 
128  SmallVector<Value> dynShapes;
129  for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
130  if (shape == ShapedType::kDynamic)
131  dynShapes.push_back(sourceDims[idx]);
132  }
133 
134  // Compute strides in reverse order.
135  SmallVector<Value> dynStrides;
136  Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
137  // Last stride is guaranteed to be static and unit.
138  for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
139  accStride =
140  arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]);
141  if (strides[i] == ShapedType::kDynamic)
142  dynStrides.push_back(accStride);
143  }
144  std::reverse(dynStrides.begin(), dynStrides.end());
145 
146  ndDesc = xegpu::CreateNdDescOp::create(
147  rewriter, loc, descType, src, dynOffsets, dynShapes, dynStrides,
148  DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets),
149  DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()),
150  DenseI64ArrayAttr::get(rewriter.getContext(), strides));
151  }
152 
153  return ndDesc;
154 }
155 
156 // Adjusts the strides of a memref according to a given permutation map for
157 // vector operations.
158 //
159 // This function updates the innermost strides in the `strides` array to
160 // reflect the permutation specified by `permMap`. The permutation is computed
161 // using the inverse and broadcasting-aware version of the permutation map,
162 // and is applied to the relevant strides. This ensures that memory accesses
163 // are consistent with the logical permutation of vector elements.
164 //
165 // Example:
166 // Suppose we have a memref of rank 4 with strides `[s0, s1, s2, s3]`.
167 // If the permutation map swaps the last two dimensions (e.g., [0, 1] -> [1,
168 // 0]), then after calling this function, the last two strides will be
169 // swapped:
170 // Original strides: [s0, s1, s2, s3]
171 // After permutation: [s0, s1, s3, s2]
172 //
173 static void adjustStridesForPermutation(AffineMap permMap,
174  SmallVectorImpl<Value> &strides) {
175 
177  SmallVector<unsigned> perms;
179  SmallVector<int64_t> perms64(perms.begin(), perms.end());
180  strides = applyPermutation(strides, perms64);
181 }
182 
183 // Computes memory strides and a memref offset for vector transfer operations,
184 // handling both static and dynamic memrefs while applying permutation
185 // transformations for XeGPU lowering.
186 template <
187  typename OpType,
188  typename = std::enable_if_t<llvm::is_one_of<
189  std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
190  vector::GatherOp, vector::ScatterOp>::value>>
191 static std::pair<SmallVector<Value>, Value>
192 computeMemrefMeta(OpType xferOp, PatternRewriter &rewriter) {
193  SmallVector<Value> strides;
194  Value baseMemref = xferOp.getBase();
195  MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
196 
197  Location loc = xferOp.getLoc();
198  Value offsetVal = nullptr;
199  if (memrefType.hasStaticShape()) {
200  int64_t offset;
201  SmallVector<int64_t> intStrides;
202  if (failed(memrefType.getStridesAndOffset(intStrides, offset)))
203  return {{}, offsetVal};
204  bool hasDynamicStrides = llvm::any_of(intStrides, [](int64_t strideVal) {
205  return ShapedType::isDynamic(strideVal);
206  });
207 
208  if (!hasDynamicStrides)
209  for (int64_t s : intStrides)
210  strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s));
211 
212  if (!ShapedType::isDynamic(offset))
213  offsetVal = arith::ConstantIndexOp::create(rewriter, loc, offset);
214  }
215 
216  if (strides.empty() || !offsetVal) {
217  // For dynamic shape memref, use memref.extract_strided_metadata to get
218  // stride values
219  unsigned rank = memrefType.getRank();
220  Type indexType = rewriter.getIndexType();
221 
222  // Result types: [base_memref, offset, stride0, stride1, ..., strideN-1,
223  // size0, size1, ..., sizeN-1]
224  SmallVector<Type> resultTypes;
225  resultTypes.push_back(MemRefType::get(
226  {}, memrefType.getElementType())); // base memref (unranked)
227  resultTypes.push_back(indexType); // offset
228 
229  for (unsigned i = 0; i < rank; ++i)
230  resultTypes.push_back(indexType); // strides
231 
232  for (unsigned i = 0; i < rank; ++i)
233  resultTypes.push_back(indexType); // sizes
234 
235  auto meta = memref::ExtractStridedMetadataOp::create(
236  rewriter, loc, resultTypes, baseMemref);
237 
238  if (strides.empty())
239  strides.append(meta.getStrides().begin(), meta.getStrides().end());
240 
241  if (!offsetVal)
242  offsetVal = meta.getOffset();
243  }
244 
245  if constexpr (llvm::is_one_of<std::decay_t<OpType>, vector::TransferReadOp,
246  vector::TransferWriteOp>::value) {
247  AffineMap permMap = xferOp.getPermutationMap();
248  // Adjust strides according to the permutation map (e.g., for transpose)
249  adjustStridesForPermutation(permMap, strides);
250  }
251 
252  return {strides, offsetVal};
253 }
254 
255 // This function compute the vectors of localOffsets for scattered load/stores.
256 // It is used in the lowering of vector.transfer_read/write to
257 // load_gather/store_scatter Example:
258 // %0 = vector.transfer_read %expand_shape[%block_id_y, %c0, %c0, %c0, %c0],
259 // %cst {in_bounds = [true, true, true, true]}>} :
260 // memref<8x4x2x6x32xbf16>, vector<4x2x6x32xbf16>
261 //
262 // %6 = vector.step: vector<4xindex>
263 // %7 = vector.step: vector<2xindex>
264 // %8 = vector.step: vector<6xindex>
265 // %9 = vector.step: vector<32xindex>
266 // %10 = arith.mul %6, 384
267 // %11 = arith.mul %7, 192
268 // %12 = arith.mul %8, 32
269 // %13 = arith.mul %9, 1
270 // %14 = vector.shape_cast %10: vector<4xindex> -> vector<4x1x1x1xbf16>
271 // %15 = vector.shape_cast %11: vector<2xindex> -> vector<1x2x1x1xbf16>
272 // %16 = vector.shape_cast %12: vector<6xindex> -> vector<1x1x6x1xbf16>
273 // %17 = vector.shape_cast %13: vector<32xindex> -> vector<1x1x1x32xbf16>
274 // %18 = vector.broadcast %14: vector<4x1x1x1xbf16> -> vector<4x2x6x32xindex>
275 // %19 = vector.broadcast %15: vector<1x2x1x1xbf16> -> vector<4x2x6x32xindex>
276 // %20 = vector.broadcast %16: vector<1x1x6x1xbf16> -> vector<4x2x6x32xindex>
277 // %21 = vector.broadcast %17: vector<1x1x1x32xbf16> -> vector<4x2x6x32xindex>
278 // %22 = arith.add %18, %19
279 // %23 = arith.add %20, %21
280 // %local_offsets = arith.add %22, %23
281 // %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
282 // %offsets = memref_offset + orig_offset + local_offsets
283 static Value computeOffsets(VectorTransferOpInterface xferOp,
284  PatternRewriter &rewriter, ArrayRef<Value> strides,
285  Value baseOffset) {
286  Location loc = xferOp.getLoc();
287  VectorType vectorType = xferOp.getVectorType();
288  SmallVector<Value> indices(xferOp.getIndices().begin(),
289  xferOp.getIndices().end());
290  ArrayRef<int64_t> vectorShape = vectorType.getShape();
291 
292  // Create vector.step operations for each dimension
293  SmallVector<Value> stepVectors;
294  llvm::map_to_vector(vectorShape, [&](int64_t dim) {
295  auto stepType = VectorType::get({dim}, rewriter.getIndexType());
296  auto stepOp = vector::StepOp::create(rewriter, loc, stepType);
297  stepVectors.push_back(stepOp);
298  return stepOp;
299  });
300 
301  // Multiply step vectors by corresponding strides
302  size_t memrefRank = strides.size();
303  size_t vectorRank = vectorShape.size();
304  SmallVector<Value> strideMultiplied;
305  for (size_t i = 0; i < vectorRank; ++i) {
306  size_t memrefDim = memrefRank - vectorRank + i;
307  Value strideValue = strides[memrefDim];
308  auto mulType = dyn_cast<VectorType>(stepVectors[i].getType());
309  auto bcastOp =
310  vector::BroadcastOp::create(rewriter, loc, mulType, strideValue);
311  auto mulOp = arith::MulIOp::create(rewriter, loc, stepVectors[i], bcastOp);
312  strideMultiplied.push_back(mulOp);
313  }
314 
315  // Shape cast each multiplied vector to add singleton dimensions
316  SmallVector<Value> shapeCasted;
317  for (size_t i = 0; i < vectorRank; ++i) {
318  SmallVector<int64_t> newShape(vectorRank, 1);
319  newShape[i] = vectorShape[i];
320  auto newType = VectorType::get(newShape, rewriter.getIndexType());
321  auto castOp = vector::ShapeCastOp::create(rewriter, loc, newType,
322  strideMultiplied[i]);
323  shapeCasted.push_back(castOp);
324  }
325 
326  // Broadcast each shape-casted vector to full vector shape
327  SmallVector<Value> broadcasted;
328  auto fullIndexVectorType =
330  for (Value shapeCastVal : shapeCasted) {
331  auto broadcastOp = vector::BroadcastOp::create(
332  rewriter, loc, fullIndexVectorType, shapeCastVal);
333  broadcasted.push_back(broadcastOp);
334  }
335 
336  // Add all broadcasted vectors together to compute local offsets
337  Value localOffsets = broadcasted[0];
338  for (size_t i = 1; i < broadcasted.size(); ++i)
339  localOffsets =
340  arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);
341 
342  // Compute base offset from transfer read indices
343  for (size_t i = 0; i < indices.size(); ++i) {
344  Value strideVal = strides[i];
345  Value offsetContrib =
346  arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
347  baseOffset =
348  arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
349  }
350  // Broadcast base offset to match vector shape
351  Value bcastBase = vector::BroadcastOp::create(
352  rewriter, loc, fullIndexVectorType, baseOffset);
353  localOffsets = arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
354  return localOffsets;
355 }
356 
357 // Compute the element-wise offsets for vector.gather or vector.scatter ops.
358 //
359 // This function linearizes the base offsets of the gather/scatter operation
360 // and combines them with the per-element indices to produce a final vector of
361 // memory offsets.
362 template <
363  typename OpType,
364  typename = std::enable_if_t<llvm::is_one_of<
365  std::decay_t<OpType>, vector::GatherOp, vector::ScatterOp>::value>>
366 static Value computeOffsets(PatternRewriter &rewriter, OpType gatScatOp,
367  ArrayRef<Value> strides, Value baseOffset) {
368  Location loc = gatScatOp.getLoc();
369  SmallVector<Value> offsets = gatScatOp.getOffsets();
370  for (size_t i = 0; i < offsets.size(); ++i) {
371  Value offsetContrib =
372  arith::MulIOp::create(rewriter, loc, offsets[i], strides[i]);
373  baseOffset =
374  arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
375  }
376  Value indices = gatScatOp.getIndices();
377  VectorType vecType = cast<VectorType>(indices.getType());
378 
379  Value strideVector =
380  vector::BroadcastOp::create(rewriter, loc, vecType, strides.back())
381  .getResult();
382  Value stridedIndices =
383  arith::MulIOp::create(rewriter, loc, strideVector, indices).getResult();
384 
385  Value baseVector =
386  vector::BroadcastOp::create(
387  rewriter, loc,
388  VectorType::get(vecType.getShape(), rewriter.getIndexType()),
389  baseOffset)
390  .getResult();
391  return arith::AddIOp::create(rewriter, loc, baseVector, stridedIndices)
392  .getResult();
393 }
394 
395 template <
396  typename OpType,
397  typename = std::enable_if_t<llvm::is_one_of<
398  std::decay_t<OpType>, vector::TransferReadOp, vector::TransferWriteOp,
399  vector::GatherOp, vector::ScatterOp>::value>>
400 // Convert memref to i64 base pointer
401 static Value memrefToIndexPtr(OpType xferOp, PatternRewriter &rewriter) {
402  Location loc = xferOp.getLoc();
403  auto indexPtr = memref::ExtractAlignedPointerAsIndexOp::create(
404  rewriter, loc, xferOp.getBase())
405  .getResult();
406  return arith::IndexCastOp::create(rewriter, loc, rewriter.getI64Type(),
407  indexPtr)
408  .getResult();
409 }
410 
411 static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
412  PatternRewriter &rewriter) {
413 
414  Location loc = readOp.getLoc();
415  VectorType vectorType = readOp.getVectorType();
416  ArrayRef<int64_t> vectorShape = vectorType.getShape();
417  auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType());
418  if (!memrefType)
419  return rewriter.notifyMatchFailure(readOp, "Expected memref source");
420 
421  auto meta = computeMemrefMeta(readOp, rewriter);
422  if (meta.first.empty())
423  return rewriter.notifyMatchFailure(readOp, "Failed to compute strides");
424 
425  Value localOffsets =
426  computeOffsets(readOp, rewriter, meta.first, meta.second);
427 
428  Value flatMemref = memrefToIndexPtr(readOp, rewriter);
429 
430  Value mask = vector::ConstantMaskOp::create(
431  rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
432  vectorShape);
433  auto gatherOp = xegpu::LoadGatherOp::create(
434  rewriter, loc, vectorType, flatMemref, localOffsets, mask,
435  /*chunk_size=*/IntegerAttr{},
436  /*l1_hint=*/xegpu::CachePolicyAttr{},
437  /*l2_hint=*/xegpu::CachePolicyAttr{},
438  /*l3_hint=*/xegpu::CachePolicyAttr{});
439 
440  rewriter.replaceOp(readOp, gatherOp.getResult());
441  return success();
442 }
443 
444 static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
445  PatternRewriter &rewriter) {
446 
447  Location loc = writeOp.getLoc();
448  VectorType vectorType = writeOp.getVectorType();
449  ArrayRef<int64_t> vectorShape = vectorType.getShape();
450 
451  auto memrefType = dyn_cast<MemRefType>(writeOp.getShapedType());
452  if (!memrefType)
453  return rewriter.notifyMatchFailure(writeOp, "Expected memref source");
454 
455  auto meta = computeMemrefMeta(writeOp, rewriter);
456  if (meta.first.empty())
457  return rewriter.notifyMatchFailure(writeOp, "Failed to compute strides");
458 
459  Value localOffsets =
460  computeOffsets(writeOp, rewriter, meta.first, meta.second);
461 
462  Value flatMemref = memrefToIndexPtr(writeOp, rewriter);
463 
464  Value mask = vector::ConstantMaskOp::create(
465  rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
466  vectorShape);
467  xegpu::StoreScatterOp::create(rewriter, loc, writeOp.getVector(), flatMemref,
468  localOffsets, mask,
469  /*chunk_size=*/IntegerAttr{},
470  /*l1_hint=*/xegpu::CachePolicyAttr{},
471  /*l2_hint=*/xegpu::CachePolicyAttr{},
472  /*l3_hint=*/xegpu::CachePolicyAttr{});
473  rewriter.eraseOp(writeOp);
474  return success();
475 }
476 
477 struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
479 
480  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
481  PatternRewriter &rewriter) const override {
482  Location loc = readOp.getLoc();
483 
484  if (failed(transferPreconditions(rewriter, readOp)))
485  return failure();
486 
487  // TODO:This check needs to be replaced with proper uArch capability check
488  auto chip = xegpu::getChipStr(readOp);
489  if (chip != "pvc" && chip != "bmg") {
490  // lower to scattered load Op if the target HW doesn't have 2d block load
491  // support
492  // TODO: add support for OutOfBound access
493  if (readOp.hasOutOfBoundsDim())
494  return failure();
495  return lowerToScatteredLoadOp(readOp, rewriter);
496  }
497 
498  // Perform common data transfer checks.
499  VectorType vecTy = readOp.getVectorType();
500  if (failed(storeLoadPreconditions(rewriter, readOp, vecTy)))
501  return failure();
502 
503  bool isOutOfBounds = readOp.hasOutOfBoundsDim();
504  if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
505  return rewriter.notifyMatchFailure(
506  readOp, "Unsupported non-zero padded out-of-bounds read");
507 
508  AffineMap readMap = readOp.getPermutationMap();
509  bool isTransposeLoad = !readMap.isMinorIdentity();
510 
511  Type elementType = vecTy.getElementType();
512  unsigned minTransposeBitWidth = 32;
513  if (isTransposeLoad &&
514  elementType.getIntOrFloatBitWidth() < minTransposeBitWidth)
515  return rewriter.notifyMatchFailure(
516  readOp, "Unsupported data type for transposition");
517 
518  // If load is transposed, get the base shape for the tensor descriptor.
519  SmallVector<int64_t> descShape(vecTy.getShape());
520  if (isTransposeLoad)
521  std::reverse(descShape.begin(), descShape.end());
522  auto descType = xegpu::TensorDescType::get(
523  descShape, elementType, /*array_length=*/1,
524  /*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global);
525 
526  xegpu::CreateNdDescOp ndDesc =
527  createNdDescriptor(rewriter, loc, descType,
528  dyn_cast<TypedValue<MemRefType>>(readOp.getBase()),
529  readOp.getIndices());
530 
531  DenseI64ArrayAttr transposeAttr =
532  !isTransposeLoad ? nullptr
533  : DenseI64ArrayAttr::get(rewriter.getContext(),
534  ArrayRef<int64_t>{1, 0});
535  // By default, no specific caching policy is assigned.
536  xegpu::CachePolicyAttr hint = nullptr;
537  auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
538  /*packed=*/nullptr, transposeAttr,
539  /*l1_hint=*/hint,
540  /*l2_hint=*/hint, /*l3_hint=*/hint);
541  rewriter.replaceOp(readOp, loadOp);
542 
543  return success();
544  }
545 };
546 
547 struct TransferWriteLowering
548  : public OpRewritePattern<vector::TransferWriteOp> {
550 
551  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
552  PatternRewriter &rewriter) const override {
553  Location loc = writeOp.getLoc();
554 
555  if (failed(transferPreconditions(rewriter, writeOp)))
556  return failure();
557 
558  // TODO:This check needs to be replaced with proper uArch capability check
559  auto chip = xegpu::getChipStr(writeOp);
560  if (chip != "pvc" && chip != "bmg") {
561  // lower to scattered store Op if the target HW doesn't have 2d block
562  // store support
563  // TODO: add support for OutOfBound access
564  if (writeOp.hasOutOfBoundsDim())
565  return failure();
566  return lowerToScatteredStoreOp(writeOp, rewriter);
567  }
568 
569  // Perform common data transfer checks.
570  VectorType vecTy = writeOp.getVectorType();
571  if (failed(storeLoadPreconditions(rewriter, writeOp, vecTy)))
572  return failure();
573 
574  AffineMap map = writeOp.getPermutationMap();
575  if (!map.isMinorIdentity())
576  return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
577 
578  auto descType = xegpu::TensorDescType::get(
579  vecTy.getShape(), vecTy.getElementType(),
580  /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
581  xegpu::MemorySpace::Global);
582  xegpu::CreateNdDescOp ndDesc =
583  createNdDescriptor(rewriter, loc, descType,
584  dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()),
585  writeOp.getIndices());
586 
587  // By default, no specific caching policy is assigned.
588  xegpu::CachePolicyAttr hint = nullptr;
589  auto storeOp =
590  xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
591  /*l1_hint=*/hint,
592  /*l2_hint=*/hint, /*l3_hint=*/hint);
593  rewriter.replaceOp(writeOp, storeOp);
594 
595  return success();
596  }
597 };
598 
599 struct GatherLowering : public OpRewritePattern<vector::GatherOp> {
601 
602  LogicalResult matchAndRewrite(vector::GatherOp gatherOp,
603  PatternRewriter &rewriter) const override {
604  auto srcTy = dyn_cast<MemRefType>(gatherOp.getBase().getType());
605  if (!srcTy)
606  return rewriter.notifyMatchFailure(gatherOp, "Expects memref source");
607 
608  Location loc = gatherOp.getLoc();
609  VectorType vectorType = gatherOp.getVectorType();
610 
611  auto meta = computeMemrefMeta(gatherOp, rewriter);
612  if (meta.first.empty())
613  return rewriter.notifyMatchFailure(gatherOp, "Failed to compute strides");
614 
615  Value localOffsets =
616  computeOffsets(rewriter, gatherOp, meta.first, meta.second);
617  Value flatMemref = memrefToIndexPtr(gatherOp, rewriter);
618 
619  auto xeGatherOp = xegpu::LoadGatherOp::create(
620  rewriter, loc, vectorType, flatMemref, localOffsets, gatherOp.getMask(),
621  /*chunk_size=*/IntegerAttr{},
622  /*l1_hint=*/xegpu::CachePolicyAttr{},
623  /*l2_hint=*/xegpu::CachePolicyAttr{},
624  /*l3_hint=*/xegpu::CachePolicyAttr{});
625 
626  auto selectOp =
627  arith::SelectOp::create(rewriter, loc, gatherOp.getMask(),
628  xeGatherOp.getResult(), gatherOp.getPassThru());
629  rewriter.replaceOp(gatherOp, selectOp.getResult());
630  return success();
631  }
632 };
633 
634 struct ScatterLowering : public OpRewritePattern<vector::ScatterOp> {
636 
637  LogicalResult matchAndRewrite(vector::ScatterOp scatterOp,
638  PatternRewriter &rewriter) const override {
639  auto srcTy = dyn_cast<MemRefType>(scatterOp.getBase().getType());
640  if (!srcTy)
641  return rewriter.notifyMatchFailure(scatterOp, "Expects memref source");
642 
643  Location loc = scatterOp.getLoc();
644  auto meta = computeMemrefMeta(scatterOp, rewriter);
645  if (meta.first.empty())
646  return rewriter.notifyMatchFailure(scatterOp,
647  "Failed to compute strides");
648 
649  Value localOffsets =
650  computeOffsets(rewriter, scatterOp, meta.first, meta.second);
651  Value flatMemref = memrefToIndexPtr(scatterOp, rewriter);
652 
653  xegpu::StoreScatterOp::create(rewriter, loc, scatterOp.getValueToStore(),
654  flatMemref, localOffsets, scatterOp.getMask(),
655  /*chunk_size=*/IntegerAttr{},
656  /*l1_hint=*/xegpu::CachePolicyAttr{},
657  /*l2_hint=*/xegpu::CachePolicyAttr{},
658  /*l3_hint=*/xegpu::CachePolicyAttr{});
659  rewriter.eraseOp(scatterOp);
660  return success();
661  }
662 };
663 
664 struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
666 
667  LogicalResult matchAndRewrite(vector::LoadOp loadOp,
668  PatternRewriter &rewriter) const override {
669  Location loc = loadOp.getLoc();
670 
671  VectorType vecTy = loadOp.getResult().getType();
672  if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy)))
673  return failure();
674 
675  // Boundary check is available only for block instructions.
676  bool boundaryCheck = vecTy.getRank() > 1;
677 
678  auto descType = xegpu::TensorDescType::get(
679  vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
680  boundaryCheck, xegpu::MemorySpace::Global);
681  xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
682  rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
683 
684  // By default, no specific caching policy is assigned.
685  xegpu::CachePolicyAttr hint = nullptr;
686  auto loadNdOp = xegpu::LoadNdOp::create(
687  rewriter, loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr,
688  /*l1_hint=*/hint,
689  /*l2_hint=*/hint, /*l3_hint=*/hint);
690  rewriter.replaceOp(loadOp, loadNdOp);
691 
692  return success();
693  }
694 };
695 
696 struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
698 
699  LogicalResult matchAndRewrite(vector::StoreOp storeOp,
700  PatternRewriter &rewriter) const override {
701  Location loc = storeOp.getLoc();
702 
703  TypedValue<VectorType> vector = storeOp.getValueToStore();
704  VectorType vecTy = vector.getType();
705  if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy)))
706  return failure();
707 
708  // Boundary check is available only for block instructions.
709  bool boundaryCheck = vecTy.getRank() > 1;
710 
711  auto descType = xegpu::TensorDescType::get(
712  vecTy.getShape(), vecTy.getElementType(),
713  /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global);
714  xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
715  rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
716 
717  // By default, no specific caching policy is assigned.
718  xegpu::CachePolicyAttr hint = nullptr;
719  auto storeNdOp =
720  xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
721  /*l1_hint=*/hint,
722  /*l2_hint=*/hint, /*l3_hint=*/hint);
723  rewriter.replaceOp(storeOp, storeNdOp);
724 
725  return success();
726  }
727 };
728 
729 struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
731 
732  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
733  PatternRewriter &rewriter) const override {
734  Location loc = contractOp.getLoc();
735 
736  if (contractOp.getKind() != vector::CombiningKind::ADD)
737  return rewriter.notifyMatchFailure(contractOp,
738  "Expects add combining kind");
739 
740  TypedValue<Type> acc = contractOp.getAcc();
741  VectorType accType = dyn_cast<VectorType>(acc.getType());
742  if (!accType || accType.getRank() != 2)
743  return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
744 
745  // Accept only plain 2D data layout.
746  // VNNI packing is applied to DPAS as a separate lowering step.
747  TypedValue<VectorType> lhs = contractOp.getLhs();
748  TypedValue<VectorType> rhs = contractOp.getRhs();
749  if (lhs.getType().getRank() != 2 || rhs.getType().getRank() != 2)
750  return rewriter.notifyMatchFailure(contractOp,
751  "Expects lhs and rhs 2D vectors");
752 
753  if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr()))
754  return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
755 
756  auto dpasOp = xegpu::DpasOp::create(rewriter, loc,
757  TypeRange{contractOp.getResultType()},
758  ValueRange{lhs, rhs, acc});
759  rewriter.replaceOp(contractOp, dpasOp);
760 
761  return success();
762  }
763 };
764 
765 struct ConvertVectorToXeGPUPass
766  : public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
767  void runOnOperation() override {
770  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
771  return signalPassFailure();
772  }
773 };
774 
775 } // namespace
776 
779  patterns
780  .add<TransferReadLowering, TransferWriteLowering, LoadLowering,
781  ScatterLowering, GatherLowering, StoreLowering, ContractionLowering>(
782  patterns.getContext());
783 }
static MLIRContext * getContext(OpFoldResult val)
static std::optional< VectorShape > vectorShape(Type type)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
Definition: AffineMap.cpp:151
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:611
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
Definition: AffineMap.cpp:212
unsigned getNumInputs() const
Definition: AffineMap.cpp:399
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:260
IntegerType getI64Type()
Definition: Builders.cpp:64
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:52
IndexType getIndexType()
Definition: Builders.cpp:50
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Definition: XeGPUUtils.cpp:432
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
Definition: AffineMap.cpp:808
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
const FrozenRewritePatternSet & patterns
void populateVectorToXeGPUConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the vector to XeGPU ops.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool isRowMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major matmul.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314