MLIR  19.0.0git
VectorToGPU.cpp
Go to the documentation of this file.
1 //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector operations to GPU dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
15 #include <type_traits>
16 
29 #include "mlir/IR/Builders.h"
30 #include "mlir/IR/BuiltinOps.h"
31 #include "mlir/IR/Region.h"
32 #include "mlir/Pass/Pass.h"
35 #include "mlir/Transforms/Passes.h"
36 #include "llvm/ADT/STLExtras.h"
37 #include "llvm/ADT/TypeSwitch.h"
38 
39 #define DEBUG_TYPE "vector-to-gpu"
40 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
41 #define DBGSNL() (llvm::dbgs() << "\n")
42 
43 namespace mlir {
44 #define GEN_PASS_DEF_CONVERTVECTORTOGPU
45 #include "mlir/Conversion/Passes.h.inc"
46 } // namespace mlir
47 
48 using namespace mlir;
49 
50 /// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an
51 /// AffineMap representing offsets to apply to indices, the function fills
52 /// `indices` with the original indices plus the offsets. The offsets are
53 /// applied by taking into account the permutation map of the transfer op. If
54 /// the `offsetMap` has dimension placeholders, those should be provided in
55 /// `dimValues`.
56 template <typename TransferOpType>
57 static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp,
58  AffineMap offsetMap, ArrayRef<Value> dimValues,
59  SmallVector<Value, 4> &indices) {
60  indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
61  Location loc = xferOp.getLoc();
62  unsigned offsetsIdx = 0;
63  for (auto expr : xferOp.getPermutationMap().getResults()) {
64  if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
65  Value prevIdx = indices[dim.getPosition()];
66  SmallVector<OpFoldResult, 3> dims(dimValues.begin(), dimValues.end());
67  dims.push_back(prevIdx);
68  AffineExpr d0 = rewriter.getAffineDimExpr(offsetMap.getNumDims());
69  indices[dim.getPosition()] = affine::makeComposedAffineApply(
70  rewriter, loc, d0 + offsetMap.getResult(offsetsIdx++), dims);
71  continue;
72  }
73  }
74 }
75 
76 // Return true if the contract op can be convert to MMA matmul.
77 static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
78  bool useNvGpu) {
79  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
80  auto infer = [&](MapList m) {
81  return AffineMap::inferFromExprList(m, contract.getContext());
82  };
83  AffineExpr m, n, k;
84  bindDims(contract.getContext(), m, n, k);
85  auto iteratorTypes = contract.getIteratorTypes().getValue();
86  if (!(vector::isParallelIterator(iteratorTypes[0]) &&
87  vector::isParallelIterator(iteratorTypes[1]) &&
88  vector::isReductionIterator(iteratorTypes[2])))
89  return false;
90 
91  // The contract needs to represent a matmul to be able to convert to
92  // MMAMatrix matmul.
93  if (!useNvGpu &&
94  contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}}))
95  return false;
96  if (useNvGpu &&
97  contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}}))
98  return false;
99 
100  return true;
101 }
102 
103 // Return true if the given map represents a transposed matrix load,
104 // i.e. (d0, d1, ...) -> (dn-1, dn-2).
105 static bool isTransposeMatrixLoadMap(AffineMap permutationMap) {
106  MLIRContext *ctx = permutationMap.getContext();
107  // Local OpBuilder is fine here, we just build attributes.
108  OpBuilder b(ctx);
109  auto nDim = permutationMap.getNumDims();
110  AffineExpr zero = b.getAffineConstantExpr(0);
111  if (nDim < 2) {
112  // Support transposed+broadcasted cases: affine_map<(d0) -> (d0, 0)>.
113  AffineExpr dim0 = b.getAffineDimExpr(0);
114  return permutationMap == AffineMap::get(1, 0, {dim0, zero}, ctx);
115  }
116 
117  AffineExpr innerDim = b.getAffineDimExpr(nDim - 1);
118  AffineExpr outerDim = b.getAffineDimExpr(nDim - 2);
119  // Support both transposed and transposed+broadcasted cases.
120  return permutationMap == AffineMap::get(nDim, 0, {innerDim, outerDim}, ctx) ||
121  permutationMap == AffineMap::get(nDim, 0, {innerDim, zero}, ctx);
122 }
123 
124 // Return the stide for the second-to-last dimension of |type| if it is a memref
125 // and has a constant stride.
126 static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) {
127  auto memrefType = dyn_cast<MemRefType>(type);
128  if (!memrefType)
129  return false;
130  // If the memref is 0 or 1D the horizontal stride is 0.
131  if (memrefType.getRank() < 2)
132  return 0;
133  int64_t offset = 0;
134  SmallVector<int64_t, 2> strides;
135  if (failed(getStridesAndOffset(memrefType, strides, offset)) ||
136  strides.back() != 1)
137  return std::nullopt;
138  int64_t stride = strides[strides.size() - 2];
139  if (stride == ShapedType::kDynamic)
140  return std::nullopt;
141  return stride;
142 }
143 
144 // Return true if the transfer op can be converted to a MMA matrix load.
145 static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
146  if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
147  readOp.getVectorType().getRank() != 2)
148  return false;
149  if (!getStaticallyKnownRowStride(readOp.getShapedType()))
150  return false;
151 
152  // Only allow integer types if the signedness can be inferred.
153  if (readOp.getVectorType().getElementType().isInteger(8))
154  if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) &&
155  !isa<arith::ExtUIOp>(*readOp->user_begin())))
156  return false;
157 
158  AffineMap map = readOp.getPermutationMap();
159  MLIRContext *ctx = readOp.getContext();
160  AffineExpr innerDim = getAffineDimExpr(map.getNumDims() - 1, ctx);
161  AffineExpr zero = getAffineConstantExpr(0, ctx);
162  auto broadcastInnerDim =
163  AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, ctx);
164  return map.isMinorIdentity() || map == broadcastInnerDim ||
166 }
167 
168 // Return true if the transfer op can be converted to a MMA matrix store.
169 static bool
170 transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
171  // TODO: support 0-d corner case.
172  if (writeOp.getTransferRank() == 0)
173  return false;
174 
175  if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
176  writeOp.getVectorType().getRank() != 2)
177  return false;
178  if (!getStaticallyKnownRowStride(writeOp.getShapedType()))
179  return false;
180  // TODO: Support transpose once it is added to GPU dialect ops.
181  if (!writeOp.getPermutationMap().isMinorIdentity())
182  return false;
183  return true;
184 }
185 
186 /// Return true if the constant is a splat to a 2D vector so that it can be
187 /// converted to a MMA constant matrix op.
188 static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
189  auto vecType = dyn_cast<VectorType>(constantOp.getType());
190  if (!vecType || vecType.getRank() != 2)
191  return false;
192  return isa<SplatElementsAttr>(constantOp.getValue());
193 }
194 
195 /// Return true if this is a broadcast from scalar to a 2D vector.
196 static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
197  return broadcastOp.getResultVectorType().getRank() == 2;
198 }
199 
200 /// Return true if this integer extend op can be folded into a contract op.
201 template <typename ExtOpTy>
202 static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
203  if (!isa<vector::TransferReadOp>(extOp.getOperand().getDefiningOp()))
204  return false;
205  return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>);
206 }
207 
208 static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; }
209 
210 /// Return the MMA elementwise enum associated with `op` if it is supported.
211 /// Return `std::nullopt` otherwise.
212 static std::optional<gpu::MMAElementwiseOp>
214  if (isa<arith::AddFOp>(op))
215  return gpu::MMAElementwiseOp::ADDF;
216  if (isa<arith::MulFOp>(op))
217  return gpu::MMAElementwiseOp::MULF;
218  if (isa<arith::SubFOp>(op))
219  return gpu::MMAElementwiseOp::SUBF;
220  if (isa<arith::MaximumFOp>(op))
221  return gpu::MMAElementwiseOp::MAXF;
222  if (isa<arith::MinimumFOp>(op))
223  return gpu::MMAElementwiseOp::MINF;
224  if (isa<arith::DivFOp>(op))
225  return gpu::MMAElementwiseOp::DIVF;
226  if (isa<arith::AddIOp>(op))
228  if (isa<arith::MulIOp>(op))
230  if (isa<arith::SubIOp>(op))
232  if (isa<arith::DivSIOp>(op))
233  return gpu::MMAElementwiseOp::DIVS;
234  if (isa<arith::DivUIOp>(op))
235  return gpu::MMAElementwiseOp::DIVU;
236  if (isa<arith::NegFOp>(op))
237  return gpu::MMAElementwiseOp::NEGATEF;
238  if (isa<arith::ExtFOp>(op))
239  return gpu::MMAElementwiseOp::EXTF;
240  return std::nullopt;
241 }
242 
243 /// Return true if the op is supported as elementwise op on MMAMatrix type.
245  return convertElementwiseOpToMMA(op).has_value();
246 }
247 
248 /// Returns true if the extract strided slice op is supported with `mma.sync`
249 /// path.
250 static bool
251 extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) {
252 
253  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
255  if (failed(warpMatrixInfo))
256  return false;
257 
259  if (failed(contractOp))
260  return false;
261 
262  // Handle vector.extract_strided_slice on registers containing
263  // matrixB and matrixC operands. vector.extract_strided_slice op
264  // is not supported on registers containing matrixA operands.
265  if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B)
266  return (cast<VectorType>(op->getResult(0).getType()) ==
267  cast<VectorType>((*contractOp).getRhs().getType()));
268  if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C)
269  return (cast<VectorType>(op->getResult(0).getType()) ==
270  cast<VectorType>((*contractOp).getAcc().getType()));
271 
272  return false;
273 }
274 
275 static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
276  if (isa<scf::ForOp, scf::YieldOp>(op))
277  return true;
278  if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
279  return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferRead)
280  : transferReadSupportsMMAMatrixType(transferRead);
281  if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
282  return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferWrite)
283  : transferWriteSupportsMMAMatrixType(transferWrite);
284  if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
285  return useNvGpu &&
286  extractStridedSliceSupportsMMAMatrixType(extractStridedSlice);
287  if (auto contract = dyn_cast<vector::ContractionOp>(op))
288  return contractSupportsMMAMatrixType(contract, useNvGpu);
289  if (auto constant = dyn_cast<arith::ConstantOp>(op))
290  return constantSupportsMMAMatrixType(constant);
291  if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
293  if (auto signedExtend = dyn_cast<arith::ExtSIOp>(op))
294  return integerExtendSupportsMMAMatrixType<arith::ExtSIOp>(signedExtend);
295  if (auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op))
296  return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend);
297  if (auto fpExtend = dyn_cast<arith::ExtFOp>(op))
298  return fpExtendSupportsMMAMatrixType(fpExtend);
300 }
301 
302 /// Return an unsorted slice handling scf.for region differently than
303 /// `getSlice`. In scf.for we only want to include as part of the slice elements
304 /// that are part of the use/def chain.
307  const BackwardSliceOptions &backwardSliceOptions,
308  const ForwardSliceOptions &forwardSliceOptions) {
310  slice.insert(op);
311  unsigned currentIndex = 0;
312  SetVector<Operation *> backwardSlice;
313  SetVector<Operation *> forwardSlice;
314  while (currentIndex != slice.size()) {
315  auto *currentOp = (slice)[currentIndex];
316  // Compute and insert the backwardSlice starting from currentOp.
317  backwardSlice.clear();
318  getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
319  slice.insert(backwardSlice.begin(), backwardSlice.end());
320 
321  // Compute and insert the forwardSlice starting from currentOp.
322  forwardSlice.clear();
323  // Special case for ForOp, we don't want to include the whole region but
324  // only the value using the region arguments.
325  // TODO: We should refine this to only care about the region arguments being
326  // converted to matrix type.
327  if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
328  for (Value forOpResult : forOp.getResults())
329  getForwardSlice(forOpResult, &forwardSlice, forwardSliceOptions);
330  for (BlockArgument &arg : forOp.getRegionIterArgs())
331  getForwardSlice(arg, &forwardSlice, forwardSliceOptions);
332  } else {
333  getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions);
334  }
335  slice.insert(forwardSlice.begin(), forwardSlice.end());
336  ++currentIndex;
337  }
338  return slice;
339 }
340 
341 // Analyze slice of operations based on convert op to figure out if the whole
342 // slice can be converted to MMA operations.
344  bool useNvGpu) {
345  auto hasVectorDest = [](Operation *op) {
346  return llvm::any_of(op->getResultTypes(), llvm::IsaPred<VectorType>);
347  };
348  BackwardSliceOptions backwardSliceOptions;
349  backwardSliceOptions.filter = hasVectorDest;
350 
351  auto hasVectorSrc = [](Operation *op) {
352  return llvm::any_of(op->getOperandTypes(), llvm::IsaPred<VectorType>);
353  };
354  ForwardSliceOptions forwardSliceOptions;
355  forwardSliceOptions.filter = hasVectorSrc;
356 
357  SetVector<Operation *> opToConvert;
358  op->walk([&](vector::ContractionOp contract) {
359  if (opToConvert.contains(contract.getOperation()))
360  return;
361  SetVector<Operation *> dependentOps =
362  getSliceContract(contract, backwardSliceOptions, forwardSliceOptions);
363  // If any instruction cannot use MMA matrix type drop the whole
364  // chain. MMA matrix are stored in an opaque type so they cannot be used
365  // by all operations.
366  if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
367  if (!supportsMMaMatrixType(op, useNvGpu)) {
368  LLVM_DEBUG(DBGS() << "cannot convert op: " << *op << "\n");
369  return true;
370  }
371  return false;
372  }))
373  return;
374 
375  opToConvert.insert(dependentOps.begin(), dependentOps.end());
376  });
377  // Sort the operations so that we can convert them in topological order.
378  return topologicalSort(opToConvert);
379 }
380 
381 namespace {
382 // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted
383 // to MMA matmul.
384 struct PrepareContractToGPUMMA
385  : public OpRewritePattern<vector::ContractionOp> {
387 
388  LogicalResult matchAndRewrite(vector::ContractionOp op,
389  PatternRewriter &rewriter) const override {
390  Location loc = op.getLoc();
391  Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc();
392 
393  // Set up the parallel/reduction structure in right form.
394  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
395  auto infer = [&](MapList m) {
397  };
398  AffineExpr m, n, k;
399  bindDims(rewriter.getContext(), m, n, k);
400  static constexpr std::array<int64_t, 2> perm = {1, 0};
401  auto iteratorTypes = op.getIteratorTypes().getValue();
402  SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
403  if (!(vector::isParallelIterator(iteratorTypes[0]) &&
404  vector::isParallelIterator(iteratorTypes[1]) &&
405  vector::isReductionIterator(iteratorTypes[2])))
406  return rewriter.notifyMatchFailure(op, "not a gemm contraction");
407  //
408  // Two outer parallel, one inner reduction (matmat flavor).
409  //
410  // This is the classical row-major matmul, nothing to do.
411  if (maps == infer({{m, k}, {k, n}, {m, n}}))
412  return rewriter.notifyMatchFailure(op, "contraction already prepared");
413  if (maps == infer({{m, k}, {n, k}, {m, n}})) {
414  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
415  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
416  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
417  } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
418  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
419  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
420  } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
421  std::swap(rhs, lhs);
422  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
423  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
424  } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
425  std::swap(rhs, lhs);
426  rhs = rewriter.create<vector::TransposeOp>(loc, rhs, perm);
427  } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
428  std::swap(lhs, rhs);
429  lhs = rewriter.create<vector::TransposeOp>(loc, lhs, perm);
430  } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
431  std::swap(lhs, rhs);
432  } else {
433  // TODO: llvm_unreachable ?
434  return rewriter.notifyMatchFailure(op, "unexpected contraction case");
435  }
436  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
437  op, lhs, rhs, res,
438  rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})),
439  op.getIteratorTypes());
440  return success();
441  }
442 };
443 
444 // Fold transpose op into the transfer read op. Nvgpu mma.sync op only supports
445 // row-, column-, and row-major layout for matrixA, matrixB, and matrixC,
446 // respectively. We can fold the transpose operation when loading the data from
447 // Shared Memory to registers.
448 struct CombineTransferReadOpTranspose final
449  : public OpRewritePattern<vector::TransposeOp> {
451 
452  LogicalResult matchAndRewrite(vector::TransposeOp op,
453  PatternRewriter &rewriter) const override {
454  // Look through integer extend ops.
455  Value source = op.getVector();
456  Type resultType = op.getType();
457  Operation *extOp;
458  if ((extOp = source.getDefiningOp<arith::ExtSIOp>()) ||
459  (extOp = source.getDefiningOp<arith::ExtUIOp>()) ||
460  (extOp = source.getDefiningOp<arith::ExtFOp>())) {
461  source = extOp->getOperand(0);
462  resultType =
463  VectorType::get(cast<VectorType>(resultType).getShape(),
464  cast<VectorType>(source.getType()).getElementType());
465  }
466 
467  auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>();
468  if (!transferReadOp)
469  return rewriter.notifyMatchFailure(op, "no transfer read");
470 
471  // TODO: support 0-d corner case.
472  if (transferReadOp.getTransferRank() == 0)
473  return rewriter.notifyMatchFailure(op, "0-D transfer read");
474 
475  if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
476  return rewriter.notifyMatchFailure(op, "not inbounds transfer read");
477 
478  AffineMap permutationMap =
479  AffineMap::getPermutationMap(op.getPermutation(), op.getContext());
480  AffineMap newMap =
481  permutationMap.compose(transferReadOp.getPermutationMap());
482 
483  auto loc = op.getLoc();
484  Value result =
485  rewriter
486  .create<vector::TransferReadOp>(
487  loc, resultType, transferReadOp.getSource(),
488  transferReadOp.getIndices(), AffineMapAttr::get(newMap),
489  transferReadOp.getPadding(), transferReadOp.getMask(),
490  transferReadOp.getInBoundsAttr())
491  .getResult();
492 
493  // Fuse through the integer extend op.
494  if (extOp) {
495  if (isa<arith::ExtSIOp>(extOp))
496  result = rewriter.create<arith::ExtSIOp>(loc, op.getType(), result)
497  .getResult();
498  else if (isa<arith::ExtUIOp>(extOp))
499  result = rewriter.create<arith::ExtUIOp>(loc, op.getType(), result)
500  .getResult();
501  else
502  result = rewriter.create<arith::ExtFOp>(loc, op.getType(), result)
503  .getResult();
504  }
505 
506  rewriter.replaceOp(op, result);
507  return success();
508  }
509 };
510 
511 } // namespace
512 
513 // MMA types have different layout based on how they are used in matmul ops.
514 // Figure the right layout to use by looking at op uses.
515 // TODO: Change the GPU dialect to abstract the layout at the this level and
516 // only care about it during lowering to NVVM.
517 static const char *inferFragType(Operation *op) {
518  for (Operation *users : op->getUsers()) {
519  auto contract = dyn_cast<vector::ContractionOp>(users);
520  if (!contract)
521  continue;
522  assert(op->getNumResults() == 1);
523  if (contract.getLhs() == op->getResult(0))
524  return "AOp";
525  if (contract.getRhs() == op->getResult(0))
526  return "BOp";
527  }
528  return "COp";
529 }
530 
531 static LogicalResult
532 convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
533  llvm::DenseMap<Value, Value> &valueMapping) {
534  OpBuilder::InsertionGuard g(rewriter);
535  rewriter.setInsertionPoint(op);
536 
537  assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
539  "expected convertible operation");
540 
541  std::optional<int64_t> stride =
542  getStaticallyKnownRowStride(op.getShapedType());
543  if (!stride.has_value()) {
544  LLVM_DEBUG(DBGS() << "no stride\n");
545  return rewriter.notifyMatchFailure(op, "no stride");
546  }
547 
548  AffineMap map = op.getPermutationMap();
549  bool isTranspose = isTransposeMatrixLoadMap(map);
550 
551  // Handle broadcast by setting the stride to 0.
552  if (auto cstExpr = dyn_cast<AffineConstantExpr>(map.getResult(isTranspose))) {
553  assert(cstExpr.getValue() == 0);
554  stride = 0;
555  }
556 
557  Value mappingResult = op.getResult();
558  auto elType = op.getVectorType().getElementType();
559  const char *fragType = inferFragType(op);
560  if (op->hasOneUse()) {
561  auto *user = *op->user_begin();
562  // Infer the signedness of the mma type from the integer extend.
563  bool isSignedExtend = isa<arith::ExtSIOp>(user);
564  if (isSignedExtend || isa<arith::ExtUIOp>(user)) {
565  elType = IntegerType::get(
566  op.getContext(), cast<IntegerType>(elType).getWidth(),
567  isSignedExtend ? IntegerType::Signed : IntegerType::Unsigned);
568  mappingResult = user->getResult(0);
569  fragType = inferFragType(user);
570  }
571  }
572  gpu::MMAMatrixType type =
573  gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType);
574  Value load = rewriter.create<gpu::SubgroupMmaLoadMatrixOp>(
575  op.getLoc(), type, op.getSource(), op.getIndices(),
576  rewriter.getIndexAttr(*stride),
577  isTranspose ? rewriter.getUnitAttr() : UnitAttr());
578  valueMapping[mappingResult] = load;
579 
580  LLVM_DEBUG(DBGS() << "transfer read to: " << load << "\n");
581  return success();
582 }
583 
584 static LogicalResult
585 convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
586  llvm::DenseMap<Value, Value> &valueMapping) {
587  OpBuilder::InsertionGuard g(rewriter);
588  rewriter.setInsertionPoint(op);
589 
591  std::optional<int64_t> stride =
592  getStaticallyKnownRowStride(op.getShapedType());
593  if (!stride.has_value()) {
594  LLVM_DEBUG(DBGS() << "no stride\n");
595  return rewriter.notifyMatchFailure(op, "no stride");
596  }
597 
598  auto it = valueMapping.find(op.getVector());
599  if (it == valueMapping.end()) {
600  LLVM_DEBUG(DBGS() << "no mapping\n");
601  return rewriter.notifyMatchFailure(op, "no mapping");
602  }
603 
604  Value matrix = it->second;
605  auto store = rewriter.create<gpu::SubgroupMmaStoreMatrixOp>(
606  op.getLoc(), matrix, op.getSource(), op.getIndices(),
607  rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr());
608  (void)store;
609 
610  LLVM_DEBUG(DBGS() << "transfer write to: " << store << "\n");
611 
612  LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
613  rewriter.eraseOp(op);
614  return success();
615 }
616 
617 /// Returns the vector type which represents a matrix fragment.
618 static VectorType
621  regInfo.elementsPerRegister};
622  Type elType = regInfo.registerLLVMType;
623  if (auto vecType = dyn_cast<VectorType>(elType))
624  elType = vecType.getElementType();
625  return VectorType::get(shape, elType);
626 }
627 
628 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
629 static LogicalResult
630 convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
631  llvm::DenseMap<Value, Value> &valueMapping) {
632  OpBuilder::InsertionGuard g(rewriter);
633  rewriter.setInsertionPoint(op);
634 
635  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
637  if (failed(warpMatrixInfo)) {
638  LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
639  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
640  }
641 
643  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
644  if (failed(regInfo)) {
645  LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
646  return rewriter.notifyMatchFailure(op, "not mma sync reg info");
647  }
648 
649  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
650  auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
651  if (!dense) {
652  LLVM_DEBUG(DBGS() << "not a splat\n");
653  return rewriter.notifyMatchFailure(op, "not a splat");
654  }
655 
656  Value result = rewriter.create<arith::ConstantOp>(
657  op.getLoc(), vectorType,
658  DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>()));
659  valueMapping[op.getResult()] = result;
660  return success();
661 }
662 
663 /// Check if the loaded matrix operand requires transposed.
664 /// Transposed Map Example:
665 /// Example 1 : (..., d0, d1) -> (d1 * 1, d0 * 2)
666 /// Example 2 : (d0, d1, d2, d3) -> (d3, d2)
667 /// The code below checks if the output 2D is transposed using a generalized
668 /// version : (d0, d1, dn, ..., dm, ...) -> (dm, dn)
669 /// Returns : true; if m > n, false o.w.
670 static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
671  mlir::AffineMap map = op.getPermutationMap();
672 
673  if (map.getNumResults() != 2) {
674  LLVM_DEBUG(DBGS() << "Failed because the result of `vector.transfer_read` "
675  "is not a 2d operand\n");
676  return failure();
677  }
678 
679  // Output 2D matrix dimensions in the order of d0, d1.
680  mlir::AffineExpr dM = map.getResult(0);
681  mlir::AffineExpr dN = map.getResult(1);
682 
683  // Find the position of these expressions in the input.
684  auto exprM = dyn_cast<AffineDimExpr>(dM);
685  auto exprN = dyn_cast<AffineDimExpr>(dN);
686 
687  if (!exprM || !exprN) {
688  LLVM_DEBUG(DBGS() << "Failed because expressions are not affine dim "
689  "expressions, then transpose cannot be determined.\n");
690  return failure();
691  }
692 
693  return exprM.getPosition() > exprN.getPosition();
694 }
695 
696 static LogicalResult
697 creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
698  llvm::DenseMap<Value, Value> &valueMapping) {
699  OpBuilder::InsertionGuard g(rewriter);
700  rewriter.setInsertionPoint(op);
701  Location loc = op->getLoc();
702 
703  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
705  if (failed(warpMatrixInfo)) {
706  LLVM_DEBUG(DBGS() << "no warpMatrixInfo\n");
707  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
708  }
709 
711  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
712  if (failed(regInfo)) {
713  LLVM_DEBUG(DBGS() << "not mma sync reg info\n");
714  return rewriter.notifyMatchFailure(op, "not mma sync reg info");
715  }
716 
718  if (failed(transpose)) {
719  LLVM_DEBUG(DBGS() << "failed to determine the transpose\n");
720  return rewriter.notifyMatchFailure(
721  op, "Op should likely not be converted to a nvgpu.ldmatrix call.");
722  }
723 
725  nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose);
726 
727  if (failed(params)) {
728  LLVM_DEBUG(
729  DBGS()
730  << "failed to convert vector.transfer_read to ldmatrix. "
731  << "Op should likely not be converted to a nvgpu.ldmatrix call.\n");
732  return rewriter.notifyMatchFailure(
733  op, "failed to convert vector.transfer_read to ldmatrix; this op "
734  "likely should not be converted to a nvgpu.ldmatrix call.");
735  }
736 
737  // Adjust the load offset.
738  auto laneId = rewriter.create<gpu::LaneIdOp>(loc);
739  FailureOr<AffineMap> offsets =
740  nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params);
741  if (failed(offsets)) {
742  LLVM_DEBUG(DBGS() << "no offsets\n");
743  return rewriter.notifyMatchFailure(op, "no offsets");
744  }
745 
746  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
747 
748  SmallVector<Value, 4> indices;
749  getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId},
750  indices);
751 
752  nvgpu::LdMatrixOp newOp = rewriter.create<nvgpu::LdMatrixOp>(
753  loc, vectorType, op.getSource(), indices, *transpose, params->numTiles);
754  valueMapping[op] = newOp->getResult(0);
755  return success();
756 }
757 
758 static LogicalResult
759 createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
760  llvm::DenseMap<Value, Value> &valueMapping) {
761  OpBuilder::InsertionGuard g(rewriter);
762  rewriter.setInsertionPoint(op);
763 
764  Location loc = op.getLoc();
765  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
767  if (failed(warpMatrixInfo))
768  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
770  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
771  if (failed(regInfo)) {
772  return rewriter.notifyMatchFailure(
773  op, "Failed to deduce register fragment type during "
774  "conversion to distributed non-ldmatrix compatible load");
775  }
776 
777  Value laneId = rewriter.create<gpu::LaneIdOp>(loc);
778  SmallVector<Value, 4> elements;
779 
780  // This is the individual element type.
781  Type loadedElType = regInfo->registerLLVMType;
782  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
783 
784  Value fill = rewriter.create<arith::ConstantOp>(
785  op.getLoc(), vectorType.getElementType(),
786  rewriter.getZeroAttr(vectorType.getElementType()));
787  Value result =
788  rewriter.create<vector::SplatOp>(op.getLoc(), fill, vectorType);
789 
790  bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
791 
792  // If we are not transposing, then we can use vectorized loads. Otherwise, we
793  // must load each element individually.
794  if (!isTransposeLoad) {
795  if (!isa<VectorType>(loadedElType)) {
796  loadedElType = VectorType::get({1}, loadedElType);
797  }
798 
799  for (int i = 0; i < vectorType.getShape()[0]; i++) {
801  rewriter, op.getLoc(), *warpMatrixInfo);
802  if (failed(coords))
803  return rewriter.notifyMatchFailure(op, "no coords");
804 
805  Value logicalValueId = rewriter.create<arith::ConstantOp>(
806  loc, rewriter.getIndexType(),
807  rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
808  SmallVector<Value, 4> newIndices;
809  getXferIndices<vector::TransferReadOp>(
810  rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
811 
812  Value el = rewriter.create<vector::LoadOp>(loc, loadedElType,
813  op.getSource(), newIndices);
814  result = rewriter.create<vector::InsertOp>(loc, el, result, i);
815  }
816  } else {
817  if (auto vecType = dyn_cast<VectorType>(loadedElType)) {
818  loadedElType = vecType.getElementType();
819  }
820  for (int i = 0; i < vectorType.getShape()[0]; i++) {
821  for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
822  innerIdx++) {
823 
824  Value logicalValueId = rewriter.create<arith::ConstantOp>(
825  loc, rewriter.getIndexType(),
826  rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
828  rewriter, op.getLoc(), *warpMatrixInfo);
829  if (failed(coords))
830  return rewriter.notifyMatchFailure(op, "no coords");
831 
832  SmallVector<Value, 4> newIndices;
833  getXferIndices<vector::TransferReadOp>(
834  rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
835  Value el = rewriter.create<memref::LoadOp>(op.getLoc(), loadedElType,
836  op.getSource(), newIndices);
837  result = rewriter.create<vector::InsertOp>(
838  op.getLoc(), el, result, ArrayRef<int64_t>{i, innerIdx});
839  }
840  }
841  }
842 
843  valueMapping[op.getResult()] = result;
844  return success();
845 }
846 
847 /// Return true if this is a shared memory memref type.
848 static bool isSharedMemory(MemRefType type) {
849  auto addressSpace =
850  dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
851  return addressSpace &&
852  addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
853 }
854 
855 /// Converts a `vector.transfer_read` operation directly to either a
856 /// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be
857 /// used when converting to `nvgpu.mma.sync` operations.
858 static LogicalResult
859 convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op,
860  llvm::DenseMap<Value, Value> &valueMapping) {
861  OpBuilder::InsertionGuard g(rewriter);
862  rewriter.setInsertionPoint(op);
863 
864  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
866  if (failed(warpMatrixInfo))
867  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
868 
869  bool isLdMatrixCompatible =
870  isSharedMemory(cast<MemRefType>(op.getSource().getType())) &&
871  nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;
872 
873  VectorType vecTy = op.getVectorType();
874  int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
875 
876  // When we are transposing the B operand, ldmatrix will only work if we have
877  // at least 8 rows to read and the width to read for the transpose is 128
878  // bits.
879  if (!op.getPermutationMap().isMinorIdentity() &&
880  (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
881  vecTy.getDimSize(0) * bitWidth < 128))
882  isLdMatrixCompatible = false;
883 
884  if (!isLdMatrixCompatible)
885  return createNonLdMatrixLoads(rewriter, op, valueMapping);
886 
887  return creatLdMatrixCompatibleLoads(rewriter, op, valueMapping);
888 }
889 
890 static LogicalResult
891 convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
892  llvm::DenseMap<Value, Value> &valueMapping) {
893  OpBuilder::InsertionGuard g(rewriter);
894  rewriter.setInsertionPoint(op);
895 
896  Location loc = op->getLoc();
897  auto it = valueMapping.find(op.getVector());
898  if (it == valueMapping.end())
899  return rewriter.notifyMatchFailure(op, "no mapping");
900  Value matrix = it->second;
901 
902  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
904  if (failed(warpMatrixInfo))
905  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
907  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
908  if (failed(regInfo))
909  return rewriter.notifyMatchFailure(op, "not mma sync reg info");
910 
911  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
912  Value laneId = rewriter.create<gpu::LaneIdOp>(loc);
913 
914  for (unsigned i = 0; i < vectorType.getShape()[0]; i++) {
915  Value logicalValueId = rewriter.create<arith::ConstantOp>(
916  loc, rewriter.getIndexType(),
917  rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
919  rewriter, op.getLoc(), *warpMatrixInfo);
920  if (failed(coords))
921  return rewriter.notifyMatchFailure(op, "no coords");
922 
923  Value el =
924  rewriter.create<vector::ExtractOp>(loc, matrix, ArrayRef<int64_t>{i});
925  SmallVector<Value, 4> newIndices;
926  getXferIndices<vector::TransferWriteOp>(
927  rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
928  rewriter.create<vector::StoreOp>(loc, el, op.getSource(), newIndices);
929  }
930 
931  LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
932  rewriter.eraseOp(op);
933  return success();
934 }
935 
936 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
937  SmallVectorImpl<int64_t> &results) {
938  for (auto attr : arrayAttr)
939  results.push_back(cast<IntegerAttr>(attr).getInt());
940 }
941 
942 static LogicalResult
944  vector::ExtractStridedSliceOp op,
945  llvm::DenseMap<Value, Value> &valueMapping) {
946  OpBuilder::InsertionGuard g(rewriter);
947  rewriter.setInsertionPoint(op);
948 
949  Location loc = op->getLoc();
950 
951  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
953  if (failed(warpMatrixInfo))
954  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
955 
956  FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo =
957  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
958  if (failed(mmaSyncFragmentInfo))
959  return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo");
960 
961  // Find the vector.transer_read whose result vector is being sliced.
962  auto transferReadOp = op.getVector().getDefiningOp<vector::TransferReadOp>();
963  if (!transferReadOp)
964  return rewriter.notifyMatchFailure(op, "no transfer read");
965 
966  warpMatrixInfo = nvgpu::getWarpMatrixInfo(transferReadOp);
967  if (failed(warpMatrixInfo))
968  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
969 
971  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
972  if (failed(ldFragmentInfo))
973  return rewriter.notifyMatchFailure(op, "no ldFragmentInfo");
974 
975  assert(
976  (mmaSyncFragmentInfo->elementsPerRegister ==
977  ldFragmentInfo->elementsPerRegister) &&
978  "Number of elements per register should be same for load and mma.sync");
979 
980  // Create vector.extract_strided_slice op for thread-owned fragments.
981  std::array<int64_t, 2> strides = {1,
982  1}; // stride for extract slice is always 1.
983  std::array<int64_t, 2> sliceShape = {
984  mmaSyncFragmentInfo->numRegistersPerFragment,
985  mmaSyncFragmentInfo->elementsPerRegister};
986  auto it = valueMapping.find(transferReadOp);
987  if (it == valueMapping.end())
988  return rewriter.notifyMatchFailure(op, "no mapping");
989  auto sourceVector = it->second;
990 
991  // offset and sizes at warp-level of onwership.
992  SmallVector<int64_t> offsets;
993  populateFromInt64AttrArray(op.getOffsets(), offsets);
994 
995  SmallVector<int64_t> sizes;
996  populateFromInt64AttrArray(op.getSizes(), sizes);
997  ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
998 
999  // Compute offset in vector registers. Note that the mma.sync vector registers
1000  // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
1001  // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0].
1002  std::array<int64_t, 2> sliceOffset = {0, 0};
1003 
1004  if (offsets[0] && offsets[1])
1005  return op->emitError() << "Slicing fragments in 2D is not supported. ";
1006  if (offsets[0])
1007  sliceOffset[0] = (warpVectorShape[0] / offsets[0]);
1008  else if (offsets[1])
1009  sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
1010 
1011  Value newOp = rewriter.create<vector::ExtractStridedSliceOp>(
1012  loc, sourceVector, sliceOffset, sliceShape, strides);
1013 
1014  valueMapping[op] = newOp;
1015  return success();
1016 }
1017 
1018 static LogicalResult
1019 convertContractOp(RewriterBase &rewriter, vector::ContractionOp op,
1020  llvm::DenseMap<Value, Value> &valueMapping) {
1021  OpBuilder::InsertionGuard g(rewriter);
1022  rewriter.setInsertionPoint(op);
1023 
1024  auto itA = valueMapping.find(op.getLhs());
1025  auto itB = valueMapping.find(op.getRhs());
1026  auto itC = valueMapping.find(op.getAcc());
1027  if (itA == valueMapping.end() || itB == valueMapping.end() ||
1028  itC == valueMapping.end())
1029  return rewriter.notifyMatchFailure(op, "no mapping");
1030  Value opA = itA->second, opB = itB->second, opC = itC->second;
1031  Value matmul = rewriter.create<gpu::SubgroupMmaComputeOp>(
1032  op.getLoc(), opC.getType(), opA, opB, opC, /*a_transpose=*/UnitAttr(),
1033  /*b_transpose=*/UnitAttr());
1034  valueMapping[op.getResult()] = matmul;
1035  return success();
1036 }
1037 
1038 static LogicalResult
1039 convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op,
1040  llvm::DenseMap<Value, Value> &valueMapping) {
1041  OpBuilder::InsertionGuard g(rewriter);
1042  rewriter.setInsertionPoint(op);
1043 
1044  auto itA = valueMapping.find(op.getLhs());
1045  auto itB = valueMapping.find(op.getRhs());
1046  auto itC = valueMapping.find(op.getAcc());
1047  if (itA == valueMapping.end() || itB == valueMapping.end() ||
1048  itC == valueMapping.end())
1049  return rewriter.notifyMatchFailure(op, "no mapping");
1050  Value opA = itA->second, opB = itB->second, opC = itC->second;
1051  int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
1052  int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
1053  int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
1054  Value matmul = rewriter.create<nvgpu::MmaSyncOp>(
1055  op.getLoc(), opA, opB, opC, rewriter.getI64ArrayAttr({m, n, k}));
1056  valueMapping[op.getResult()] = matmul;
1057  return success();
1058 }
1059 
1060 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
1061 static LogicalResult
1062 convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op,
1063  llvm::DenseMap<Value, Value> &valueMapping) {
1064  OpBuilder::InsertionGuard g(rewriter);
1065  rewriter.setInsertionPoint(op);
1066 
1067  assert(constantSupportsMMAMatrixType(op));
1068 
1069  auto splat =
1070  cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>();
1071  auto scalarConstant =
1072  rewriter.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
1073  const char *fragType = inferFragType(op);
1074  auto vecType = cast<VectorType>(op.getType());
1076  vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1077  auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
1078  op.getLoc(), type, scalarConstant);
1079  valueMapping[op.getResult()] = matrix;
1080  return success();
1081 }
1082 
1083 /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
1084 static LogicalResult
1085 convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
1086  llvm::DenseMap<Value, Value> &valueMapping) {
1087  OpBuilder::InsertionGuard g(rewriter);
1088  rewriter.setInsertionPoint(op);
1089 
1090  assert(broadcastSupportsMMAMatrixType(op));
1091 
1092  const char *fragType = inferFragType(op);
1093  auto vecType = op.getResultVectorType();
1095  vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1096  auto matrix = rewriter.create<gpu::SubgroupMmaConstantMatrixOp>(
1097  op.getLoc(), type, op.getSource());
1098  valueMapping[op.getResult()] = matrix;
1099  return success();
1100 }
1101 
1102 // Replace ForOp with a new ForOp with extra operands. The YieldOp is not
1103 // updated and needs to be updated separately for the loop to be correct.
1104 static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
1105  scf::ForOp loop,
1106  ValueRange newInitArgs) {
1107  OpBuilder::InsertionGuard g(rewriter);
1108  rewriter.setInsertionPoint(loop);
1109 
1110  // Create a new loop before the existing one, with the extra operands.
1111  rewriter.setInsertionPoint(loop);
1112  auto operands = llvm::to_vector<4>(loop.getInitArgs());
1113  llvm::append_range(operands, newInitArgs);
1114  scf::ForOp newLoop = rewriter.create<scf::ForOp>(
1115  loop.getLoc(), loop.getLowerBound(), loop.getUpperBound(), loop.getStep(),
1116  operands);
1117  rewriter.eraseBlock(newLoop.getBody());
1118 
1119  newLoop.getRegion().getBlocks().splice(
1120  newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
1121  for (Value operand : newInitArgs)
1122  newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
1123 
1124  for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
1125  loop.getNumResults())))
1126  rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
1127 
1128  LLVM_DEBUG(DBGS() << "newLoop now: " << newLoop << "\n");
1129  LLVM_DEBUG(DBGS() << "stripped scf.for: " << loop << "\n");
1130  LLVM_DEBUG(DBGS() << "erase: " << loop);
1131 
1132  rewriter.eraseOp(loop);
1133  return newLoop;
1134 }
1135 
1136 static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
1137  llvm::DenseMap<Value, Value> &valueMapping) {
1138  OpBuilder::InsertionGuard g(rewriter);
1139  rewriter.setInsertionPoint(op);
1140 
1141  SmallVector<Value> newOperands;
1143  for (const auto &operand : llvm::enumerate(op.getInitArgs())) {
1144  auto it = valueMapping.find(operand.value());
1145  if (it == valueMapping.end()) {
1146  LLVM_DEBUG(DBGS() << "no value mapping for: " << operand.value() << "\n");
1147  continue;
1148  }
1149  argMapping.push_back(std::make_pair(
1150  operand.index(), op.getInitArgs().size() + newOperands.size()));
1151  newOperands.push_back(it->second);
1152  }
1153 
1154  scf::ForOp newForOp = replaceForOpWithNewSignature(rewriter, op, newOperands);
1155  Block &loopBody = *newForOp.getBody();
1156  for (auto mapping : argMapping) {
1157  valueMapping[newForOp.getResult(mapping.first)] =
1158  newForOp.getResult(mapping.second);
1159  valueMapping[loopBody.getArgument(mapping.first +
1160  newForOp.getNumInductionVars())] =
1161  loopBody.getArgument(mapping.second + newForOp.getNumInductionVars());
1162  }
1163 
1164  LLVM_DEBUG(DBGS() << "scf.for to: " << newForOp << "\n");
1165  return success();
1166 }
1167 
1168 static LogicalResult
1169 convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
1170  llvm::DenseMap<Value, Value> &valueMapping) {
1171  OpBuilder::InsertionGuard g(rewriter);
1172  rewriter.setInsertionPoint(op);
1173 
1174  auto loop = cast<scf::ForOp>(op->getParentOp());
1175  auto yieldOperands = llvm::to_vector<4>(op.getOperands());
1176  for (const auto &operand : llvm::enumerate(op.getOperands())) {
1177  auto it = valueMapping.find(operand.value());
1178  if (it == valueMapping.end())
1179  continue;
1180  // Replace the yield of old value with the for op argument to make it easier
1181  // to remove the dead code.
1182  yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()];
1183  yieldOperands.push_back(it->second);
1184  }
1185  rewriter.create<scf::YieldOp>(op.getLoc(), yieldOperands);
1186 
1187  LLVM_DEBUG(DBGS() << "erase: " << op << "\n");
1188  rewriter.eraseOp(op);
1189  return success();
1190 }
1191 
1192 /// Convert an elementwise op to the equivalent elementwise op on MMA matrix.
1193 static LogicalResult
1195  gpu::MMAElementwiseOp opType,
1196  llvm::DenseMap<Value, Value> &valueMapping) {
1197  OpBuilder::InsertionGuard g(rewriter);
1198  rewriter.setInsertionPoint(op);
1199 
1200  SmallVector<Value> matrixOperands;
1201  for (Value operand : op->getOperands()) {
1202  auto it = valueMapping.find(operand);
1203  if (it == valueMapping.end())
1204  return rewriter.notifyMatchFailure(op, "no mapping");
1205  matrixOperands.push_back(it->second);
1206  }
1207  auto resultType = cast<gpu::MMAMatrixType>(matrixOperands[0].getType());
1208  if (opType == gpu::MMAElementwiseOp::EXTF) {
1209  // The floating point extension case has a different result type.
1210  auto vectorType = cast<VectorType>(op->getResultTypes()[0]);
1211  resultType = gpu::MMAMatrixType::get(resultType.getShape(),
1212  vectorType.getElementType(),
1213  resultType.getOperand());
1214  }
1215 
1216  Value newOp = rewriter.create<gpu::SubgroupMmaElementwiseOp>(
1217  op->getLoc(), resultType, matrixOperands, opType);
1218  valueMapping[op->getResult(0)] = newOp;
1219  return success();
1220 }
1221 
1223  bool useNvGpu) {
1224  if (!useNvGpu) {
1225  patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
1226  patterns.getContext());
1227  return;
1228  }
1230  patterns.add<CombineTransferReadOpTranspose>(patterns.getContext());
1231 }
1232 
1234  Operation *rootOp) {
1235  SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false);
1236  llvm::DenseMap<Value, Value> valueMapping;
1237 
1238  auto globalRes = LogicalResult::success();
1239  for (Operation *op : ops) {
1240  LLVM_DEBUG(DBGS() << "Process op: " << *op << "\n");
1241  // Apparently callers do not want to early exit on failure here.
1242  auto res = LogicalResult::success();
1243  if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
1244  res = convertTransferReadOp(rewriter, transferRead, valueMapping);
1245  } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
1246  res = convertTransferWriteOp(rewriter, transferWrite, valueMapping);
1247  } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
1248  res = convertContractOp(rewriter, contractOp, valueMapping);
1249  } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
1250  res = convertConstantOp(rewriter, constantOp, valueMapping);
1251  } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
1252  res = convertBroadcastOp(rewriter, broadcastOp, valueMapping);
1253  } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
1254  res = convertForOp(rewriter, forOp, valueMapping);
1255  } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
1256  res = convertYieldOp(rewriter, yieldOp, valueMapping);
1257  } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) {
1258  res = convertElementwiseOp(rewriter, op, *elementwiseType, valueMapping);
1259  }
1260  if (failed(res))
1261  globalRes = failure();
1262  }
1263  return globalRes;
1264 }
1265 
1267  Operation *rootOp) {
1268  SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true);
1269  llvm::DenseMap<Value, Value> valueMapping;
1270  for (Operation *op : ops) {
1272  .Case([&](vector::TransferReadOp transferReadOp) {
1273  return convertTransferReadToLoads(rewriter, transferReadOp,
1274  valueMapping);
1275  })
1276  .Case([&](vector::TransferWriteOp transferWriteOp) {
1277  return convertTransferWriteToStores(rewriter, transferWriteOp,
1278  valueMapping);
1279  })
1280  .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
1281  return convertExtractStridedSlice(rewriter, extractStridedSliceOp,
1282  valueMapping);
1283  })
1284  .Case([&](vector::ContractionOp contractionOp) {
1285  return convertContractOpToMmaSync(rewriter, contractionOp,
1286  valueMapping);
1287  })
1288  .Case([&](scf::ForOp forOp) {
1289  return convertForOp(rewriter, forOp, valueMapping);
1290  })
1291  .Case([&](scf::YieldOp yieldOp) {
1292  return convertYieldOp(rewriter, yieldOp, valueMapping);
1293  })
1294  .Case([&](arith::ConstantOp constOp) {
1295  return convertConstantOpMmaSync(rewriter, constOp, valueMapping);
1296  })
1297  .Default([&](Operation *op) {
1298  return op->emitError() << "unhandled vector to mma type: " << *op;
1299  })
1300  .failed()) {
1301  return op->emitOpError()
1302  << "failed to convert op during vector-to-nvgpu conversion";
1303  }
1304  }
1305  return success();
1306 }
1307 
1308 namespace {
1309 
1310 struct ConvertVectorToGPUPass
1311  : public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
1312 
1313  explicit ConvertVectorToGPUPass(bool useNvGpu_) {
1314  useNvGpu.setValue(useNvGpu_);
1315  }
1316 
1317  void runOnOperation() override {
1318  RewritePatternSet patterns(&getContext());
1319  populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue());
1320  if (failed(
1321  applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
1322  return signalPassFailure();
1323 
1324  IRRewriter rewriter(&getContext());
1325  if (useNvGpu) {
1326  if (failed(
1327  convertVectorToNVVMCompatibleMMASync(rewriter, getOperation())))
1328  return signalPassFailure();
1329  return;
1330  }
1331  (void)convertVectorToMMAOps(rewriter, getOperation());
1332  }
1333 };
1334 
1335 } // namespace
1336 
1337 std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) {
1338  return std::make_unique<ConvertVectorToGPUPass>(useNvGpu);
1339 }
static MLIRContext * getContext(OpFoldResult val)
#define SUBI(lhs, rhs)
Definition: LoopEmitter.cpp:37
#define MULI(lhs, rhs)
Definition: LoopEmitter.cpp:38
#define ADDI(lhs, rhs)
Definition: LoopEmitter.cpp:35
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, LLVMTypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:118
static LogicalResult convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertContractOp(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool isSharedMemory(MemRefType type)
Return true if this is a shared memory memref type.
static VectorType getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo &regInfo)
Returns the vector type which represents a matrix fragment.
static const char * inferFragType(Operation *op)
static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp)
static bool supportsMMaMatrixType(Operation *op, bool useNvGpu)
static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp)
Return true if the constant is a splat to a 2D vector so that it can be converted to a MMA constant m...
static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, bool useNvGpu)
Definition: VectorToGPU.cpp:77
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
static bool isTransposeMatrixLoadMap(AffineMap permutationMap)
static SetVector< Operation * > getSliceContract(Operation *op, const BackwardSliceOptions &backwardSliceOptions, const ForwardSliceOptions &forwardSliceOptions)
Return an unsorted slice handling scf.for region differently than getSlice.
static LogicalResult convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp)
Return true if this integer extend op can be folded into a contract op.
static LogicalResult convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Converts a vector.transfer_read operation directly to either a vector.load or a nvgpu....
static LogicalResult convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
static LogicalResult creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static FailureOr< bool > isTransposed(vector::TransferReadOp op)
Check if the loaded matrix operand requires transposed.
static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, scf::ForOp loop, ValueRange newInitArgs)
static bool transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp)
static std::optional< gpu::MMAElementwiseOp > convertElementwiseOpToMMA(Operation *op)
Return the MMA elementwise enum associated with op if it is supported.
static LogicalResult convertElementwiseOp(RewriterBase &rewriter, Operation *op, gpu::MMAElementwiseOp opType, llvm::DenseMap< Value, Value > &valueMapping)
Convert an elementwise op to the equivalent elementwise op on MMA matrix.
static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp)
static bool extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op)
Returns true if the extract strided slice op is supported with mma.sync path.
static LogicalResult convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static bool elementwiseSupportsMMAMatrixType(Operation *op)
Return true if the op is supported as elementwise op on MMAMatrix type.
static SetVector< Operation * > getOpToConvert(mlir::Operation *op, bool useNvGpu)
static LogicalResult convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static LogicalResult convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
#define DBGS()
Definition: VectorToGPU.cpp:40
static std::optional< int64_t > getStaticallyKnownRowStride(ShapedType type)
static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, AffineMap offsetMap, ArrayRef< Value > dimValues, SmallVector< Value, 4 > &indices)
For a vector TransferOpType xferOp, an empty indices vector, and an AffineMap representing offsets to...
Definition: VectorToGPU.cpp:57
static LogicalResult convertExtractStridedSlice(RewriterBase &rewriter, vector::ExtractStridedSliceOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp)
Return true if this is a broadcast from scalar to a 2D vector.
static LogicalResult createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Base type for affine expression.
Definition: AffineExpr.h:69
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:47
MLIRContext * getContext() const
Definition: AffineMap.cpp:327
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
Definition: AffineMap.cpp:152
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
Definition: AffineMap.cpp:378
unsigned getNumResults() const
Definition: AffineMap.cpp:386
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:395
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:248
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:540
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:296
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:319
Block represents an ordered list of Operations.
Definition: Block.h:30
BlockArgument getArgument(unsigned i)
Definition: Block.h:126
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:124
UnitAttr getUnitAttr()
Definition: Builders.cpp:114
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:379
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:331
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:371
MLIRContext * getContext() const
Definition: Builders.h:55
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:288
IndexType getIndexType()
Definition: Builders.cpp:71
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:325
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:766
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:350
This class helps build Operations.
Definition: Builders.h:209
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:400
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:464
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:345
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:845
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:402
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:793
MLIRContext * getContext()
Return the context this operation is associated with.
Definition: Operation.h:216
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition: Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:268
operand_type_range getOperandTypes()
Definition: Operation.h:392
result_type_range getResultTypes()
Definition: Operation.h:423
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:373
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:869
user_iterator user_begin()
Definition: Operation.h:865
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:671
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:399
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:785
MLIRContext * getContext() const
Definition: PatternMatch.h:822
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
Definition: PatternMatch.h:846
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:400
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:718
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:638
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:536
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:381
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:129
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:130
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Definition: GPUDialect.cpp:122
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1138
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:285
int64_t inferTileWidthInBits(const WarpMatrixInfo &type)
Returns the number of bits in a single tile row.
Definition: MMAUtils.cpp:87
FailureOr< vector::ContractionOp > getUserContract(Operation *op)
Returns the first user of the op that is vector.contract.
Definition: MMAUtils.cpp:50
FailureOr< AffineMap > getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, const WarpMatrixInfo &fragmentType)
Returns an AffineMap which maps a two dimensions representing (laneId, logicalValueId) and returns tw...
Definition: MMAUtils.cpp:173
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
If op is a vector.transfer_write, return the WarpMatrixInfo for the vector operand.
Definition: MMAUtils.cpp:58
FailureOr< AffineMap > getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc, const LdMatrixParams &params)
Returns an AffineMap which maps a single dimension representing the laneId to two results representin...
Definition: MMAUtils.cpp:238
FailureOr< LdMatrixParams > getLdMatrixParams(const WarpMatrixInfo &type, bool transpose)
Given type that contains info for a warp-matrix operand and whether or not the load is a transposed l...
Definition: MMAUtils.cpp:209
FailureOr< FragmentElementInfo > getMmaSyncRegisterType(const WarpMatrixInfo &type)
Returns a FragmentElementInfo struct describing the register types for the given matrix fragment type...
Definition: MMAUtils.cpp:100
bool canLowerToWarpMatrixOperation(vector::TransferReadOp op)
Returns whether the vector.transfer_read instruction can be interpreted as a warp-level cooperative m...
Definition: MMAUtils.cpp:276
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition: VectorOps.h:140
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition: VectorOps.h:135
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
static void transpose(llvm::ArrayRef< int64_t > trans, SmallVector< int64_t > &shape)
Definition: XeGPUOps.cpp:21
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, bool useNvGpu=false)
Patterns to transform vector ops into a canonical form to convert to MMA matrix operations.
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:349
void getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
LogicalResult getStridesAndOffset(MemRefType t, SmallVectorImpl< int64_t > &strides, int64_t &offset)
Returns the strides of the MemRef if the layout map is in strided form.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
LogicalResult convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops ops nested under rootOp to vector and GPU operaitons compatible with the nvvm....
LogicalResult applyPatternsAndFoldGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:623
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< Pass > createConvertVectorToGPUPass(bool useNvGpu=false)
Convert from vector to GPU ops.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:599
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
LogicalResult convertVectorToMMAOps(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops to MMA matrix operations nested under rootOp.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Multi-root DAG topological sort.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
static LogicalResult success(bool isSuccess=true)
If isSuccess is true a success result is generated, otherwise a 'failure' result is generated.
Definition: LogicalResult.h:30
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:358
TransitiveFilter filter
Definition: SliceAnalysis.h:30
Specifies information about the registers which compose a matrix fragment according to the PTX docume...
Definition: MMAUtils.h:52