MLIR  22.0.0git
NVGPUTransformOps.cpp
Go to the documentation of this file.
1 //===- NVGPUTransformOps.cpp - Implementation of NVGPU transform ops ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
29 #include "mlir/IR/AffineExpr.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/Value.h"
32 #include "llvm/ADT/ArrayRef.h"
33 
34 using namespace mlir;
35 using namespace mlir::linalg;
36 using namespace mlir::nvgpu;
37 using namespace mlir::NVVM;
38 using namespace mlir::transform;
39 
40 #define DEBUG_TYPE "nvgpu-transforms"
41 
42 //===----------------------------------------------------------------------===//
43 // Apply...ConversionPatternsOp
44 //===----------------------------------------------------------------------===//
45 
46 void ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
47  TypeConverter &typeConverter, RewritePatternSet &patterns) {
48  auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
49  /// device-side async tokens cannot be materialized in nvvm. We just
50  /// convert them to a dummy i32 type in order to easily drop them during
51  /// conversion.
53  llvmTypeConverter, [](gpu::AddressSpace space) -> unsigned {
54  switch (space) {
55  case gpu::AddressSpace::Global:
56  return static_cast<unsigned>(NVVM::NVVMMemorySpace::Global);
57  case gpu::AddressSpace::Workgroup:
58  return static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
59  case gpu::AddressSpace::Private:
60  return 0;
61  }
62  llvm_unreachable("unknown address space enum value");
63  return static_cast<unsigned>(NVVM::NVVMMemorySpace::Generic);
64  });
65  llvmTypeConverter.addConversion([&](DeviceAsyncTokenType type) -> Type {
66  return llvmTypeConverter.convertType(
67  IntegerType::get(type.getContext(), 32));
68  });
69  llvmTypeConverter.addConversion([&](MBarrierTokenType type) -> Type {
70  return llvmTypeConverter.convertType(
71  IntegerType::get(type.getContext(), 64));
72  });
73  llvmTypeConverter.addConversion([&](WarpgroupAccumulatorType type) -> Type {
74  Type elemType = type.getFragmented().getElementType();
75  int64_t sizeM = type.getFragmented().getDimSize(0);
76  int64_t sizeN = type.getFragmented().getDimSize(1);
77 
78  unsigned numMembers;
79  if (elemType.isF32() || elemType.isInteger(32))
80  numMembers = sizeN / 2;
81  else if (elemType.isF16())
82  numMembers = sizeN / 4;
83  else
84  llvm_unreachable("unsupported type for warpgroup accumulator");
85 
86  SmallVector<Type> innerStructBody;
87  for (unsigned i = 0; i < numMembers; i++)
88  innerStructBody.push_back(elemType);
89  auto innerStructType =
90  LLVM::LLVMStructType::getLiteral(type.getContext(), innerStructBody);
91 
92  SmallVector<Type> structBody;
93  for (int i = 0; i < sizeM; i += kWgmmaSizeM)
94  structBody.push_back(innerStructType);
95 
96  auto convertedType =
97  LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
98  return llvmTypeConverter.convertType(convertedType);
99  });
100  llvmTypeConverter.addConversion([&](MBarrierGroupType type) -> Type {
101  return llvmTypeConverter.convertType(
102  getMBarrierMemrefType(type.getContext(), type));
103  });
104  llvmTypeConverter.addConversion(
105  [&](WarpgroupMatrixDescriptorType type) -> Type {
106  return llvmTypeConverter.convertType(
107  IntegerType::get(type.getContext(), 64));
108  });
109  llvmTypeConverter.addConversion([&](TensorMapDescriptorType type) -> Type {
110  return LLVM::LLVMPointerType::get(type.getContext());
111  });
113 }
114 
115 LogicalResult ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(
116  TypeConverterBuilderOpInterface builder) {
117  if (builder.getTypeConverterType() != "LLVMTypeConverter")
118  return emitOpError("expected LLVMTypeConverter");
119  return success();
120 }
121 
122 //===---------------------------------------------------------------------===//
123 // CreateAsyncGroupsOp
124 //===---------------------------------------------------------------------===//
125 
126 void CreateAsyncGroupsOp::getEffects(
128  consumesHandle(getTargetMutable(), effects);
129  producesHandle(getOperation()->getOpResults(), effects);
130  modifiesPayload(effects);
131 }
132 
134 CreateAsyncGroupsOp::applyToOne(TransformRewriter &rewriter, Operation *target,
135  ApplyToEachResultList &results,
136  TransformState &state) {
137  createAsyncGroups(rewriter, target, getBypassL1());
138  results.push_back(target);
140 }
141 
142 //===----------------------------------------------------------------------===//
143 // PipelineSharedMemoryCopiesOp
144 //===----------------------------------------------------------------------===//
145 
146 /// Returns true if the given type has the default memory space.
148  return !type.getMemorySpace() || type.getMemorySpaceAsInt() == 0;
149 }
150 
151 /// Returns true if the given type has the shared (workgroup) memory space.
153  auto space =
154  dyn_cast_if_present<gpu::AddressSpaceAttr>(type.getMemorySpace());
155  return space &&
156  space.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
157 }
158 
159 /// Returns the value produced by a load from the default memory space. Returns
160 /// null if the operation is not such a load.
162  // TODO: consider an interface or leveraging the memory effects interface.
163  auto load = dyn_cast<vector::TransferReadOp>(op);
164  if (!load)
165  return nullptr;
166 
167  auto loadType = dyn_cast<MemRefType>(load.getBase().getType());
168  if (!loadType || !hasDefaultMemorySpace(loadType))
169  return nullptr;
170  return load;
171 }
172 
173 /// Returns true if the operation is storing the given value into shared memory.
174 static bool isStoreToShared(Operation *op, Value v) {
175  // TOD: consider an interface or leveraging the memory effects interface.
176  auto store = dyn_cast<vector::TransferWriteOp>(op);
177  if (!store || store.getVector() != v)
178  return false;
179 
180  auto storeType = dyn_cast<MemRefType>(store.getBase().getType());
181  return storeType || hasSharedMemorySpace(storeType);
182 }
183 
184 /// Returns true if the operation is a load from the default memory space the
185 /// result of which is only stored into the shared memory space.
187  Value loaded = getValueLoadedFromGlobal(op);
188  if (!loaded || !loaded.hasOneUse())
189  return false;
190 
191  return isStoreToShared(*loaded.getUsers().begin(), loaded);
192 }
193 
194 /// Populate `ops` with the set of operations that belong to the stage 0 of the
195 /// pipelined version of the given loop when pipelining copies to shared memory.
196 /// Specifically, this collects:
197 ///
198 /// 1. all loads from global memory, both sync and async;
199 /// 2. the barriers for async loads.
200 ///
201 /// In particular, barriers are omitted if they do not dominate at least one
202 /// async load for which there is not yet a barrier.
203 static LogicalResult
204 collectStage0PipeliningOps(scf::ForOp forOp,
206 
208  for (Operation &op : *forOp.getBody()) {
209  // Bail on nested ops for now.
210  if (op.getNumRegions() > 0)
211  return failure();
212 
213  if (isa<gpu::BarrierOp>(op)) {
214  barriers.insert(&op);
215  continue;
216  }
217 
218  if (isa<DeviceAsyncCopyOp, DeviceAsyncCreateGroupOp>(op)) {
219  ops.insert(&op);
220  ops.insert(std::make_move_iterator(barriers.begin()),
221  std::make_move_iterator(barriers.end()));
222  assert(barriers.empty() &&
223  "expected to have moved the barriers into another set");
224  continue;
225  }
226 
228  ops.insert(&op);
229  continue;
230  }
231  }
232 
233  return success();
234 }
235 
236 /// Hook for the loop pipeliner that sets the "num groups in flight" attribute
237 /// of async wait operations corresponding to pipelined shared memory copies.
238 // TODO: this currently assumes that there are no groups that could be in flight
239 // in the existing code.
240 static void
243  unsigned iteration, unsigned depth) {
244  // Based on the order of copies within the loop we need to set the number
245  // of copies in flight, unless it is already set.
246  auto waitOp = dyn_cast<DeviceAsyncWaitOp>(op);
247  if (!waitOp || waitOp.getNumGroups())
248  return;
249 
250  int numGroupInFlight = 0;
253  numGroupInFlight = depth - 1;
254  } else {
255  // By construction there should be no wait op in the prologue as all the
256  // wait should be in the last stage.
258  // Based on the schedule we pick we know how many groups are in flight for
259  // each iteration of the epilogue.
260  numGroupInFlight = depth - 1 - iteration;
261  }
262  waitOp.setNumGroups(numGroupInFlight);
263 }
264 
265 /// Hook for the loop pipeliner that populates `ops` with the stage information
266 /// as follows:
267 ///
268 /// - operations in `stage0Ops` (typically loads from global memory and
269 /// related barriers) are at stage 0;
270 /// - operations in the backward slice of any stage0Ops are all at stage 0;
271 /// - other operations are at stage `depth`;
272 /// - the internal order of the pipelined loop has ops at stage `depth` first,
273 /// then those at stage 0, with relative order within each group preserved.
274 ///
275 static void getPipelineStages(
276  scf::ForOp forOp,
277  std::vector<std::pair<Operation *, unsigned>> &opsWithPipelineStages,
278  unsigned depth, llvm::SmallPtrSetImpl<Operation *> &stage0Ops) {
279  SetVector<Operation *> dependencies;
280  BackwardSliceOptions options([&](Operation *visited) {
281  return visited->getBlock() == forOp.getBody();
282  });
283  options.inclusive = true;
284  for (Operation &op : forOp.getBody()->getOperations()) {
285  if (stage0Ops.contains(&op)) {
286  LogicalResult result = getBackwardSlice(&op, &dependencies, options);
287  assert(result.succeeded() && "expected a backward slice");
288  (void)result;
289  }
290  }
291 
292  for (Operation &op : forOp.getBody()->getOperations()) {
293  if (!dependencies.contains(&op) && !isa<scf::YieldOp>(op))
294  opsWithPipelineStages.emplace_back(&op, depth);
295  }
296  for (Operation &op : forOp.getBody()->getOperations()) {
297  if (dependencies.contains(&op))
298  opsWithPipelineStages.emplace_back(&op, 0);
299  }
300 }
301 
302 /// Hook for the loop pipeliner. Replaces op with a predicated version and
303 /// returns the resulting operation. Returns the original op if the predication
304 /// isn't necessary for the given op. Returns null if predication is needed but
305 /// not supported.
307  Operation *op, Value predicate) {
308  // Some operations may be fine to execute "speculatively" more times than the
309  // original number of iterations, in particular side-effect free operations
310  // and barriers, even if they cannot be predicated.
311  if (isMemoryEffectFree(op) ||
312  isa<gpu::BarrierOp, DeviceAsyncCreateGroupOp, DeviceAsyncWaitOp>(op)) {
313  return op;
314  }
315 
316  // Otherwise, only async copies can currently be predicated.
317  auto asyncCopyOp = dyn_cast<DeviceAsyncCopyOp>(op);
318  if (!asyncCopyOp)
319  return nullptr;
320 
321  // Create srcElement Value based on `predicate`. The next lines generate
322  // the following code:
323  //
324  // srcElement = (pred) ? prevSrcElements : 0;
325  //
326  Location loc = asyncCopyOp->getLoc();
327  Value dstElements = arith::ConstantOp::create(
328  rewriter, loc, asyncCopyOp.getDstElementsAttr());
329  Value originalSrcElement =
330  asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements;
331  Value c0Index = arith::ConstantIndexOp::create(rewriter, loc, 0);
332  auto srcElements = arith::SelectOp::create(rewriter, loc, predicate,
333  originalSrcElement, c0Index);
334  auto asyncCopyZeroFillOp = DeviceAsyncCopyOp::create(
335  rewriter, loc, DeviceAsyncTokenType::get(asyncCopyOp.getContext()),
336  asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(),
337  asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements,
338  UnitAttr());
339  rewriter.replaceOp(asyncCopyOp, asyncCopyZeroFillOp);
340  return asyncCopyZeroFillOp;
341 }
342 
343 /// Applies loop pipelining with the given depth to the given loop so that
344 /// copies into the shared memory are pipelined. Doesn't affect other loops.
345 /// Returns a pair containing the error state and the pipelined op, the latter
346 /// being null in case of any failure. The error state contains a definite error
347 /// if the IR has been modified and a silenceable error otherwise.
348 static std::tuple<DiagnosedSilenceableFailure, scf::ForOp>
349 pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth,
350  bool epiloguePeeling) {
352  if (failed(collectStage0PipeliningOps(forOp, stage0Ops))) {
353  return std::make_tuple(
354  emitSilenceableFailure(forOp, "cannot find stage 0 ops for pipelining"),
355  scf::ForOp());
356  }
357  if (stage0Ops.empty()) {
358  return std::make_tuple(
359  emitSilenceableFailure(forOp, "no shared memory copy"), scf::ForOp());
360  }
361 
363  unsigned maxDepth = depth;
364  auto setAnnotation = [&](Operation *op,
366  unsigned iteration) {
367  return setAsyncWaitGroupsInFlight(rewriter, op, part, iteration, maxDepth);
368  };
369  options.getScheduleFn =
370  [&](scf::ForOp schedulingFor,
371  std::vector<std::pair<Operation *, unsigned>> &ops) {
372  if (schedulingFor != forOp)
373  return;
374  return getPipelineStages(forOp, ops, maxDepth, stage0Ops);
375  };
376  options.annotateFn = setAnnotation;
377  if (!epiloguePeeling) {
378  options.peelEpilogue = false;
379  options.predicateFn = replaceOpWithPredicatedOp;
380  }
381 
382  OpBuilder::InsertionGuard guard(rewriter);
383  rewriter.setInsertionPoint(forOp);
384  bool modifiedIR;
385  FailureOr<scf::ForOp> maybePipelined =
386  pipelineForLoop(rewriter, forOp, options, &modifiedIR);
387  if (succeeded(maybePipelined)) {
388  return std::make_tuple(DiagnosedSilenceableFailure::success(),
389  *maybePipelined);
390  }
391  return std::make_tuple(
392  modifiedIR
394  : emitSilenceableFailure(forOp, "pipelining preconditions failed"),
395  scf::ForOp());
396 }
397 
398 DiagnosedSilenceableFailure PipelineSharedMemoryCopiesOp::applyToOne(
399  TransformRewriter &rewriter, scf::ForOp forOp,
400  ApplyToEachResultList &results, TransformState &state) {
401  auto [diag, pipelined] = pipelineForSharedCopies(
402  rewriter, forOp, static_cast<int64_t>(getDepth()), getPeelEpilogue());
403  if (diag.succeeded()) {
404  results.push_back(pipelined);
406  }
407  if (diag.isDefiniteFailure()) {
408  auto diag = emitDefiniteFailure("irreversible pipelining failure");
409  if (!getPeelEpilogue()) {
410  diag.attachNote(forOp->getLoc()) << "couldn't predicate?";
411  diag.attachNote(getLoc()) << "try setting " << getPeelEpilogueAttrName();
412  }
413  return diag;
414  }
415 
416  return std::move(diag);
417 }
418 
419 //===----------------------------------------------------------------------===//
420 // RewriteMatmulAsMmaSyncOp
421 //===----------------------------------------------------------------------===//
422 
423 /// Helper struct to encode a pair of row/column indexings in the form of
424 /// affine expressions.
425 struct RowColIndexing : private std::pair<AffineExpr, AffineExpr> {
427  : std::pair<AffineExpr, AffineExpr>(row, col) {}
428 
429  AffineExpr row() const { return first; };
430  AffineExpr col() const { return second; };
431 
432  void print(llvm::raw_ostream &os) const {
433  os << "- indexing: " << first << ", " << second;
434  }
435 };
436 
437 /// Helper struct to provide a simple mapping from matmul operations to the
438 /// corresponding mma.sync operation. This is constrained to the case where the
439 /// matmul matches the mma.sync operation 1-1.
442  : b(b), loc(loc), laneId(laneId) {}
443 
445  std::function<SmallVector<RowColIndexing>(MLIRContext *)>;
446 
447  /// Create the mma.sync operation corresponding to `linalgOp` along with all
448  /// the supporting load/store and vector operations.
449  FailureOr<Operation *> buildMmaSync(LinalgOp linalgOp);
450 
451 private:
452  struct MmaSyncInfo {
453  std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns;
454  std::tuple<SmallVector<int64_t>, SmallVector<int64_t>, SmallVector<int64_t>>
455  vectorShapes;
456  SmallVector<int64_t> mmaShape;
457  bool tf32Enabled;
458  };
459 
460  /// Return the specific index calculator for the given `linalgOp` or failure
461  /// if the op is not supported. This is the toplevel switch that should just
462  /// be Tablegen'd in the future.
463  FailureOr<MmaSyncInfo> getIndexCalculators(ArrayRef<int64_t> opShape,
464  TypeRange elementalTypes);
465 
466  //===--------------------------------------------------------------------===//
467  // Instruction-specific row, column indexing expression builders.
468  // These should all be declaratively specified via Tablegen in the future.
469  // The Tablegen specification should be as straightforward as possible to
470  // only model the existing size and type combinations.
471  //===--------------------------------------------------------------------===//
472  //
473  // TODO: Tablegen all this.
474  //===--------------------------------------------------------------------===//
475  // m16n8k4 tf32 case.
476  //===--------------------------------------------------------------------===//
477  /// From the NVIDIA doc:
478  /// groupID = %laneid >> 2
479  /// threadIDInGroup = %laneid % 4
480  /// row = groupID for a0
481  /// groupID + 8 for a1
482  /// col = threadIDInGroup
483  static SmallVector<RowColIndexing> m16n8k4tf32Lhs(MLIRContext *ctx) {
484  auto dim = getAffineDimExpr(0, ctx);
485  AffineExpr groupID = dim.floorDiv(4);
486  AffineExpr threadIDInGroup = dim % 4;
487  return {RowColIndexing{groupID, threadIDInGroup},
488  RowColIndexing{groupID + 8, threadIDInGroup}};
489  }
490 
491  /// From the NVIDIA doc:
492  /// groupID = %laneid >> 2
493  /// threadIDInGroup = %laneid % 4
494  /// row = threadIDInGroup
495  /// col = groupID
496  static SmallVector<RowColIndexing> m16n8k4tf32Rhs(MLIRContext *ctx) {
497  auto dim = getAffineDimExpr(0, ctx);
498  AffineExpr groupID = dim.floorDiv(4);
499  AffineExpr threadIDInGroup = dim % 4;
500  return {RowColIndexing{threadIDInGroup, groupID}};
501  }
502 
503  /// From the NVIDIA doc:
504  /// groupID = %laneid >> 2
505  /// threadIDInGroup = %laneid % 4
506  /// row = groupID for c0 and c1
507  /// groupID + 8 for c2 and c3
508  /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3}
509  static SmallVector<RowColIndexing> m16n8k4tf32Res(MLIRContext *ctx) {
510  auto dim = getAffineDimExpr(0, ctx);
511  AffineExpr groupID = dim.floorDiv(4);
512  AffineExpr threadIDInGroup = dim % 4;
513  return {RowColIndexing{groupID, threadIDInGroup * 2 + 0},
514  RowColIndexing{groupID, threadIDInGroup * 2 + 1},
515  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},
516  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}};
517  }
518 
519  //===--------------------------------------------------------------------===//
520  // m16n8k16 f16 case.
521  //===--------------------------------------------------------------------===//
522  /// From the NVIDIA doc:
523  /// groupID = %laneid >> 2
524  /// threadIDInGroup = %laneid % 4
525  ///
526  /// row = groupID for ai where 0 <= i < 2 || 4 <= i < 6
527  /// groupID + 8 Otherwise
528  ///
529  /// col = (threadIDInGroup * 2) + (i & 0x1) for ai where i < 4
530  /// (threadIDInGroup * 2) + (i & 0x1) + 8 for ai where i >= 4
531  static SmallVector<RowColIndexing> m16n8k16f16Lhs(MLIRContext *ctx) {
532  auto dim = getAffineDimExpr(0, ctx);
533  AffineExpr groupID = dim.floorDiv(4);
534  AffineExpr threadIDInGroup = dim % 4;
535  // clang-format off
536  return {
537  RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0
538  RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1
539  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2
540  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}, // i == 3
541  RowColIndexing{groupID, threadIDInGroup * 2 + 0 + 8}, // i == 4
542  RowColIndexing{groupID, threadIDInGroup * 2 + 1 + 8}, // i == 5
543  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0 + 8}, // i == 6
544  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1 + 8} // i == 7
545  };
546  // clang-format on
547  }
548 
549  /// From the NVIDIA doc:
550  /// groupID = %laneid >> 2
551  /// threadIDInGroup = %laneid % 4
552  ///
553  /// row = (threadIDInGroup * 2) + (i & 0x1) for bi where i < 2
554  /// (threadIDInGroup * 2) + (i & 0x1) + 8 for bi where i >= 2
555  ///
556  /// col = groupID
557  static SmallVector<RowColIndexing> m16n8k16f16Rhs(MLIRContext *ctx) {
558  auto dim = getAffineDimExpr(0, ctx);
559  AffineExpr groupID = dim.floorDiv(4);
560  AffineExpr threadIDInGroup = dim % 4;
561  // clang-format off
562  return {
563  RowColIndexing{threadIDInGroup * 2 + 0, groupID}, // i == 0
564  RowColIndexing{threadIDInGroup * 2 + 1, groupID}, // i == 1
565  RowColIndexing{threadIDInGroup * 2 + 0 + 8, groupID}, // i == 2
566  RowColIndexing{threadIDInGroup * 2 + 1 + 8, groupID} // i == 3
567  };
568  // clang-format on
569  }
570 
571  /// From the NVIDIA doc:
572  /// groupID = %laneid >> 2
573  /// threadIDInGroup = %laneid % 4
574  ///
575  /// row = groupID for ci where i < 2
576  /// groupID + 8 for ci where i >= 2
577  ///
578  /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3}
579  static SmallVector<RowColIndexing> m16n8k16f16Res(MLIRContext *ctx) {
580  auto dim = getAffineDimExpr(0, ctx);
581  AffineExpr groupID = dim.floorDiv(4);
582  AffineExpr threadIDInGroup = dim % 4;
583  // clang-format off
584  return {
585  RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0
586  RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1
587  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2
588  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1} // i == 3
589  };
590  // clang-format on
591  }
592 
593  //===--------------------------------------------------------------------===//
594  /// Helper functions to create customizable load and stores operations. The
595  /// specific shapes of each MMA instruction are passed via the
596  /// IndexCalculator callback.
597  //===--------------------------------------------------------------------===//
598  /// Build a list of memref.load operations indexed at `(row, col)` indices
599  /// that make sense for a particular MMA instruction and specified via the
600  /// IndexCalculator callback.
601  SmallVector<Value> buildMemRefLoads(OpBuilder &b, Location loc,
602  OpFoldResult laneId, Value memref,
603  const IndexCalculator &indexFn);
604 
605  /// Perform a distributed load of a vector operand of `vectorShape` for a
606  /// particular MMA instruction whose `(row, col)` indices are specified via
607  /// the IndexCalculator callback. Each `laneId` loads the subportion of the
608  /// data that makes sense for the particular MMA operation.
609  /// The `vectorShape` matches existing NVGPU dialect op specification but
610  /// could also be flattened in the future if needed for simplification.
611  Value buildMmaSyncMemRefLoadOperand(OpBuilder &b, Location loc,
612  OpFoldResult laneId, Value memref,
613  IndexCalculator indexFn,
615 
616  /// Build a list of memref.store operations indexed at `(row, col)` indices
617  /// that make sense for a particular MMA instruction and specified via the
618  /// IndexCalculator callback.
619  SmallVector<Operation *> buildMemRefStores(OpBuilder &b, Location loc,
620  ValueRange toStore,
621  OpFoldResult laneId, Value memref,
622  const IndexCalculator &indexFn);
623 
624  /// Perform a distributed store of a vector operand of `vectorShape` for a
625  /// particular MMA instruction whose `(row, col)` indices are specified via
626  /// the IndexCalculator callback. Each `laneId` loads the subportion of the
627  /// data that makes sense for the particular MMA operation.
628  /// The `vectorShape` matches existing NVGPU dialect op specification but
629  /// could also be flattened in the future if needed for simplification.
630  SmallVector<Operation *> buildMmaSyncMemRefStoreOperand(
631  OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
632  Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape);
633 
634  OpBuilder &b;
635  Location loc;
636  OpFoldResult laneId;
637 };
638 
639 //===--------------------------------------------------------------------===//
640 /// Helper functions to create customizable load and stores operations. The
641 /// specific shapes of each MMA instruction are passed via the
642 /// IndexCalculator callback.
643 //===--------------------------------------------------------------------===//
644 
645 template <typename ApplyFn, typename ReduceFn>
646 static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn,
647  ReduceFn reduceFn) {
648  VectorType vectorType = cast<VectorType>(vector.getType());
649  auto vectorShape = vectorType.getShape();
650  auto strides = computeStrides(vectorShape);
651  for (int64_t idx = 0, e = vectorShape[0] * strides[0]; idx < e; ++idx) {
652  auto indices = delinearize(idx, strides);
653  reduceFn(applyFn(vector, idx, indices), idx, indices);
654  }
655 }
656 
658 MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc,
659  OpFoldResult laneId, Value memref,
660  const IndexCalculator &indexFn) {
661  auto aff = [&](AffineExpr e) {
662  return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
663  };
664  SmallVector<Value> res;
665  SmallVector<RowColIndexing> indexings = indexFn(b.getContext());
666  for (auto indexing : indexings) {
667  Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row()));
668  Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col()));
669  auto load = memref::LoadOp::create(b, loc, memref, ValueRange{row, col});
670  res.push_back(load);
671  }
672  return res;
673 }
674 
675 Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
676  OpBuilder &b, Location loc, OpFoldResult laneId, Value memref,
677  IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
678  auto loads = buildMemRefLoads(b, loc, laneId, memref, std::move(indexFn));
679 
680  Type elementType = getElementTypeOrSelf(memref.getType());
681  auto vt = VectorType::get(vectorShape, elementType);
682  Value res = vector::BroadcastOp::create(b, loc, vt, loads[0]);
684  res,
685  /*applyFn=*/
686  [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
687  return loads[linearIdx];
688  },
689  /*reduceFn=*/
690  [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
691  res = vector::InsertOp::create(b, loc, v, res, indices);
692  });
693 
694  return res;
695 }
696 
697 SmallVector<Operation *> MmaSyncBuilder::buildMemRefStores(
698  OpBuilder &b, Location loc, ValueRange toStore, OpFoldResult laneId,
699  Value memref, const IndexCalculator &indexFn) {
700  auto aff = [&](AffineExpr e) {
701  return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
702  };
704  for (auto [indexing, val] :
705  llvm::zip_equal(indexFn(b.getContext()), toStore)) {
706  Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row()));
707  Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col()));
708  Operation *store =
709  memref::StoreOp::create(b, loc, val, memref, ValueRange{row, col});
710  res.push_back(store);
711  }
712  return res;
713 }
714 
715 SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemRefStoreOperand(
716  OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
717  Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
718  SmallVector<Value> toStore;
719  toStore.reserve(32);
721  vectorToStore,
722  /*applyFn=*/
723  [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
724  return vector::ExtractOp::create(b, loc, vectorToStore, indices);
725  },
726  /*reduceFn=*/
727  [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
728  toStore.push_back(v);
729  });
730  return buildMemRefStores(b, loc, toStore, laneId, memref, std::move(indexFn));
731 }
732 
733 static std::tuple<SmallVector<int64_t>, SmallVector<int64_t>,
736  ArrayRef<int64_t> res) {
737  SmallVector<int64_t> vlhs(lhs);
738  SmallVector<int64_t> vrhs(rhs);
739  SmallVector<int64_t> vres(res);
740  return std::make_tuple(vlhs, vrhs, vres);
741 }
742 
743 FailureOr<MmaSyncBuilder::MmaSyncInfo>
744 MmaSyncBuilder::getIndexCalculators(ArrayRef<int64_t> opShape,
745  TypeRange elementalTypes) {
746  // TODO: Tablegen all this.
747  Type f16 = b.getF16Type();
748  Type f32 = b.getF32Type();
749  if (opShape == ArrayRef<int64_t>{16, 8, 4} &&
750  elementalTypes == TypeRange{f32, f32, f32}) {
751  return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k4tf32Lhs,
752  &MmaSyncBuilder::m16n8k4tf32Rhs,
753  &MmaSyncBuilder::m16n8k4tf32Res),
754  makeVectorShapes({2, 1}, {1, 1}, {2, 2}),
755  SmallVector<int64_t>{opShape},
756  /*tf32Enabled=*/true};
757  }
758  // This is the version with f16 accumulation.
759  // TODO: version with f32 accumulation.
760  if (opShape == ArrayRef<int64_t>{16, 8, 16} &&
761  elementalTypes == TypeRange{f16, f16, f16}) {
762  return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k16f16Lhs,
763  &MmaSyncBuilder::m16n8k16f16Rhs,
764  &MmaSyncBuilder::m16n8k16f16Res),
765  makeVectorShapes({4, 2}, {2, 2}, {2, 2}),
766  SmallVector<int64_t>{opShape},
767  /*tf32Enabled=*/false};
768  }
769  return failure();
770 }
771 
772 FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) {
773  Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get();
774  Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get();
775  Value resMemRef = linalgOp.getDpsInitOperand(0)->get();
776  assert(cast<MemRefType>(lhsMemRef.getType()).getRank() == 2 &&
777  "expected lhs to be a 2D memref");
778  assert(cast<MemRefType>(rhsMemRef.getType()).getRank() == 2 &&
779  "expected rhs to be a 2D memref");
780  assert(cast<MemRefType>(resMemRef.getType()).getRank() == 2 &&
781  "expected res to be a 2D memref");
782 
783  int64_t m = cast<MemRefType>(lhsMemRef.getType()).getShape()[0];
784  int64_t n = cast<MemRefType>(rhsMemRef.getType()).getShape()[1];
785  int64_t k = cast<MemRefType>(lhsMemRef.getType()).getShape()[1];
786  Type lhsType = getElementTypeOrSelf(lhsMemRef.getType());
787  Type rhsType = getElementTypeOrSelf(rhsMemRef.getType());
788  Type resType = getElementTypeOrSelf(resMemRef.getType());
789 
790  FailureOr<MmaSyncInfo> maybeInfo =
791  getIndexCalculators({m, n, k}, {lhsType, rhsType, resType});
792  if (failed(maybeInfo))
793  return failure();
794 
795  MmaSyncInfo info = *maybeInfo;
796  auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;
797  auto [lhsShape, rhsShape, resShape] = info.vectorShapes;
798  Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef,
799  lhsIndexFn, lhsShape);
800  Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, rhsMemRef,
801  rhsIndexFn, rhsShape);
802  Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef,
803  resIndexFn, resShape);
804  res =
805  MmaSyncOp::create(b, loc, lhs, rhs, res, info.mmaShape, info.tf32Enabled);
806  buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn,
807  resShape);
808  return res.getDefiningOp();
809 }
810 
811 DiagnosedSilenceableFailure RewriteMatmulAsMmaSyncOp::applyToOne(
812  TransformRewriter &rewriter, LinalgOp linalgOp,
813  ApplyToEachResultList &results, TransformState &state) {
814  bool fail = true;
815  // TODO: more robust detection of matmulOp, with transposes etc.
816  if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) {
817  // Check to not let go the matmul with extended semantic, through this
818  // transform.
819  if (linalgOp.hasUserDefinedMaps()) {
820  return emitSilenceableError()
821  << "only matmul ops with non-extended semantics are supported";
822  }
823  Location loc = linalgOp.getLoc();
824  // TODO: more robust computation of laneId, for now assume a single warp.
825  Value laneId = gpu::ThreadIdOp::create(
826  rewriter, loc, rewriter.getIndexType(), gpu::Dimension::x);
827  if (succeeded(MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp)))
828  fail = false;
829  }
830 
831  if (fail) {
832  DiagnosedSilenceableFailure diag = emitSilenceableError()
833  << "unsupported target op: " << linalgOp;
834  diag.attachNote(linalgOp->getLoc()) << "target op";
835  return diag;
836  }
837 
838  rewriter.eraseOp(linalgOp);
840 }
841 
842 //===----------------------------------------------------------------------===//
843 // Hopper builders.
844 //===----------------------------------------------------------------------===//
845 
846 /// Helper to create the base Hopper-specific operations that are reused in
847 /// various other places.
850  : rewriter(rewriter), loc(loc) {}
851 
853  buildAndInitBarrierInSharedMemory(OpFoldResult numThreads);
854 
855  /// Create tma descriptor op to initiate transfer from global to shared
856  /// memory. This must be done before the launch op, on the host.
858  buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
859  gpu::LaunchOp launchOp);
860 
861  /// Build a tma load from global memory to shared memory using `barrier` to
862  /// synchronize. Return the number of bytes that will be transferred.
863  OpFoldResult buildTmaAsyncLoad(TypedValue<TensorMapDescriptorType> globalDesc,
864  TypedValue<MemRefType> sharedMemref,
867  void buildBarrierArriveTx(TypedValue<MBarrierGroupType> barrier,
868  ArrayRef<OpFoldResult> sizes);
869 
870  /// If threadIdx.x == 0 does TMA request + wait, else just wait.
871  /// Return the operation that performs the transfer on thread0.
872  // TODO: In the future, don't hardcode to thread 0 but elect a leader.
873  SmallVector<Operation *> buildPredicateLoadsOnThread0(
875  ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
877 
878  void buildTryWaitParity(TypedValue<MBarrierGroupType> barrier);
879 
882 };
883 
885  ArrayRef<TypedValue<TensorMapDescriptorType>> globalDescriptors,
886  ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
888  SmallVector<Operation *> loadOps;
889  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
890  Value tidx = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::x);
891  Value cond = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
892  tidx, zero);
893  // clang-format off
894  scf::IfOp::create(rewriter,
895  /*location=*/loc,
896  /*conditional=*/cond,
897  /*thenBuilder=*/
898  [&](OpBuilder &lb, Location loc) {
900  sizes.reserve(globalDescriptors.size());
901  for (auto [desc, shmem] : llvm::zip_equal(
902  globalDescriptors, sharedMemBuffers)) {
903  OpFoldResult sz = buildTmaAsyncLoad(desc, shmem, barrier, loadOps);
904  sizes.push_back(sz);
905  }
906  // TODO: Note that cutlass predeclares the barrier arrive tx before the tma.async.load.
907  // This may or may not have perf implications.
908  buildBarrierArriveTx(barrier, sizes);
909  scf::YieldOp::create(rewriter, loc);
910  },
911  /*elseBuilder=*/
912  [&](OpBuilder &lb, Location loc) {
913  // TODO: is this for no-thread divergence?
914  // Should we just yield the size and hoist?
915  buildBarrierArriveTx(barrier, getAsIndexOpFoldResult(rewriter.getContext(), 0));
916  scf::YieldOp::create(rewriter, loc);
917  });
918  // clang-format on
919  return loadOps;
920 }
921 
924  b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
925  // return b.getI64IntegerAttr(static_cast<int64_t>(kSharedMemorySpace));
926 }
927 
930  auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
931  Value barrier = MBarrierCreateOp::create(
932  rewriter, loc,
933  MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace));
934  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
935  nvgpu::MBarrierInitOp::create(
936  rewriter, loc, barrier,
937  getValueOrCreateConstantIndexOp(rewriter, loc, numThreads), zero,
938  Value());
939  gpu::BarrierOp::create(rewriter, loc);
940  return cast<TypedValue<MBarrierGroupType>>(barrier);
941 }
942 
945  gpu::LaunchOp launchOp) {
946  OpBuilder::InsertionGuard guard(rewriter);
947  rewriter.setInsertionPoint(launchOp);
948  Value unrankedMemRef = memref::CastOp::create(
949  rewriter, loc,
950  UnrankedMemRefType::get(memref.getType().getElementType(),
951  memref.getType().getMemorySpace()),
952  memref);
953  SmallVector<OpFoldResult> mixedSizes =
954  memref::getMixedSizes(rewriter, loc, memref);
955  SmallVector<Value> sizes =
956  getValueOrCreateConstantIndexOp(rewriter, loc, mixedSizes);
957 
958  auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
959  Value desc = TmaCreateDescriptorOp::create(
960  rewriter, loc,
962  MemRefType::Builder(memref.getType())
963  .setMemorySpace(sharedMemorySpace),
964  TensorMapSwizzleKind::SWIZZLE_NONE,
965  TensorMapL2PromoKind::L2PROMO_NONE,
966  TensorMapOOBKind::OOB_ZERO,
967  TensorMapInterleaveKind::INTERLEAVE_NONE),
968  unrankedMemRef, sizes);
969  return cast<TypedValue<TensorMapDescriptorType>>(desc);
970 }
971 
974  TypedValue<MemRefType> sharedMemref,
976  SmallVectorImpl<Operation *> &loadOps) {
977  MLIRContext *ctx = rewriter.getContext();
978  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
979  Operation *loadOp =
980  TmaAsyncLoadOp::create(rewriter, loc, sharedMemref, barrier, globalDesc,
981  ValueRange{zero, zero}, zero, Value(), Value());
982  loadOps.push_back(loadOp);
983  auto mixedSizes = memref::getMixedSizes(rewriter, loc, sharedMemref);
984  SmallVector<AffineExpr> symbols(mixedSizes.size());
985  bindSymbolsList(ctx, llvm::MutableArrayRef{symbols});
986  AffineExpr prodExprInBytes =
987  computeProduct(ctx, symbols) *
988  (sharedMemref.getType().getElementTypeBitWidth() / 8);
989  auto res = affine::makeComposedFoldedAffineApply(rewriter, loc,
990  prodExprInBytes, mixedSizes);
991  return res;
992 }
993 
995  ArrayRef<OpFoldResult> mixedSizes) {
996  assert(!mixedSizes.empty() && "expecte non-empty sizes");
997  MLIRContext *ctx = rewriter.getContext();
998  SmallVector<AffineExpr> symbols(mixedSizes.size());
999  bindSymbolsList(ctx, llvm::MutableArrayRef{symbols});
1000  AffineExpr sumExpr = computeSum(ctx, symbols);
1001  OpFoldResult size =
1002  affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, mixedSizes);
1003  Value sizeVal = getValueOrCreateConstantIndexOp(rewriter, loc, size);
1004  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1005  nvgpu::MBarrierArriveExpectTxOp::create(rewriter, loc, barrier, sizeVal, zero,
1006  Value());
1007 }
1008 
1010  Type i1 = rewriter.getI1Type();
1011  Value parity = LLVM::ConstantOp::create(rewriter, loc, i1, 0);
1012  // 10M is an arbitrary, not too small or too big number to specify the number
1013  // of ticks before retry.
1014  // TODO: hoist this in a default dialect constant.
1015  Value ticksBeforeRetry =
1016  arith::ConstantIndexOp::create(rewriter, loc, 10000000);
1017  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1018  nvgpu::MBarrierTryWaitParityOp::create(rewriter, loc, barrier, parity,
1019  ticksBeforeRetry, zero);
1020 }
1021 
1022 //===----------------------------------------------------------------------===//
1023 // RewriteCopyAsTmaOp
1024 //===----------------------------------------------------------------------===//
1025 
1026 /// Helper to create the tma operations corresponding to `linalg::CopyOp`.
1027 struct CopyBuilder : public HopperBuilder {
1029  : HopperBuilder(rewriter, loc) {}
1030 
1032 };
1033 
1034 SmallVector<Operation *> CopyBuilder::rewrite(ArrayRef<Operation *> copyOps) {
1035  MLIRContext *ctx = rewriter.getContext();
1036  if (copyOps.empty())
1037  return SmallVector<Operation *>();
1038 
1039  auto launchOp = copyOps.front()->getParentOfType<gpu::LaunchOp>();
1040  assert(launchOp && "expected launch op");
1041 
1042  // 1. Init a barrier object in shared memory.
1043  OpBuilder::InsertionGuard g(rewriter);
1044  rewriter.setInsertionPoint(copyOps.front());
1045  AffineExpr bx, by, bz;
1046  bindSymbols(ctx, bx, by, bz);
1047  AffineExpr prod = computeProduct(ctx, ArrayRef<AffineExpr>{bx, by, bz});
1049  rewriter, loc, prod,
1050  ArrayRef<OpFoldResult>{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(),
1051  launchOp.getBlockSizeZ()});
1052 
1054  buildAndInitBarrierInSharedMemory(numThreads);
1055 
1058  for (Operation *op : copyOps) {
1059  auto copyOp = cast<linalg::CopyOp>(op);
1060  auto inMemRef =
1061  cast<TypedValue<MemRefType>>(copyOp.getDpsInputOperand(0)->get());
1062  assert(inMemRef.getType().getRank() == 2 &&
1063  "expected in to be a 2D memref");
1064 
1065  // 2. Build global memory descriptor.
1067  buildGlobalMemRefDescriptor(inMemRef, launchOp);
1068  globalDescs.push_back(globalDesc);
1069 
1070  // 3. Shared memory and descriptor for the tmp array.
1071  auto shmem =
1072  cast<TypedValue<MemRefType>>(copyOp.getDpsInitOperand(0)->get());
1073  shmems.push_back(shmem);
1074  }
1075 
1076  // 4. Load in from global memory to shared memory using tma.
1077  OpBuilder::InsertionGuard g2(rewriter);
1078  rewriter.setInsertionPoint(copyOps.front());
1079  SmallVector<Operation *> results =
1080  buildPredicateLoadsOnThread0(globalDescs, shmems, barrier);
1081 
1082  // 5. Spin-loop until data is ready.
1083  buildTryWaitParity(barrier);
1084 
1085  // 6. Erase the ops that have now been rewritten.
1086  for (Operation *op : copyOps)
1087  rewriter.eraseOp(op);
1088 
1089  return results;
1090 }
1091 
1093 RewriteCopyAsTmaOp::apply(TransformRewriter &rewriter,
1094  TransformResults &results, TransformState &state) {
1095  auto payloadOps = state.getPayloadOps(getTarget());
1096  gpu::LaunchOp commonLaunchOp;
1097  Operation *firstOp, *failingOp;
1098  if (llvm::any_of(payloadOps, [&](Operation *op) {
1099  if (!commonLaunchOp) {
1100  commonLaunchOp = op->getParentOfType<gpu::LaunchOp>();
1101  firstOp = op;
1102  }
1103  auto fail = !op->getParentOfType<gpu::LaunchOp>() ||
1104  commonLaunchOp != op->getParentOfType<gpu::LaunchOp>() ||
1105  !isa<linalg::CopyOp>(op);
1106  if (fail)
1107  failingOp = op;
1108  return fail;
1109  })) {
1111  emitSilenceableError()
1112  << "target ops must be linalg::CopyOp nested under a common "
1113  "gpu.LaunchOp to be rewritten because the tma descriptors need to "
1114  "be created on the host.\nBut got: "
1115  << *firstOp << "\nand " << *failingOp;
1116  return diag;
1117  }
1118 
1119  // TODO: more robust detection of copy, with transposes etc.
1120  CopyBuilder(rewriter, getLoc()).rewrite(llvm::to_vector(payloadOps));
1121 
1123 }
1124 
1125 //===----------------------------------------------------------------------===//
1126 // Transform op registration
1127 //===----------------------------------------------------------------------===//
1128 
1129 namespace {
1130 class NVGPUTransformDialectExtension
1131  : public TransformDialectExtension<NVGPUTransformDialectExtension> {
1132 public:
1133  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NVGPUTransformDialectExtension)
1134 
1135  NVGPUTransformDialectExtension() {
1136  declareGeneratedDialect<arith::ArithDialect>();
1137  declareGeneratedDialect<affine::AffineDialect>();
1138  declareGeneratedDialect<NVGPUDialect>();
1139  declareGeneratedDialect<NVVM::NVVMDialect>();
1140  declareGeneratedDialect<vector::VectorDialect>();
1141  registerTransformOps<
1142 #define GET_OP_LIST
1143 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1144  >();
1145  }
1146 };
1147 } // namespace
1148 
1149 #define GET_OP_CLASSES
1150 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1151 
1153  registry.addExtensions<NVGPUTransformDialectExtension>();
1154 }
static std::string diag(const llvm::Value &value)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
Definition: NVGPUDialect.h:40
static Attribute getSharedAddressSpaceAttribute(OpBuilder &b)
static bool hasDefaultMemorySpace(BaseMemRefType type)
Returns true if the given type has the default memory space.
static LogicalResult collectStage0PipeliningOps(scf::ForOp forOp, llvm::SmallPtrSet< Operation *, 16 > &ops)
Populate ops with the set of operations that belong to the stage 0 of the pipelined version of the gi...
static Operation * replaceOpWithPredicatedOp(RewriterBase &rewriter, Operation *op, Value predicate)
Hook for the loop pipeliner.
static bool isStoreToShared(Operation *op, Value v)
Returns true if the operation is storing the given value into shared memory.
static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn, ReduceFn reduceFn)
Helper functions to create customizable load and stores operations.
static std::tuple< DiagnosedSilenceableFailure, scf::ForOp > pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth, bool epiloguePeeling)
Applies loop pipelining with the given depth to the given loop so that copies into the shared memory ...
static bool hasSharedMemorySpace(BaseMemRefType type)
Returns true if the given type has the shared (workgroup) memory space.
static bool isLoadFromGlobalStoredToShared(Operation *op)
Returns true if the operation is a load from the default memory space the result of which is only sto...
static std::tuple< SmallVector< int64_t >, SmallVector< int64_t >, SmallVector< int64_t > > makeVectorShapes(ArrayRef< int64_t > lhs, ArrayRef< int64_t > rhs, ArrayRef< int64_t > res)
static void setAsyncWaitGroupsInFlight(OpBuilder &builder, Operation *op, scf::PipeliningOption::PipelinerPart part, unsigned iteration, unsigned depth)
Hook for the loop pipeliner that sets the "num groups in flight" attribute of async wait operations c...
static void getPipelineStages(scf::ForOp forOp, std::vector< std::pair< Operation *, unsigned >> &opsWithPipelineStages, unsigned depth, llvm::SmallPtrSetImpl< Operation * > &stage0Ops)
Hook for the loop pipeliner that populates ops with the stage information as follows:
static Value getValueLoadedFromGlobal(Operation *op)
Returns the value produced by a load from the default memory space.
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< VectorShape > vectorShape(Type type)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:331
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:959
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:104
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
FloatType getF32Type()
Definition: Builders.cpp:43
FloatType getF16Type()
Definition: Builders.cpp:39
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:53
IndexType getIndexType()
Definition: Builders.cpp:51
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:182
Builder & setMemorySpace(Attribute newMemorySpace)
Definition: BuiltinTypes.h:208
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Type conversion class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_range getUsers() const
Definition: Value.h:218
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:197
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void push_back(Operation *op)
Appends an element to the list.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1329
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
void registerTransformDialectExtension(DialectRegistry &registry)
void createAsyncGroups(RewriterBase &rewriter, Operation *op, bool bypassL1)
Convert global->shared vector transfers to async device copies.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
FailureOr< ForOp > pipelineForLoop(RewriterBase &rewriter, ForOp forOp, const PipeliningOption &options, bool *modifiedIR=nullptr)
Generate a pipelined version of the scf.for loop based on the schedule given as option.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:619
int64_t computeSum(ArrayRef< int64_t > basis)
Self-explicit.
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
Definition: AffineExpr.h:330
Helper to create the tma operations corresponding to linalg::CopyOp.
SmallVector< Operation * > rewrite(ArrayRef< Operation * > copyOps)
CopyBuilder(RewriterBase &rewriter, Location loc)
Helper to create the base Hopper-specific operations that are reused in various other places.
void buildBarrierArriveTx(TypedValue< MBarrierGroupType > barrier, ArrayRef< OpFoldResult > sizes)
OpFoldResult buildTmaAsyncLoad(TypedValue< TensorMapDescriptorType > globalDesc, TypedValue< MemRefType > sharedMemref, TypedValue< MBarrierGroupType > barrier, SmallVectorImpl< Operation * > &loadOps)
Build a tma load from global memory to shared memory using barrier to synchronize.
RewriterBase & rewriter
TypedValue< TensorMapDescriptorType > buildGlobalMemRefDescriptor(TypedValue< MemRefType > memref, gpu::LaunchOp launchOp)
Create tma descriptor op to initiate transfer from global to shared memory.
void buildTryWaitParity(TypedValue< MBarrierGroupType > barrier)
TypedValue< MBarrierGroupType > buildAndInitBarrierInSharedMemory(OpFoldResult numThreads)
SmallVector< Operation * > buildPredicateLoadsOnThread0(ArrayRef< TypedValue< TensorMapDescriptorType >> globalDescriptors, ArrayRef< TypedValue< MemRefType >> sharedMemBuffers, TypedValue< MBarrierGroupType > barrier)
If threadIdx.x == 0 does TMA request + wait, else just wait.
HopperBuilder(RewriterBase &rewriter, Location loc)
Helper struct to provide a simple mapping from matmul operations to the corresponding mma....
std::function< SmallVector< RowColIndexing >(MLIRContext *)> IndexCalculator
MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId)
FailureOr< Operation * > buildMmaSync(LinalgOp linalgOp)
Create the mma.sync operation corresponding to linalgOp along with all the supporting load/store and ...
Helper struct to encode a pair of row/column indexings in the form of affine expressions.
AffineExpr col() const
RowColIndexing(AffineExpr row, AffineExpr col)
void print(llvm::raw_ostream &os) const
AffineExpr row() const
Options to dictate how loops should be pipelined.
Definition: Transforms.h:129