MLIR  22.0.0git
NVGPUTransformOps.cpp
Go to the documentation of this file.
1 //===- NVGPUTransformOps.cpp - Implementation of NVGPU transform ops ------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
10 
29 #include "mlir/IR/AffineExpr.h"
30 #include "mlir/IR/BuiltinTypes.h"
31 #include "mlir/IR/Value.h"
32 #include "llvm/ADT/ArrayRef.h"
33 
34 using namespace mlir;
35 using namespace mlir::linalg;
36 using namespace mlir::nvgpu;
37 using namespace mlir::NVVM;
38 using namespace mlir::transform;
39 
40 #define DEBUG_TYPE "nvgpu-transforms"
41 
42 //===----------------------------------------------------------------------===//
43 // Apply...ConversionPatternsOp
44 //===----------------------------------------------------------------------===//
45 
46 void transform::ApplyNVGPUToNVVMConversionPatternsOp::populatePatterns(
47  TypeConverter &typeConverter, RewritePatternSet &patterns) {
48  auto &llvmTypeConverter = static_cast<LLVMTypeConverter &>(typeConverter);
49  /// device-side async tokens cannot be materialized in nvvm. We just
50  /// convert them to a dummy i32 type in order to easily drop them during
51  /// conversion.
53  llvmTypeConverter, [](gpu::AddressSpace space) -> unsigned {
54  switch (space) {
55  case gpu::AddressSpace::Global:
56  return static_cast<unsigned>(NVVM::NVVMMemorySpace::Global);
57  case gpu::AddressSpace::Workgroup:
58  return static_cast<unsigned>(NVVM::NVVMMemorySpace::Shared);
59  case gpu::AddressSpace::Private:
60  return 0;
61  }
62  llvm_unreachable("unknown address space enum value");
63  return static_cast<unsigned>(NVVM::NVVMMemorySpace::Generic);
64  });
65  llvmTypeConverter.addConversion(
66  [&](nvgpu::DeviceAsyncTokenType type) -> Type {
67  return llvmTypeConverter.convertType(
68  IntegerType::get(type.getContext(), 32));
69  });
70  llvmTypeConverter.addConversion([&](nvgpu::MBarrierTokenType type) -> Type {
71  return llvmTypeConverter.convertType(
72  IntegerType::get(type.getContext(), 64));
73  });
74  llvmTypeConverter.addConversion(
75  [&](nvgpu::WarpgroupAccumulatorType type) -> Type {
76  Type elemType = type.getFragmented().getElementType();
77  int64_t sizeM = type.getFragmented().getDimSize(0);
78  int64_t sizeN = type.getFragmented().getDimSize(1);
79 
80  unsigned numMembers;
81  if (elemType.isF32() || elemType.isInteger(32))
82  numMembers = sizeN / 2;
83  else if (elemType.isF16())
84  numMembers = sizeN / 4;
85  else
86  llvm_unreachable("unsupported type for warpgroup accumulator");
87 
88  SmallVector<Type> innerStructBody;
89  for (unsigned i = 0; i < numMembers; i++)
90  innerStructBody.push_back(elemType);
91  auto innerStructType = LLVM::LLVMStructType::getLiteral(
92  type.getContext(), innerStructBody);
93 
94  SmallVector<Type> structBody;
95  for (int i = 0; i < sizeM; i += kWgmmaSizeM)
96  structBody.push_back(innerStructType);
97 
98  auto convertedType =
99  LLVM::LLVMStructType::getLiteral(type.getContext(), structBody);
100  return llvmTypeConverter.convertType(convertedType);
101  });
102  llvmTypeConverter.addConversion([&](nvgpu::MBarrierGroupType type) -> Type {
103  return llvmTypeConverter.convertType(
104  getMBarrierMemrefType(type.getContext(), type));
105  });
106  llvmTypeConverter.addConversion(
107  [&](nvgpu::WarpgroupMatrixDescriptorType type) -> Type {
108  return llvmTypeConverter.convertType(
109  IntegerType::get(type.getContext(), 64));
110  });
111  llvmTypeConverter.addConversion(
112  [&](nvgpu::TensorMapDescriptorType type) -> Type {
113  return LLVM::LLVMPointerType::get(type.getContext());
114  });
116 }
117 
118 LogicalResult
119 transform::ApplyNVGPUToNVVMConversionPatternsOp::verifyTypeConverter(
120  transform::TypeConverterBuilderOpInterface builder) {
121  if (builder.getTypeConverterType() != "LLVMTypeConverter")
122  return emitOpError("expected LLVMTypeConverter");
123  return success();
124 }
125 
126 //===---------------------------------------------------------------------===//
127 // CreateAsyncGroupsOp
128 //===---------------------------------------------------------------------===//
129 
130 void transform::CreateAsyncGroupsOp::getEffects(
132  transform::consumesHandle(getTargetMutable(), effects);
133  transform::producesHandle(getOperation()->getOpResults(), effects);
135 }
136 
137 DiagnosedSilenceableFailure transform::CreateAsyncGroupsOp::applyToOne(
138  TransformRewriter &rewriter, Operation *target,
139  ApplyToEachResultList &results, TransformState &state) {
140  nvgpu::createAsyncGroups(rewriter, target, getBypassL1());
141  results.push_back(target);
143 }
144 
145 //===----------------------------------------------------------------------===//
146 // PipelineSharedMemoryCopiesOp
147 //===----------------------------------------------------------------------===//
148 
149 /// Returns true if the given type has the default memory space.
151  return !type.getMemorySpace() || type.getMemorySpaceAsInt() == 0;
152 }
153 
154 /// Returns true if the given type has the shared (workgroup) memory space.
156  auto space =
157  dyn_cast_if_present<gpu::AddressSpaceAttr>(type.getMemorySpace());
158  return space &&
159  space.getValue() == gpu::GPUDialect::getWorkgroupAddressSpace();
160 }
161 
162 /// Returns the value produced by a load from the default memory space. Returns
163 /// null if the operation is not such a load.
165  // TODO: consider an interface or leveraging the memory effects interface.
166  auto load = dyn_cast<vector::TransferReadOp>(op);
167  if (!load)
168  return nullptr;
169 
170  auto loadType = dyn_cast<MemRefType>(load.getBase().getType());
171  if (!loadType || !hasDefaultMemorySpace(loadType))
172  return nullptr;
173  return load;
174 }
175 
176 /// Returns true if the operation is storing the given value into shared memory.
177 static bool isStoreToShared(Operation *op, Value v) {
178  // TOD: consider an interface or leveraging the memory effects interface.
179  auto store = dyn_cast<vector::TransferWriteOp>(op);
180  if (!store || store.getVector() != v)
181  return false;
182 
183  auto storeType = dyn_cast<MemRefType>(store.getBase().getType());
184  return storeType || hasSharedMemorySpace(storeType);
185 }
186 
187 /// Returns true if the operation is a load from the default memory space the
188 /// result of which is only stored into the shared memory space.
190  Value loaded = getValueLoadedFromGlobal(op);
191  if (!loaded || !loaded.hasOneUse())
192  return false;
193 
194  return isStoreToShared(*loaded.getUsers().begin(), loaded);
195 }
196 
197 /// Populate `ops` with the set of operations that belong to the stage 0 of the
198 /// pipelined version of the given loop when pipelining copies to shared memory.
199 /// Specifically, this collects:
200 ///
201 /// 1. all loads from global memory, both sync and async;
202 /// 2. the barriers for async loads.
203 ///
204 /// In particular, barriers are omitted if they do not dominate at least one
205 /// async load for which there is not yet a barrier.
206 static LogicalResult
207 collectStage0PipeliningOps(scf::ForOp forOp,
209 
211  for (Operation &op : *forOp.getBody()) {
212  // Bail on nested ops for now.
213  if (op.getNumRegions() > 0)
214  return failure();
215 
216  if (isa<gpu::BarrierOp>(op)) {
217  barriers.insert(&op);
218  continue;
219  }
220 
221  if (isa<nvgpu::DeviceAsyncCopyOp, nvgpu::DeviceAsyncCreateGroupOp>(op)) {
222  ops.insert(&op);
223  ops.insert(std::make_move_iterator(barriers.begin()),
224  std::make_move_iterator(barriers.end()));
225  assert(barriers.empty() &&
226  "expected to have moved the barriers into another set");
227  continue;
228  }
229 
231  ops.insert(&op);
232  continue;
233  }
234  }
235 
236  return success();
237 }
238 
239 /// Hook for the loop pipeliner that sets the "num groups in flight" attribute
240 /// of async wait operations corresponding to pipelined shared memory copies.
241 // TODO: this currently assumes that there are no groups that could be in flight
242 // in the existing code.
243 static void
246  unsigned iteration, unsigned depth) {
247  // Based on the order of copies within the loop we need to set the number
248  // of copies in flight, unless it is already set.
249  auto waitOp = dyn_cast<nvgpu::DeviceAsyncWaitOp>(op);
250  if (!waitOp || waitOp.getNumGroups())
251  return;
252 
253  int numGroupInFlight = 0;
256  numGroupInFlight = depth - 1;
257  } else {
258  // By construction there should be no wait op in the prologue as all the
259  // wait should be in the last stage.
261  // Based on the schedule we pick we know how many groups are in flight for
262  // each iteration of the epilogue.
263  numGroupInFlight = depth - 1 - iteration;
264  }
265  waitOp.setNumGroups(numGroupInFlight);
266 }
267 
268 /// Hook for the loop pipeliner that populates `ops` with the stage information
269 /// as follows:
270 ///
271 /// - operations in `stage0Ops` (typically loads from global memory and
272 /// related barriers) are at stage 0;
273 /// - operations in the backward slice of any stage0Ops are all at stage 0;
274 /// - other operations are at stage `depth`;
275 /// - the internal order of the pipelined loop has ops at stage `depth` first,
276 /// then those at stage 0, with relative order within each group preserved.
277 ///
278 static void getPipelineStages(
279  scf::ForOp forOp,
280  std::vector<std::pair<Operation *, unsigned>> &opsWithPipelineStages,
281  unsigned depth, llvm::SmallPtrSetImpl<Operation *> &stage0Ops) {
282  SetVector<Operation *> dependencies;
283  BackwardSliceOptions options([&](Operation *visited) {
284  return visited->getBlock() == forOp.getBody();
285  });
286  options.inclusive = true;
287  for (Operation &op : forOp.getBody()->getOperations()) {
288  if (stage0Ops.contains(&op)) {
289  LogicalResult result = getBackwardSlice(&op, &dependencies, options);
290  assert(result.succeeded() && "expected a backward slice");
291  (void)result;
292  }
293  }
294 
295  for (Operation &op : forOp.getBody()->getOperations()) {
296  if (!dependencies.contains(&op) && !isa<scf::YieldOp>(op))
297  opsWithPipelineStages.emplace_back(&op, depth);
298  }
299  for (Operation &op : forOp.getBody()->getOperations()) {
300  if (dependencies.contains(&op))
301  opsWithPipelineStages.emplace_back(&op, 0);
302  }
303 }
304 
305 /// Hook for the loop pipeliner. Replaces op with a predicated version and
306 /// returns the resulting operation. Returns the original op if the predication
307 /// isn't necessary for the given op. Returns null if predication is needed but
308 /// not supported.
310  Operation *op, Value predicate) {
311  // Some operations may be fine to execute "speculatively" more times than the
312  // original number of iterations, in particular side-effect free operations
313  // and barriers, even if they cannot be predicated.
314  if (isMemoryEffectFree(op) ||
315  isa<gpu::BarrierOp, nvgpu::DeviceAsyncCreateGroupOp,
316  nvgpu::DeviceAsyncWaitOp>(op)) {
317  return op;
318  }
319 
320  // Otherwise, only async copies can currently be predicated.
321  auto asyncCopyOp = dyn_cast<nvgpu::DeviceAsyncCopyOp>(op);
322  if (!asyncCopyOp)
323  return nullptr;
324 
325  // Create srcElement Value based on `predicate`. The next lines generate
326  // the following code:
327  //
328  // srcElement = (pred) ? prevSrcElements : 0;
329  //
330  Location loc = asyncCopyOp->getLoc();
331  Value dstElements = arith::ConstantOp::create(
332  rewriter, loc, asyncCopyOp.getDstElementsAttr());
333  Value originalSrcElement =
334  asyncCopyOp.getSrcElements() ? asyncCopyOp.getSrcElements() : dstElements;
335  Value c0Index = arith::ConstantIndexOp::create(rewriter, loc, 0);
336  auto srcElements = arith::SelectOp::create(rewriter, loc, predicate,
337  originalSrcElement, c0Index);
338  auto asyncCopyZeroFillOp = nvgpu::DeviceAsyncCopyOp::create(
339  rewriter, loc, nvgpu::DeviceAsyncTokenType::get(asyncCopyOp.getContext()),
340  asyncCopyOp.getDst(), asyncCopyOp.getDstIndices(), asyncCopyOp.getSrc(),
341  asyncCopyOp.getSrcIndices(), asyncCopyOp.getDstElements(), srcElements,
342  UnitAttr());
343  rewriter.replaceOp(asyncCopyOp, asyncCopyZeroFillOp);
344  return asyncCopyZeroFillOp;
345 }
346 
347 /// Applies loop pipelining with the given depth to the given loop so that
348 /// copies into the shared memory are pipelined. Doesn't affect other loops.
349 /// Returns a pair containing the error state and the pipelined op, the latter
350 /// being null in case of any failure. The error state contains a definite error
351 /// if the IR has been modified and a silenceable error otherwise.
352 static std::tuple<DiagnosedSilenceableFailure, scf::ForOp>
353 pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth,
354  bool epiloguePeeling) {
356  if (failed(collectStage0PipeliningOps(forOp, stage0Ops))) {
357  return std::make_tuple(
358  emitSilenceableFailure(forOp, "cannot find stage 0 ops for pipelining"),
359  scf::ForOp());
360  }
361  if (stage0Ops.empty()) {
362  return std::make_tuple(
363  emitSilenceableFailure(forOp, "no shared memory copy"), scf::ForOp());
364  }
365 
367  unsigned maxDepth = depth;
368  auto setAnnotation = [&](Operation *op,
370  unsigned iteration) {
371  return setAsyncWaitGroupsInFlight(rewriter, op, part, iteration, maxDepth);
372  };
373  options.getScheduleFn =
374  [&](scf::ForOp schedulingFor,
375  std::vector<std::pair<Operation *, unsigned>> &ops) {
376  if (schedulingFor != forOp)
377  return;
378  return getPipelineStages(forOp, ops, maxDepth, stage0Ops);
379  };
380  options.annotateFn = setAnnotation;
381  if (!epiloguePeeling) {
382  options.peelEpilogue = false;
383  options.predicateFn = replaceOpWithPredicatedOp;
384  }
385 
386  OpBuilder::InsertionGuard guard(rewriter);
387  rewriter.setInsertionPoint(forOp);
388  bool modifiedIR;
389  FailureOr<scf::ForOp> maybePipelined =
390  pipelineForLoop(rewriter, forOp, options, &modifiedIR);
391  if (succeeded(maybePipelined)) {
392  return std::make_tuple(DiagnosedSilenceableFailure::success(),
393  *maybePipelined);
394  }
395  return std::make_tuple(
396  modifiedIR
398  : emitSilenceableFailure(forOp, "pipelining preconditions failed"),
399  scf::ForOp());
400 }
401 
402 DiagnosedSilenceableFailure PipelineSharedMemoryCopiesOp::applyToOne(
403  TransformRewriter &rewriter, scf::ForOp forOp,
404  ApplyToEachResultList &results, TransformState &state) {
405  auto [diag, pipelined] = pipelineForSharedCopies(
406  rewriter, forOp, static_cast<int64_t>(getDepth()), getPeelEpilogue());
407  if (diag.succeeded()) {
408  results.push_back(pipelined);
410  }
411  if (diag.isDefiniteFailure()) {
412  auto diag = emitDefiniteFailure("irreversible pipelining failure");
413  if (!getPeelEpilogue()) {
414  diag.attachNote(forOp->getLoc()) << "couldn't predicate?";
415  diag.attachNote(getLoc()) << "try setting " << getPeelEpilogueAttrName();
416  }
417  return diag;
418  }
419 
420  return std::move(diag);
421 }
422 
423 //===----------------------------------------------------------------------===//
424 // RewriteMatmulAsMmaSyncOp
425 //===----------------------------------------------------------------------===//
426 
427 /// Helper struct to encode a pair of row/column indexings in the form of
428 /// affine expressions.
429 struct RowColIndexing : private std::pair<AffineExpr, AffineExpr> {
431  : std::pair<AffineExpr, AffineExpr>(row, col) {}
432 
433  AffineExpr row() const { return first; };
434  AffineExpr col() const { return second; };
435 
436  void print(llvm::raw_ostream &os) const {
437  os << "- indexing: " << first << ", " << second;
438  }
439 };
440 
441 /// Helper struct to provide a simple mapping from matmul operations to the
442 /// corresponding mma.sync operation. This is constrained to the case where the
443 /// matmul matches the mma.sync operation 1-1.
446  : b(b), loc(loc), laneId(laneId) {}
447 
449  std::function<SmallVector<RowColIndexing>(MLIRContext *)>;
450 
451  /// Create the mma.sync operation corresponding to `linalgOp` along with all
452  /// the supporting load/store and vector operations.
453  FailureOr<Operation *> buildMmaSync(LinalgOp linalgOp);
454 
455 private:
456  struct MmaSyncInfo {
457  std::tuple<IndexCalculator, IndexCalculator, IndexCalculator> indexFns;
458  std::tuple<SmallVector<int64_t>, SmallVector<int64_t>, SmallVector<int64_t>>
459  vectorShapes;
460  SmallVector<int64_t> mmaShape;
461  bool tf32Enabled;
462  };
463 
464  /// Return the specific index calculator for the given `linalgOp` or failure
465  /// if the op is not supported. This is the toplevel switch that should just
466  /// be Tablegen'd in the future.
467  FailureOr<MmaSyncInfo> getIndexCalculators(ArrayRef<int64_t> opShape,
468  TypeRange elementalTypes);
469 
470  //===--------------------------------------------------------------------===//
471  // Instruction-specific row, column indexing expression builders.
472  // These should all be declaratively specified via Tablegen in the future.
473  // The Tablegen specification should be as straightforward as possible to
474  // only model the existing size and type combinations.
475  //===--------------------------------------------------------------------===//
476  //
477  // TODO: Tablegen all this.
478  //===--------------------------------------------------------------------===//
479  // m16n8k4 tf32 case.
480  //===--------------------------------------------------------------------===//
481  /// From the NVIDIA doc:
482  /// groupID = %laneid >> 2
483  /// threadIDInGroup = %laneid % 4
484  /// row = groupID for a0
485  /// groupID + 8 for a1
486  /// col = threadIDInGroup
487  static SmallVector<RowColIndexing> m16n8k4tf32Lhs(MLIRContext *ctx) {
488  auto dim = getAffineDimExpr(0, ctx);
489  AffineExpr groupID = dim.floorDiv(4);
490  AffineExpr threadIDInGroup = dim % 4;
491  return {RowColIndexing{groupID, threadIDInGroup},
492  RowColIndexing{groupID + 8, threadIDInGroup}};
493  }
494 
495  /// From the NVIDIA doc:
496  /// groupID = %laneid >> 2
497  /// threadIDInGroup = %laneid % 4
498  /// row = threadIDInGroup
499  /// col = groupID
500  static SmallVector<RowColIndexing> m16n8k4tf32Rhs(MLIRContext *ctx) {
501  auto dim = getAffineDimExpr(0, ctx);
502  AffineExpr groupID = dim.floorDiv(4);
503  AffineExpr threadIDInGroup = dim % 4;
504  return {RowColIndexing{threadIDInGroup, groupID}};
505  }
506 
507  /// From the NVIDIA doc:
508  /// groupID = %laneid >> 2
509  /// threadIDInGroup = %laneid % 4
510  /// row = groupID for c0 and c1
511  /// groupID + 8 for c2 and c3
512  /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3}
513  static SmallVector<RowColIndexing> m16n8k4tf32Res(MLIRContext *ctx) {
514  auto dim = getAffineDimExpr(0, ctx);
515  AffineExpr groupID = dim.floorDiv(4);
516  AffineExpr threadIDInGroup = dim % 4;
517  return {RowColIndexing{groupID, threadIDInGroup * 2 + 0},
518  RowColIndexing{groupID, threadIDInGroup * 2 + 1},
519  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0},
520  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}};
521  }
522 
523  //===--------------------------------------------------------------------===//
524  // m16n8k16 f16 case.
525  //===--------------------------------------------------------------------===//
526  /// From the NVIDIA doc:
527  /// groupID = %laneid >> 2
528  /// threadIDInGroup = %laneid % 4
529  ///
530  /// row = groupID for ai where 0 <= i < 2 || 4 <= i < 6
531  /// groupID + 8 Otherwise
532  ///
533  /// col = (threadIDInGroup * 2) + (i & 0x1) for ai where i < 4
534  /// (threadIDInGroup * 2) + (i & 0x1) + 8 for ai where i >= 4
535  static SmallVector<RowColIndexing> m16n8k16f16Lhs(MLIRContext *ctx) {
536  auto dim = getAffineDimExpr(0, ctx);
537  AffineExpr groupID = dim.floorDiv(4);
538  AffineExpr threadIDInGroup = dim % 4;
539  // clang-format off
540  return {
541  RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0
542  RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1
543  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2
544  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1}, // i == 3
545  RowColIndexing{groupID, threadIDInGroup * 2 + 0 + 8}, // i == 4
546  RowColIndexing{groupID, threadIDInGroup * 2 + 1 + 8}, // i == 5
547  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0 + 8}, // i == 6
548  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1 + 8} // i == 7
549  };
550  // clang-format on
551  }
552 
553  /// From the NVIDIA doc:
554  /// groupID = %laneid >> 2
555  /// threadIDInGroup = %laneid % 4
556  ///
557  /// row = (threadIDInGroup * 2) + (i & 0x1) for bi where i < 2
558  /// (threadIDInGroup * 2) + (i & 0x1) + 8 for bi where i >= 2
559  ///
560  /// col = groupID
561  static SmallVector<RowColIndexing> m16n8k16f16Rhs(MLIRContext *ctx) {
562  auto dim = getAffineDimExpr(0, ctx);
563  AffineExpr groupID = dim.floorDiv(4);
564  AffineExpr threadIDInGroup = dim % 4;
565  // clang-format off
566  return {
567  RowColIndexing{threadIDInGroup * 2 + 0, groupID}, // i == 0
568  RowColIndexing{threadIDInGroup * 2 + 1, groupID}, // i == 1
569  RowColIndexing{threadIDInGroup * 2 + 0 + 8, groupID}, // i == 2
570  RowColIndexing{threadIDInGroup * 2 + 1 + 8, groupID} // i == 3
571  };
572  // clang-format on
573  }
574 
575  /// From the NVIDIA doc:
576  /// groupID = %laneid >> 2
577  /// threadIDInGroup = %laneid % 4
578  ///
579  /// row = groupID for ci where i < 2
580  /// groupID + 8 for ci where i >= 2
581  ///
582  /// col = (threadIDInGroup * 2) + (i & 0x1) for ci where i = {0,..,3}
583  static SmallVector<RowColIndexing> m16n8k16f16Res(MLIRContext *ctx) {
584  auto dim = getAffineDimExpr(0, ctx);
585  AffineExpr groupID = dim.floorDiv(4);
586  AffineExpr threadIDInGroup = dim % 4;
587  // clang-format off
588  return {
589  RowColIndexing{groupID, threadIDInGroup * 2 + 0}, // i == 0
590  RowColIndexing{groupID, threadIDInGroup * 2 + 1}, // i == 1
591  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 0}, // i == 2
592  RowColIndexing{groupID + 8, threadIDInGroup * 2 + 1} // i == 3
593  };
594  // clang-format on
595  }
596 
597  //===--------------------------------------------------------------------===//
598  /// Helper functions to create customizable load and stores operations. The
599  /// specific shapes of each MMA instruction are passed via the
600  /// IndexCalculator callback.
601  //===--------------------------------------------------------------------===//
602  /// Build a list of memref.load operations indexed at `(row, col)` indices
603  /// that make sense for a particular MMA instruction and specified via the
604  /// IndexCalculator callback.
605  SmallVector<Value> buildMemRefLoads(OpBuilder &b, Location loc,
606  OpFoldResult laneId, Value memref,
607  const IndexCalculator &indexFn);
608 
609  /// Perform a distributed load of a vector operand of `vectorShape` for a
610  /// particular MMA instruction whose `(row, col)` indices are specified via
611  /// the IndexCalculator callback. Each `laneId` loads the subportion of the
612  /// data that makes sense for the particular MMA operation.
613  /// The `vectorShape` matches existing NVGPU dialect op specification but
614  /// could also be flattened in the future if needed for simplification.
615  Value buildMmaSyncMemRefLoadOperand(OpBuilder &b, Location loc,
616  OpFoldResult laneId, Value memref,
617  IndexCalculator indexFn,
619 
620  /// Build a list of memref.store operations indexed at `(row, col)` indices
621  /// that make sense for a particular MMA instruction and specified via the
622  /// IndexCalculator callback.
623  SmallVector<Operation *> buildMemRefStores(OpBuilder &b, Location loc,
624  ValueRange toStore,
625  OpFoldResult laneId, Value memref,
626  const IndexCalculator &indexFn);
627 
628  /// Perform a distributed store of a vector operand of `vectorShape` for a
629  /// particular MMA instruction whose `(row, col)` indices are specified via
630  /// the IndexCalculator callback. Each `laneId` loads the subportion of the
631  /// data that makes sense for the particular MMA operation.
632  /// The `vectorShape` matches existing NVGPU dialect op specification but
633  /// could also be flattened in the future if needed for simplification.
634  SmallVector<Operation *> buildMmaSyncMemRefStoreOperand(
635  OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
636  Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape);
637 
638  OpBuilder &b;
639  Location loc;
640  OpFoldResult laneId;
641 };
642 
643 //===--------------------------------------------------------------------===//
644 /// Helper functions to create customizable load and stores operations. The
645 /// specific shapes of each MMA instruction are passed via the
646 /// IndexCalculator callback.
647 //===--------------------------------------------------------------------===//
648 
649 template <typename ApplyFn, typename ReduceFn>
650 static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn,
651  ReduceFn reduceFn) {
652  VectorType vectorType = cast<VectorType>(vector.getType());
653  auto vectorShape = vectorType.getShape();
654  auto strides = computeStrides(vectorShape);
655  for (int64_t idx = 0, e = vectorShape[0] * strides[0]; idx < e; ++idx) {
656  auto indices = delinearize(idx, strides);
657  reduceFn(applyFn(vector, idx, indices), idx, indices);
658  }
659 }
660 
662 MmaSyncBuilder::buildMemRefLoads(OpBuilder &b, Location loc,
663  OpFoldResult laneId, Value memref,
664  const IndexCalculator &indexFn) {
665  auto aff = [&](AffineExpr e) {
666  return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
667  };
668  SmallVector<Value> res;
669  SmallVector<RowColIndexing> indexings = indexFn(b.getContext());
670  for (auto indexing : indexings) {
671  Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row()));
672  Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col()));
673  auto load = memref::LoadOp::create(b, loc, memref, ValueRange{row, col});
674  res.push_back(load);
675  }
676  return res;
677 }
678 
679 Value MmaSyncBuilder::buildMmaSyncMemRefLoadOperand(
680  OpBuilder &b, Location loc, OpFoldResult laneId, Value memref,
681  IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
682  auto loads = buildMemRefLoads(b, loc, laneId, memref, std::move(indexFn));
683 
684  Type elementType = getElementTypeOrSelf(memref.getType());
685  auto vt = VectorType::get(vectorShape, elementType);
686  Value res = vector::BroadcastOp::create(b, loc, vt, loads[0]);
688  res,
689  /*applyFn=*/
690  [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
691  return loads[linearIdx];
692  },
693  /*reduceFn=*/
694  [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
695  res = vector::InsertOp::create(b, loc, v, res, indices);
696  });
697 
698  return res;
699 }
700 
701 SmallVector<Operation *> MmaSyncBuilder::buildMemRefStores(
702  OpBuilder &b, Location loc, ValueRange toStore, OpFoldResult laneId,
703  Value memref, const IndexCalculator &indexFn) {
704  auto aff = [&](AffineExpr e) {
705  return affine::makeComposedFoldedAffineApply(b, loc, e, laneId);
706  };
708  for (auto [indexing, val] :
709  llvm::zip_equal(indexFn(b.getContext()), toStore)) {
710  Value row = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.row()));
711  Value col = getValueOrCreateConstantIndexOp(b, loc, aff(indexing.col()));
712  Operation *store =
713  memref::StoreOp::create(b, loc, val, memref, ValueRange{row, col});
714  res.push_back(store);
715  }
716  return res;
717 }
718 
719 SmallVector<Operation *> MmaSyncBuilder::buildMmaSyncMemRefStoreOperand(
720  OpBuilder &b, Location loc, Value vectorToStore, OpFoldResult laneId,
721  Value memref, IndexCalculator indexFn, ArrayRef<int64_t> vectorShape) {
722  SmallVector<Value> toStore;
723  toStore.reserve(32);
725  vectorToStore,
726  /*applyFn=*/
727  [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
728  return vector::ExtractOp::create(b, loc, vectorToStore, indices);
729  },
730  /*reduceFn=*/
731  [&](Value v, int64_t linearIdx, ArrayRef<int64_t> indices) {
732  toStore.push_back(v);
733  });
734  return buildMemRefStores(b, loc, toStore, laneId, memref, std::move(indexFn));
735 }
736 
737 static std::tuple<SmallVector<int64_t>, SmallVector<int64_t>,
740  ArrayRef<int64_t> res) {
741  SmallVector<int64_t> vlhs(lhs);
742  SmallVector<int64_t> vrhs(rhs);
743  SmallVector<int64_t> vres(res);
744  return std::make_tuple(vlhs, vrhs, vres);
745 }
746 
747 FailureOr<MmaSyncBuilder::MmaSyncInfo>
748 MmaSyncBuilder::getIndexCalculators(ArrayRef<int64_t> opShape,
749  TypeRange elementalTypes) {
750  // TODO: Tablegen all this.
751  Type f16 = b.getF16Type();
752  Type f32 = b.getF32Type();
753  if (opShape == ArrayRef<int64_t>{16, 8, 4} &&
754  elementalTypes == TypeRange{f32, f32, f32}) {
755  return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k4tf32Lhs,
756  &MmaSyncBuilder::m16n8k4tf32Rhs,
757  &MmaSyncBuilder::m16n8k4tf32Res),
758  makeVectorShapes({2, 1}, {1, 1}, {2, 2}),
759  SmallVector<int64_t>{opShape},
760  /*tf32Enabled=*/true};
761  }
762  // This is the version with f16 accumulation.
763  // TODO: version with f32 accumulation.
764  if (opShape == ArrayRef<int64_t>{16, 8, 16} &&
765  elementalTypes == TypeRange{f16, f16, f16}) {
766  return MmaSyncInfo{std::make_tuple(&MmaSyncBuilder::m16n8k16f16Lhs,
767  &MmaSyncBuilder::m16n8k16f16Rhs,
768  &MmaSyncBuilder::m16n8k16f16Res),
769  makeVectorShapes({4, 2}, {2, 2}, {2, 2}),
770  SmallVector<int64_t>{opShape},
771  /*tf32Enabled=*/false};
772  }
773  return failure();
774 }
775 
776 FailureOr<Operation *> MmaSyncBuilder::buildMmaSync(LinalgOp linalgOp) {
777  Value lhsMemRef = linalgOp.getDpsInputOperand(0)->get();
778  Value rhsMemRef = linalgOp.getDpsInputOperand(1)->get();
779  Value resMemRef = linalgOp.getDpsInitOperand(0)->get();
780  assert(cast<MemRefType>(lhsMemRef.getType()).getRank() == 2 &&
781  "expected lhs to be a 2D memref");
782  assert(cast<MemRefType>(rhsMemRef.getType()).getRank() == 2 &&
783  "expected rhs to be a 2D memref");
784  assert(cast<MemRefType>(resMemRef.getType()).getRank() == 2 &&
785  "expected res to be a 2D memref");
786 
787  int64_t m = cast<MemRefType>(lhsMemRef.getType()).getShape()[0];
788  int64_t n = cast<MemRefType>(rhsMemRef.getType()).getShape()[1];
789  int64_t k = cast<MemRefType>(lhsMemRef.getType()).getShape()[1];
790  Type lhsType = getElementTypeOrSelf(lhsMemRef.getType());
791  Type rhsType = getElementTypeOrSelf(rhsMemRef.getType());
792  Type resType = getElementTypeOrSelf(resMemRef.getType());
793 
794  FailureOr<MmaSyncInfo> maybeInfo =
795  getIndexCalculators({m, n, k}, {lhsType, rhsType, resType});
796  if (failed(maybeInfo))
797  return failure();
798 
799  MmaSyncInfo info = *maybeInfo;
800  auto [lhsIndexFn, rhsIndexFn, resIndexFn] = info.indexFns;
801  auto [lhsShape, rhsShape, resShape] = info.vectorShapes;
802  Value lhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, lhsMemRef,
803  lhsIndexFn, lhsShape);
804  Value rhs = buildMmaSyncMemRefLoadOperand(b, loc, laneId, rhsMemRef,
805  rhsIndexFn, rhsShape);
806  Value res = buildMmaSyncMemRefLoadOperand(b, loc, laneId, resMemRef,
807  resIndexFn, resShape);
808  res = nvgpu::MmaSyncOp::create(b, loc, lhs, rhs, res, info.mmaShape,
809  info.tf32Enabled);
810  buildMmaSyncMemRefStoreOperand(b, loc, res, laneId, resMemRef, resIndexFn,
811  resShape);
812  return res.getDefiningOp();
813 }
814 
815 DiagnosedSilenceableFailure transform::RewriteMatmulAsMmaSyncOp::applyToOne(
816  transform::TransformRewriter &rewriter, LinalgOp linalgOp,
818  transform::TransformState &state) {
819  bool fail = true;
820  // TODO: more robust detection of matmulOp, with transposes etc.
821  if (isa_and_nonnull<linalg::MatmulOp>(linalgOp.getOperation())) {
822  // Check to not let go the matmul with extended semantic, through this
823  // transform.
824  if (linalgOp.hasUserDefinedMaps()) {
825  return emitSilenceableError()
826  << "only matmul ops with non-extended semantics are supported";
827  }
828  Location loc = linalgOp.getLoc();
829  // TODO: more robust computation of laneId, for now assume a single warp.
830  Value laneId = gpu::ThreadIdOp::create(
831  rewriter, loc, rewriter.getIndexType(), gpu::Dimension::x);
832  if (succeeded(MmaSyncBuilder(rewriter, loc, laneId).buildMmaSync(linalgOp)))
833  fail = false;
834  }
835 
836  if (fail) {
837  DiagnosedSilenceableFailure diag = emitSilenceableError()
838  << "unsupported target op: " << linalgOp;
839  diag.attachNote(linalgOp->getLoc()) << "target op";
840  return diag;
841  }
842 
843  rewriter.eraseOp(linalgOp);
845 }
846 
847 //===----------------------------------------------------------------------===//
848 // Hopper builders.
849 //===----------------------------------------------------------------------===//
850 
851 /// Helper to create the base Hopper-specific operations that are reused in
852 /// various other places.
855  : rewriter(rewriter), loc(loc) {}
856 
858  buildAndInitBarrierInSharedMemory(OpFoldResult numThreads);
859 
860  /// Create tma descriptor op to initiate transfer from global to shared
861  /// memory. This must be done before the launch op, on the host.
863  buildGlobalMemRefDescriptor(TypedValue<MemRefType> memref,
864  gpu::LaunchOp launchOp);
865 
866  /// Build a tma load from global memory to shared memory using `barrier` to
867  /// synchronize. Return the number of bytes that will be transferred.
869  buildTmaAsyncLoad(TypedValue<nvgpu::TensorMapDescriptorType> globalDesc,
870  TypedValue<MemRefType> sharedMemref,
873  void buildBarrierArriveTx(TypedValue<nvgpu::MBarrierGroupType> barrier,
874  ArrayRef<OpFoldResult> sizes);
875 
876  /// If threadIdx.x == 0 does TMA request + wait, else just wait.
877  /// Return the operation that performs the transfer on thread0.
878  // TODO: In the future, don't hardcode to thread 0 but elect a leader.
879  SmallVector<Operation *> buildPredicateLoadsOnThread0(
881  ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
883 
884  void buildTryWaitParity(TypedValue<nvgpu::MBarrierGroupType> barrier);
885 
888 };
889 
891  ArrayRef<TypedValue<nvgpu::TensorMapDescriptorType>> globalDescriptors,
892  ArrayRef<TypedValue<MemRefType>> sharedMemBuffers,
894  SmallVector<Operation *> loadOps;
895  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
896  Value tidx = gpu::ThreadIdOp::create(rewriter, loc, gpu::Dimension::x);
897  Value cond = arith::CmpIOp::create(rewriter, loc, arith::CmpIPredicate::eq,
898  tidx, zero);
899  // clang-format off
900  scf::IfOp::create(rewriter,
901  /*location=*/loc,
902  /*conditional=*/cond,
903  /*thenBuilder=*/
904  [&](OpBuilder &lb, Location loc) {
906  sizes.reserve(globalDescriptors.size());
907  for (auto [desc, shmem] : llvm::zip_equal(
908  globalDescriptors, sharedMemBuffers)) {
909  OpFoldResult sz = buildTmaAsyncLoad(desc, shmem, barrier, loadOps);
910  sizes.push_back(sz);
911  }
912  // TODO: Note that cutlass predeclares the barrier arrive tx before the tma.async.load.
913  // This may or may not have perf implications.
914  buildBarrierArriveTx(barrier, sizes);
915  scf::YieldOp::create(rewriter, loc);
916  },
917  /*elseBuilder=*/
918  [&](OpBuilder &lb, Location loc) {
919  // TODO: is this for no-thread divergence?
920  // Should we just yield the size and hoist?
921  buildBarrierArriveTx(barrier, getAsIndexOpFoldResult(rewriter.getContext(), 0));
922  scf::YieldOp::create(rewriter, loc);
923  });
924  // clang-format on
925  return loadOps;
926 }
927 
930  b.getContext(), gpu::GPUDialect::getWorkgroupAddressSpace());
931  // return b.getI64IntegerAttr(static_cast<int64_t>(kSharedMemorySpace));
932 }
933 
936  auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
937  Value barrier = nvgpu::MBarrierCreateOp::create(
938  rewriter, loc,
939  nvgpu::MBarrierGroupType::get(rewriter.getContext(), sharedMemorySpace));
940  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
941  nvgpu::MBarrierInitOp::create(
942  rewriter, loc, barrier,
943  getValueOrCreateConstantIndexOp(rewriter, loc, numThreads), zero,
944  Value());
945  gpu::BarrierOp::create(rewriter, loc);
946  return cast<TypedValue<nvgpu::MBarrierGroupType>>(barrier);
947 }
948 
951  gpu::LaunchOp launchOp) {
952  OpBuilder::InsertionGuard guard(rewriter);
953  rewriter.setInsertionPoint(launchOp);
954  Value unrankedMemRef = memref::CastOp::create(
955  rewriter, loc,
956  UnrankedMemRefType::get(memref.getType().getElementType(),
957  memref.getType().getMemorySpace()),
958  memref);
959  SmallVector<OpFoldResult> mixedSizes =
960  memref::getMixedSizes(rewriter, loc, memref);
961  SmallVector<Value> sizes =
962  getValueOrCreateConstantIndexOp(rewriter, loc, mixedSizes);
963 
964  auto sharedMemorySpace = getSharedAddressSpaceAttribute(rewriter);
965  Value desc = nvgpu::TmaCreateDescriptorOp::create(
966  rewriter, loc,
968  rewriter.getContext(),
969  MemRefType::Builder(memref.getType())
970  .setMemorySpace(sharedMemorySpace),
971  TensorMapSwizzleKind::SWIZZLE_NONE,
972  TensorMapL2PromoKind::L2PROMO_NONE, TensorMapOOBKind::OOB_ZERO,
973  TensorMapInterleaveKind::INTERLEAVE_NONE),
974  unrankedMemRef, sizes);
975  return cast<TypedValue<nvgpu::TensorMapDescriptorType>>(desc);
976 }
977 
980  TypedValue<MemRefType> sharedMemref,
982  SmallVectorImpl<Operation *> &loadOps) {
983  MLIRContext *ctx = rewriter.getContext();
984  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
985  Operation *loadOp = nvgpu::TmaAsyncLoadOp::create(
986  rewriter, loc, sharedMemref, barrier, globalDesc, ValueRange{zero, zero},
987  zero, Value(), Value());
988  loadOps.push_back(loadOp);
989  auto mixedSizes = memref::getMixedSizes(rewriter, loc, sharedMemref);
990  SmallVector<AffineExpr> symbols(mixedSizes.size());
991  bindSymbolsList(ctx, llvm::MutableArrayRef{symbols});
992  AffineExpr prodExprInBytes =
993  computeProduct(ctx, symbols) *
994  (sharedMemref.getType().getElementTypeBitWidth() / 8);
995  auto res = affine::makeComposedFoldedAffineApply(rewriter, loc,
996  prodExprInBytes, mixedSizes);
997  return res;
998 }
999 
1002  ArrayRef<OpFoldResult> mixedSizes) {
1003  assert(!mixedSizes.empty() && "expecte non-empty sizes");
1004  MLIRContext *ctx = rewriter.getContext();
1005  SmallVector<AffineExpr> symbols(mixedSizes.size());
1006  bindSymbolsList(ctx, llvm::MutableArrayRef{symbols});
1007  AffineExpr sumExpr = computeSum(ctx, symbols);
1008  OpFoldResult size =
1009  affine::makeComposedFoldedAffineApply(rewriter, loc, sumExpr, mixedSizes);
1010  Value sizeVal = getValueOrCreateConstantIndexOp(rewriter, loc, size);
1011  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1012  nvgpu::MBarrierArriveExpectTxOp::create(rewriter, loc, barrier, sizeVal, zero,
1013  Value());
1014 }
1015 
1018  Type i1 = rewriter.getI1Type();
1019  Value parity = LLVM::ConstantOp::create(rewriter, loc, i1, 0);
1020  // 10M is an arbitrary, not too small or too big number to specify the number
1021  // of ticks before retry.
1022  // TODO: hoist this in a default dialect constant.
1023  Value ticksBeforeRetry =
1024  arith::ConstantIndexOp::create(rewriter, loc, 10000000);
1025  Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
1026  nvgpu::MBarrierTryWaitParityOp::create(rewriter, loc, barrier, parity,
1027  ticksBeforeRetry, zero);
1028 }
1029 
1030 //===----------------------------------------------------------------------===//
1031 // RewriteCopyAsTmaOp
1032 //===----------------------------------------------------------------------===//
1033 
1034 /// Helper to create the tma operations corresponding to `linalg::CopyOp`.
1035 struct CopyBuilder : public HopperBuilder {
1037  : HopperBuilder(rewriter, loc) {}
1038 
1040 };
1041 
1042 SmallVector<Operation *> CopyBuilder::rewrite(ArrayRef<Operation *> copyOps) {
1043  MLIRContext *ctx = rewriter.getContext();
1044  if (copyOps.empty())
1045  return SmallVector<Operation *>();
1046 
1047  auto launchOp = copyOps.front()->getParentOfType<gpu::LaunchOp>();
1048  assert(launchOp && "expected launch op");
1049 
1050  // 1. Init a barrier object in shared memory.
1051  OpBuilder::InsertionGuard g(rewriter);
1052  rewriter.setInsertionPoint(copyOps.front());
1053  AffineExpr bx, by, bz;
1054  bindSymbols(ctx, bx, by, bz);
1055  AffineExpr prod = computeProduct(ctx, ArrayRef<AffineExpr>{bx, by, bz});
1057  rewriter, loc, prod,
1058  ArrayRef<OpFoldResult>{launchOp.getBlockSizeX(), launchOp.getBlockSizeY(),
1059  launchOp.getBlockSizeZ()});
1060 
1062  buildAndInitBarrierInSharedMemory(numThreads);
1063 
1066  for (Operation *op : copyOps) {
1067  auto copyOp = cast<linalg::CopyOp>(op);
1068  auto inMemRef =
1069  cast<TypedValue<MemRefType>>(copyOp.getDpsInputOperand(0)->get());
1070  assert(inMemRef.getType().getRank() == 2 &&
1071  "expected in to be a 2D memref");
1072 
1073  // 2. Build global memory descriptor.
1075  buildGlobalMemRefDescriptor(inMemRef, launchOp);
1076  globalDescs.push_back(globalDesc);
1077 
1078  // 3. Shared memory and descriptor for the tmp array.
1079  auto shmem =
1080  cast<TypedValue<MemRefType>>(copyOp.getDpsInitOperand(0)->get());
1081  shmems.push_back(shmem);
1082  }
1083 
1084  // 4. Load in from global memory to shared memory using tma.
1085  OpBuilder::InsertionGuard g2(rewriter);
1086  rewriter.setInsertionPoint(copyOps.front());
1087  SmallVector<Operation *> results =
1088  buildPredicateLoadsOnThread0(globalDescs, shmems, barrier);
1089 
1090  // 5. Spin-loop until data is ready.
1091  buildTryWaitParity(barrier);
1092 
1093  // 6. Erase the ops that have now been rewritten.
1094  for (Operation *op : copyOps)
1095  rewriter.eraseOp(op);
1096 
1097  return results;
1098 }
1099 
1101 transform::RewriteCopyAsTmaOp::apply(transform::TransformRewriter &rewriter,
1102  transform::TransformResults &results,
1103  transform::TransformState &state) {
1104  auto payloadOps = state.getPayloadOps(getTarget());
1105  gpu::LaunchOp commonLaunchOp;
1106  Operation *firstOp, *failingOp;
1107  if (llvm::any_of(payloadOps, [&](Operation *op) {
1108  if (!commonLaunchOp) {
1109  commonLaunchOp = op->getParentOfType<gpu::LaunchOp>();
1110  firstOp = op;
1111  }
1112  auto fail = !op->getParentOfType<gpu::LaunchOp>() ||
1113  commonLaunchOp != op->getParentOfType<gpu::LaunchOp>() ||
1114  !isa<linalg::CopyOp>(op);
1115  if (fail)
1116  failingOp = op;
1117  return fail;
1118  })) {
1120  emitSilenceableError()
1121  << "target ops must be linalg::CopyOp nested under a common "
1122  "gpu.LaunchOp to be rewritten because the tma descriptors need to "
1123  "be created on the host.\nBut got: "
1124  << *firstOp << "\nand " << *failingOp;
1125  return diag;
1126  }
1127 
1128  // TODO: more robust detection of copy, with transposes etc.
1129  CopyBuilder(rewriter, getLoc()).rewrite(llvm::to_vector(payloadOps));
1130 
1132 }
1133 
1134 //===----------------------------------------------------------------------===//
1135 // Transform op registration
1136 //===----------------------------------------------------------------------===//
1137 
1138 namespace {
1139 class NVGPUTransformDialectExtension
1141  NVGPUTransformDialectExtension> {
1142 public:
1143  MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(NVGPUTransformDialectExtension)
1144 
1145  NVGPUTransformDialectExtension() {
1146  declareGeneratedDialect<arith::ArithDialect>();
1147  declareGeneratedDialect<affine::AffineDialect>();
1148  declareGeneratedDialect<nvgpu::NVGPUDialect>();
1149  declareGeneratedDialect<NVVM::NVVMDialect>();
1150  declareGeneratedDialect<vector::VectorDialect>();
1151  registerTransformOps<
1152 #define GET_OP_LIST
1153 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1154  >();
1155  }
1156 };
1157 } // namespace
1158 
1159 #define GET_OP_CLASSES
1160 #include "mlir/Dialect/NVGPU/TransformOps/NVGPUTransformOps.cpp.inc"
1161 
1163  registry.addExtensions<NVGPUTransformDialectExtension>();
1164 }
static std::string diag(const llvm::Value &value)
constexpr int kWgmmaSizeM
M size of wgmma.mma_async instruction.
Definition: NVGPUDialect.h:40
static Attribute getSharedAddressSpaceAttribute(OpBuilder &b)
static bool hasDefaultMemorySpace(BaseMemRefType type)
Returns true if the given type has the default memory space.
static LogicalResult collectStage0PipeliningOps(scf::ForOp forOp, llvm::SmallPtrSet< Operation *, 16 > &ops)
Populate ops with the set of operations that belong to the stage 0 of the pipelined version of the gi...
static Operation * replaceOpWithPredicatedOp(RewriterBase &rewriter, Operation *op, Value predicate)
Hook for the loop pipeliner.
static bool isStoreToShared(Operation *op, Value v)
Returns true if the operation is storing the given value into shared memory.
static void foreachIndividualVectorElement(Value vector, ApplyFn applyFn, ReduceFn reduceFn)
Helper functions to create customizable load and stores operations.
static std::tuple< DiagnosedSilenceableFailure, scf::ForOp > pipelineForSharedCopies(RewriterBase &rewriter, scf::ForOp forOp, int64_t depth, bool epiloguePeeling)
Applies loop pipelining with the given depth to the given loop so that copies into the shared memory ...
static bool hasSharedMemorySpace(BaseMemRefType type)
Returns true if the given type has the shared (workgroup) memory space.
static bool isLoadFromGlobalStoredToShared(Operation *op)
Returns true if the operation is a load from the default memory space the result of which is only sto...
static std::tuple< SmallVector< int64_t >, SmallVector< int64_t >, SmallVector< int64_t > > makeVectorShapes(ArrayRef< int64_t > lhs, ArrayRef< int64_t > rhs, ArrayRef< int64_t > res)
static void setAsyncWaitGroupsInFlight(OpBuilder &builder, Operation *op, scf::PipeliningOption::PipelinerPart part, unsigned iteration, unsigned depth)
Hook for the loop pipeliner that sets the "num groups in flight" attribute of async wait operations c...
static void getPipelineStages(scf::ForOp forOp, std::vector< std::pair< Operation *, unsigned >> &opsWithPipelineStages, unsigned depth, llvm::SmallPtrSetImpl< Operation * > &stage0Ops)
Hook for the loop pipeliner that populates ops with the stage information as follows:
static Value getValueLoadedFromGlobal(Operation *op)
Returns the value produced by a load from the default memory space.
static llvm::ManagedStatic< PassManagerOptions > options
static std::optional< VectorShape > vectorShape(Type type)
static void rewrite(DataFlowSolver &solver, MLIRContext *context, MutableArrayRef< Region > initialRegions)
Rewrite the given regions using the computing analysis.
Definition: SCCP.cpp:67
#define MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CLASS_NAME)
Definition: TypeID.h:331
Base type for affine expression.
Definition: AffineExpr.h:68
AffineExpr floorDiv(uint64_t v) const
Definition: AffineExpr.cpp:959
Attributes are known-constant values of operations.
Definition: Attributes.h:25
This class provides a shared interface for ranked and unranked memref types.
Definition: BuiltinTypes.h:104
Attribute getMemorySpace() const
Returns the memory space in which data referred to by this memref resides.
unsigned getMemorySpaceAsInt() const
[deprecated] Returns the memory space in old raw integer representation.
FloatType getF32Type()
Definition: Builders.cpp:42
FloatType getF16Type()
Definition: Builders.cpp:38
MLIRContext * getContext() const
Definition: Builders.h:56
IntegerType getI1Type()
Definition: Builders.cpp:52
IndexType getIndexType()
Definition: Builders.cpp:50
The result of a transform IR operation application.
static DiagnosedSilenceableFailure success()
Constructs a DiagnosedSilenceableFailure in the success state.
static DiagnosedSilenceableFailure definiteFailure()
Constructs a DiagnosedSilenceableFailure in the failure state.
The DialectRegistry maps a dialect namespace to a constructor for the matching dialect.
void addExtensions()
Add the given extensions to the registry.
Conversion from types to the LLVM IR dialect.
Definition: TypeConverter.h:35
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition: Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition: MLIRContext.h:63
This is a builder type that keeps local references to arguments.
Definition: BuiltinTypes.h:182
Builder & setMemorySpace(Attribute newMemorySpace)
Definition: BuiltinTypes.h:208
RAII guard to reset the insertion point of the builder when destroyed.
Definition: Builders.h:348
This class helps build Operations.
Definition: Builders.h:207
void setInsertionPoint(Block *block, Block::iterator insertPoint)
Set the insertion point to the specified location.
Definition: Builders.h:398
This class represents a single result from folding an operation.
Definition: OpDefinition.h:272
Operation is the basic unit of execution within MLIR.
Definition: Operation.h:88
Block * getBlock()
Returns the operation block that contains this operation.
Definition: Operation.h:213
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition: Operation.h:238
This class coordinates the application of a rewrite on a set of IR, providing a way for clients to tr...
Definition: PatternMatch.h:368
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
Type conversion class.
This class provides an abstraction over the various different ranges of value types.
Definition: TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition: Types.h:74
bool isF32() const
Definition: Types.cpp:40
bool isInteger() const
Return true if this is an integer type (with the specified width).
Definition: Types.cpp:56
bool isF16() const
Definition: Types.cpp:38
This class provides an abstraction over the different types of ranges over Values.
Definition: ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition: Value.h:96
Type getType() const
Return the type of this value.
Definition: Value.h:105
user_range getUsers() const
Definition: Value.h:218
bool hasOneUse() const
Returns true if this value has exactly one use.
Definition: Value.h:197
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition: Value.cpp:18
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition: ArithOps.cpp:359
A list of results of applying a transform op with ApplyEachOpTrait to a single payload operation,...
void push_back(Operation *op)
Appends an element to the list.
Base class for extensions of the Transform dialect that supports injecting operations into the Transf...
Local mapping between values defined by a specific op implementing the TransformOpInterface and the p...
This is a special rewriter to be used in transform op implementations, providing additional helper fu...
The state maintained across applications of various ops implementing the TransformOpInterface.
OpFoldResult makeComposedFoldedAffineApply(OpBuilder &b, Location loc, AffineMap map, ArrayRef< OpFoldResult > operands, bool composeAffineMin=false)
Constructs an AffineApplyOp that applies map to operands after composing the map with the maps of any...
Definition: AffineOps.cpp:1329
SmallVector< OpFoldResult > getMixedSizes(OpBuilder &builder, Location loc, Value value)
Return the dimensions of the given memref value.
Definition: MemRefOps.cpp:77
MemRefType getMBarrierMemrefType(MLIRContext *context, MBarrierGroupType barrierType)
Return the memref type that can be used to represent an mbarrier object.
void registerTransformDialectExtension(DialectRegistry &registry)
void createAsyncGroups(RewriterBase &rewriter, Operation *op, bool bypassL1)
Convert global->shared vector transfers to async device copies.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition: Remarks.h:491
FailureOr< ForOp > pipelineForLoop(RewriterBase &rewriter, ForOp forOp, const PipeliningOption &options, bool *modifiedIR=nullptr)
Generate a pipelined version of the scf.for loop based on the schedule given as option.
void producesHandle(ResultRange handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
void consumesHandle(MutableArrayRef< OpOperand > handles, SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the operation on the given handle value:
void modifiesPayload(SmallVectorImpl< MemoryEffects::EffectInstance > &effects)
Populates effects with the memory effects indicating the access to payload IR resource.
Include the generated interface declarations.
OpFoldResult getAsIndexOpFoldResult(MLIRContext *ctx, int64_t val)
Convert int64_t to integer attributes of index type and return them as OpFoldResult.
LogicalResult getBackwardSlice(Operation *op, SetVector< Operation * > *backwardSlice, const BackwardSliceOptions &options={})
Fills backwardSlice with the computed backward slice (i.e.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition: Value.h:488
SmallVector< int64_t > computeStrides(ArrayRef< int64_t > sizes)
Definition: IndexingUtils.h:47
DiagnosedSilenceableFailure emitSilenceableFailure(Location loc, const Twine &message={})
Emits a silenceable failure with the given message.
SmallVector< int64_t > delinearize(int64_t linearIndex, ArrayRef< int64_t > strides)
Given the strides together with a linear index in the dimension space, return the vector-space offset...
bool isMemoryEffectFree(Operation *op)
Returns true if the given operation is free of memory effects.
int64_t computeProduct(ArrayRef< int64_t > basis)
Self-explicit.
DiagnosedDefiniteFailure emitDefiniteFailure(Location loc, const Twine &message={})
Emits a definite failure with the given message.
void populateNVGPUToNVVMConversionPatterns(const LLVMTypeConverter &converter, RewritePatternSet &patterns)
Type getElementTypeOrSelf(Type type)
Return the element type or return the type itself.
const FrozenRewritePatternSet & patterns
void bindSymbols(MLIRContext *ctx, AffineExprTy &...exprs)
Bind a list of AffineExpr references to SymbolExpr at positions: [0 .
Definition: AffineExpr.h:325
void populateGpuMemorySpaceAttributeConversions(TypeConverter &typeConverter, const MemorySpaceMapping &mapping)
Populates memory space attribute conversion rules for lowering gpu.address_space to integer values.
Value getValueOrCreateConstantIndexOp(OpBuilder &b, Location loc, OpFoldResult ofr)
Converts an OpFoldResult to a Value.
Definition: Utils.cpp:111
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
AffineExpr getAffineDimExpr(unsigned position, MLIRContext *context)
These free functions allow clients of the API to not use classes in detail.
Definition: AffineExpr.cpp:619
int64_t computeSum(ArrayRef< int64_t > basis)
Self-explicit.
void bindSymbolsList(MLIRContext *ctx, MutableArrayRef< AffineExprTy > exprs)
Definition: AffineExpr.h:330
Helper to create the tma operations corresponding to linalg::CopyOp.
SmallVector< Operation * > rewrite(ArrayRef< Operation * > copyOps)
CopyBuilder(RewriterBase &rewriter, Location loc)
Helper to create the base Hopper-specific operations that are reused in various other places.
OpFoldResult buildTmaAsyncLoad(TypedValue< nvgpu::TensorMapDescriptorType > globalDesc, TypedValue< MemRefType > sharedMemref, TypedValue< nvgpu::MBarrierGroupType > barrier, SmallVectorImpl< Operation * > &loadOps)
Build a tma load from global memory to shared memory using barrier to synchronize.
TypedValue< nvgpu::MBarrierGroupType > buildAndInitBarrierInSharedMemory(OpFoldResult numThreads)
void buildTryWaitParity(TypedValue< nvgpu::MBarrierGroupType > barrier)
RewriterBase & rewriter
TypedValue< nvgpu::TensorMapDescriptorType > buildGlobalMemRefDescriptor(TypedValue< MemRefType > memref, gpu::LaunchOp launchOp)
Create tma descriptor op to initiate transfer from global to shared memory.
SmallVector< Operation * > buildPredicateLoadsOnThread0(ArrayRef< TypedValue< nvgpu::TensorMapDescriptorType >> globalDescriptors, ArrayRef< TypedValue< MemRefType >> sharedMemBuffers, TypedValue< nvgpu::MBarrierGroupType > barrier)
If threadIdx.x == 0 does TMA request + wait, else just wait.
void buildBarrierArriveTx(TypedValue< nvgpu::MBarrierGroupType > barrier, ArrayRef< OpFoldResult > sizes)
HopperBuilder(RewriterBase &rewriter, Location loc)
Helper struct to provide a simple mapping from matmul operations to the corresponding mma....
std::function< SmallVector< RowColIndexing >(MLIRContext *)> IndexCalculator
MmaSyncBuilder(OpBuilder &b, Location loc, OpFoldResult laneId)
FailureOr< Operation * > buildMmaSync(LinalgOp linalgOp)
Create the mma.sync operation corresponding to linalgOp along with all the supporting load/store and ...
Helper struct to encode a pair of row/column indexings in the form of affine expressions.
AffineExpr col() const
RowColIndexing(AffineExpr row, AffineExpr col)
void print(llvm::raw_ostream &os) const
AffineExpr row() const
Options to dictate how loops should be pipelined.
Definition: Transforms.h:129