MLIR  22.0.0git
VectorToGPU.cpp
Go to the documentation of this file.
1 //===- VectorToGPU.cpp - Convert vector to GPU dialect ----------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements lowering of vector operations to GPU dialect ops.
10 //
11 //===----------------------------------------------------------------------===//
12 
14 
28 #include "mlir/IR/Builders.h"
29 #include "mlir/IR/Region.h"
30 #include "mlir/Pass/Pass.h"
32 #include "llvm/ADT/STLExtras.h"
33 #include "llvm/ADT/TypeSwitch.h"
34 #include "llvm/Support/DebugLog.h"
35 
36 #define DEBUG_TYPE "vector-to-gpu"
37 
38 namespace mlir {
39 #define GEN_PASS_DEF_CONVERTVECTORTOGPU
40 #include "mlir/Conversion/Passes.h.inc"
41 } // namespace mlir
42 
43 using namespace mlir;
44 
45 /// For a vector TransferOpType `xferOp`, an empty `indices` vector, and an
46 /// AffineMap representing offsets to apply to indices, the function fills
47 /// `indices` with the original indices plus the offsets. The offsets are
48 /// applied by taking into account the permutation map of the transfer op. If
49 /// the `offsetMap` has dimension placeholders, those should be provided in
50 /// `dimValues`.
51 template <typename TransferOpType>
52 static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp,
53  AffineMap offsetMap, ArrayRef<Value> dimValues,
54  SmallVector<Value, 4> &indices) {
55  indices.append(xferOp.getIndices().begin(), xferOp.getIndices().end());
56  Location loc = xferOp.getLoc();
57  unsigned offsetsIdx = 0;
58  for (auto expr : xferOp.getPermutationMap().getResults()) {
59  if (auto dim = dyn_cast<AffineDimExpr>(expr)) {
60  Value prevIdx = indices[dim.getPosition()];
61  SmallVector<OpFoldResult, 3> dims(dimValues);
62  dims.push_back(prevIdx);
63  AffineExpr d0 = rewriter.getAffineDimExpr(offsetMap.getNumDims());
64  indices[dim.getPosition()] = affine::makeComposedAffineApply(
65  rewriter, loc, d0 + offsetMap.getResult(offsetsIdx++), dims);
66  continue;
67  }
68  }
69 }
70 
71 // Return true if the contract op can be convert to MMA matmul.
72 static bool contractSupportsMMAMatrixType(vector::ContractionOp contract,
73  bool useNvGpu) {
74  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
75  auto infer = [&](MapList m) {
76  return AffineMap::inferFromExprList(m, contract.getContext());
77  };
78  AffineExpr m, n, k;
79  bindDims(contract.getContext(), m, n, k);
80  auto iteratorTypes = contract.getIteratorTypes().getValue();
81  if (!(vector::isParallelIterator(iteratorTypes[0]) &&
82  vector::isParallelIterator(iteratorTypes[1]) &&
83  vector::isReductionIterator(iteratorTypes[2])))
84  return false;
85 
86  // The contract needs to represent a matmul to be able to convert to
87  // MMAMatrix matmul.
88  if (!useNvGpu &&
89  contract.getIndexingMapsArray() != infer({{m, k}, {k, n}, {m, n}}))
90  return false;
91  if (useNvGpu &&
92  contract.getIndexingMapsArray() != infer({{m, k}, {n, k}, {m, n}}))
93  return false;
94 
95  return true;
96 }
97 
98 // Return true if the given map represents a transposed matrix load,
99 // i.e. (d0, d1, ...) -> (dn-1, dn-2).
100 static bool isTransposeMatrixLoadMap(AffineMap permutationMap) {
101  MLIRContext *ctx = permutationMap.getContext();
102  // Local OpBuilder is fine here, we just build attributes.
103  OpBuilder b(ctx);
104  auto nDim = permutationMap.getNumDims();
105  AffineExpr zero = b.getAffineConstantExpr(0);
106  if (nDim < 2) {
107  // Support transposed+broadcasted cases: affine_map<(d0) -> (d0, 0)>.
108  AffineExpr dim0 = b.getAffineDimExpr(0);
109  return permutationMap == AffineMap::get(1, 0, {dim0, zero}, ctx);
110  }
111 
112  AffineExpr innerDim = b.getAffineDimExpr(nDim - 1);
113  AffineExpr outerDim = b.getAffineDimExpr(nDim - 2);
114  // Support both transposed and transposed+broadcasted cases.
115  return permutationMap == AffineMap::get(nDim, 0, {innerDim, outerDim}, ctx) ||
116  permutationMap == AffineMap::get(nDim, 0, {innerDim, zero}, ctx);
117 }
118 
119 // Return the stide for the second-to-last dimension of |type| if it is a memref
120 // and has a constant stride.
121 static std::optional<int64_t> getStaticallyKnownRowStride(ShapedType type) {
122  auto memrefType = dyn_cast<MemRefType>(type);
123  if (!memrefType)
124  return false;
125  // If the memref is 0 or 1D the horizontal stride is 0.
126  if (memrefType.getRank() < 2)
127  return 0;
128  int64_t offset = 0;
129  SmallVector<int64_t, 2> strides;
130  if (failed(memrefType.getStridesAndOffset(strides, offset)) ||
131  strides.back() != 1)
132  return std::nullopt;
133  int64_t stride = strides[strides.size() - 2];
134  if (stride == ShapedType::kDynamic)
135  return std::nullopt;
136  return stride;
137 }
138 
139 // Return true if the transfer op can be converted to a MMA matrix load.
140 static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp) {
141  if (readOp.getMask() || readOp.hasOutOfBoundsDim() ||
142  readOp.getVectorType().getRank() != 2)
143  return false;
144  if (!getStaticallyKnownRowStride(readOp.getShapedType()))
145  return false;
146 
147  // Only allow integer types if the signedness can be inferred.
148  if (readOp.getVectorType().getElementType().isInteger(8))
149  if (!readOp->hasOneUse() || (!isa<arith::ExtSIOp>(*readOp->user_begin()) &&
150  !isa<arith::ExtUIOp>(*readOp->user_begin())))
151  return false;
152 
153  AffineMap map = readOp.getPermutationMap();
154  MLIRContext *ctx = readOp.getContext();
155  AffineExpr innerDim = getAffineDimExpr(map.getNumDims() - 1, ctx);
156  AffineExpr zero = getAffineConstantExpr(0, ctx);
157  auto broadcastInnerDim =
158  AffineMap::get(map.getNumDims(), 0, {zero, innerDim}, ctx);
159  return map.isMinorIdentity() || map == broadcastInnerDim ||
161 }
162 
163 // Return true if the transfer op can be converted to a MMA matrix store.
164 static bool
165 transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp) {
166  // TODO: support 0-d corner case.
167  if (writeOp.getTransferRank() == 0)
168  return false;
169 
170  if (writeOp.getMask() || writeOp.hasOutOfBoundsDim() ||
171  writeOp.getVectorType().getRank() != 2)
172  return false;
173  if (!getStaticallyKnownRowStride(writeOp.getShapedType()))
174  return false;
175  // TODO: Support transpose once it is added to GPU dialect ops.
176  if (!writeOp.getPermutationMap().isMinorIdentity())
177  return false;
178  return true;
179 }
180 
181 /// Return true if the constant is a splat to a 2D vector so that it can be
182 /// converted to a MMA constant matrix op.
183 static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp) {
184  auto vecType = dyn_cast<VectorType>(constantOp.getType());
185  if (!vecType || vecType.getRank() != 2)
186  return false;
187  return isa<SplatElementsAttr>(constantOp.getValue());
188 }
189 
190 /// Return true if this is a broadcast from scalar to a 2D vector.
191 static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp) {
192  return broadcastOp.getResultVectorType().getRank() == 2;
193 }
194 
195 /// Return true if this integer extend op can be folded into a contract op.
196 template <typename ExtOpTy>
197 static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp) {
198  auto transferReadOp =
199  extOp.getOperand().template getDefiningOp<vector::TransferReadOp>();
200  if (!transferReadOp)
201  return false;
202  return llvm::all_of(extOp->getUsers(), llvm::IsaPred<vector::ContractionOp>);
203 }
204 
205 static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp) { return true; }
206 
207 /// Return the MMA elementwise enum associated with `op` if it is supported.
208 /// Return `std::nullopt` otherwise.
209 static std::optional<gpu::MMAElementwiseOp>
211  if (isa<arith::AddFOp>(op))
212  return gpu::MMAElementwiseOp::ADDF;
213  if (isa<arith::MulFOp>(op))
214  return gpu::MMAElementwiseOp::MULF;
215  if (isa<arith::SubFOp>(op))
216  return gpu::MMAElementwiseOp::SUBF;
217  if (isa<arith::MaximumFOp>(op))
218  return gpu::MMAElementwiseOp::MAXF;
219  if (isa<arith::MinimumFOp>(op))
220  return gpu::MMAElementwiseOp::MINF;
221  if (isa<arith::DivFOp>(op))
222  return gpu::MMAElementwiseOp::DIVF;
223  if (isa<arith::AddIOp>(op))
225  if (isa<arith::MulIOp>(op))
227  if (isa<arith::SubIOp>(op))
229  if (isa<arith::DivSIOp>(op))
230  return gpu::MMAElementwiseOp::DIVS;
231  if (isa<arith::DivUIOp>(op))
232  return gpu::MMAElementwiseOp::DIVU;
233  if (isa<arith::NegFOp>(op))
234  return gpu::MMAElementwiseOp::NEGATEF;
235  if (isa<arith::ExtFOp>(op))
236  return gpu::MMAElementwiseOp::EXTF;
237  return std::nullopt;
238 }
239 
240 /// Return true if the op is supported as elementwise op on MMAMatrix type.
242  return convertElementwiseOpToMMA(op).has_value();
243 }
244 
245 /// Returns true if the extract strided slice op is supported with `mma.sync`
246 /// path.
247 static bool
248 extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op) {
249 
250  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
252  if (failed(warpMatrixInfo))
253  return false;
254 
255  FailureOr<vector::ContractionOp> contractOp = nvgpu::getUserContract(op);
256  if (failed(contractOp))
257  return false;
258 
259  // Handle vector.extract_strided_slice on registers containing
260  // matrixB and matrixC operands. vector.extract_strided_slice op
261  // is not supported on registers containing matrixA operands.
262  if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::B)
263  return (cast<VectorType>(op->getResult(0).getType()) ==
264  cast<VectorType>((*contractOp).getRhs().getType()));
265  if (warpMatrixInfo->operandRole == nvgpu::MatMulOperandRole::C)
266  return (cast<VectorType>(op->getResult(0).getType()) ==
267  cast<VectorType>((*contractOp).getAcc().getType()));
268 
269  return false;
270 }
271 
272 static bool supportsMMaMatrixType(Operation *op, bool useNvGpu) {
273  if (isa<scf::ForOp, scf::YieldOp>(op))
274  return true;
275  if (auto transferRead = dyn_cast<vector::TransferReadOp>(op))
276  return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferRead)
277  : transferReadSupportsMMAMatrixType(transferRead);
278  if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op))
279  return useNvGpu ? nvgpu::canLowerToWarpMatrixOperation(transferWrite)
280  : transferWriteSupportsMMAMatrixType(transferWrite);
281  if (auto extractStridedSlice = dyn_cast<vector::ExtractStridedSliceOp>(op))
282  return useNvGpu &&
283  extractStridedSliceSupportsMMAMatrixType(extractStridedSlice);
284  if (auto contract = dyn_cast<vector::ContractionOp>(op))
285  return contractSupportsMMAMatrixType(contract, useNvGpu);
286  if (auto constant = dyn_cast<arith::ConstantOp>(op))
287  return constantSupportsMMAMatrixType(constant);
288  if (auto broadcast = dyn_cast<vector::BroadcastOp>(op))
290  if (auto signedExtend = dyn_cast<arith::ExtSIOp>(op))
291  return integerExtendSupportsMMAMatrixType<arith::ExtSIOp>(signedExtend);
292  if (auto unsignedExtend = dyn_cast<arith::ExtUIOp>(op))
293  return integerExtendSupportsMMAMatrixType<arith::ExtUIOp>(unsignedExtend);
294  if (auto fpExtend = dyn_cast<arith::ExtFOp>(op))
295  return fpExtendSupportsMMAMatrixType(fpExtend);
297 }
298 
299 /// Return an unsorted slice handling scf.for region differently than
300 /// `getSlice`. In scf.for we only want to include as part of the slice elements
301 /// that are part of the use/def chain.
304  const BackwardSliceOptions &backwardSliceOptions,
305  const ForwardSliceOptions &forwardSliceOptions) {
307  slice.insert(op);
308  unsigned currentIndex = 0;
309  SetVector<Operation *> backwardSlice;
310  SetVector<Operation *> forwardSlice;
311  while (currentIndex != slice.size()) {
312  auto *currentOp = (slice)[currentIndex];
313  // Compute and insert the backwardSlice starting from currentOp.
314  backwardSlice.clear();
315  LogicalResult result =
316  getBackwardSlice(currentOp, &backwardSlice, backwardSliceOptions);
317  assert(result.succeeded() && "expected a backward slice");
318  (void)result;
319  slice.insert_range(backwardSlice);
320 
321  // Compute and insert the forwardSlice starting from currentOp.
322  forwardSlice.clear();
323  // Special case for ForOp, we don't want to include the whole region but
324  // only the value using the region arguments.
325  // TODO: We should refine this to only care about the region arguments being
326  // converted to matrix type.
327  if (auto forOp = dyn_cast<scf::ForOp>(currentOp)) {
328  for (Value forOpResult : forOp.getResults())
329  getForwardSlice(forOpResult, &forwardSlice, forwardSliceOptions);
330  for (BlockArgument &arg : forOp.getRegionIterArgs())
331  getForwardSlice(arg, &forwardSlice, forwardSliceOptions);
332  } else {
333  getForwardSlice(currentOp, &forwardSlice, forwardSliceOptions);
334  }
335  slice.insert_range(forwardSlice);
336  ++currentIndex;
337  }
338  return slice;
339 }
340 
341 // Analyze slice of operations based on convert op to figure out if the whole
342 // slice can be converted to MMA operations.
344  bool useNvGpu) {
345  auto hasVectorDest = [](Operation *op) {
346  return llvm::any_of(op->getResultTypes(), llvm::IsaPred<VectorType>);
347  };
348  BackwardSliceOptions backwardSliceOptions;
349  backwardSliceOptions.filter = hasVectorDest;
350 
351  auto hasVectorSrc = [](Operation *op) {
352  return llvm::any_of(op->getOperandTypes(), llvm::IsaPred<VectorType>);
353  };
354  ForwardSliceOptions forwardSliceOptions;
355  forwardSliceOptions.filter = hasVectorSrc;
356 
357  SetVector<Operation *> opToConvert;
358  op->walk([&](Operation *nestedOp) {
359  if (!isa<vector::ContractionOp>(nestedOp) &&
361  return;
362  if (opToConvert.contains(nestedOp))
363  return;
364  SetVector<Operation *> dependentOps =
365  getSliceContract(nestedOp, backwardSliceOptions, forwardSliceOptions);
366  // If any instruction cannot use MMA matrix type drop the whole
367  // chain. MMA matrix are stored in an opaque type so they cannot be used
368  // by all operations.
369  if (llvm::any_of(dependentOps, [useNvGpu](Operation *op) {
370  if (!supportsMMaMatrixType(op, useNvGpu)) {
371  LDBG() << "cannot convert op: " << *op;
372  return true;
373  }
374  return false;
375  }))
376  return;
377 
378  opToConvert.insert_range(dependentOps);
379  });
380  // Sort the operations so that we can convert them in topological order.
381  return topologicalSort(opToConvert);
382 }
383 
384 namespace {
385 // Transform contract into (m, k)x(k, n)x(m, n) form so that it can be converted
386 // to MMA matmul.
387 struct PrepareContractToGPUMMA
388  : public OpRewritePattern<vector::ContractionOp> {
390 
391  LogicalResult matchAndRewrite(vector::ContractionOp op,
392  PatternRewriter &rewriter) const override {
393  Location loc = op.getLoc();
394  Value lhs = op.getLhs(), rhs = op.getRhs(), res = op.getAcc();
395 
396  // Set up the parallel/reduction structure in right form.
397  using MapList = ArrayRef<ArrayRef<AffineExpr>>;
398  auto infer = [&](MapList m) {
399  return AffineMap::inferFromExprList(m, op.getContext());
400  };
401  AffineExpr m, n, k;
402  bindDims(rewriter.getContext(), m, n, k);
403  static constexpr std::array<int64_t, 2> perm = {1, 0};
404  auto iteratorTypes = op.getIteratorTypes().getValue();
405  SmallVector<AffineMap, 4> maps = op.getIndexingMapsArray();
406  if (!(vector::isParallelIterator(iteratorTypes[0]) &&
407  vector::isParallelIterator(iteratorTypes[1]) &&
408  vector::isReductionIterator(iteratorTypes[2])))
409  return rewriter.notifyMatchFailure(op, "not a gemm contraction");
410  //
411  // Two outer parallel, one inner reduction (matmat flavor).
412  //
413  // This is the classical row-major matmul, nothing to do.
414  if (maps == infer({{m, k}, {k, n}, {m, n}}))
415  return rewriter.notifyMatchFailure(op, "contraction already prepared");
416  if (maps == infer({{m, k}, {n, k}, {m, n}})) {
417  rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
418  } else if (maps == infer({{k, m}, {k, n}, {m, n}})) {
419  lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
420  } else if (maps == infer({{k, m}, {n, k}, {m, n}})) {
421  rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
422  lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
423  } else if (maps == infer({{m, k}, {k, n}, {n, m}})) {
424  std::swap(rhs, lhs);
425  rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
426  lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
427  } else if (maps == infer({{m, k}, {n, k}, {n, m}})) {
428  std::swap(rhs, lhs);
429  rhs = vector::TransposeOp::create(rewriter, loc, rhs, perm);
430  } else if (maps == infer({{k, m}, {k, n}, {n, m}})) {
431  std::swap(lhs, rhs);
432  lhs = vector::TransposeOp::create(rewriter, loc, lhs, perm);
433  } else if (maps == infer({{k, m}, {n, k}, {n, m}})) {
434  std::swap(lhs, rhs);
435  } else {
436  // TODO: llvm_unreachable ?
437  return rewriter.notifyMatchFailure(op, "unexpected contraction case");
438  }
439  rewriter.replaceOpWithNewOp<vector::ContractionOp>(
440  op, lhs, rhs, res,
441  rewriter.getAffineMapArrayAttr(infer({{m, k}, {k, n}, {m, n}})),
442  op.getIteratorTypes());
443  return success();
444  }
445 };
446 
447 // Fold transpose op into the transfer read op. NVGPU mma.sync op only supports
448 // row-, column-, and row-major layout for matrixA, matrixB, and matrixC,
449 // respectively. We can fold the transpose operation when loading the data from
450 // Shared Memory to registers.
451 struct CombineTransferReadOpTranspose final
452  : public OpRewritePattern<vector::TransposeOp> {
454 
455  LogicalResult matchAndRewrite(vector::TransposeOp op,
456  PatternRewriter &rewriter) const override {
457  // Look through integer extend ops.
458  Value source = op.getVector();
459  Type resultType = op.getType();
460  Operation *extOp;
461  if ((extOp = source.getDefiningOp<arith::ExtSIOp>()) ||
462  (extOp = source.getDefiningOp<arith::ExtUIOp>()) ||
463  (extOp = source.getDefiningOp<arith::ExtFOp>())) {
464  source = extOp->getOperand(0);
465  resultType =
466  VectorType::get(cast<VectorType>(resultType).getShape(),
467  cast<VectorType>(source.getType()).getElementType());
468  }
469 
470  auto transferReadOp = source.getDefiningOp<vector::TransferReadOp>();
471  if (!transferReadOp)
472  return rewriter.notifyMatchFailure(op, "no transfer read");
473 
474  // TODO: support 0-d corner case.
475  if (transferReadOp.getTransferRank() == 0)
476  return rewriter.notifyMatchFailure(op, "0-D transfer read");
477 
478  if (transferReadOp.getMask() || transferReadOp.hasOutOfBoundsDim())
479  return rewriter.notifyMatchFailure(op, "not inbounds transfer read");
480 
481  AffineMap permutationMap =
482  AffineMap::getPermutationMap(op.getPermutation(), op.getContext());
483  AffineMap newMap =
484  permutationMap.compose(transferReadOp.getPermutationMap());
485 
486  auto loc = op.getLoc();
487  Value result = vector::TransferReadOp::create(
488  rewriter, loc, resultType, transferReadOp.getBase(),
489  transferReadOp.getIndices(), AffineMapAttr::get(newMap),
490  transferReadOp.getPadding(), transferReadOp.getMask(),
491  transferReadOp.getInBoundsAttr())
492  .getResult();
493 
494  // Fuse through the integer extend op.
495  if (extOp) {
496  if (isa<arith::ExtSIOp>(extOp))
497  result = arith::ExtSIOp::create(rewriter, loc, op.getType(), result)
498  .getResult();
499  else if (isa<arith::ExtUIOp>(extOp))
500  result = arith::ExtUIOp::create(rewriter, loc, op.getType(), result)
501  .getResult();
502  else
503  result = arith::ExtFOp::create(rewriter, loc, op.getType(), result)
504  .getResult();
505  }
506 
507  rewriter.replaceOp(op, result);
508  return success();
509  }
510 };
511 
512 } // namespace
513 
514 // MMA types have different layout based on how they are used in matmul ops.
515 // Figure the right layout to use by looking at op uses.
516 // TODO: Change the GPU dialect to abstract the layout at the this level and
517 // only care about it during lowering to NVVM.
518 static const char *inferFragType(Operation *op) {
519  // We can have arith.ext ops before reaching contract ops. See through them
520  // and other kinds of elementwise ops.
521  if (op->hasOneUse()) {
522  Operation *userOp = *op->user_begin();
523  if (userOp->hasTrait<OpTrait::Elementwise>())
524  return inferFragType(userOp);
525  }
526 
527  for (Operation *users : op->getUsers()) {
528  auto contract = dyn_cast<vector::ContractionOp>(users);
529  if (!contract)
530  continue;
531  assert(op->getNumResults() == 1);
532  if (contract.getLhs() == op->getResult(0))
533  return "AOp";
534  if (contract.getRhs() == op->getResult(0))
535  return "BOp";
536  }
537  return "COp";
538 }
539 
540 static LogicalResult
541 convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op,
542  llvm::DenseMap<Value, Value> &valueMapping) {
543  OpBuilder::InsertionGuard g(rewriter);
544  rewriter.setInsertionPoint(op);
545 
546  assert(op.getTransferRank() > 0 && "unexpected 0-d transfer");
548  "expected convertible operation");
549 
550  std::optional<int64_t> stride =
551  getStaticallyKnownRowStride(op.getShapedType());
552  if (!stride.has_value()) {
553  LDBG() << "no stride";
554  return rewriter.notifyMatchFailure(op, "no stride");
555  }
556 
557  AffineMap map = op.getPermutationMap();
558  bool isTranspose = isTransposeMatrixLoadMap(map);
559 
560  // Handle broadcast by setting the stride to 0.
561  if (auto cstExpr = dyn_cast<AffineConstantExpr>(map.getResult(isTranspose))) {
562  assert(cstExpr.getValue() == 0);
563  stride = 0;
564  }
565 
566  Value mappingResult = op.getResult();
567  auto elType = op.getVectorType().getElementType();
568  const char *fragType = inferFragType(op);
569  if (op->hasOneUse()) {
570  auto *user = *op->user_begin();
571  // Infer the signedness of the mma type from the integer extend.
572  if (isa<arith::ExtSIOp, arith::ExtUIOp>(user)) {
573  elType = IntegerType::get(
574  op.getContext(), cast<IntegerType>(elType).getWidth(),
575  isa<arith::ExtSIOp>(user) ? IntegerType::Signed
576  : IntegerType::Unsigned);
577  mappingResult = user->getResult(0);
578  }
579  }
580  gpu::MMAMatrixType type =
581  gpu::MMAMatrixType::get(op.getVectorType().getShape(), elType, fragType);
582  Value load = gpu::SubgroupMmaLoadMatrixOp::create(
583  rewriter, op.getLoc(), type, op.getBase(), op.getIndices(),
584  rewriter.getIndexAttr(*stride),
585  isTranspose ? rewriter.getUnitAttr() : UnitAttr());
586  valueMapping[mappingResult] = load;
587 
588  LDBG() << "transfer read to: " << load;
589  return success();
590 }
591 
592 static LogicalResult
593 convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op,
594  llvm::DenseMap<Value, Value> &valueMapping) {
595  OpBuilder::InsertionGuard g(rewriter);
596  rewriter.setInsertionPoint(op);
597 
599  std::optional<int64_t> stride =
600  getStaticallyKnownRowStride(op.getShapedType());
601  if (!stride.has_value()) {
602  LDBG() << "no stride";
603  return rewriter.notifyMatchFailure(op, "no stride");
604  }
605 
606  auto it = valueMapping.find(op.getVector());
607  if (it == valueMapping.end()) {
608  LDBG() << "no mapping";
609  return rewriter.notifyMatchFailure(op, "no mapping");
610  }
611 
612  Value matrix = it->second;
613  auto store = gpu::SubgroupMmaStoreMatrixOp::create(
614  rewriter, op.getLoc(), matrix, op.getBase(), op.getIndices(),
615  rewriter.getIndexAttr(*stride), /*transpose=*/UnitAttr());
616  (void)store;
617 
618  LDBG() << "transfer write to: " << store;
619 
620  LDBG() << "erase: " << op;
621  rewriter.eraseOp(op);
622  return success();
623 }
624 
625 /// Returns the vector type which represents a matrix fragment.
626 static VectorType
629  regInfo.elementsPerRegister};
630  Type elType = regInfo.registerLLVMType;
631  if (auto vecType = dyn_cast<VectorType>(elType))
632  elType = vecType.getElementType();
633  return VectorType::get(shape, elType);
634 }
635 
636 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
637 static LogicalResult
638 convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op,
639  llvm::DenseMap<Value, Value> &valueMapping) {
640  OpBuilder::InsertionGuard g(rewriter);
641  rewriter.setInsertionPoint(op);
642 
643  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
645  if (failed(warpMatrixInfo)) {
646  LDBG() << "no warpMatrixInfo";
647  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
648  }
649 
650  FailureOr<nvgpu::FragmentElementInfo> regInfo =
651  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
652  if (failed(regInfo)) {
653  LDBG() << "not mma sync reg info";
654  return rewriter.notifyMatchFailure(op, "not mma sync reg info");
655  }
656 
657  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
658  auto dense = dyn_cast<SplatElementsAttr>(op.getValue());
659  if (!dense) {
660  LDBG() << "not a splat";
661  return rewriter.notifyMatchFailure(op, "not a splat");
662  }
663 
664  Value result = arith::ConstantOp::create(
665  rewriter, op.getLoc(), vectorType,
666  DenseElementsAttr::get(vectorType, dense.getSplatValue<Attribute>()));
667  valueMapping[op.getResult()] = result;
668  return success();
669 }
670 
671 /// Check if the loaded matrix operand requires transposed.
672 /// Transposed Map Example:
673 /// Example 1 : (..., d0, d1) -> (d1 * 1, d0 * 2)
674 /// Example 2 : (d0, d1, d2, d3) -> (d3, d2)
675 /// The code below checks if the output 2D is transposed using a generalized
676 /// version : (d0, d1, dn, ..., dm, ...) -> (dm, dn)
677 /// Returns : true; if m > n, false o.w.
678 static FailureOr<bool> isTransposed(vector::TransferReadOp op) {
680 
681  if (map.getNumResults() != 2) {
682  LDBG() << "Failed because the result of `vector.transfer_read` "
683  "is not a 2d operand";
684  return failure();
685  }
686 
687  // Output 2D matrix dimensions in the order of d0, d1.
688  mlir::AffineExpr dM = map.getResult(0);
689  mlir::AffineExpr dN = map.getResult(1);
690 
691  // Find the position of these expressions in the input.
692  auto exprM = dyn_cast<AffineDimExpr>(dM);
693  auto exprN = dyn_cast<AffineDimExpr>(dN);
694 
695  if (!exprM || !exprN) {
696  LDBG() << "Failed because expressions are not affine dim "
697  "expressions, then transpose cannot be determined.";
698  return failure();
699  }
700 
701  return exprM.getPosition() > exprN.getPosition();
702 }
703 
704 static LogicalResult
705 creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op,
706  llvm::DenseMap<Value, Value> &valueMapping) {
707  OpBuilder::InsertionGuard g(rewriter);
708  rewriter.setInsertionPoint(op);
709  Location loc = op->getLoc();
710 
711  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
713  if (failed(warpMatrixInfo)) {
714  LDBG() << "no warpMatrixInfo";
715  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
716  }
717 
718  FailureOr<nvgpu::FragmentElementInfo> regInfo =
719  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
720  if (failed(regInfo)) {
721  LDBG() << "not mma sync reg info";
722  return rewriter.notifyMatchFailure(op, "not mma sync reg info");
723  }
724 
725  FailureOr<bool> transpose = isTransposed(op);
726  if (failed(transpose)) {
727  LDBG() << "failed to determine the transpose";
728  return rewriter.notifyMatchFailure(
729  op, "Op should likely not be converted to a nvgpu.ldmatrix call.");
730  }
731 
732  FailureOr<nvgpu::LdMatrixParams> params =
733  nvgpu::getLdMatrixParams(*warpMatrixInfo, *transpose);
734 
735  if (failed(params)) {
736  LDBG() << "failed to convert vector.transfer_read to ldmatrix. "
737  << "Op should likely not be converted to a nvgpu.ldmatrix call.";
738  return rewriter.notifyMatchFailure(
739  op, "failed to convert vector.transfer_read to ldmatrix; this op "
740  "likely should not be converted to a nvgpu.ldmatrix call.");
741  }
742 
743  // Adjust the load offset.
744  auto laneId = gpu::LaneIdOp::create(rewriter, loc, /*upperBound=*/nullptr);
745  FailureOr<AffineMap> offsets =
746  nvgpu::getLaneIdToLdMatrixMatrixCoord(rewriter, loc, *params);
747  if (failed(offsets)) {
748  LDBG() << "no offsets";
749  return rewriter.notifyMatchFailure(op, "no offsets");
750  }
751 
752  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
753 
754  SmallVector<Value, 4> indices;
755  getXferIndices<vector::TransferReadOp>(rewriter, op, *offsets, {laneId},
756  indices);
757 
758  nvgpu::LdMatrixOp newOp =
759  nvgpu::LdMatrixOp::create(rewriter, loc, vectorType, op.getBase(),
760  indices, *transpose, params->numTiles);
761  valueMapping[op] = newOp->getResult(0);
762  return success();
763 }
764 
765 static LogicalResult
766 createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op,
767  llvm::DenseMap<Value, Value> &valueMapping) {
768  OpBuilder::InsertionGuard g(rewriter);
769  rewriter.setInsertionPoint(op);
770 
771  Location loc = op.getLoc();
772  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
774  if (failed(warpMatrixInfo))
775  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
776  FailureOr<nvgpu::FragmentElementInfo> regInfo =
777  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
778  if (failed(regInfo)) {
779  return rewriter.notifyMatchFailure(
780  op, "Failed to deduce register fragment type during "
781  "conversion to distributed non-ldmatrix compatible load");
782  }
783 
784  Value laneId = gpu::LaneIdOp::create(rewriter, loc, /*upperBound=*/nullptr);
785 
786  // This is the individual element type.
787  Type loadedElType = regInfo->registerLLVMType;
788  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
789 
790  Value fill = arith::ConstantOp::create(
791  rewriter, op.getLoc(), vectorType.getElementType(),
792  rewriter.getZeroAttr(vectorType.getElementType()));
793  Value result =
794  vector::BroadcastOp::create(rewriter, op.getLoc(), vectorType, fill);
795 
796  bool isTransposeLoad = !op.getPermutationMap().isMinorIdentity();
797 
798  // If we are not transposing, then we can use vectorized loads. Otherwise, we
799  // must load each element individually.
800  if (!isTransposeLoad) {
801  if (!isa<VectorType>(loadedElType)) {
802  loadedElType = VectorType::get({1}, loadedElType);
803  }
804 
805  for (int i = 0; i < vectorType.getShape()[0]; i++) {
806  FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
807  rewriter, op.getLoc(), *warpMatrixInfo);
808  if (failed(coords))
809  return rewriter.notifyMatchFailure(op, "no coords");
810 
811  Value logicalValueId = arith::ConstantOp::create(
812  rewriter, loc, rewriter.getIndexType(),
813  rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
814  SmallVector<Value, 4> newIndices;
815  getXferIndices<vector::TransferReadOp>(
816  rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
817 
818  Value el = vector::LoadOp::create(rewriter, loc, loadedElType,
819  op.getBase(), newIndices);
820  result = vector::InsertOp::create(rewriter, loc, el, result, i);
821  }
822  } else {
823  if (auto vecType = dyn_cast<VectorType>(loadedElType)) {
824  loadedElType = vecType.getElementType();
825  }
826  for (int i = 0; i < vectorType.getShape()[0]; i++) {
827  for (unsigned innerIdx = 0; innerIdx < vectorType.getShape()[1];
828  innerIdx++) {
829 
830  Value logicalValueId = arith::ConstantOp::create(
831  rewriter, loc, rewriter.getIndexType(),
832  rewriter.getIndexAttr(i * regInfo->elementsPerRegister + innerIdx));
833  FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
834  rewriter, op.getLoc(), *warpMatrixInfo);
835  if (failed(coords))
836  return rewriter.notifyMatchFailure(op, "no coords");
837 
838  SmallVector<Value, 4> newIndices;
839  getXferIndices<vector::TransferReadOp>(
840  rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
841  Value el = memref::LoadOp::create(rewriter, op.getLoc(), loadedElType,
842  op.getBase(), newIndices);
843  result = vector::InsertOp::create(rewriter, op.getLoc(), el, result,
844  ArrayRef<int64_t>{i, innerIdx});
845  }
846  }
847  }
848 
849  valueMapping[op.getResult()] = result;
850  return success();
851 }
852 
853 /// Return true if this is a shared memory memref type.
854 static bool isSharedMemory(MemRefType type) {
855  auto addressSpace =
856  dyn_cast_or_null<gpu::AddressSpaceAttr>(type.getMemorySpace());
857  return addressSpace &&
858  addressSpace.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
859 }
860 
861 /// Converts a `vector.transfer_read` operation directly to either a
862 /// `vector.load` or a `nvgpu.ldmatrix` operation. This function should only be
863 /// used when converting to `nvgpu.mma.sync` operations.
864 static LogicalResult
865 convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op,
866  llvm::DenseMap<Value, Value> &valueMapping) {
867  OpBuilder::InsertionGuard g(rewriter);
868  rewriter.setInsertionPoint(op);
869 
870  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
872  if (failed(warpMatrixInfo))
873  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
874 
875  bool isLdMatrixCompatible =
876  isSharedMemory(cast<MemRefType>(op.getBase().getType())) &&
877  nvgpu::inferTileWidthInBits(*warpMatrixInfo) == 128;
878 
879  VectorType vecTy = op.getVectorType();
880  int64_t bitWidth = vecTy.getElementType().getIntOrFloatBitWidth();
881 
882  // When we are transposing the B operand, ldmatrix will only work if we have
883  // at least 8 rows to read and the width to read for the transpose is 128
884  // bits.
885  if (!op.getPermutationMap().isMinorIdentity() &&
886  (bitWidth != 16 || vecTy.getDimSize(1) < 8 ||
887  vecTy.getDimSize(0) * bitWidth < 128))
888  isLdMatrixCompatible = false;
889 
890  if (!isLdMatrixCompatible)
891  return createNonLdMatrixLoads(rewriter, op, valueMapping);
892 
893  return creatLdMatrixCompatibleLoads(rewriter, op, valueMapping);
894 }
895 
896 static LogicalResult
897 convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op,
898  llvm::DenseMap<Value, Value> &valueMapping) {
899  OpBuilder::InsertionGuard g(rewriter);
900  rewriter.setInsertionPoint(op);
901 
902  Location loc = op->getLoc();
903  auto it = valueMapping.find(op.getVector());
904  if (it == valueMapping.end())
905  return rewriter.notifyMatchFailure(op, "no mapping");
906  Value matrix = it->second;
907 
908  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
910  if (failed(warpMatrixInfo))
911  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
912  FailureOr<nvgpu::FragmentElementInfo> regInfo =
913  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
914  if (failed(regInfo))
915  return rewriter.notifyMatchFailure(op, "not mma sync reg info");
916 
917  VectorType vectorType = getMmaSyncVectorOperandType(*regInfo);
918  Value laneId = gpu::LaneIdOp::create(rewriter, loc, /*upperBound=*/nullptr);
919 
920  for (unsigned i = 0; i < vectorType.getShape()[0]; i++) {
921  Value logicalValueId = arith::ConstantOp::create(
922  rewriter, loc, rewriter.getIndexType(),
923  rewriter.getIndexAttr(i * regInfo->elementsPerRegister));
924  FailureOr<AffineMap> coords = nvgpu::getLaneIdAndValueIdToOperandCoord(
925  rewriter, op.getLoc(), *warpMatrixInfo);
926  if (failed(coords))
927  return rewriter.notifyMatchFailure(op, "no coords");
928 
929  Value el =
930  vector::ExtractOp::create(rewriter, loc, matrix, ArrayRef<int64_t>{i});
931  SmallVector<Value, 4> newIndices;
932  getXferIndices<vector::TransferWriteOp>(
933  rewriter, op, *coords, {laneId, logicalValueId}, newIndices);
934  vector::StoreOp::create(rewriter, loc, el, op.getBase(), newIndices);
935  }
936 
937  LDBG() << "erase: " << op;
938  rewriter.eraseOp(op);
939  return success();
940 }
941 
942 static void populateFromInt64AttrArray(ArrayAttr arrayAttr,
943  SmallVectorImpl<int64_t> &results) {
944  for (auto attr : arrayAttr)
945  results.push_back(cast<IntegerAttr>(attr).getInt());
946 }
947 
948 static LogicalResult
950  vector::ExtractStridedSliceOp op,
951  llvm::DenseMap<Value, Value> &valueMapping) {
952  OpBuilder::InsertionGuard g(rewriter);
953  rewriter.setInsertionPoint(op);
954 
955  Location loc = op->getLoc();
956 
957  FailureOr<nvgpu::WarpMatrixInfo> warpMatrixInfo =
959  if (failed(warpMatrixInfo))
960  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
961 
962  FailureOr<nvgpu::FragmentElementInfo> mmaSyncFragmentInfo =
963  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
964  if (failed(mmaSyncFragmentInfo))
965  return rewriter.notifyMatchFailure(op, "no mmaSyncFragmentInfo");
966 
967  // Find the vector.transer_read whose result vector is being sliced.
968  auto transferReadOp = op.getSource().getDefiningOp<vector::TransferReadOp>();
969  if (!transferReadOp)
970  return rewriter.notifyMatchFailure(op, "no transfer read");
971 
972  warpMatrixInfo = nvgpu::getWarpMatrixInfo(transferReadOp);
973  if (failed(warpMatrixInfo))
974  return rewriter.notifyMatchFailure(op, "no warpMatrixInfo");
975 
976  FailureOr<nvgpu::FragmentElementInfo> ldFragmentInfo =
977  nvgpu::getMmaSyncRegisterType(*warpMatrixInfo);
978  if (failed(ldFragmentInfo))
979  return rewriter.notifyMatchFailure(op, "no ldFragmentInfo");
980 
981  assert(
982  (mmaSyncFragmentInfo->elementsPerRegister ==
983  ldFragmentInfo->elementsPerRegister) &&
984  "Number of elements per register should be same for load and mma.sync");
985 
986  // Create vector.extract_strided_slice op for thread-owned fragments.
987  std::array<int64_t, 2> strides = {1,
988  1}; // stride for extract slice is always 1.
989  std::array<int64_t, 2> sliceShape = {
990  mmaSyncFragmentInfo->numRegistersPerFragment,
991  mmaSyncFragmentInfo->elementsPerRegister};
992  auto it = valueMapping.find(transferReadOp);
993  if (it == valueMapping.end())
994  return rewriter.notifyMatchFailure(op, "no mapping");
995  auto sourceVector = it->second;
996 
997  // offset and sizes at warp-level of onwership.
998  SmallVector<int64_t> offsets;
999  populateFromInt64AttrArray(op.getOffsets(), offsets);
1000 
1001  SmallVector<int64_t> sizes;
1002  populateFromInt64AttrArray(op.getSizes(), sizes);
1003  ArrayRef<int64_t> warpVectorShape = op.getSourceVectorType().getShape();
1004 
1005  // Compute offset in vector registers. Note that the mma.sync vector registers
1006  // are shaped as numberOfFragments x numberOfRegistersPerfFragment. The vector
1007  // registers can only be sliced along numberOfFragments, i.e., sliceOffset[0].
1008  std::array<int64_t, 2> sliceOffset = {0, 0};
1009 
1010  if (offsets[0] && offsets[1])
1011  return op->emitError() << "Slicing fragments in 2D is not supported. ";
1012  if (offsets[0])
1013  sliceOffset[0] = (warpVectorShape[0] / offsets[0]);
1014  else if (offsets[1])
1015  sliceOffset[0] = (warpVectorShape[1] / offsets[1]);
1016 
1017  Value newOp = vector::ExtractStridedSliceOp::create(
1018  rewriter, loc, sourceVector, sliceOffset, sliceShape, strides);
1019 
1020  valueMapping[op] = newOp;
1021  return success();
1022 }
1023 
1024 static LogicalResult
1025 convertContractOp(RewriterBase &rewriter, vector::ContractionOp op,
1026  llvm::DenseMap<Value, Value> &valueMapping) {
1027  OpBuilder::InsertionGuard g(rewriter);
1028  rewriter.setInsertionPoint(op);
1029 
1030  auto itA = valueMapping.find(op.getLhs());
1031  auto itB = valueMapping.find(op.getRhs());
1032  auto itC = valueMapping.find(op.getAcc());
1033  if (itA == valueMapping.end() || itB == valueMapping.end() ||
1034  itC == valueMapping.end())
1035  return rewriter.notifyMatchFailure(op, "no mapping");
1036  Value opA = itA->second, opB = itB->second, opC = itC->second;
1037  Value matmul = gpu::SubgroupMmaComputeOp::create(rewriter, op.getLoc(),
1038  opC.getType(), opA, opB, opC,
1039  /*a_transpose=*/UnitAttr(),
1040  /*b_transpose=*/UnitAttr());
1041  valueMapping[op.getResult()] = matmul;
1042  return success();
1043 }
1044 
1045 static LogicalResult
1046 convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op,
1047  llvm::DenseMap<Value, Value> &valueMapping) {
1048  OpBuilder::InsertionGuard g(rewriter);
1049  rewriter.setInsertionPoint(op);
1050 
1051  auto itA = valueMapping.find(op.getLhs());
1052  auto itB = valueMapping.find(op.getRhs());
1053  auto itC = valueMapping.find(op.getAcc());
1054  if (itA == valueMapping.end() || itB == valueMapping.end() ||
1055  itC == valueMapping.end())
1056  return rewriter.notifyMatchFailure(op, "no mapping");
1057  Value opA = itA->second, opB = itB->second, opC = itC->second;
1058  int64_t m = cast<VectorType>(op.getLhs().getType()).getShape()[0];
1059  int64_t n = cast<VectorType>(op.getRhs().getType()).getShape()[0];
1060  int64_t k = cast<VectorType>(op.getLhs().getType()).getShape()[1];
1061  Value matmul = nvgpu::MmaSyncOp::create(rewriter, op.getLoc(), opA, opB, opC,
1062  rewriter.getI64ArrayAttr({m, n, k}));
1063  valueMapping[op.getResult()] = matmul;
1064  return success();
1065 }
1066 
1067 /// Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
1068 static LogicalResult
1069 convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op,
1070  llvm::DenseMap<Value, Value> &valueMapping) {
1071  OpBuilder::InsertionGuard g(rewriter);
1072  rewriter.setInsertionPoint(op);
1073 
1074  assert(constantSupportsMMAMatrixType(op));
1075 
1076  auto splat =
1077  cast<SplatElementsAttr>(op.getValue()).getSplatValue<TypedAttr>();
1078  auto scalarConstant =
1079  arith::ConstantOp::create(rewriter, op.getLoc(), splat.getType(), splat);
1080  const char *fragType = inferFragType(op);
1081  auto vecType = cast<VectorType>(op.getType());
1083  vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1084  auto matrix = gpu::SubgroupMmaConstantMatrixOp::create(rewriter, op.getLoc(),
1085  type, scalarConstant);
1086  valueMapping[op.getResult()] = matrix;
1087  return success();
1088 }
1089 
1090 /// Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
1091 static LogicalResult
1092 convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op,
1093  llvm::DenseMap<Value, Value> &valueMapping) {
1094  OpBuilder::InsertionGuard g(rewriter);
1095  rewriter.setInsertionPoint(op);
1096 
1097  assert(broadcastSupportsMMAMatrixType(op));
1098 
1099  const char *fragType = inferFragType(op);
1100  auto vecType = op.getResultVectorType();
1102  vecType.getShape(), vecType.getElementType(), llvm::StringRef(fragType));
1103  auto matrix = gpu::SubgroupMmaConstantMatrixOp::create(rewriter, op.getLoc(),
1104  type, op.getSource());
1105  valueMapping[op.getResult()] = matrix;
1106  return success();
1107 }
1108 
1109 // Replace ForOp with a new ForOp with extra operands. The YieldOp is not
1110 // updated and needs to be updated separately for the loop to be correct.
1111 static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter,
1112  scf::ForOp loop,
1113  ValueRange newInitArgs) {
1114  OpBuilder::InsertionGuard g(rewriter);
1115  rewriter.setInsertionPoint(loop);
1116 
1117  // Create a new loop before the existing one, with the extra operands.
1118  rewriter.setInsertionPoint(loop);
1119  auto operands = llvm::to_vector<4>(loop.getInitArgs());
1120  llvm::append_range(operands, newInitArgs);
1121  scf::ForOp newLoop =
1122  scf::ForOp::create(rewriter, loop.getLoc(), loop.getLowerBound(),
1123  loop.getUpperBound(), loop.getStep(), operands);
1124  rewriter.eraseBlock(newLoop.getBody());
1125 
1126  newLoop.getRegion().getBlocks().splice(
1127  newLoop.getRegion().getBlocks().begin(), loop.getRegion().getBlocks());
1128  for (Value operand : newInitArgs)
1129  newLoop.getBody()->addArgument(operand.getType(), operand.getLoc());
1130 
1131  for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
1132  loop.getNumResults())))
1133  rewriter.replaceAllUsesWith(std::get<0>(it), std::get<1>(it));
1134 
1135  LDBG() << "newLoop now: " << newLoop;
1136  LDBG() << "stripped scf.for: " << loop;
1137  LDBG() << "erase: " << loop;
1138 
1139  rewriter.eraseOp(loop);
1140  return newLoop;
1141 }
1142 
1143 static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op,
1144  llvm::DenseMap<Value, Value> &valueMapping) {
1145  OpBuilder::InsertionGuard g(rewriter);
1146  rewriter.setInsertionPoint(op);
1147 
1148  SmallVector<Value> newOperands;
1150  for (const auto &operand : llvm::enumerate(op.getInitArgs())) {
1151  auto it = valueMapping.find(operand.value());
1152  if (it == valueMapping.end()) {
1153  LDBG() << "no value mapping for: " << operand.value();
1154  continue;
1155  }
1156  argMapping.push_back(std::make_pair(
1157  operand.index(), op.getInitArgs().size() + newOperands.size()));
1158  newOperands.push_back(it->second);
1159  }
1160 
1161  scf::ForOp newForOp = replaceForOpWithNewSignature(rewriter, op, newOperands);
1162  Block &loopBody = *newForOp.getBody();
1163  for (auto mapping : argMapping) {
1164  valueMapping[newForOp.getResult(mapping.first)] =
1165  newForOp.getResult(mapping.second);
1166  valueMapping[loopBody.getArgument(mapping.first +
1167  newForOp.getNumInductionVars())] =
1168  loopBody.getArgument(mapping.second + newForOp.getNumInductionVars());
1169  }
1170 
1171  LDBG() << "scf.for to: " << newForOp;
1172  return success();
1173 }
1174 
1175 static LogicalResult
1176 convertYieldOp(RewriterBase &rewriter, scf::YieldOp op,
1177  llvm::DenseMap<Value, Value> &valueMapping) {
1178  OpBuilder::InsertionGuard g(rewriter);
1179  rewriter.setInsertionPoint(op);
1180 
1181  auto loop = cast<scf::ForOp>(op->getParentOp());
1182  auto yieldOperands = llvm::to_vector<4>(op.getOperands());
1183  for (const auto &operand : llvm::enumerate(op.getOperands())) {
1184  auto it = valueMapping.find(operand.value());
1185  if (it == valueMapping.end())
1186  continue;
1187  // Replace the yield of old value with the for op argument to make it easier
1188  // to remove the dead code.
1189  yieldOperands[operand.index()] = loop.getInitArgs()[operand.index()];
1190  yieldOperands.push_back(it->second);
1191  }
1192  scf::YieldOp::create(rewriter, op.getLoc(), yieldOperands);
1193 
1194  LDBG() << "erase: " << op;
1195  rewriter.eraseOp(op);
1196  return success();
1197 }
1198 
1199 /// Convert an elementwise op to the equivalent elementwise op on MMA matrix.
1200 static LogicalResult
1202  gpu::MMAElementwiseOp opType,
1203  llvm::DenseMap<Value, Value> &valueMapping) {
1204  OpBuilder::InsertionGuard g(rewriter);
1205  rewriter.setInsertionPoint(op);
1206 
1207  SmallVector<Value> matrixOperands;
1208  for (Value operand : op->getOperands()) {
1209  auto it = valueMapping.find(operand);
1210  if (it == valueMapping.end())
1211  return rewriter.notifyMatchFailure(op, "no mapping");
1212  matrixOperands.push_back(it->second);
1213  }
1214  auto resultType = cast<gpu::MMAMatrixType>(matrixOperands[0].getType());
1215  if (opType == gpu::MMAElementwiseOp::EXTF) {
1216  // The floating point extension case has a different result type.
1217  auto vectorType = cast<VectorType>(op->getResultTypes()[0]);
1218  resultType = gpu::MMAMatrixType::get(resultType.getShape(),
1219  vectorType.getElementType(),
1220  resultType.getOperand());
1221  }
1222 
1223  Value newOp = gpu::SubgroupMmaElementwiseOp::create(
1224  rewriter, op->getLoc(), resultType, matrixOperands, opType);
1225  valueMapping[op->getResult(0)] = newOp;
1226  return success();
1227 }
1228 
1230  bool useNvGpu) {
1231  if (!useNvGpu) {
1232  patterns.add<PrepareContractToGPUMMA, CombineTransferReadOpTranspose>(
1233  patterns.getContext());
1234  return;
1235  }
1237  patterns.add<CombineTransferReadOpTranspose>(patterns.getContext());
1238 }
1239 
1241  Operation *rootOp) {
1242  SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/false);
1243  llvm::DenseMap<Value, Value> valueMapping;
1244 
1245  auto globalRes = LogicalResult::success();
1246  for (Operation *op : ops) {
1247  LDBG() << "Process op: " << *op;
1248  // Apparently callers do not want to early exit on failure here.
1249  auto res = LogicalResult::success();
1250  if (auto transferRead = dyn_cast<vector::TransferReadOp>(op)) {
1251  res = convertTransferReadOp(rewriter, transferRead, valueMapping);
1252  } else if (auto transferWrite = dyn_cast<vector::TransferWriteOp>(op)) {
1253  res = convertTransferWriteOp(rewriter, transferWrite, valueMapping);
1254  } else if (auto contractOp = dyn_cast<vector::ContractionOp>(op)) {
1255  res = convertContractOp(rewriter, contractOp, valueMapping);
1256  } else if (auto constantOp = dyn_cast<arith::ConstantOp>(op)) {
1257  res = convertConstantOp(rewriter, constantOp, valueMapping);
1258  } else if (auto broadcastOp = dyn_cast<vector::BroadcastOp>(op)) {
1259  res = convertBroadcastOp(rewriter, broadcastOp, valueMapping);
1260  } else if (auto forOp = dyn_cast<scf::ForOp>(op)) {
1261  res = convertForOp(rewriter, forOp, valueMapping);
1262  } else if (auto yieldOp = dyn_cast<scf::YieldOp>(op)) {
1263  res = convertYieldOp(rewriter, yieldOp, valueMapping);
1264  } else if (auto elementwiseType = convertElementwiseOpToMMA(op)) {
1265  res = convertElementwiseOp(rewriter, op, *elementwiseType, valueMapping);
1266  }
1267  if (failed(res))
1268  globalRes = failure();
1269  }
1270  return globalRes;
1271 }
1272 
1274  Operation *rootOp) {
1275  SetVector<Operation *> ops = getOpToConvert(rootOp, /*useNvGpu=*/true);
1276  llvm::DenseMap<Value, Value> valueMapping;
1277  for (Operation *op : ops) {
1279  .Case([&](vector::TransferReadOp transferReadOp) {
1280  return convertTransferReadToLoads(rewriter, transferReadOp,
1281  valueMapping);
1282  })
1283  .Case([&](vector::TransferWriteOp transferWriteOp) {
1284  return convertTransferWriteToStores(rewriter, transferWriteOp,
1285  valueMapping);
1286  })
1287  .Case([&](vector::ExtractStridedSliceOp extractStridedSliceOp) {
1288  return convertExtractStridedSlice(rewriter, extractStridedSliceOp,
1289  valueMapping);
1290  })
1291  .Case([&](vector::ContractionOp contractionOp) {
1292  return convertContractOpToMmaSync(rewriter, contractionOp,
1293  valueMapping);
1294  })
1295  .Case([&](scf::ForOp forOp) {
1296  return convertForOp(rewriter, forOp, valueMapping);
1297  })
1298  .Case([&](scf::YieldOp yieldOp) {
1299  return convertYieldOp(rewriter, yieldOp, valueMapping);
1300  })
1301  .Case([&](arith::ConstantOp constOp) {
1302  return convertConstantOpMmaSync(rewriter, constOp, valueMapping);
1303  })
1304  .Default([&](Operation *op) {
1305  return op->emitError() << "unhandled vector to mma type: " << *op;
1306  })
1307  .failed()) {
1308  return op->emitOpError()
1309  << "failed to convert op during vector-to-nvgpu conversion";
1310  }
1311  }
1312  return success();
1313 }
1314 
1315 namespace {
1316 
1317 struct ConvertVectorToGPUPass
1318  : public impl::ConvertVectorToGPUBase<ConvertVectorToGPUPass> {
1319 
1320  explicit ConvertVectorToGPUPass(bool useNvGpu_) {
1321  useNvGpu.setValue(useNvGpu_);
1322  }
1323 
1324  void runOnOperation() override {
1326  populatePrepareVectorToMMAPatterns(patterns, useNvGpu.getValue());
1327  if (failed(applyPatternsGreedily(getOperation(), std::move(patterns))))
1328  return signalPassFailure();
1329 
1330  IRRewriter rewriter(&getContext());
1331  if (useNvGpu) {
1332  if (failed(
1333  convertVectorToNVVMCompatibleMMASync(rewriter, getOperation())))
1334  return signalPassFailure();
1335  return;
1336  }
1337  (void)convertVectorToMMAOps(rewriter, getOperation());
1338  }
1339 };
1340 
1341 } // namespace
1342 
1343 std::unique_ptr<Pass> mlir::createConvertVectorToGPUPass(bool useNvGpu) {
1344  return std::make_unique<ConvertVectorToGPUPass>(useNvGpu);
1345 }
static MLIRContext * getContext(OpFoldResult val)
#define SUBI(lhs, rhs)
Definition: LoopEmitter.cpp:35
#define MULI(lhs, rhs)
Definition: LoopEmitter.cpp:36
#define ADDI(lhs, rhs)
Definition: LoopEmitter.cpp:33
static void contract(RootOrderingGraph &graph, ArrayRef< Value > cycle, const DenseMap< Value, unsigned > &parentDepths, DenseMap< Value, Value > &actualSource, DenseMap< Value, Value > &actualTarget)
Contracts the specified cycle in the given graph in-place.
static Value broadcast(Location loc, Value toBroadcast, unsigned numElements, const TypeConverter &typeConverter, ConversionPatternRewriter &rewriter)
Broadcasts the value to vector with numElements number of elements.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition: Traits.cpp:117
static LogicalResult convertTransferWriteOp(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertForOp(RewriterBase &rewriter, scf::ForOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertContractOp(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertContractOpToMmaSync(RewriterBase &rewriter, vector::ContractionOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool isSharedMemory(MemRefType type)
Return true if this is a shared memory memref type.
static VectorType getMmaSyncVectorOperandType(const nvgpu::FragmentElementInfo &regInfo)
Returns the vector type which represents a matrix fragment.
static const char * inferFragType(Operation *op)
static bool fpExtendSupportsMMAMatrixType(arith::ExtFOp extOp)
static bool supportsMMaMatrixType(Operation *op, bool useNvGpu)
static bool constantSupportsMMAMatrixType(arith::ConstantOp constantOp)
Return true if the constant is a splat to a 2D vector so that it can be converted to a MMA constant m...
static bool contractSupportsMMAMatrixType(vector::ContractionOp contract, bool useNvGpu)
Definition: VectorToGPU.cpp:72
static void populateFromInt64AttrArray(ArrayAttr arrayAttr, SmallVectorImpl< int64_t > &results)
static bool isTransposeMatrixLoadMap(AffineMap permutationMap)
static SetVector< Operation * > getSliceContract(Operation *op, const BackwardSliceOptions &backwardSliceOptions, const ForwardSliceOptions &forwardSliceOptions)
Return an unsorted slice handling scf.for region differently than getSlice.
static LogicalResult convertTransferReadOp(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool integerExtendSupportsMMAMatrixType(ExtOpTy extOp)
Return true if this integer extend op can be folded into a contract op.
static LogicalResult convertTransferReadToLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Converts a vector.transfer_read operation directly to either a vector.load or a nvgpu....
static LogicalResult convertBroadcastOp(RewriterBase &rewriter, vector::BroadcastOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a vector.broadcast from scalar to a SubgroupMmaConstantMatrix op.
static LogicalResult creatLdMatrixCompatibleLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
static FailureOr< bool > isTransposed(vector::TransferReadOp op)
Check if the loaded matrix operand requires transposed.
static scf::ForOp replaceForOpWithNewSignature(RewriterBase &rewriter, scf::ForOp loop, ValueRange newInitArgs)
static bool transferWriteSupportsMMAMatrixType(vector::TransferWriteOp writeOp)
static std::optional< gpu::MMAElementwiseOp > convertElementwiseOpToMMA(Operation *op)
Return the MMA elementwise enum associated with op if it is supported.
static LogicalResult convertElementwiseOp(RewriterBase &rewriter, Operation *op, gpu::MMAElementwiseOp opType, llvm::DenseMap< Value, Value > &valueMapping)
Convert an elementwise op to the equivalent elementwise op on MMA matrix.
static bool transferReadSupportsMMAMatrixType(vector::TransferReadOp readOp)
static bool extractStridedSliceSupportsMMAMatrixType(vector::ExtractStridedSliceOp op)
Returns true if the extract strided slice op is supported with mma.sync path.
static LogicalResult convertConstantOpMmaSync(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static bool elementwiseSupportsMMAMatrixType(Operation *op)
Return true if the op is supported as elementwise op on MMAMatrix type.
static SetVector< Operation * > getOpToConvert(mlir::Operation *op, bool useNvGpu)
static LogicalResult convertConstantOp(RewriterBase &rewriter, arith::ConstantOp op, llvm::DenseMap< Value, Value > &valueMapping)
Convert a 2D splat ConstantOp to a SubgroupMmaConstantMatrix op.
static LogicalResult convertYieldOp(RewriterBase &rewriter, scf::YieldOp op, llvm::DenseMap< Value, Value > &valueMapping)
static LogicalResult convertTransferWriteToStores(RewriterBase &rewriter, vector::TransferWriteOp op, llvm::DenseMap< Value, Value > &valueMapping)
static std::optional< int64_t > getStaticallyKnownRowStride(ShapedType type)
static void getXferIndices(RewriterBase &rewriter, TransferOpType xferOp, AffineMap offsetMap, ArrayRef< Value > dimValues, SmallVector< Value, 4 > &indices)
For a vector TransferOpType xferOp, an empty indices vector, and an AffineMap representing offsets to...
Definition: VectorToGPU.cpp:52
static LogicalResult convertExtractStridedSlice(RewriterBase &rewriter, vector::ExtractStridedSliceOp op, llvm::DenseMap< Value, Value > &valueMapping)
static bool broadcastSupportsMMAMatrixType(vector::BroadcastOp broadcastOp)
Return true if this is a broadcast from scalar to a 2D vector.
static LogicalResult createNonLdMatrixLoads(RewriterBase &rewriter, vector::TransferReadOp op, llvm::DenseMap< Value, Value > &valueMapping)
Base type for affine expression.
Definition: AffineExpr.h:68
A multi-dimensional affine map Affine map's are immutable like Type's, and they are uniqued.
Definition: AffineMap.h:46
MLIRContext * getContext() const
Definition: AffineMap.cpp:339
bool isMinorIdentity() const
Returns true if this affine map is a minor identity, i.e.
Definition: AffineMap.cpp:151
static AffineMap get(MLIRContext *context)
Returns a zero result affine map with no dimensions or symbols: () -> ().
unsigned getNumDims() const
Definition: AffineMap.cpp:390
unsigned getNumResults() const
Definition: AffineMap.cpp:398
AffineExpr getResult(unsigned idx) const
Definition: AffineMap.cpp:407
static AffineMap getPermutationMap(ArrayRef< unsigned > permutation, MLIRContext *context)
Returns an AffineMap representing a permutation.
Definition: AffineMap.cpp:260
AffineMap compose(AffineMap map) const
Returns the AffineMap resulting from composing this with map.
Definition: AffineMap.cpp:552
static SmallVector< AffineMap, 4 > inferFromExprList(ArrayRef< ArrayRef< AffineExpr >> exprsList, MLIRContext *context)
Returns a vector of AffineMaps; each with as many results as exprs.size(), as many dims as the larges...
Definition: AffineMap.cpp:308
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class represents an argument of a Block.
Definition: Value.h:309
Block represents an ordered list of Operations.
Definition: Block.h:33
BlockArgument getArgument(unsigned i)
Definition: Block.h:129
IntegerAttr getIndexAttr(int64_t value)
Definition: Builders.cpp:107
UnitAttr getUnitAttr()
Definition: Builders.cpp:97
AffineExpr getAffineConstantExpr(int64_t constant)
Definition: Builders.cpp:371
Ty getType(Args &&...args)
Get or construct an instance of the type Ty with provided arguments.
Definition: Builders.h:91
TypedAttr getZeroAttr(Type type)
Definition: Builders.cpp:323
AffineExpr getAffineDimExpr(unsigned position)
Definition: Builders.cpp:363
MLIRContext * getContext() const
Definition: Builders.h:56
ArrayAttr getI64ArrayAttr(ArrayRef< int64_t > values)
Definition: Builders.cpp:280
IndexType getIndexType()
Definition: Builders.cpp:50
ArrayAttr getAffineMapArrayAttr(ArrayRef< AffineMap > values)
Definition: Builders.cpp:317
static DenseElementsAttr get(ShapedType type, ArrayRef< Attribute > values)
Constructs a dense elements attribute from an array of element values.
This class coordinates rewriting a piece of IR outside of a pattern rewrite, providing a way to keep ...
Definition: PatternMatch.h:774
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Value getOperand(unsigned idx)
Definition: Operation.h:350
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition: Operation.h:749
bool hasOneUse()
Returns true if this operation has exactly one use.
Definition: Operation.h:849
OpResult getResult(unsigned idx)
Get the 'idx'th result of this operation.
Definition: Operation.h:407
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition: Operation.h:797
Location getLoc()
The source location the operation was defined or derived from.
Definition: Operation.h:223
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Definition: Operation.cpp:267
operand_type_range getOperandTypes()
Definition: Operation.h:397
result_type_range getResultTypes()
Definition: Operation.h:428
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition: Operation.h:378
user_range getUsers()
Returns a range of all users.
Definition: Operation.h:873
user_iterator user_begin()
Definition: Operation.h:869
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Definition: Operation.cpp:672
unsigned getNumResults()
Return the number of results held by this operation.
Definition: Operation.h:404
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
Definition: PatternMatch.h:793
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
std::enable_if_t<!std::is_convertible< CallbackT, Twine >::value, LogicalResult > notifyMatchFailure(Location loc, CallbackT &&reasonCallback)
Used to notify the listener that the IR failed to be rewritten because of a match failure,...
Definition: PatternMatch.h:726
virtual void eraseBlock(Block *block)
This method erases all operations in a block.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void replaceAllUsesWith(Value from, Value to)
Find uses of from and replace them with to.
Definition: PatternMatch.h:646
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Definition: PatternMatch.h:529
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
MMAMatrix represents a matrix held by a subgroup for matrix-matrix multiply accumulate operations.
Definition: GPUDialect.h:131
static MMAMatrixType get(ArrayRef< int64_t > shape, Type elementType, StringRef operand)
Get MMAMatrixType and verify construction Invariants.
Definition: GPUDialect.cpp:187
AffineApplyOp makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Returns a composed AffineApplyOp by composing map and operands with other AffineApplyOps supplying th...
Definition: AffineOps.cpp:1276
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Definition: Matchers.h:344
int64_t inferTileWidthInBits(const WarpMatrixInfo &type)
Returns the number of bits in a single tile row.
Definition: MMAUtils.cpp:85
FailureOr< vector::ContractionOp > getUserContract(Operation *op)
Returns the first user of the op that is vector.contract.
Definition: MMAUtils.cpp:48
FailureOr< AffineMap > getLaneIdAndValueIdToOperandCoord(OpBuilder &builder, Location loc, const WarpMatrixInfo &fragmentType)
Returns an AffineMap which maps a two dimensions representing (laneId, logicalValueId) and returns tw...
Definition: MMAUtils.cpp:167
FailureOr< WarpMatrixInfo > getWarpMatrixInfo(Operation *op)
If op is a vector.transfer_write, return the WarpMatrixInfo for the vector operand.
Definition: MMAUtils.cpp:56
FailureOr< AffineMap > getLaneIdToLdMatrixMatrixCoord(OpBuilder &builder, Location loc, const LdMatrixParams &params)
Returns an AffineMap which maps a single dimension representing the laneId to two results representin...
Definition: MMAUtils.cpp:232
FailureOr< LdMatrixParams > getLdMatrixParams(const WarpMatrixInfo &type, bool transpose)
Given type that contains info for a warp-matrix operand and whether or not the load is a transposed l...
Definition: MMAUtils.cpp:203
FailureOr< FragmentElementInfo > getMmaSyncRegisterType(const WarpMatrixInfo &type)
Returns a FragmentElementInfo struct describing the register types for the given matrix fragment type...
Definition: MMAUtils.cpp:98
bool canLowerToWarpMatrixOperation(vector::TransferReadOp op)
Returns whether the vector.transfer_read instruction can be interpreted as a warp-level cooperative m...
Definition: MMAUtils.cpp:270
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
bool isReductionIterator(Attribute attr)
Returns true if attr has "reduction" iterator type semantics.
Definition: VectorOps.h:154
bool isParallelIterator(Attribute attr)
Returns true if attr has "parallel" iterator type semantics.
Definition: VectorOps.h:149
void populateVectorContractCanonicalizeMatmulToMMT(RewritePatternSet &patterns, std::function< LogicalResult(vector::ContractionOp)> constraint=[](vector::ContractionOp) { return success();}, PatternBenefit=1)
Canonicalization of a vector.contraction a, b, c with row-major matmul semantics to a contraction wit...
Include the generated interface declarations.
void populatePrepareVectorToMMAPatterns(RewritePatternSet &patterns, bool useNvGpu=false)
Patterns to transform vector ops into a canonical form to convert to MMA matrix operations.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition: Utils.cpp:304
void bindDims(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to DimExpr at positions: [0 .
Definition: AffineExpr.h:311
LogicalResult applyPatternsGreedily(Region &region, const FrozenRewritePatternSet &patterns, GreedyRewriteConfig config=GreedyRewriteConfig(), bool *changed=nullptr)
Rewrite ops in the given region, which must be isolated from above, by repeatedly applying the highes...
const FrozenRewritePatternSet & patterns
LogicalResult convertVectorToNVVMCompatibleMMASync(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops ops nested under rootOp to vector and GPU operaitons compatible with the nvvm....
AffineExpr getAffineConstantExpr(int64_t constant, MLIRContext *context)
Definition: AffineExpr.cpp:643
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
std::unique_ptr< Pass > createConvertVectorToGPUPass(bool useNvGpu=false)
Convert from vector to GPU ops.
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:619
LogicalResult convertVectorToMMAOps(RewriterBase &rewriter, Operation *rootOp)
Convert vector ops to MMA matrix operations nested under rootOp.
SetVector< Operation * > topologicalSort(const SetVector< Operation * > &toSort)
Sorts all operations in toSort topologically while also considering region semantics.
void getForwardSlice(Operation *op, SetVector< Operation * > *forwardSlice, const ForwardSliceOptions &options={})
Fills forwardSlice with the computed forward slice (i.e.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
Definition: PatternMatch.h:314
This trait tags element-wise ops on vectors or tensors.
TransitiveFilter filter
Definition: SliceAnalysis.h:29
Specifies information about the registers which compose a matrix fragment according to the PTX docume...
Definition: MMAUtils.h:52