MLIR  22.0.0git
VectorToXeGPU.cpp
Go to the documentation of this file.
1 //===- VectorToXeGPU.cpp - Convert vector to XeGPU dialect ------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector operations to XeGPU dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
22 #include "mlir/Pass/Pass.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 
26 #include <algorithm>
27 #include <optional>
28 
29 namespace mlir {
30 #define GEN_PASS_DEF_CONVERTVECTORTOXEGPU
31 #include "mlir/Conversion/Passes.h.inc"
32 } // namespace mlir
33 
34 using namespace mlir;
35 
36 namespace {
37 
38 // Return true if value represents a zero constant.
39 static bool isZeroConstant(Value val) {
40  auto constant = val.getDefiningOp<arith::ConstantOp>();
41  if (!constant)
42  return false;
43 
44  return TypeSwitch<Attribute, bool>(constant.getValue())
45  .Case<FloatAttr>(
46  [](auto floatAttr) { return floatAttr.getValue().isZero(); })
47  .Case<IntegerAttr>(
48  [](auto intAttr) { return intAttr.getValue().isZero(); })
49  .Default([](auto) { return false; });
50 }
51 
52 static LogicalResult storeLoadPreconditions(PatternRewriter &rewriter,
53  Operation *op, VectorType vecTy) {
54  // Validate only vector as the basic vector store and load ops guarantee
55  // XeGPU-compatible memref source.
56  unsigned vecRank = vecTy.getRank();
57  if (!(vecRank == 1 || vecRank == 2))
58  return rewriter.notifyMatchFailure(op, "Expects 1D or 2D vector");
59 
60  return success();
61 }
62 
63 static LogicalResult transferPreconditions(PatternRewriter &rewriter,
64  VectorTransferOpInterface xferOp) {
65  if (xferOp.getMask())
66  return rewriter.notifyMatchFailure(xferOp,
67  "Masked transfer is not supported");
68 
69  auto srcTy = dyn_cast<MemRefType>(xferOp.getShapedType());
70  if (!srcTy)
71  return rewriter.notifyMatchFailure(xferOp, "Expects memref source");
72 
73  // Validate further transfer op semantics.
74  SmallVector<int64_t> strides;
75  int64_t offset;
76  if (failed(srcTy.getStridesAndOffset(strides, offset)) || strides.back() != 1)
77  return rewriter.notifyMatchFailure(
78  xferOp, "Buffer must be contiguous in the innermost dimension");
79 
80  VectorType vecTy = xferOp.getVectorType();
81  unsigned vecRank = vecTy.getRank();
82  if (xferOp.hasOutOfBoundsDim() && vecRank < 2)
83  return rewriter.notifyMatchFailure(
84  xferOp, "Boundary check is available only for block instructions.");
85 
86  AffineMap map = xferOp.getPermutationMap();
87  if (!map.isProjectedPermutation(/*allowZeroInResults=*/false))
88  return rewriter.notifyMatchFailure(xferOp, "Unsupported permutation map");
89  unsigned numInputDims = map.getNumInputs();
90  for (AffineExpr expr : map.getResults().take_back(vecRank)) {
91  auto dim = dyn_cast<AffineDimExpr>(expr);
92  if (dim.getPosition() < (numInputDims - vecRank))
93  return rewriter.notifyMatchFailure(
94  xferOp, "Only the innermost dimensions can be accessed");
95  }
96 
97  return success();
98 }
99 
100 static xegpu::CreateNdDescOp
101 createNdDescriptor(PatternRewriter &rewriter, Location loc,
102  xegpu::TensorDescType descType, TypedValue<MemRefType> src,
103  Operation::operand_range offsets) {
104  MemRefType srcTy = src.getType();
105  auto [strides, offset] = srcTy.getStridesAndOffset();
106 
107  xegpu::CreateNdDescOp ndDesc;
108  if (srcTy.hasStaticShape()) {
109  ndDesc = xegpu::CreateNdDescOp::create(rewriter, loc, descType, src,
110  getAsOpFoldResult(offsets));
111  } else {
112  // In case of any dynamic shapes, source's shape and strides have to be
113  // explicitly provided.
114  SmallVector<Value> sourceDims;
115  unsigned srcRank = srcTy.getRank();
116  for (unsigned i = 0; i < srcRank; ++i)
117  sourceDims.push_back(memref::DimOp::create(rewriter, loc, src, i));
118 
119  SmallVector<int64_t> constOffsets;
120  SmallVector<Value> dynOffsets;
121  for (Value offset : offsets) {
122  std::optional<int64_t> staticVal = getConstantIntValue(offset);
123  if (!staticVal)
124  dynOffsets.push_back(offset);
125  constOffsets.push_back(staticVal.value_or(ShapedType::kDynamic));
126  }
127 
128  SmallVector<Value> dynShapes;
129  for (auto [idx, shape] : llvm::enumerate(srcTy.getShape())) {
130  if (shape == ShapedType::kDynamic)
131  dynShapes.push_back(sourceDims[idx]);
132  }
133 
134  // Compute strides in reverse order.
135  SmallVector<Value> dynStrides;
136  Value accStride = arith::ConstantIndexOp::create(rewriter, loc, 1);
137  // Last stride is guaranteed to be static and unit.
138  for (int i = static_cast<int>(strides.size()) - 2; i >= 0; --i) {
139  accStride =
140  arith::MulIOp::create(rewriter, loc, accStride, sourceDims[i + 1]);
141  if (strides[i] == ShapedType::kDynamic)
142  dynStrides.push_back(accStride);
143  }
144  std::reverse(dynStrides.begin(), dynStrides.end());
145 
146  ndDesc = xegpu::CreateNdDescOp::create(
147  rewriter, loc, descType, src, dynOffsets, dynShapes, dynStrides,
148  DenseI64ArrayAttr::get(rewriter.getContext(), constOffsets),
149  DenseI64ArrayAttr::get(rewriter.getContext(), srcTy.getShape()),
150  DenseI64ArrayAttr::get(rewriter.getContext(), strides));
151  }
152 
153  return ndDesc;
154 }
155 
156 // Adjusts the strides of a memref according to a given permutation map for
157 // vector operations.
158 //
159 // This function updates the innermost strides in the `strides` array to
160 // reflect the permutation specified by `permMap`. The permutation is computed
161 // using the inverse and broadcasting-aware version of the permutation map,
162 // and is applied to the relevant strides. This ensures that memory accesses
163 // are consistent with the logical permutation of vector elements.
164 //
165 // Example:
166 // Suppose we have a memref of rank 4 with strides `[s0, s1, s2, s3]`.
167 // If the permutation map swaps the last two dimensions (e.g., [0, 1] -> [1,
168 // 0]), then after calling this function, the last two strides will be
169 // swapped:
170 // Original strides: [s0, s1, s2, s3]
171 // After permutation: [s0, s1, s3, s2]
172 //
173 static void adjustStridesForPermutation(AffineMap permMap,
174  SmallVectorImpl<Value> &strides) {
175 
177  SmallVector<unsigned> perms;
179  SmallVector<int64_t> perms64(perms.begin(), perms.end());
180  strides = applyPermutation(strides, perms64);
181 }
182 
183 // Computes memory strides for vector transfer operations, handling both
184 // static and dynamic memrefs while applying permutation transformations
185 // for XeGPU lowering.
186 static SmallVector<Value> computeStrides(VectorTransferOpInterface xferOp,
187  PatternRewriter &rewriter) {
188  SmallVector<Value> strides;
189  Value baseMemref = xferOp.getBase();
190  AffineMap permMap = xferOp.getPermutationMap();
191  MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
192 
193  Location loc = xferOp.getLoc();
194  if (memrefType.hasStaticShape()) {
195  int64_t offset;
196  SmallVector<int64_t> intStrides;
197  if (failed(memrefType.getStridesAndOffset(intStrides, offset)))
198  return {};
199  // Wrap static strides as MLIR values
200  for (int64_t s : intStrides)
201  strides.push_back(arith::ConstantIndexOp::create(rewriter, loc, s));
202  } else {
203  // For dynamic shape memref, use memref.extract_strided_metadata to get
204  // stride values
205  unsigned rank = memrefType.getRank();
206  Type indexType = rewriter.getIndexType();
207 
208  // Result types: [base_memref, offset, stride0, stride1, ..., strideN-1,
209  // size0, size1, ..., sizeN-1]
210  SmallVector<Type> resultTypes;
211  resultTypes.push_back(MemRefType::get(
212  {}, memrefType.getElementType())); // base memref (unranked)
213  resultTypes.push_back(indexType); // offset
214 
215  for (unsigned i = 0; i < rank; ++i)
216  resultTypes.push_back(indexType); // strides
217 
218  for (unsigned i = 0; i < rank; ++i)
219  resultTypes.push_back(indexType); // sizes
220 
221  auto meta = memref::ExtractStridedMetadataOp::create(
222  rewriter, loc, resultTypes, baseMemref);
223  strides.append(meta.getStrides().begin(), meta.getStrides().end());
224  }
225  // Adjust strides according to the permutation map (e.g., for transpose)
226  adjustStridesForPermutation(permMap, strides);
227  return strides;
228 }
229 
230 // This function compute the vectors of localOffsets for scattered load/stores.
231 // It is used in the lowering of vector.transfer_read/write to
232 // load_gather/store_scatter Example:
233 // %0 = vector.transfer_read %expand_shape[%block_id_y, %c0, %c0, %c0, %c0],
234 // %cst {in_bounds = [true, true, true, true]}>} :
235 // memref<8x4x2x6x32xbf16>, vector<4x2x6x32xbf16>
236 //
237 // %6 = vector.step: vector<4xindex>
238 // %7 = vector.step: vector<2xindex>
239 // %8 = vector.step: vector<6xindex>
240 // %9 = vector.step: vector<32xindex>
241 // %10 = arith.mul %6, 384
242 // %11 = arith.mul %7, 192
243 // %12 = arith.mul %8, 32
244 // %13 = arith.mul %9, 1
245 // %14 = vector.shape_cast %10: vector<4xindex> -> vector<4x1x1x1xbf16>
246 // %15 = vector.shape_cast %11: vector<2xindex> -> vector<1x2x1x1xbf16>
247 // %16 = vector.shape_cast %12: vector<6xindex> -> vector<1x1x6x1xbf16>
248 // %17 = vector.shape_cast %13: vector<32xindex> -> vector<1x1x1x32xbf16>
249 // %18 = vector.broadcast %14: vector<4x1x1x1xbf16> -> vector<4x2x6x32xindex>
250 // %19 = vector.broadcast %15: vector<1x2x1x1xbf16> -> vector<4x2x6x32xindex>
251 // %20 = vector.broadcast %16: vector<1x1x6x1xbf16> -> vector<4x2x6x32xindex>
252 // %21 = vector.broadcast %17: vector<1x1x1x32xbf16> -> vector<4x2x6x32xindex>
253 // %22 = arith.add %18, %19
254 // %23 = arith.add %20, %21
255 // %local_offsets = arith.add %22, %23
256 // %orig_offset = %block_id_y * 4x2x6x32 // consider using affine map
257 // %offsets = orig_offset + local_offsets
258 static Value computeOffsets(VectorTransferOpInterface xferOp,
259  PatternRewriter &rewriter,
260  ArrayRef<Value> strides) {
261  Location loc = xferOp.getLoc();
262  VectorType vectorType = xferOp.getVectorType();
263  SmallVector<Value> indices(xferOp.getIndices().begin(),
264  xferOp.getIndices().end());
265  ArrayRef<int64_t> vectorShape = vectorType.getShape();
266 
267  // Create vector.step operations for each dimension
268  SmallVector<Value> stepVectors;
269  llvm::map_to_vector(vectorShape, [&](int64_t dim) {
270  auto stepType = VectorType::get({dim}, rewriter.getIndexType());
271  auto stepOp = vector::StepOp::create(rewriter, loc, stepType);
272  stepVectors.push_back(stepOp);
273  return stepOp;
274  });
275 
276  // Multiply step vectors by corresponding strides
277  size_t memrefRank = strides.size();
278  size_t vectorRank = vectorShape.size();
279  SmallVector<Value> strideMultiplied;
280  for (size_t i = 0; i < vectorRank; ++i) {
281  size_t memrefDim = memrefRank - vectorRank + i;
282  Value strideValue = strides[memrefDim];
283  auto mulType = dyn_cast<VectorType>(stepVectors[i].getType());
284  auto bcastOp =
285  vector::BroadcastOp::create(rewriter, loc, mulType, strideValue);
286  auto mulOp = arith::MulIOp::create(rewriter, loc, stepVectors[i], bcastOp);
287  strideMultiplied.push_back(mulOp);
288  }
289 
290  // Shape cast each multiplied vector to add singleton dimensions
291  SmallVector<Value> shapeCasted;
292  for (size_t i = 0; i < vectorRank; ++i) {
293  SmallVector<int64_t> newShape(vectorRank, 1);
294  newShape[i] = vectorShape[i];
295  auto newType = VectorType::get(newShape, rewriter.getIndexType());
296  auto castOp = vector::ShapeCastOp::create(rewriter, loc, newType,
297  strideMultiplied[i]);
298  shapeCasted.push_back(castOp);
299  }
300 
301  // Broadcast each shape-casted vector to full vector shape
302  SmallVector<Value> broadcasted;
303  auto fullIndexVectorType =
305  for (Value shapeCastVal : shapeCasted) {
306  auto broadcastOp = vector::BroadcastOp::create(
307  rewriter, loc, fullIndexVectorType, shapeCastVal);
308  broadcasted.push_back(broadcastOp);
309  }
310 
311  // Add all broadcasted vectors together to compute local offsets
312  Value localOffsets = broadcasted[0];
313  for (size_t i = 1; i < broadcasted.size(); ++i)
314  localOffsets =
315  arith::AddIOp::create(rewriter, loc, localOffsets, broadcasted[i]);
316 
317  // Compute base offset from transfer read indices
318  Value baseOffset = nullptr;
319  if (!indices.empty()) {
320  baseOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
321  for (size_t i = 0; i < indices.size(); ++i) {
322  Value strideVal = strides[i];
323  Value offsetContrib =
324  arith::MulIOp::create(rewriter, loc, indices[i], strideVal);
325  baseOffset =
326  arith::AddIOp::create(rewriter, loc, baseOffset, offsetContrib);
327  }
328  // Broadcast base offset to match vector shape
329  Value bcastBase = vector::BroadcastOp::create(
330  rewriter, loc, fullIndexVectorType, baseOffset);
331  localOffsets =
332  arith::AddIOp::create(rewriter, loc, bcastBase, localOffsets);
333  }
334  return localOffsets;
335 }
336 
337 // Collapse memref shape to 1D
338 static Value collapseMemrefTo1D(VectorTransferOpInterface xferOp,
339  PatternRewriter &rewriter) {
340  Location loc = xferOp.getLoc();
341 
342  Value baseMemref = xferOp.getBase();
343  MemRefType memrefType = dyn_cast<MemRefType>(baseMemref.getType());
344  Type elementType = memrefType.getElementType();
345 
346  // Compute the total number of elements in the memref
347  MemRefType flatMemrefType;
348  if (memrefType.hasStaticShape()) {
349  auto totalElements = memrefType.getNumElements();
350  flatMemrefType = MemRefType::get({totalElements}, elementType);
351  } else {
352  flatMemrefType = MemRefType::get({ShapedType::kDynamic}, elementType);
353  }
354 
355  SmallVector<ReassociationIndices> reassociation;
356  ReassociationIndices allDims =
357  llvm::to_vector(llvm::seq<int64_t>(0, memrefType.getRank()));
358  reassociation.push_back(allDims);
359 
360  auto collapseOp = memref::CollapseShapeOp::create(
361  rewriter, loc, flatMemrefType, baseMemref, reassociation);
362  return collapseOp;
363 }
364 
365 static LogicalResult lowerToScatteredLoadOp(vector::TransferReadOp readOp,
366  PatternRewriter &rewriter) {
367 
368  Location loc = readOp.getLoc();
369  VectorType vectorType = readOp.getVectorType();
370  ArrayRef<int64_t> vectorShape = vectorType.getShape();
371  auto memrefType = dyn_cast<MemRefType>(readOp.getShapedType());
372  if (!memrefType)
373  return rewriter.notifyMatchFailure(readOp, "Expected memref source");
374 
375  SmallVector<Value> strides = computeStrides(readOp, rewriter);
376  if (strides.empty())
377  return rewriter.notifyMatchFailure(readOp, "Failed to compute strides");
378 
379  Value localOffsets = computeOffsets(readOp, rewriter, strides);
380 
381  Value flatMemref = collapseMemrefTo1D(readOp, rewriter);
382 
383  Value mask = vector::ConstantMaskOp::create(
384  rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
385  vectorShape);
386  auto gatherOp = xegpu::LoadGatherOp::create(
387  rewriter, loc, vectorType, flatMemref, localOffsets, mask,
388  /*chunk_size=*/IntegerAttr{},
389  /*l1_hint=*/xegpu::CachePolicyAttr{},
390  /*l2_hint=*/xegpu::CachePolicyAttr{},
391  /*l3_hint=*/xegpu::CachePolicyAttr{});
392 
393  rewriter.replaceOp(readOp, gatherOp.getResult());
394  return success();
395 }
396 
397 static LogicalResult lowerToScatteredStoreOp(vector::TransferWriteOp writeOp,
398  PatternRewriter &rewriter) {
399 
400  Location loc = writeOp.getLoc();
401  VectorType vectorType = writeOp.getVectorType();
402  ArrayRef<int64_t> vectorShape = vectorType.getShape();
403 
404  auto memrefType = dyn_cast<MemRefType>(writeOp.getShapedType());
405  if (!memrefType)
406  return rewriter.notifyMatchFailure(writeOp, "Expected memref source");
407 
408  SmallVector<Value> strides = computeStrides(writeOp, rewriter);
409 
410  Value localOffsets = computeOffsets(writeOp, rewriter, strides);
411 
412  Value flatMemref = collapseMemrefTo1D(writeOp, rewriter);
413 
414  Value mask = vector::ConstantMaskOp::create(
415  rewriter, loc, VectorType::get(vectorShape, rewriter.getI1Type()),
416  vectorShape);
417  xegpu::StoreScatterOp::create(rewriter, loc, writeOp.getVector(), flatMemref,
418  localOffsets, mask,
419  /*chunk_size=*/IntegerAttr{},
420  /*l1_hint=*/xegpu::CachePolicyAttr{},
421  /*l2_hint=*/xegpu::CachePolicyAttr{},
422  /*l3_hint=*/xegpu::CachePolicyAttr{});
423  rewriter.eraseOp(writeOp);
424  return success();
425 }
426 
427 struct TransferReadLowering : public OpRewritePattern<vector::TransferReadOp> {
429 
430  LogicalResult matchAndRewrite(vector::TransferReadOp readOp,
431  PatternRewriter &rewriter) const override {
432  Location loc = readOp.getLoc();
433 
434  if (failed(transferPreconditions(rewriter, readOp)))
435  return failure();
436 
437  // TODO:This check needs to be replaced with proper uArch capability check
438  auto chip = xegpu::getChipStr(readOp);
439  if (chip != "pvc" && chip != "bmg") {
440  // lower to scattered load Op if the target HW doesn't have 2d block load
441  // support
442  // TODO: add support for OutOfBound access
443  if (readOp.hasOutOfBoundsDim())
444  return failure();
445  return lowerToScatteredLoadOp(readOp, rewriter);
446  }
447 
448  // Perform common data transfer checks.
449  VectorType vecTy = readOp.getVectorType();
450  if (failed(storeLoadPreconditions(rewriter, readOp, vecTy)))
451  return failure();
452 
453  bool isOutOfBounds = readOp.hasOutOfBoundsDim();
454  if (isOutOfBounds && !isZeroConstant(readOp.getPadding()))
455  return rewriter.notifyMatchFailure(
456  readOp, "Unsupported non-zero padded out-of-bounds read");
457 
458  AffineMap readMap = readOp.getPermutationMap();
459  bool isTransposeLoad = !readMap.isMinorIdentity();
460 
461  Type elementType = vecTy.getElementType();
462  unsigned minTransposeBitWidth = 32;
463  if (isTransposeLoad &&
464  elementType.getIntOrFloatBitWidth() < minTransposeBitWidth)
465  return rewriter.notifyMatchFailure(
466  readOp, "Unsupported data type for transposition");
467 
468  // If load is transposed, get the base shape for the tensor descriptor.
469  SmallVector<int64_t> descShape(vecTy.getShape());
470  if (isTransposeLoad)
471  std::reverse(descShape.begin(), descShape.end());
472  auto descType = xegpu::TensorDescType::get(
473  descShape, elementType, /*array_length=*/1,
474  /*boundary_check=*/isOutOfBounds, xegpu::MemorySpace::Global);
475 
476  xegpu::CreateNdDescOp ndDesc =
477  createNdDescriptor(rewriter, loc, descType,
478  dyn_cast<TypedValue<MemRefType>>(readOp.getBase()),
479  readOp.getIndices());
480 
481  DenseI64ArrayAttr transposeAttr =
482  !isTransposeLoad ? nullptr
483  : DenseI64ArrayAttr::get(rewriter.getContext(),
484  ArrayRef<int64_t>{1, 0});
485  // By default, no specific caching policy is assigned.
486  xegpu::CachePolicyAttr hint = nullptr;
487  auto loadOp = xegpu::LoadNdOp::create(rewriter, loc, vecTy, ndDesc,
488  /*packed=*/nullptr, transposeAttr,
489  /*l1_hint=*/hint,
490  /*l2_hint=*/hint, /*l3_hint=*/hint);
491  rewriter.replaceOp(readOp, loadOp);
492 
493  return success();
494  }
495 };
496 
497 struct TransferWriteLowering
498  : public OpRewritePattern<vector::TransferWriteOp> {
500 
501  LogicalResult matchAndRewrite(vector::TransferWriteOp writeOp,
502  PatternRewriter &rewriter) const override {
503  Location loc = writeOp.getLoc();
504 
505  if (failed(transferPreconditions(rewriter, writeOp)))
506  return failure();
507 
508  // TODO:This check needs to be replaced with proper uArch capability check
509  auto chip = xegpu::getChipStr(writeOp);
510  if (chip != "pvc" && chip != "bmg") {
511  // lower to scattered store Op if the target HW doesn't have 2d block
512  // store support
513  // TODO: add support for OutOfBound access
514  if (writeOp.hasOutOfBoundsDim())
515  return failure();
516  return lowerToScatteredStoreOp(writeOp, rewriter);
517  }
518 
519  // Perform common data transfer checks.
520  VectorType vecTy = writeOp.getVectorType();
521  if (failed(storeLoadPreconditions(rewriter, writeOp, vecTy)))
522  return failure();
523 
524  AffineMap map = writeOp.getPermutationMap();
525  if (!map.isMinorIdentity())
526  return rewriter.notifyMatchFailure(writeOp, "Expects identity map");
527 
528  auto descType = xegpu::TensorDescType::get(
529  vecTy.getShape(), vecTy.getElementType(),
530  /*array_length=*/1, /*boundary_check=*/writeOp.hasOutOfBoundsDim(),
531  xegpu::MemorySpace::Global);
532  xegpu::CreateNdDescOp ndDesc =
533  createNdDescriptor(rewriter, loc, descType,
534  dyn_cast<TypedValue<MemRefType>>(writeOp.getBase()),
535  writeOp.getIndices());
536 
537  // By default, no specific caching policy is assigned.
538  xegpu::CachePolicyAttr hint = nullptr;
539  auto storeOp =
540  xegpu::StoreNdOp::create(rewriter, loc, writeOp.getVector(), ndDesc,
541  /*l1_hint=*/hint,
542  /*l2_hint=*/hint, /*l3_hint=*/hint);
543  rewriter.replaceOp(writeOp, storeOp);
544 
545  return success();
546  }
547 };
548 
549 struct LoadLowering : public OpRewritePattern<vector::LoadOp> {
551 
552  LogicalResult matchAndRewrite(vector::LoadOp loadOp,
553  PatternRewriter &rewriter) const override {
554  Location loc = loadOp.getLoc();
555 
556  VectorType vecTy = loadOp.getResult().getType();
557  if (failed(storeLoadPreconditions(rewriter, loadOp, vecTy)))
558  return failure();
559 
560  // Boundary check is available only for block instructions.
561  bool boundaryCheck = vecTy.getRank() > 1;
562 
563  auto descType = xegpu::TensorDescType::get(
564  vecTy.getShape(), vecTy.getElementType(), /*array_length=*/1,
565  boundaryCheck, xegpu::MemorySpace::Global);
566  xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
567  rewriter, loc, descType, loadOp.getBase(), loadOp.getIndices());
568 
569  // By default, no specific caching policy is assigned.
570  xegpu::CachePolicyAttr hint = nullptr;
571  auto loadNdOp = xegpu::LoadNdOp::create(
572  rewriter, loc, vecTy, ndDesc, /*packed=*/nullptr, /*transpose=*/nullptr,
573  /*l1_hint=*/hint,
574  /*l2_hint=*/hint, /*l3_hint=*/hint);
575  rewriter.replaceOp(loadOp, loadNdOp);
576 
577  return success();
578  }
579 };
580 
581 struct StoreLowering : public OpRewritePattern<vector::StoreOp> {
583 
584  LogicalResult matchAndRewrite(vector::StoreOp storeOp,
585  PatternRewriter &rewriter) const override {
586  Location loc = storeOp.getLoc();
587 
588  TypedValue<VectorType> vector = storeOp.getValueToStore();
589  VectorType vecTy = vector.getType();
590  if (failed(storeLoadPreconditions(rewriter, storeOp, vecTy)))
591  return failure();
592 
593  // Boundary check is available only for block instructions.
594  bool boundaryCheck = vecTy.getRank() > 1;
595 
596  auto descType = xegpu::TensorDescType::get(
597  vecTy.getShape(), vecTy.getElementType(),
598  /*array_length=*/1, boundaryCheck, xegpu::MemorySpace::Global);
599  xegpu::CreateNdDescOp ndDesc = createNdDescriptor(
600  rewriter, loc, descType, storeOp.getBase(), storeOp.getIndices());
601 
602  // By default, no specific caching policy is assigned.
603  xegpu::CachePolicyAttr hint = nullptr;
604  auto storeNdOp =
605  xegpu::StoreNdOp::create(rewriter, loc, vector, ndDesc,
606  /*l1_hint=*/hint,
607  /*l2_hint=*/hint, /*l3_hint=*/hint);
608  rewriter.replaceOp(storeOp, storeNdOp);
609 
610  return success();
611  }
612 };
613 
614 struct ContractionLowering : public OpRewritePattern<vector::ContractionOp> {
616 
617  LogicalResult matchAndRewrite(vector::ContractionOp contractOp,
618  PatternRewriter &rewriter) const override {
619  Location loc = contractOp.getLoc();
620 
621  if (contractOp.getKind() != vector::CombiningKind::ADD)
622  return rewriter.notifyMatchFailure(contractOp,
623  "Expects add combining kind");
624 
625  TypedValue<Type> acc = contractOp.getAcc();
626  VectorType accType = dyn_cast<VectorType>(acc.getType());
627  if (!accType || accType.getRank() != 2)
628  return rewriter.notifyMatchFailure(contractOp, "Expects acc 2D vector");
629 
630  // Accept only plain 2D data layout.
631  // VNNI packing is applied to DPAS as a separate lowering step.
632  TypedValue<VectorType> lhs = contractOp.getLhs();
633  TypedValue<VectorType> rhs = contractOp.getRhs();
634  if (lhs.getType().getRank() != 2 || rhs.getType().getRank() != 2)
635  return rewriter.notifyMatchFailure(contractOp,
636  "Expects lhs and rhs 2D vectors");
637 
638  if (!isRowMajorMatmul(contractOp.getIndexingMapsAttr()))
639  return rewriter.notifyMatchFailure(contractOp, "Invalid indexing maps");
640 
641  auto dpasOp = xegpu::DpasOp::create(rewriter, loc,
642  TypeRange{contractOp.getResultType()},
643  ValueRange{lhs, rhs, acc});
644  rewriter.replaceOp(contractOp, dpasOp);
645 
646  return success();
647  }
648 };
649 
650 struct ConvertVectorToXeGPUPass
651  : public impl::ConvertVectorToXeGPUBase<ConvertVectorToXeGPUPass> {
652  void runOnOperation() override {
655  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
656  return signalPassFailure();
657  }
658 };
659 
660 } // namespace
661 
664  patterns.add<TransferReadLowering, TransferWriteLowering, LoadLowering,
665  StoreLowering, ContractionLowering>(patterns.getContext());
666 }
static MLIRContext * getContext(OpFoldResult val)
static std::optional< VectorShape > vectorShape(Type type)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
Definition: AffineMap.cpp:151
bool isProjectedPermutation(bool allowZeroInResults=false) const
Returns true if the AffineMap represents a subset (i.e.
Definition: AffineMap.cpp:611
ArrayRef< AffineExpr > getResults() const
Definition: AffineMap.cpp:403
bool isPermutationOfMinorIdentityWithBroadcasting(SmallVectorImpl< unsigned > &permutedDims) const
Return true if this affine map can be converted to a minor identity with broadcast by doing a permute...
Definition: AffineMap.cpp:212
unsigned getNumInputs() const
Definition: AffineMap.cpp:399
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:260
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:52
IndexType getIndexType()
Definition: Builders.cpp:50
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
This class implements the operand iterators for the Operation class.
Definition: ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:783
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:716
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
unsigned getIntOrFloatBitWidth() const
Return the bit width of an integer or a float type, assert failure on other types.
Definition: Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
Base class for DenseArrayAttr that is instantiated and specialized for each supported element type be...
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< T > content)
Builder from ArrayRef<T>.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
std::optional< std::string > getChipStr(Operation *op)
Retrieves the chip string from the XeVM target attribute of the parent GPU module operation.
Definition: XeGPUUtils.cpp:414
Include the generated interface declarations.
std::optional< int64_t > getConstantIntValue(OpFoldResult ofr)
If ofr is a constant integer or an IntegerAttr, return the integer.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
AffineMap inverseAndBroadcastProjectedPermutation(AffineMap map)
Return the reverse map of a projected permutation where the projected dimensions are transformed into...
Definition: AffineMap.cpp:808
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
SmallVector< T > applyPermutation(ArrayRef< T > input, ArrayRef< int64_t > permutation)
const FrozenRewritePatternSet & patterns
void populateVectorToXeGPUConversionPatterns(RewritePatternSet &patterns)
Collect a set of patterns to convert from the vector to XeGPU ops.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
OpFoldResult getAsOpFoldResult(Value val)
Given a value, try to extract a constant Attribute.
bool isRowMajorMatmul(ArrayAttr indexingMaps)
Tests whether the given maps describe a row major matmul.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314