MLIR  18.0.0git
NVGPUTransformOps.cpp
Go to the documentation of this file.
1 //===- NVGPUTransformOps.cpp - Implementation of NVGPU transform ops ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
28 #include "mlir/IR/AffineExpr.h"
29 #include "mlir/IR/BuiltinTypes.h"
30 #include "mlir/IR/Value.h"
31 #include "llvm/ADT/ArrayRef.h"
32 
33 using namespace mlir;
34 using namespace mlir::linalg;
35 using namespace mlir::nvgpu;
36 using namespace mlir::NVVM;
37 using namespace mlir::transform;
38 
39 #define DEBUG_TYPE "nvgpu-transforms"
40 #define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
41 #define DBGSNL() (llvm::dbgs() << "\n")
42 #define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
43 
44 //===----------------------------------------------------------------------===//
45 // Apply...ConversionPatternsOp
46 //===----------------------------------------------------------------------===//
47 
48 void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
49  TypeConverter &typeConverter, RewritePatternSet &patterns) {
50  auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
51  /// device-side async tokens cannot be materialized in nvvm. We just
52  /// convert them to a dummy i32 type in order to easily drop them during
53  /// conversion.
54  llvmTypeConverter.addConversion(
55  [&](nvgpu::DeviceAsyncTokenType type) -> Type {
56  return llvmTypeConverter.convertType(
57  IntegerType::get(type.getContext(), 32));
58  });
59  llvmTypeConverter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
60  return llvmTypeConverter.convertType(
61  IntegerType::get(type.getContext(), 64));
62  });
63  llvmTypeConverter.addConversion(
64  [&](nvgpu::WarpgroupAccumulatorType type) -> Type {
65  Type elemType = type.getFragmented().getElementType();
66  int64_t sizeM = type.getFragmented().getDimSize(0);
67  int64_t sizeN = type.getFragmented().getDimSize(1);
68 
69  unsigned numMembers;
70  if (elemType.isF32() || elemType.isInteger(32))
71  numMembers = sizeN / 2;
72  else if (elemType.isF16())
73  numMembers = sizeN / 4;
74  else
75  llvm_unreachable("unsupported type for warpgroup accumulator");
76 
77  SmallVector<Type> innerStructBody;
78  for (unsigned i = 0; i < numMembers; i++)
79  innerStructBody.push_back(elemType);
80  auto innerStructType = LLVM::LLVMStructType::getLiteral(
81  type.getContext(), innerStructBody);
82 
83  SmallVector<Type> structBody;
84  for (int i = 0; i < sizeM; i += kWgmmaSizeM)
85  structBody.push_back(innerStructType);
86 
87  auto convertedType =
88  LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
89  return llvmTypeConverter.convertType(convertedType);
90  });
91  llvmTypeConverter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
92  return llvmTypeConverter.convertType(
93  getMBarrierMemrefType(type.getContext(), type));
94  });
95  llvmTypeConverter.addConversion(
96  [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
97  return llvmTypeConverter.convertType(
98  IntegerType::get(type.getContext(), 64));
99  });
100  llvmTypeConverter.addConversion(
101  [&](nvgpu::TensorMapDescriptorType type) -> Type {
102  return LLVM::LLVMPointerType::get(type.getContext());
103  });
104  populateNVGPUToNVVMConversionPatterns(llvmTypeConverter, patterns);
105 }
106 
108 transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(
109  transform::TypeConverterBuilderOpInterface builder) {
110  if (builder.getTypeConverterType() != "LLVMTypeConverter")
111  return emitOpError("expected LLVMTypeConverter");
112  return success();
113 }
114 
115 //===---------------------------------------------------------------------===//
116 // CreateAsyncGroupsOp
117 //===---------------------------------------------------------------------===//
118 
119 void transform::CreateAsyncGroupsOp::getEffects(
121  transform::consumesHandle(getTarget(), effects);
122  transform::producesHandle(getResult(), effects);
124 }
125 
126 DiagnosedSilenceableFailure transform::CreateAsyncGroupsOp::applyToOne(
127  TransformRewriter &rewriter, Operation *target,
128  ApplyToEachResultList &results, TransformState &state) {
129  nvgpu::createAsyncGroups(rewriter, target, getBypassL1());
130  results.push_back(target);
132 }
133 
134 //===----------------------------------------------------------------------===//
135 // PipelineSharedMemoryCopiesOp
136 //===----------------------------------------------------------------------===//
137 
138 /// Returns true if the given type has the default memory space.
140  return !type.getMemorySpace() || type.getMemorySpaceAsInt() == 0;
141 }
142 
143 /// Returns true if the given type has the shared (workgroup) memory space.
145  auto space =
146  dyn_cast_if_present<gpu::AddressSpaceAttr>(type.getMemorySpace());
147  return space &&
148  space.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
149 }
150 
151 /// Returns the value produced by a load from the default memory space. Returns
152 /// null if the operation is not such a load.
154  // TODO: consider an interface or leveraging the memory effects interface.
155  auto load = dyn_cast<vector::TransferReadOp>(op);
156  if (!load)
157  return nullptr;
158 
159  auto loadType = dyn_cast<MemRefType>(load.getSource().getType());
160  if (!loadType || !hasDefaultMemorySpace(loadType))
161  return nullptr;
162  return load;
163 }
164 
165 /// Returns true if the operation is storing the given value into shared memory.
166 static bool isStoreToShared(Operation *op, Value v) {
167  // TOD: consider an interface or leveraging the memory effects interface.
168  auto store = dyn_cast<vector::TransferWriteOp>(op);
169  if (!store || store.getVector() != v)
170  return false;
171 
172  auto storeType = dyn_cast<MemRefType>(store.getSource().getType());
173  return storeType || hasSharedMemorySpace(storeType);
174 }
175 
176 /// Returns true if the operation is a load from the default memory space the
177 /// result of which is only stored into the shared memory space.
179  Value loaded = getValueLoadedFromGlobal(op);
180  if (!loaded || !loaded.hasOneUse())
181  return false;
182 
183  return isStoreToShared(*loaded.getUsers().begin(), loaded);
184 }
185 
186 /// Populate `ops` with the set of operations that belong to the stage 0 of the
187 /// pipelined version of the given loop when pipelining copies to shared memory.
188 /// Specifically, this collects:
189 ///
190 /// 1. all loads from global memory, both sync and async;
191 /// 2. the barriers for async loads.
192 ///
193 /// In particular, barriers are omitted if they do not dominate at least one
194 /// async load for which there is not yet a barrier.
195 static LogicalResult
196 collectStage0PipeliningOps(scf::ForOp forOp,
198 
200  for (Operation &op : *forOp.getBody()) {
201  // Bail on nested ops for now.
202  if (op.getNumRegions() > 0)
203  return failure();
204 
205  if (isa<gpu::BarrierOp>(op)) {
206  barriers.insert(&op);
207  continue;
208  }
209 
210  if (isa<nvgpu::DeviceAsyncCopyOp, nvgpu::DeviceAsyncCreateGroupOp>(op)) {
211  ops.insert(&op);
212  ops.insert(std::make_move_iterator(barriers.begin()),
213  std::make_move_iterator(barriers.end()));
214  assert(barriers.empty() &&
215  "expected to have moved the barriers into another set");
216  continue;
217  }
218 
220  ops.insert(&op);
221  continue;
222  }
223  }
224 
225  return success();
226 }
227 
228 /// Hook for the loop pipeliner that sets the "num groups in flight" attribute
229 /// of async wait operations corresponding to pipelined shared memory copies.
230 // TODO: this currently assumes that there are no groups that could be in flight
231 // in the existing code.
232 static void
235  unsigned iteration, unsigned depth) {
236  // Based on the order of copies within the loop we need to set the number
237  // of copies in flight, unless it is already set.
238  auto waitOp = dyn_cast<nvgpu::DeviceAsyncWaitOp>(op);
239  if (!waitOp || waitOp.getNumGroups())
240  return;
241 
242  int numGroupInFlight = 0;
245  numGroupInFlight = depth - 1;
246  } else {
247  // By construction there should be no wait op in the prologue as all the
248  // wait should be in the last stage.
250  // Based on the schedule we pick we know how many groups are in flight for
251  // each iteration of the epilogue.
252  numGroupInFlight = depth - 1 - iteration;
253  }
254  waitOp.setNumGroups(numGroupInFlight);
255 }
256 
257 /// Hook for the loop pipeliner that populates `ops` with the stage information
258 /// as follows:
259 ///
260 /// - operations in `stage0Ops` (typically loads from global memory and
261 /// related barriers) are at stage 0;
262 /// - operations in the backward slice of any stage0Ops are all at stage 0;
263 /// - other operations are at stage `depth`;
264 /// - the internal order of the pipelined loop has ops at stage `depth` first,
265 /// then those at stage 0, with relative order within each group preserved.
266 ///
267 static void getPipelineStages(
268  scf::ForOp forOp,
269  std::vector<std::pair<Operation *, unsigned>> &opsWithPipelineStages,
270  unsigned depth, llvm::SmallPtrSetImpl<Operation *> &stage0Ops) {
271  SetVector<Operation *> dependencies;
272  BackwardSliceOptions options([&](Operation *visited) {
273  return visited->getBlock() == forOp.getBody();
274  });
275  options.inclusive = true;
276  for (Operation &op : forOp.getBody()->getOperations()) {
277  if (stage0Ops.contains(&op))
278  getBackwardSlice(&op, &dependencies, options);
279  }
280 
281  for (Operation &op : forOp.getBody()->getOperations()) {
282  if (!dependencies.contains(&op) && !isa<scf::YieldOp>(op))
283  opsWithPipelineStages.emplace_back(&op, depth);
284  }
285  for (Operation &op : forOp.getBody()->getOperations()) {
286  if (dependencies.contains(&op))
287  opsWithPipelineStages.emplace_back(&op, 0);
288  }
289 }
290 
291 /// Hook for the loop pipeliner. Replaces op with a predicated version and
292 /// returns the resulting operation. Returns the original op if the predication
293 /// isn't necessary for the given op. Returns null if predication is needed but
294 /// not supported.
296  Operation *op, Value predicate) {
297  // Some operations may be fine to execute "speculatively" more times than the
298  // original number of iterations, in particular side-effect free operations
299  // and barriers, even if they cannot be predicated.
300  if (isMemoryEffectFree(op) ||
301  isa<gpu::BarrierOp, nvgpu::DeviceAsyncCreateGroupOp,
302  nvgpu::DeviceAsyncWaitOp>(op)) {
303  return op;
304  }
305 
306  // Otherwise, only async copies can currently be predicated.
307  auto asyncCopyOp = dyn_cast<nvgpu::DeviceAsyncCopyOp>(op);
308  if (!asyncCopyOp)
309  return nullptr;
310 
311  // Create srcElement Value based on `predicate`. The next lines generate
312  // the following code:
313  //
314  // srcElement = (pred) ? prevSrcElements : 0;
315  //
316  Location loc = asyncCopyOp->getLoc();
317  Value dstElements =
318  rewriter.create<arith::ConstantOp>(loc, asyncCopyOp.getDstElementsAttr());
319  Value originalSrcElement =
320  asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements;
321  Value c0Index = rewriter.create<arith::ConstantIndexOp>(loc, 0);
322  auto srcElements = rewriter.create<arith::SelectOp>(
323  loc, predicate, originalSrcElement, c0Index);
324  auto asyncCopyZeroFillOp = rewriter.create<nvgpu::DeviceAsyncCopyOp>(
325  loc, nvgpu::DeviceAsyncTokenType::get(asyncCopyOp.getContext()),
326  asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(),
327  asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements,
328  UnitAttr());
329  rewriter.replaceOp(asyncCopyOp, asyncCopyZeroFillOp);
330  return asyncCopyZeroFillOp;
331 }
332 
333 /// Applies loop pipelining with the given depth to the given loop so that
334 /// copies into the shared memory are pipelined. Doesn't affect other loops.
335 /// Returns a pair containing the error state and the pipelined op, the latter
336 /// being null in case of any failure. The error state contains a definite error
337 /// if the IR has been modified and a silenceable error otherwise.
338 static std::tuple<DiagnosedSilenceableFailure, scf::ForOp>
339 pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth,
340  bool epiloguePeeling) {
342  if (failed(collectStage0PipeliningOps(forOp, stage0Ops))) {
343  return std::make_tuple(
344  emitSilenceableFailure(forOp, "cannot find stage 0 ops for pipelining"),
345  scf::ForOp());
346  }
347  if (stage0Ops.empty()) {
348  return std::make_tuple(
349  emitSilenceableFailure(forOp, "no shared memory copy"), scf::ForOp());
350  }
351 
353  unsigned maxDepth = depth;
354  auto setAnnotation = [&](Operation *op,
356  unsigned iteration) {
357  return setAsyncWaitGroupsInFlight(rewriter, op, part, iteration, maxDepth);
358  };
359  options.getScheduleFn =
360  [&](scf::ForOp schedulingFor,
361  std::vector<std::pair<Operation *, unsigned>> &ops) {
362  if (schedulingFor != forOp)
363  return;
364  return getPipelineStages(forOp, ops, maxDepth, stage0Ops);
365  };
366  options.annotateFn = setAnnotation;
367  if (!epiloguePeeling) {
368  options.peelEpilogue = false;
369  options.predicateFn = replaceOpWithPredicatedOp;
370  }
371 
372  OpBuilder::InsertionGuard guard(rewriter);
373  rewriter.setInsertionPoint(forOp);
374  bool modifiedIR;
375  FailureOr<scf::ForOp> maybePipelined =
376  pipelineForLoop(rewriter, forOp, options, &modifiedIR);
377  if (succeeded(maybePipelined)) {
378  return std::make_tuple(DiagnosedSilenceableFailure::success(),
379  *maybePipelined);
380  }
381  return std::make_tuple(
382  modifiedIR
384  : emitSilenceableFailure(forOp, "pipelining preconditions failed"),
385  scf::ForOp());
386 }
387 
388 DiagnosedSilenceableFailure PipelineSharedMemoryCopiesOp::applyToOne(
389  TransformRewriter &rewriter, scf::ForOp forOp,
390  ApplyToEachResultList &results, TransformState &state) {
391  auto [diag, pipelined] = pipelineForSharedCopies(
392  rewriter, forOp, static_cast<int64_t>(getDepth()), getPeelEpilogue());
393  if (diag.succeeded()) {
394  results.push_back(pipelined);
396  }
397  if (diag.isDefiniteFailure()) {
398  auto diag = emitDefiniteFailure("irreversible pipelining failure");
399  if (!getPeelEpilogue()) {
400  diag.attachNote(forOp->getLoc()) << "couldn't predicate?";
401  diag.attachNote(getLoc()) << "try setting " << getPeelEpilogueAttrName();
402  }
403  return diag;
404  }
405 
406  return std::move(diag);
407 }
408 
409 //===----------------------------------------------------------------------===//
410 // RewriteMatmulAsMmaSyncOp
411 //===----------------------------------------------------------------------===//
412 
413 /// Helper struct to encode a pair of row/column indexings in the form of
414 /// affine expressions.
415 struct RowColIndexing : private std::pair<AffineExpr, AffineExpr> {
417  : std::pair<AffineExpr, AffineExpr>(row, col) {}
418 
419  AffineExpr row() const { return first; };
420  AffineExpr col() const { return second; };
421 
422  void print(llvm::raw_ostream &os) const {
423  os << "- indexing: " << first << ", " << second;
424  }
425 };
426 
427 /// Helper struct to provide a simple mapping from matmul operations to the
428 /// corresponding mma.sync operation. This is constrained to the case where the
429 /// matmul matches the mma.sync operation 1-1.
432  : b(b), loc(loc), laneId(laneId) {}
433 
435  std::function<SmallVector<RowColIndexing>(MLIRContext *)>;
436 
437  /// Create the mma.sync operation corresponding to `linalgOp` along with all
438  /// the supporting load/store and vector operations.
439  FailureOr<Operation *> buildMmaSync(LinalgOp linalgOp);
440 
441 private:
442  struct MmaSyncInfo {
443  std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns;
444  std::tuple<SmallVector<int64_t>, SmallVector<int64_t>, SmallVector<int64_t>>
445  vectorShapes;
446  SmallVector<int64_t> mmaShape;
447  bool tf32Enabled;
448  };
449 
450  /// Return the specific index calculator for the given `linalgOp` or failure
451  /// if the op is not supported. This is the toplevel switch that should just
452  /// be Tablegen'd in the future.
453  FailureOr<MmaSyncInfo> getIndexCalculators(ArrayRef<int64_t> opShape,
454  TypeRange elementalTypes);
455 
456  //===--------------------------------------------------------------------===//
457  // Instruction-specific row, column indexing expression builders.
458  // These should all be declaratively specified via Tablegen in the future.
459  // The Tablegen specification should be as straightforward as possible to
460  // only model the existing size and type combinations.
461  //===--------------------------------------------------------------------===//
462  //
463  // TODO: Tablegen all this.
464  //===--------------------------------------------------------------------===//
465  // m16n8k4 tf32 case.
466  //===--------------------------------------------------------------------===//
467  /// From the NVIDIA doc:
468  /// groupID = %laneid >> 2
469  /// threadIDInGroup = %laneid % 4
470  /// row = groupID for a0
471  /// groupID + 8 for a1
472  /// col = threadIDInGroup
473  static SmallVector<RowColIndexing> m16n8k4tf32Lhs(MLIRContext *ctx) {
474  auto dim = getAffineDimExpr(0, ctx);
475  AffineExpr groupID = dim.floorDiv(4);
476  AffineExpr threadIDInGroup = dim % 4;
477  return {RowColIndexing{groupID, threadIDInGroup},
478  RowColIndexing{groupID + 8, threadIDInGroup}};
479  }
480 
481  /// From the NVIDIA doc:
482  /// groupID = %laneid >> 2
483  /// threadIDInGroup = %laneid % 4
484  /// row = threadIDInGroup
485  /// col = groupID
486  static SmallVector<RowColIndexing> m16n8k4tf32Rhs(MLIRContext *ctx) {
487  auto dim = getAffineDimExpr(0, ctx);
488  AffineExpr groupID = dim.floorDiv(4);
489  AffineExpr threadIDInGroup = dim % 4;
490  return {RowColIndexing{threadIDInGroup, groupID}};
491  }
492 
493  /// From the NVIDIA doc:
494  /// groupID = %laneid >> 2
495  /// threadIDInGroup = %laneid % 4
496  /// row = groupID for c0 and c1
497  /// groupID + 8 for c2 and c3
498  /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3}
499  static SmallVector<RowColIndexing> m16n8k4tf32Res(MLIRContext *ctx) {
500  auto dim = getAffineDimExpr(0, ctx);
501  AffineExpr groupID = dim.floorDiv(4);
502  AffineExpr threadIDInGroup = dim % 4;
503  return {RowColIndexing{groupID, threadIDInGroup * 2 + 0},
504  RowColIndexing{groupID, threadIDInGroup * 2 + 1},
505  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},
506  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}};
507  }
508 
509  //===--------------------------------------------------------------------===//
510  // m16n8k16 f16 case.
511  //===--------------------------------------------------------------------===//
512  /// From the NVIDIA doc:
513  /// groupID = %laneid >> 2
514  /// threadIDInGroup = %laneid % 4
515  ///
516  /// row = groupID for ai where 0 <= i < 2 || 4 <= i < 6
517  /// groupID + 8 Otherwise
518  ///
519  /// col = (threadIDInGroup * 2) + (i & 0x1) for ai where i < 4
520  /// (threadIDInGroup * 2) + (i & 0x1) + 8 for ai where i >= 4
521  static SmallVector<RowColIndexing> m16n8k16f16Lhs(MLIRContext *ctx) {
522  auto dim = getAffineDimExpr(0, ctx);
523  AffineExpr groupID = dim.floorDiv(4);
524  AffineExpr threadIDInGroup = dim % 4;
525  // clang-format off
526  return {
527  RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0
528  RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1
529  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2
530  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}, // i == 3
531  RowColIndexing{groupID, threadIDInGroup * 2 + 0 + 8}, // i == 4
532  RowColIndexing{groupID, threadIDInGroup * 2 + 1 + 8}, // i == 5
533  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0 + 8}, // i == 6
534  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1 + 8} // i == 7
535  };
536  // clang-format on
537  }
538 
539  /// From the NVIDIA doc:
540  /// groupID = %laneid >> 2
541  /// threadIDInGroup = %laneid % 4
542  ///
543  /// row = (threadIDInGroup * 2) + (i & 0x1) for bi where i < 2
544  /// (threadIDInGroup * 2) + (i & 0x1) + 8 for bi where i >= 2
545  ///
546  /// col = groupID
547  static SmallVector<RowColIndexing> m16n8k16f16Rhs(MLIRContext *ctx) {
548  auto dim = getAffineDimExpr(0, ctx);
549  AffineExpr groupID = dim.floorDiv(4);
550  AffineExpr threadIDInGroup = dim % 4;
551  // clang-format off
552  return {
553  RowColIndexing{threadIDInGroup * 2 + 0, groupID}, // i == 0
554  RowColIndexing{threadIDInGroup * 2 + 1, groupID}, // i == 1
555  RowColIndexing{threadIDInGroup * 2 + 0 + 8, groupID}, // i == 2
556  RowColIndexing{threadIDInGroup * 2 + 1 + 8, groupID} // i == 3
557  };
558  // clang-format on
559  }
560 
561  /// From the NVIDIA doc:
562  /// groupID = %laneid >> 2
563  /// threadIDInGroup = %laneid % 4
564  ///
565  /// row = groupID for ci where i < 2
566  /// groupID + 8 for ci where i >= 2
567  ///
568  /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3}
569  static SmallVector<RowColIndexing> m16n8k16f16Res(MLIRContext *ctx) {
570  auto dim = getAffineDimExpr(0, ctx);
571  AffineExpr groupID = dim.floorDiv(4);
572  AffineExpr threadIDInGroup = dim % 4;
573  // clang-format off
574  return {
575  RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0
576  RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1
577  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2
578  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1} // i == 3
579  };
580  // clang-format on
581  }
582 
583  //===--------------------------------------------------------------------===//
584  /// Helper functions to create customizable load and stores operations. The
585  /// specific shapes of each MMA instruction are passed via the
586  /// IndexCalculator callback.
587  //===--------------------------------------------------------------------===//
588  /// Build a list of memref.load operations indexed at `(row, col)` indices
589  /// that make sense for a particular MMA instruction and specified via the
590  /// IndexCalculator callback.
591  SmallVector<Value> buildMemRefLoads(OpBuilder &b, Location loc,
592  OpFoldResult laneId, Value memref,
593  IndexCalculator indexFn);
594 
595  /// Perform a distributed load of a vector operand of `vectorShape` for a
596  /// particular MMA instruction whose `(row, col)` indices are specified via
597  /// the IndexCalculator callback. Each `laneId` loads the subportion of the
598  /// data that makes sense for the particular MMA operation.
599  /// The `vectorShape` matches existing NVGPU dialect op specification but
600  /// could also be flattened in the future if needed for simplification.
601  Value buildMmaSyncMemRefLoadOperand(OpBuilder &b, Location loc,
602  OpFoldResult laneId, Value memref,
603  IndexCalculator indexFn,
605 
606  /// Build a list of memref.store operations indexed at `(row, col)` indices
607  /// that make sense for a particular MMA instruction and specified via the
608  /// IndexCalculator callback.
609  SmallVector<Operation *> buildMemRefStores(OpBuilder &b, Location loc,
610  ValueRange toStore,
611  OpFoldResult laneId, Value memref,
612  IndexCalculator indexFn);
613 
614  /// Perform a distributed store of a vector operand of `vectorShape` for a
615  /// particular MMA instruction whose `(row, col)` indices are specified via
616  /// the IndexCalculator callback. Each `laneId` loads the subportion of the
617  /// data that makes sense for the particular MMA operation.
618  /// The `vectorShape` matches existing NVGPU dialect op specification but
619  /// could also be flattened in the future if needed for simplification.
620  SmallVector<Operation *> buildMmaSyncMemRefStoreOperand(
621  OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
622  Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape);
623 
624  OpBuilder &b;
625  Location loc;
626  OpFoldResult laneId;
627 };
628 
629 //===--------------------------------------------------------------------===//
630 /// Helper functions to create customizable load and stores operations. The
631 /// specific shapes of each MMA instruction are passed via the
632 /// IndexCalculator callback.
633 //===--------------------------------------------------------------------===//
634 
635 template <typename ApplyFn, typename ReduceFn>
636 static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn,
637  ReduceFn reduceFn) {
638  VectorType vectorType = vector.getType().cast<VectorType>();
639  auto vectorShape = vectorType.getShape();
640  auto strides = computeStrides(vectorShape);
641  for (int64_t idx = 0, e = vectorShape[0] * strides[0]; idx < e; ++idx) {
642  auto indices = delinearize(idx, strides);
643  reduceFn(applyFn(vector, idx, indices), idx, indices);
644  }
645 }
646 
647 SmallVector<Value> MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc,
648  OpFoldResult laneId,
649  Value memref,
650  IndexCalculator indexFn) {
651  auto aff = [&](AffineExpr e) {
652  return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
653  };
654  SmallVector<Value> res;
655  SmallVector<RowColIndexing> indexings = indexFn(b.getContext());
656  for (auto indexing : indexings) {
657  Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row()));
658  Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col()));
659  auto load = b.create<memref::LoadOp>(loc, memref, ValueRange{row, col});
660  res.push_back(load);
661  }
662  return res;
663 }
664 
665 Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
666  OpBuilder &b, Location loc, OpFoldResult laneId, Value memref,
667  IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
668  auto loads = buildMemRefLoads(b, loc, laneId, memref, indexFn);
669 
670  Type elementType = getElementTypeOrSelf(memref.getType());
671  auto vt = VectorType::get(vectorShape, elementType);
672  Value res = b.create<vector::SplatOp>(loc, vt, loads[0]);
674  res,
675  /*applyFn=*/
676  [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
677  return loads[linearIdx];
678  },
679  /*reduceFn=*/
680  [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
681  res = b.create<vector::InsertOp>(loc, v, res, indices);
682  });
683 
684  return res;
685 }
686 
688 MmaSyncBuilder::buildMemRefStores(OpBuilder &b, Location loc,
689  ValueRange toStore, OpFoldResult laneId,
690  Value memref, IndexCalculator indexFn) {
691  auto aff = [&](AffineExpr e) {
692  return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
693  };
695  for (auto [indexing, val] :
696  llvm::zip_equal(indexFn(b.getContext()), toStore)) {
697  Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row()));
698  Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col()));
699  Operation *store =
700  b.create<memref::StoreOp>(loc, val, memref, ValueRange{row, col});
701  res.push_back(store);
702  }
703  return res;
704 }
705 
706 SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemRefStoreOperand(
707  OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
708  Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
709  SmallVector<Value> toStore;
710  toStore.reserve(32);
712  vectorToStore,
713  /*applyFn=*/
714  [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
715  return b.create<vector::ExtractOp>(loc, vectorToStore, indices);
716  },
717  /*reduceFn=*/
718  [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
719  toStore.push_back(v);
720  });
721  return buildMemRefStores(b, loc, toStore, laneId, memref, indexFn);
722 }
723 
724 static std::tuple<SmallVector<int64_t>, SmallVector<int64_t>,
727  ArrayRef<int64_t> res) {
728  SmallVector<int64_t> vlhs{lhs.begin(), lhs.end()};
729  SmallVector<int64_t> vrhs{rhs.begin(), rhs.end()};
730  SmallVector<int64_t> vres{res.begin(), res.end()};
731  return std::make_tuple(vlhs, vrhs, vres);
732 }
733 
735 MmaSyncBuilder::getIndexCalculators(ArrayRef<int64_t> opShape,
736  TypeRange elementalTypes) {
737  // TODO: Tablegen all this.
738  Type f16 = b.getF16Type();
739  Type f32 = b.getF32Type();
740  if (opShape == ArrayRef<int64_t>{16, 8, 4} &&
741  elementalTypes == TypeRange{f32, f32, f32}) {
742  return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k4tf32Lhs,
743  &MmaSyncBuilder::m16n8k4tf32Rhs,
744  &MmaSyncBuilder::m16n8k4tf32Res),
745  makeVectorShapes({2, 1}, {1, 1}, {2, 2}),
746  SmallVector<int64_t>{opShape.begin(), opShape.end()},
747  /*tf32Enabled=*/true};
748  }
749  // This is the version with f16 accumulation.
750  // TODO: version with f32 accumulation.
751  if (opShape == ArrayRef<int64_t>{16, 8, 16} &&
752  elementalTypes == TypeRange{f16, f16, f16}) {
753  return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k16f16Lhs,
754  &MmaSyncBuilder::m16n8k16f16Rhs,
755  &MmaSyncBuilder::m16n8k16f16Res),
756  makeVectorShapes({4, 2}, {2, 2}, {2, 2}),
757  SmallVector<int64_t>{opShape.begin(), opShape.end()},
758  /*tf32Enabled=*/false};
759  }
760  return failure();
761 }
762 
764  Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get();
765  Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get();
766  Value resMemRef = linalgOp.getDpsInitOperand(0)->get();
767  assert(lhsMemRef.getType().cast<MemRefType>().getRank() == 2 &&
768  "expected lhs to be a 2D memref");
769  assert(rhsMemRef.getType().cast<MemRefType>().getRank() == 2 &&
770  "expected rhs to be a 2D memref");
771  assert(resMemRef.getType().cast<MemRefType>().getRank() == 2 &&
772  "expected res to be a 2D memref");
773 
774  int64_t m = cast<MemRefType>(lhsMemRef.getType()).getShape()[0];
775  int64_t n = cast<MemRefType>(rhsMemRef.getType()).getShape()[1];
776  int64_t k = cast<MemRefType>(lhsMemRef.getType()).getShape()[1];
777  Type lhsType = getElementTypeOrSelf(lhsMemRef.getType());
778  Type rhsType = getElementTypeOrSelf(rhsMemRef.getType());
779  Type resType = getElementTypeOrSelf(resMemRef.getType());
780 
781  FailureOr<MmaSyncInfo> maybeInfo =
782  getIndexCalculators({m, n, k}, {lhsType, rhsType, resType});
783  if (failed(maybeInfo))
784  return failure();
785 
786  MmaSyncInfo info = *maybeInfo;
787  auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;
788  auto [lhsShape, rhsShape, resShape] = info.vectorShapes;
789  Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef,
790  lhsIndexFn, lhsShape);
791  Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, rhsMemRef,
792  rhsIndexFn, rhsShape);
793  Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef,
794  resIndexFn, resShape);
795  res = b.create<nvgpu::MmaSyncOp>(loc, lhs, rhs, res, info.mmaShape,
796  info.tf32Enabled);
797  buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn,
798  resShape);
799  return res.getDefiningOp();
800 }
801 
802 DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne(
803  transform::TransformRewriter &rewriter, LinalgOp linalgOp,
805  transform::TransformState &state) {
806  bool fail = true;
807  // TODO: more robust detection of matmulOp, with transposes etc.
808  if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) {
809  Location loc = linalgOp.getLoc();
810  // TODO: more robust computation of laneId, for now assume a single warp.
811  Value laneId = rewriter.create<gpu::ThreadIdOp>(
812  loc, rewriter.getIndexType(), gpu::Dimension::x);
813  if (succeeded(MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp)))
814  fail = false;
815  }
816 
817  if (fail) {
818  DiagnosedSilenceableFailure diag = emitSilenceableError()
819  << "unsupported target op: " << linalgOp;
820  diag.attachNote(linalgOp->getLoc()) << "target op";
821  return diag;
822  }
823 
824  rewriter.eraseOp(linalgOp);
826 }
827 
828 //===----------------------------------------------------------------------===//
829 // Hopper builders.
830 //===----------------------------------------------------------------------===//
831 
832 /// Helper to create the base Hopper-specific operations that are reused in
833 /// various other places.
836  : rewriter(rewriter), loc(loc) {}
837 
839  buildAndInitBarrierInSharedMemory(OpFoldResult numThreads);
840 
841  /// Create tma descriptor op to initiate transfer from global to shared
842  /// memory. This must be done before the launch op, on the host.
844  buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
845  gpu::LaunchOp launchOp);
846 
847  /// Build a tma load from global memory to shared memory using `barrier` to
848  /// synchronize. Return the number of bytes that will be transferred.
850  buildTmaAsyncLoad(TypedValue<nvgpu::TensorMapDescriptorType> globalDesc,
851  TypedValue<MemRefType> sharedMemref,
854  void buildBarrierArriveTx(TypedValue<nvgpu::MBarrierGroupType> barrier,
855  ArrayRef<OpFoldResult> sizes);
856 
857  /// If threadIdx.x == 0 does TMA request + wait, else just wait.
858  /// Return the operation that performs the transfer on thread0.
859  // TODO: In the future, don't hardcode to thread 0 but elect a leader.
860  SmallVector<Operation *> buildPredicateLoadsOnThread0(
862  ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
864 
865  void buildTryWaitParity(TypedValue<nvgpu::MBarrierGroupType> barrier);
866 
869 };
870 
872  ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors,
873  ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
875  SmallVector<Operation *> loadOps;
876  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
877  Value tidx = rewriter.create<gpu::ThreadIdOp>(loc, gpu::Dimension::x);
878  Value cond =
879  rewriter.create<arith::CmpIOp>(loc, arith::CmpIPredicate::eq, tidx, zero);
880  // clang-format off
881  rewriter.create<scf::IfOp>(
882  /*location=*/loc,
883  /*conditional=*/cond,
884  /*thenBuilder=*/
885  [&](OpBuilder &lb, Location loc) {
887  sizes.reserve(globalDescriptors.size());
888  for (auto [desc, shmem] : llvm::zip_equal(
889  globalDescriptors, sharedMemBuffers)) {
890  OpFoldResult sz = buildTmaAsyncLoad(desc, shmem, barrier, loadOps);
891  sizes.push_back(sz);
892  }
893  // TODO: Note that cutlass predeclares the barrier arrive tx before the tma.async.load.
894  // This may or may not have perf implications.
895  buildBarrierArriveTx(barrier, sizes);
896  rewriter.create<scf::YieldOp>(loc);
897  },
898  /*elseBuilder=*/
899  [&](OpBuilder &lb, Location loc) {
900  // TODO: is this for no-thread divergence?
901  // Should we just yield the size and hoist?
902  buildBarrierArriveTx(barrier, getAsIndexOpFoldResult(rewriter.getContext(), 0));
903  rewriter.create<scf::YieldOp>(loc);
904  });
905  // clang-format on
906  return loadOps;
907 }
908 
911  b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
912  // return b.getI64IntegerAttr(static_cast<int64_t>(kSharedMemorySpace));
913 }
914 
917  auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
918  Value barrier = rewriter.create<nvgpu::MBarrierCreateOp>(
919  loc,
920  nvgpu::MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace));
921  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
922  rewriter.create<nvgpu::MBarrierInitOp>(
923  loc, barrier, getValueOrCreateConstantIndexOp(rewriter, loc, numThreads),
924  zero, Value());
925  rewriter.create<gpu::BarrierOp>(loc);
926  return cast<TypedValue<nvgpu::MBarrierGroupType>>(barrier);
927 }
928 
931  gpu::LaunchOp launchOp) {
932  OpBuilder::InsertionGuard guard(rewriter);
933  rewriter.setInsertionPoint(launchOp);
934  Value unrankedMemRef = rewriter.create<memref::CastOp>(
935  loc,
936  UnrankedMemRefType::get(memref.getType().getElementType(),
937  memref.getType().getMemorySpace()),
938  memref);
939  SmallVector<OpFoldResult> mixedSizes =
940  memref::getMixedSizes(rewriter, loc, memref);
941  SmallVector<Value> sizes =
942  getValueOrCreateConstantIndexOp(rewriter, loc, mixedSizes);
943 
944  auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
945  Value desc = rewriter.create<nvgpu::TmaCreateDescriptorOp>(
946  loc,
948  rewriter.getContext(),
949  MemRefType::Builder(memref.getType())
950  .setMemorySpace(sharedMemorySpace),
951  TensorMapSwizzleKind::SWIZZLE_NONE,
952  TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO,
953  TensorMapInterleaveKind::INTERLEAVE_NONE),
954  unrankedMemRef, sizes);
955  return cast<TypedValue<nvgpu::TensorMapDescriptorType>>(desc);
956 }
957 
960  TypedValue<MemRefType> sharedMemref,
962  SmallVectorImpl<Operation *> &loadOps) {
963  MLIRContext *ctx = rewriter.getContext();
964  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
965  Operation *loadOp = rewriter.create<nvgpu::TmaAsyncLoadOp>(
966  loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero}, zero,
967  Value());
968  loadOps.push_back(loadOp);
969  auto mixedSizes = memref::getMixedSizes(rewriter, loc, sharedMemref);
970  SmallVector<AffineExpr> symbols(mixedSizes.size());
971  bindSymbolsList(ctx, llvm::MutableArrayRef{symbols});
972  AffineExpr prodExprInBytes =
973  computeProduct(ctx, symbols) *
974  (sharedMemref.getType().getElementTypeBitWidth() / 8);
975  auto res = affine::makeComposedFoldedAffineApply(rewriter, loc,
976  prodExprInBytes, mixedSizes);
977  return res;
978 }
979 
982  ArrayRef<OpFoldResult> mixedSizes) {
983  assert(!mixedSizes.empty() && "expecte non-empty sizes");
984  MLIRContext *ctx = rewriter.getContext();
985  SmallVector<AffineExpr> symbols(mixedSizes.size());
986  bindSymbolsList(ctx, llvm::MutableArrayRef{symbols});
987  AffineExpr sumExpr = computeSum(ctx, symbols);
988  OpFoldResult size =
989  affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, mixedSizes);
990  Value sizeVal = getValueOrCreateConstantIndexOp(rewriter, loc, size);
991  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
992  rewriter.create<nvgpu::MBarrierArriveExpectTxOp>(loc, barrier, sizeVal, zero,
993  Value());
994 }
995 
998  Value parity = rewriter.create<arith::ConstantIndexOp>(loc, 0);
999  // 10M is an arbitrary, not too small or too big number to specify the number
1000  // of ticks before retry.
1001  // TODO: hoist this in a default dialect constant.
1002  Value ticksBeforeRetry =
1003  rewriter.create<arith::ConstantIndexOp>(loc, 10000000);
1004  Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
1005  rewriter.create<nvgpu::MBarrierTryWaitParityOp>(loc, barrier, parity,
1006  ticksBeforeRetry, zero);
1007 }
1008 
1009 //===----------------------------------------------------------------------===//
1010 // RewriteCopyAsTmaOp
1011 //===----------------------------------------------------------------------===//
1012 
1013 /// Helper to create the tma operations corresponding to `linalg::CopyOp`.
1014 struct CopyBuilder : public HopperBuilder {
1016  : HopperBuilder(rewriter, loc) {}
1017 
1019 };
1020 
1021 SmallVector<Operation *> CopyBuilder::rewrite(ArrayRef<Operation *> copyOps) {
1022  MLIRContext *ctx = rewriter.getContext();
1023  if (copyOps.empty())
1024  return SmallVector<Operation *>();
1025 
1026  auto launchOp = copyOps.front()->getParentOfType<gpu::LaunchOp>();
1027  assert(launchOp && "expected launch op");
1028 
1029  // 1. Init a barrier object in shared memory.
1030  OpBuilder::InsertionGuard g(rewriter);
1031  rewriter.setInsertionPoint(copyOps.front());
1032  AffineExpr bx, by, bz;
1033  bindSymbols(ctx, bx, by, bz);
1034  AffineExpr prod = computeProduct(ctx, ArrayRef<AffineExpr>{bx, by, bz});
1036  rewriter, loc, prod,
1037  ArrayRef<OpFoldResult>{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(),
1038  launchOp.getBlockSizeZ()});
1039 
1041  buildAndInitBarrierInSharedMemory(numThreads);
1042 
1045  for (Operation *op : copyOps) {
1046  auto copyOp = cast<linalg::CopyOp>(op);
1047  auto inMemRef =
1048  cast<TypedValue<MemRefType>>(copyOp.getDpsInputOperand(0)->get());
1049  assert(inMemRef.getType().getRank() == 2 &&
1050  "expected in to be a 2D memref");
1051 
1052  // 2. Build global memory descriptor.
1054  buildGlobalMemRefDescriptor(inMemRef, launchOp);
1055  globalDescs.push_back(globalDesc);
1056 
1057  // 3. Shared memory and descriptor for the tmp array.
1058  auto shmem =
1059  cast<TypedValue<MemRefType>>(copyOp.getDpsInitOperand(0)->get());
1060  shmems.push_back(shmem);
1061  }
1062 
1063  // 4. Load in from global memory to shared memory using tma.
1064  OpBuilder::InsertionGuard g2(rewriter);
1065  rewriter.setInsertionPoint(copyOps.front());
1066  SmallVector<Operation *> results =
1067  buildPredicateLoadsOnThread0(globalDescs, shmems, barrier);
1068 
1069  // 5. Spin-loop until data is ready.
1070  buildTryWaitParity(barrier);
1071 
1072  // 6. Erase the ops that have now been rewritten.
1073  for (Operation *op : copyOps)
1074  rewriter.eraseOp(op);
1075 
1076  return results;
1077 }
1078 
1080 transform::RewriteCopyAsTmaOp::apply(transform::TransformRewriter &rewriter,
1081  transform::TransformResults &results,
1082  transform::TransformState &state) {
1083  auto payloadOps = state.getPayloadOps(getTarget());
1084  gpu::LaunchOp commonLaunchOp;
1085  Operation *firstOp, *failingOp;
1086  if (llvm::any_of(payloadOps, [&](Operation *op) {
1087  if (!commonLaunchOp) {
1088  commonLaunchOp = op->getParentOfType<gpu::LaunchOp>();
1089  firstOp = op;
1090  }
1091  auto fail = !op->getParentOfType<gpu::LaunchOp>() ||
1092  commonLaunchOp != op->getParentOfType<gpu::LaunchOp>() ||
1093  !isa<linalg::CopyOp>(op);
1094  if (fail)
1095  failingOp = op;
1096  return fail;
1097  })) {
1099  emitSilenceableError()
1100  << "target ops must be linalg::CopyOp nested under a common "
1101  "gpu.LaunchOp to be rewritten because the tma descriptors need to "
1102  "be created on the host.\nBut got: "
1103  << *firstOp << "\nand " << *failingOp;
1104  return diag;
1105  }
1106 
1107  // TODO: more robust detection of copy, with transposes etc.
1108  CopyBuilder(rewriter, getLoc()).rewrite(llvm::to_vector(payloadOps));
1109 
1111 }
1112 
1113 //===----------------------------------------------------------------------===//
1114 // Transform op registration
1115 //===----------------------------------------------------------------------===//
1116 
1117 namespace {
1118 class NVGPUTransformDialectExtension
1120  NVGPUTransformDialectExtension> {
1121 public:
1122  NVGPUTransformDialectExtension() {
1123  declareGeneratedDialect<arith::ArithDialect>();
1124  declareGeneratedDialect<affine::AffineDialect>();
1125  declareGeneratedDialect<nvgpu::NVGPUDialect>();
1126  declareGeneratedDialect<NVVM::NVVMDialect>();
1127  declareGeneratedDialect<vector::VectorDialect>();
1128  registerTransformOps<
1129 #define GET_OP_LIST
1130 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1131  >();
1132  }
1133 };
1134 } // namespace
1135 
1136 #define GET_OP_CLASSES
1137 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1138 
1140  registry.addExtensions<NVGPUTransformDialectExtension>();
1141 }
static std::string diag(const llvm::Value &value)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
Definition: NVGPUDialect.h:27
static Attribute getSharedAddressSpaceAttribute(OpBuilder &b)
static bool hasDefaultMemorySpace(BaseMemRefType type)
Returns true if the given type has the default memory space.
static LogicalResult collectStage0PipeliningOps(scf::ForOp forOp, llvm::SmallPtrSet< Operation *, 16 > &ops)
Populate ops with the set of operations that belong to the stage 0 of the pipelined version of the gi...
static Operation * replaceOpWithPredicatedOp(RewriterBase &rewriter, Operation *op, Value predicate)
Hook for the loop pipeliner.
static bool isStoreToShared(Operation *op, Value v)
Returns true if the operation is storing the given value into shared memory.
static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn, ReduceFn reduceFn)
Helper functions to create customizable load and stores operations.
static std::tuple< DiagnosedSilenceableFailure, scf::ForOp > pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth, bool epiloguePeeling)
Applies loop pipelining with the given depth to the given loop so that copies into the shared memory ...
static bool hasSharedMemorySpace(BaseMemRefType type)
Returns true if the given type has the shared (workgroup) memory space.
static bool isLoadFromGlobalStoredToShared(Operation *op)
Returns true if the operation is a load from the default memory space the result of which is only sto...
static std::tuple< SmallVector< int64_t >, SmallVector< int64_t >, SmallVector< int64_t > > makeVectorShapes(ArrayRef< int64_t > lhs, ArrayRef< int64_t > rhs, ArrayRef< int64_t > res)
static void setAsyncWaitGroupsInFlight(OpBuilder &builder, Operation *op, scf::PipeliningOption::PipelinerPart part, unsigned iteration, unsigned depth)
Hook for the loop pipeliner that sets the "num groups in flight" attribute of async wait operations c...
static void getPipelineStages(scf::ForOp forOp, std::vector< std::pair< Operation *, unsigned >> &opsWithPipelineStages, unsigned depth, llvm::SmallPtrSetImpl< Operation * > &stage0Ops)
Hook for the loop pipeliner that populates ops with the stage information as follows:
static Value getValueLoadedFromGlobal(Operation *op)
Returns the value produced by a load from the default memory space.
static llvm::ManagedStatic< PassManagerOptions > options
static ArrayRef< int64_t > vectorShape(Type type)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:867
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:138
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
FloatType getF32Type()
Definition: Builders.cpp:63
FloatType getF16Type()
Definition: Builders.cpp:59
MLIRContext * getContext() const
Definition: Builders.h:55
IndexType getIndexType()
Definition: Builders.cpp:71
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
This class provides support for representing a failure result, or a valid value of type T.
Definition: LogicalResult.h:78
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:33
static LLVMStructType getLiteral(MLIRContext *context, ArrayRef< Type > types, bool isPacked=false)
Gets or creates a literal struct with the given body in the provided context.
Definition: LLVMTypes.cpp:436
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:63
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:60
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:201
Builder & setMemorySpace(Attribute newMemorySpace)
Definition: BuiltinTypes.h:227
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:333
This class helps build Operations.
Definition: Builders.h:206
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:383
Operation * create(const OperationState &state)
Creates an operation given the fields represented as an OperationState.
Definition: Builders.cpp:446
This class represents a single result from folding an operation.
Definition: OpDefinition.h:266
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition: Operation.h:652
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:399
virtual void replaceOp(Operation *op, ValueRange newValues)
This method replaces the results of the operation with the specified list of values.
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Type conversion class.
void addConversion(FnT &&callback)
Register a conversion function.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:36
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
U cast() const
Definition: Types.h:339
bool isInteger(unsigned width) const
Return true if this is an integer type with the specified width.
Definition: Types.cpp:59
bool isF32() const
Definition: Types.cpp:51
bool isF16() const
Definition: Types.cpp:49
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:378
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:125
user_range getUsers() const
Definition: Value.h:224
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:211
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:20
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void push_back(Operation *op)
Appends an element to the list.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1172
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
void registerTransformDialectExtension(DialectRegistry &registry)
void createAsyncGroups(RewriterBase &rewriter, Operation *op, bool bypassL1)
Convert global->shared vector transfers to async device copies.
FailureOr< ForOp > pipelineForLoop(RewriterBase &rewriter, ForOp forOp, const PipeliningOption &options, bool *modifiedIR=nullptr)
Generate a pipelined version of the scf.for loop based on the schedule given as option.
void consumesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void producesHandle(ValueRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:62
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:494
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
void getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
Definition: LogicalResult.h:68
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
Definition: LogicalResult.h:56
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
void populateNVGPUToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns)
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:348
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:40
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:584
int64_t computeSum(ArrayRef< int64_t > basis)
Self-explicit.
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
Definition: LogicalResult.h:72
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
Definition: AffineExpr.h:353
Helper to create the tma operations corresponding to linalg::CopyOp.
SmallVector< Operation * > rewrite(ArrayRef< Operation * > copyOps)
CopyBuilder(RewriterBase &rewriter, Location loc)
Helper to create the base Hopper-specific operations that are reused in various other places.
OpFoldResult buildTmaAsyncLoad(TypedValue< nvgpu::TensorMapDescriptorType > globalDesc, TypedValue< MemRefType > sharedMemref, TypedValue< nvgpu::MBarrierGroupType > barrier, SmallVectorImpl< Operation * > &loadOps)
Build a tma load from global memory to shared memory using barrier to synchronize.
TypedValue< nvgpu::MBarrierGroupType > buildAndInitBarrierInSharedMemory(OpFoldResult numThreads)
void buildTryWaitParity(TypedValue< nvgpu::MBarrierGroupType > barrier)
RewriterBase & rewriter
TypedValue< nvgpu::TensorMapDescriptorType > buildGlobalMemRefDescriptor(TypedValue< MemRefType > memref, gpu::LaunchOp launchOp)
Create tma descriptor op to initiate transfer from global to shared memory.
SmallVector< Operation * > buildPredicateLoadsOnThread0(ArrayRef< TypedValue< nvgpu::TensorMapDescriptorType >> globalDescriptors, ArrayRef< TypedValue< MemRefType >> sharedMemBuffers, TypedValue< nvgpu::MBarrierGroupType > barrier)
If threadIdx.x == 0 does TMA request + wait, else just wait.
void buildBarrierArriveTx(TypedValue< nvgpu::MBarrierGroupType > barrier, ArrayRef< OpFoldResult > sizes)
HopperBuilder(RewriterBase &rewriter, Location loc)
Helper struct to provide a simple mapping from matmul operations to the corresponding mma....
std::function< SmallVector< RowColIndexing >(MLIRContext *)> IndexCalculator
MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId)
FailureOr< Operation * > buildMmaSync(LinalgOp linalgOp)
Create the mma.sync operation corresponding to linalgOp along with all the supporting load/store and ...
Helper struct to encode a pair of row/column indexings in the form of affine expressions.
AffineExpr col() const
RowColIndexing(AffineExpr row, AffineExpr col)
void print(llvm::raw_ostream &os) const
AffineExpr row() const
This class represents an efficient way to signal success or failure.
Definition: LogicalResult.h:26
Options to dictate how loops should be pipelined.
Definition: Transforms.h:104